diff --git a/.gitignore b/.gitignore index e1a4187b0..8dd3dfb0f 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ cmake-build-* cmake/DownloadProject/ dist/ src/query/frontend/opencypher/generated/ +src/query/v2/frontend/opencypher/generated/ tags ve/ ve3/ @@ -50,15 +51,21 @@ src/distributed/pull_produce_rpc_messages.hpp src/distributed/storage_gc_rpc_messages.hpp src/distributed/token_sharing_rpc_messages.hpp src/distributed/updates_rpc_messages.hpp +src/query/v2/frontend/ast/ast.hpp src/query/frontend/ast/ast.hpp src/query/distributed/frontend/ast/ast_serialization.hpp +src/query/v2/distributed/frontend/ast/ast_serialization.hpp src/durability/distributed/state_delta.hpp src/durability/single_node/state_delta.hpp src/durability/single_node_ha/state_delta.hpp src/query/frontend/semantic/symbol.hpp +src/query/v2/frontend/semantic/symbol.hpp src/query/distributed/frontend/semantic/symbol_serialization.hpp +src/query/v2/distributed/frontend/semantic/symbol_serialization.hpp src/query/distributed/plan/ops.hpp +src/query/v2/distributed/plan/ops.hpp src/query/plan/operator.hpp +src/query/v2/plan/operator.hpp src/raft/log_entry.hpp src/raft/raft_rpc_messages.hpp src/raft/snapshot_metadata.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e6160d972..efc653b9a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -13,6 +13,7 @@ add_subdirectory(storage/v2) add_subdirectory(storage/v3) add_subdirectory(integrations) add_subdirectory(query) +add_subdirectory(query/v2) add_subdirectory(slk) add_subdirectory(rpc) add_subdirectory(auth) diff --git a/src/query/v2/CMakeLists.txt b/src/query/v2/CMakeLists.txt new file mode 100644 index 000000000..3efb0c91e --- /dev/null +++ b/src/query/v2/CMakeLists.txt @@ -0,0 +1,104 @@ +define_add_lcp(add_lcp_query lcp_query_v2_cpp_files generated_lcp_query_v2_files) + +add_lcp_query(frontend/ast/ast.lcp) +add_lcp_query(frontend/semantic/symbol.lcp) +add_lcp_query(plan/operator.lcp) + +add_custom_target(generate_lcp_query_v2 DEPENDS ${generated_lcp_query_v2_files}) + +set(mg_query_v2_sources + ${lcp_query_v2_cpp_files} + common.cpp + cypher_query_interpreter.cpp + dump.cpp + frontend/ast/cypher_main_visitor.cpp + frontend/ast/pretty_print.cpp + frontend/parsing.cpp + frontend/semantic/required_privileges.cpp + frontend/semantic/symbol_generator.cpp + frontend/stripped.cpp + interpret/awesome_memgraph_functions.cpp + interpret/eval.cpp + interpreter.cpp + metadata.cpp + plan/operator.cpp + plan/preprocess.cpp + plan/pretty_print.cpp + plan/profile.cpp + plan/read_write_type_checker.cpp + plan/rewrite/index_lookup.cpp + plan/rule_based_planner.cpp + plan/variable_start_planner.cpp + procedure/mg_procedure_impl.cpp + procedure/mg_procedure_helpers.cpp + procedure/module.cpp + procedure/py_module.cpp + serialization/property_value.cpp + stream/streams.cpp + stream/sources.cpp + stream/common.cpp + trigger.cpp + trigger_context.cpp + typed_value.cpp) + +find_package(Boost REQUIRED) + +add_library(mg-query-v2 STATIC ${mg_query_v2_sources}) +add_dependencies(mg-query-v2 generate_lcp_query_v2) +target_include_directories(mg-query-v2 PUBLIC ${CMAKE_SOURCE_DIR}/include) +target_link_libraries(mg-query-v2 dl cppitertools Boost::headers) +target_link_libraries(mg-query-v2 mg-integrations-pulsar mg-integrations-kafka mg-storage-v3 mg-license mg-utils mg-kvstore mg-memory) + +if(NOT "${MG_PYTHON_PATH}" STREQUAL "") + set(Python3_ROOT_DIR "${MG_PYTHON_PATH}") +endif() + +if("${MG_PYTHON_VERSION}" STREQUAL "") + find_package(Python3 3.5 REQUIRED COMPONENTS Development) +else() + find_package(Python3 "${MG_PYTHON_VERSION}" EXACT REQUIRED COMPONENTS Development) +endif() + +target_link_libraries(mg-query-v2 Python3::Python) + +# Generate Antlr openCypher parser +set(opencypher_frontend ${CMAKE_CURRENT_SOURCE_DIR}/frontend/opencypher) +set(opencypher_generated ${opencypher_frontend}/generated) +set(opencypher_lexer_grammar ${opencypher_frontend}/grammar/MemgraphCypherLexer.g4) +set(opencypher_parser_grammar ${opencypher_frontend}/grammar/MemgraphCypher.g4) + +set(antlr_opencypher_generated_src + ${opencypher_generated}/MemgraphCypherLexer.cpp + ${opencypher_generated}/MemgraphCypher.cpp + ${opencypher_generated}/MemgraphCypherBaseVisitor.cpp + ${opencypher_generated}/MemgraphCypherVisitor.cpp +) +set(antlr_opencypher_generated_include + ${opencypher_generated}/MemgraphCypherLexer.h + ${opencypher_generated}/MemgraphCypher.h + ${opencypher_generated}/MemgraphCypherBaseVisitor.h + ${opencypher_generated}/MemgraphCypherVisitor.h +) + +add_custom_command( + OUTPUT ${antlr_opencypher_generated_src} ${antlr_opencypher_generated_include} + COMMAND ${CMAKE_COMMAND} -E make_directory ${opencypher_generated} + COMMAND + java -jar ${CMAKE_SOURCE_DIR}/libs/antlr-4.9.2-complete.jar + -Dlanguage=Cpp -visitor -package antlropencypher + -o ${opencypher_generated} + ${opencypher_lexer_grammar} ${opencypher_parser_grammar} + WORKING_DIRECTORY "${CMAKE_BINARY_DIR}" + DEPENDS + ${opencypher_lexer_grammar} ${opencypher_parser_grammar} + ${opencypher_frontend}/grammar/CypherLexer.g4 + ${opencypher_frontend}/grammar/Cypher.g4) + +add_custom_target(generate_opencypher_parser_v2 + DEPENDS ${antlr_opencypher_generated_src} ${antlr_opencypher_generated_include}) + +add_library(antlr_opencypher_parser_lib_v2 STATIC ${antlr_opencypher_generated_src}) +add_dependencies(antlr_opencypher_parser_lib_v2 generate_opencypher_parser_v2) +target_link_libraries(antlr_opencypher_parser_lib_v2 antlr4) + +target_link_libraries(mg-query-v2 antlr_opencypher_parser_lib_v2) diff --git a/src/query/v2/auth_checker.hpp b/src/query/v2/auth_checker.hpp new file mode 100644 index 000000000..48d755a16 --- /dev/null +++ b/src/query/v2/auth_checker.hpp @@ -0,0 +1,29 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "query/v2/frontend/ast/ast.hpp" + +namespace memgraph::query::v2 { +class AuthChecker { + public: + virtual bool IsUserAuthorized(const std::optional<std::string> &username, + const std::vector<query::v2::AuthQuery::Privilege> &privileges) const = 0; +}; + +class AllowEverythingAuthChecker final : public query::v2::AuthChecker { + bool IsUserAuthorized(const std::optional<std::string> & /*username*/, + const std::vector<query::v2::AuthQuery::Privilege> & /*privileges*/) const override { + return true; + } +}; +} // namespace memgraph::query::v2 diff --git a/src/query/v2/common.cpp b/src/query/v2/common.cpp new file mode 100644 index 000000000..4ca63f6b0 --- /dev/null +++ b/src/query/v2/common.cpp @@ -0,0 +1,76 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/common.hpp" + +namespace memgraph::query::v2 { + +namespace impl { + +bool TypedValueCompare(const TypedValue &a, const TypedValue &b) { + // in ordering null comes after everything else + // at the same time Null is not less that null + // first deal with Null < Whatever case + if (a.IsNull()) return false; + // now deal with NotNull < Null case + if (b.IsNull()) return true; + + // comparisons are from this point legal only between values of + // the same type, or int+float combinations + if ((a.type() != b.type() && !(a.IsNumeric() && b.IsNumeric()))) + throw QueryRuntimeException("Can't compare value of type {} to value of type {}.", a.type(), b.type()); + + switch (a.type()) { + case TypedValue::Type::Bool: + return !a.ValueBool() && b.ValueBool(); + case TypedValue::Type::Int: + if (b.type() == TypedValue::Type::Double) + return a.ValueInt() < b.ValueDouble(); + else + return a.ValueInt() < b.ValueInt(); + case TypedValue::Type::Double: + if (b.type() == TypedValue::Type::Int) + return a.ValueDouble() < b.ValueInt(); + else + return a.ValueDouble() < b.ValueDouble(); + case TypedValue::Type::String: + // NOLINTNEXTLINE(modernize-use-nullptr) + return a.ValueString() < b.ValueString(); + case TypedValue::Type::Date: + // NOLINTNEXTLINE(modernize-use-nullptr) + return a.ValueDate() < b.ValueDate(); + case TypedValue::Type::LocalTime: + // NOLINTNEXTLINE(modernize-use-nullptr) + return a.ValueLocalTime() < b.ValueLocalTime(); + case TypedValue::Type::LocalDateTime: + // NOLINTNEXTLINE(modernize-use-nullptr) + return a.ValueLocalDateTime() < b.ValueLocalDateTime(); + case TypedValue::Type::Duration: + // NOLINTNEXTLINE(modernize-use-nullptr) + return a.ValueDuration() < b.ValueDuration(); + case TypedValue::Type::List: + case TypedValue::Type::Map: + case TypedValue::Type::Vertex: + case TypedValue::Type::Edge: + case TypedValue::Type::Path: + throw QueryRuntimeException("Comparison is not defined for values of type {}.", a.type()); + case TypedValue::Type::Null: + LOG_FATAL("Invalid type"); + } +} + +} // namespace impl + +int64_t QueryTimestamp() { + return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()) + .count(); +} +} // namespace memgraph::query::v2 diff --git a/src/query/v2/common.hpp b/src/query/v2/common.hpp new file mode 100644 index 000000000..e79ca996c --- /dev/null +++ b/src/query/v2/common.hpp @@ -0,0 +1,111 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +#pragma once + +#include <concepts> +#include <cstdint> +#include <string> +#include <string_view> + +#include "query/v2/db_accessor.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/semantic/symbol.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/view.hpp" +#include "utils/logging.hpp" + +namespace memgraph::query::v2 { + +namespace impl { +bool TypedValueCompare(const TypedValue &a, const TypedValue &b); +} // namespace impl + +/// Custom Comparator type for comparing vectors of TypedValues. +/// +/// Does lexicographical ordering of elements based on the above +/// defined TypedValueCompare, and also accepts a vector of Orderings +/// the define how respective elements compare. +class TypedValueVectorCompare final { + public: + TypedValueVectorCompare() {} + explicit TypedValueVectorCompare(const std::vector<Ordering> &ordering) : ordering_(ordering) {} + + template <class TAllocator> + bool operator()(const std::vector<TypedValue, TAllocator> &c1, const std::vector<TypedValue, TAllocator> &c2) const { + // ordering is invalid if there are more elements in the collections + // then there are in the ordering_ vector + MG_ASSERT(c1.size() <= ordering_.size() && c2.size() <= ordering_.size(), + "Collections contain more elements then there are orderings"); + + auto c1_it = c1.begin(); + auto c2_it = c2.begin(); + auto ordering_it = ordering_.begin(); + for (; c1_it != c1.end() && c2_it != c2.end(); c1_it++, c2_it++, ordering_it++) { + if (impl::TypedValueCompare(*c1_it, *c2_it)) return *ordering_it == Ordering::ASC; + if (impl::TypedValueCompare(*c2_it, *c1_it)) return *ordering_it == Ordering::DESC; + } + + // at least one collection is exhausted + // c1 is less then c2 iff c1 reached the end but c2 didn't + return (c1_it == c1.end()) && (c2_it != c2.end()); + } + + // TODO: Remove this, member is public + const auto &ordering() const { return ordering_; } + + std::vector<Ordering> ordering_; +}; + +/// Raise QueryRuntimeException if the value for symbol isn't of expected type. +inline void ExpectType(const Symbol &symbol, const TypedValue &value, TypedValue::Type expected) { + if (value.type() != expected) + throw QueryRuntimeException("Expected a {} for '{}', but got {}.", expected, symbol.name(), value.type()); +} + +template <typename T> +concept AccessorWithSetProperty = requires(T accessor, const storage::v3::PropertyId key, + const storage::v3::PropertyValue new_value) { + { accessor.SetProperty(key, new_value) } -> std::same_as<storage::v3::Result<storage::v3::PropertyValue>>; +}; + +/// Set a property `value` mapped with given `key` on a `record`. +/// +/// @throw QueryRuntimeException if value cannot be set as a property value +template <AccessorWithSetProperty T> +storage::v3::PropertyValue PropsSetChecked(T *record, const storage::v3::PropertyId &key, const TypedValue &value) { + try { + auto maybe_old_value = record->SetProperty(key, storage::v3::PropertyValue(value)); + if (maybe_old_value.HasError()) { + switch (maybe_old_value.GetError()) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set properties on a deleted object."); + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Can't set property because properties on edges are disabled."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a property."); + } + } + return std::move(*maybe_old_value); + } catch (const TypedValueException &) { + throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type()); + } +} + +int64_t QueryTimestamp(); +} // namespace memgraph::query::v2 diff --git a/src/query/v2/config.hpp b/src/query/v2/config.hpp new file mode 100644 index 000000000..13b0539cc --- /dev/null +++ b/src/query/v2/config.hpp @@ -0,0 +1,32 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once +#include <chrono> +#include <string> + +namespace memgraph::query::v2 { +struct InterpreterConfig { + struct Query { + bool allow_load_csv{true}; + } query; + + // The default execution timeout is 10 minutes. + double execution_timeout_sec{600.0}; + // The same as \ref memgraph::storage::v3::replication::ReplicationClientConfig + std::chrono::seconds replication_replica_check_frequency{1}; + + std::string default_kafka_bootstrap_servers; + std::string default_pulsar_service_url; + uint32_t stream_transaction_conflict_retries; + std::chrono::milliseconds stream_transaction_retry_interval; +}; +} // namespace memgraph::query::v2 diff --git a/src/query/v2/constants.hpp b/src/query/v2/constants.hpp new file mode 100644 index 000000000..c19d939b6 --- /dev/null +++ b/src/query/v2/constants.hpp @@ -0,0 +1,19 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once +#include <cstdint> +#include <string> + +namespace memgraph::query::v2 { +inline constexpr uint16_t kDefaultReplicationPort = 10000; +inline constexpr auto *kDefaultReplicationServerIp = "0.0.0.0"; +} // namespace memgraph::query::v2 diff --git a/src/query/v2/context.hpp b/src/query/v2/context.hpp new file mode 100644 index 000000000..982de53a5 --- /dev/null +++ b/src/query/v2/context.hpp @@ -0,0 +1,89 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <type_traits> + +#include "query/v2/common.hpp" +#include "query/v2/frontend/semantic/symbol_table.hpp" +#include "query/v2/metadata.hpp" +#include "query/v2/parameters.hpp" +#include "query/v2/plan/profile.hpp" +#include "query/v2/trigger.hpp" +#include "utils/async_timer.hpp" + +namespace memgraph::query::v2 { + +struct EvaluationContext { + /// Memory for allocations during evaluation of a *single* Pull call. + /// + /// Although the assigned memory may live longer than the duration of a Pull + /// (e.g. memory is the same as the whole execution memory), you have to treat + /// it as if the lifetime is only valid during the Pull. + utils::MemoryResource *memory{utils::NewDeleteResource()}; + int64_t timestamp{-1}; + Parameters parameters; + /// All properties indexable via PropertyIx + std::vector<storage::v3::PropertyId> properties; + /// All labels indexable via LabelIx + std::vector<storage::v3::LabelId> labels; + /// All counters generated by `counter` function, mutable because the function + /// modifies the values + mutable std::unordered_map<std::string, int64_t> counters; +}; + +inline std::vector<storage::v3::PropertyId> NamesToProperties(const std::vector<std::string> &property_names, + DbAccessor *dba) { + std::vector<storage::v3::PropertyId> properties; + properties.reserve(property_names.size()); + for (const auto &name : property_names) { + properties.push_back(dba->NameToProperty(name)); + } + return properties; +} + +inline std::vector<storage::v3::LabelId> NamesToLabels(const std::vector<std::string> &label_names, DbAccessor *dba) { + std::vector<storage::v3::LabelId> labels; + labels.reserve(label_names.size()); + for (const auto &name : label_names) { + labels.push_back(dba->NameToLabel(name)); + } + return labels; +} + +struct ExecutionContext { + DbAccessor *db_accessor{nullptr}; + SymbolTable symbol_table; + EvaluationContext evaluation_context; + std::atomic<bool> *is_shutting_down{nullptr}; + bool is_profile_query{false}; + std::chrono::duration<double> profile_execution_time; + plan::ProfilingStats stats; + plan::ProfilingStats *stats_root{nullptr}; + ExecutionStats execution_stats; + TriggerContextCollector *trigger_context_collector{nullptr}; + utils::AsyncTimer timer; +}; + +static_assert(std::is_move_assignable_v<ExecutionContext>, "ExecutionContext must be move assignable!"); +static_assert(std::is_move_constructible_v<ExecutionContext>, "ExecutionContext must be move constructible!"); + +inline bool MustAbort(const ExecutionContext &context) noexcept { + return (context.is_shutting_down != nullptr && context.is_shutting_down->load(std::memory_order_acquire)) || + context.timer.IsExpired(); +} + +inline plan::ProfilingStatsWithTotalTime GetStatsWithTotalTime(const ExecutionContext &context) { + return plan::ProfilingStatsWithTotalTime{context.stats, context.profile_execution_time}; +} + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/cypher_query_interpreter.cpp b/src/query/v2/cypher_query_interpreter.cpp new file mode 100644 index 000000000..42e2b3cf3 --- /dev/null +++ b/src/query/v2/cypher_query_interpreter.cpp @@ -0,0 +1,158 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/cypher_query_interpreter.hpp" + +// NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_HIDDEN_bool(query_cost_planner, true, "Use the cost-estimating query planner."); +// NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_VALIDATED_int32(query_plan_cache_ttl, 60, "Time to live for cached query plans, in seconds.", + FLAG_IN_RANGE(0, std::numeric_limits<int32_t>::max())); + +namespace memgraph::query::v2 { +CachedPlan::CachedPlan(std::unique_ptr<LogicalPlan> plan) : plan_(std::move(plan)) {} + +ParsedQuery ParseQuery(const std::string &query_string, const std::map<std::string, storage::v3::PropertyValue> ¶ms, + utils::SkipList<QueryCacheEntry> *cache, utils::SpinLock *antlr_lock, + const InterpreterConfig::Query &query_config) { + // Strip the query for caching purposes. The process of stripping a query + // "normalizes" it by replacing any literals with new parameters. This + // results in just the *structure* of the query being taken into account for + // caching. + frontend::StrippedQuery stripped_query{query_string}; + + // Copy over the parameters that were introduced during stripping. + Parameters parameters{stripped_query.literals()}; + + // Check that all user-specified parameters are provided. + for (const auto ¶m_pair : stripped_query.parameters()) { + auto it = params.find(param_pair.second); + + if (it == params.end()) { + throw query::v2::UnprovidedParameterError("Parameter ${} not provided.", param_pair.second); + } + + parameters.Add(param_pair.first, it->second); + } + + // Cache the query's AST if it isn't already. + auto hash = stripped_query.hash(); + auto accessor = cache->access(); + auto it = accessor.find(hash); + std::unique_ptr<frontend::opencypher::Parser> parser; + + // Return a copy of both the AST storage and the query. + CachedQuery result; + bool is_cacheable = true; + + auto get_information_from_cache = [&](const auto &cached_query) { + result.ast_storage.properties_ = cached_query.ast_storage.properties_; + result.ast_storage.labels_ = cached_query.ast_storage.labels_; + result.ast_storage.edge_types_ = cached_query.ast_storage.edge_types_; + + result.query = cached_query.query->Clone(&result.ast_storage); + result.required_privileges = cached_query.required_privileges; + }; + + if (it == accessor.end()) { + { + std::unique_lock<utils::SpinLock> guard(*antlr_lock); + + try { + parser = std::make_unique<frontend::opencypher::Parser>(stripped_query.query()); + } catch (const SyntaxException &e) { + // There is a syntax exception in the stripped query. Re-run the parser + // on the original query to get an appropriate error messsage. + parser = std::make_unique<frontend::opencypher::Parser>(query_string); + + // If an exception was not thrown here, the stripper messed something + // up. + LOG_FATAL("The stripped query can't be parsed, but the original can."); + } + } + + // Convert the ANTLR4 parse tree into an AST. + AstStorage ast_storage; + frontend::ParsingContext context{true}; + frontend::CypherMainVisitor visitor(context, &ast_storage); + + visitor.visit(parser->tree()); + + if (visitor.GetQueryInfo().has_load_csv && !query_config.allow_load_csv) { + throw utils::BasicException("Load CSV not allowed on this instance because it was disabled by a config."); + } + + if (visitor.GetQueryInfo().is_cacheable) { + CachedQuery cached_query{std::move(ast_storage), visitor.query(), + query::v2::GetRequiredPrivileges(visitor.query())}; + it = accessor.insert({hash, std::move(cached_query)}).first; + + get_information_from_cache(it->second); + } else { + result.ast_storage.properties_ = ast_storage.properties_; + result.ast_storage.labels_ = ast_storage.labels_; + result.ast_storage.edge_types_ = ast_storage.edge_types_; + + result.query = visitor.query()->Clone(&result.ast_storage); + result.required_privileges = query::v2::GetRequiredPrivileges(visitor.query()); + + is_cacheable = false; + } + } else { + get_information_from_cache(it->second); + } + + return ParsedQuery{query_string, + params, + std::move(parameters), + std::move(stripped_query), + std::move(result.ast_storage), + result.query, + std::move(result.required_privileges), + is_cacheable}; +} + +std::unique_ptr<LogicalPlan> MakeLogicalPlan(AstStorage ast_storage, CypherQuery *query, const Parameters ¶meters, + DbAccessor *db_accessor, + const std::vector<Identifier *> &predefined_identifiers) { + auto vertex_counts = plan::MakeVertexCountCache(db_accessor); + auto symbol_table = MakeSymbolTable(query, predefined_identifiers); + auto planning_context = plan::MakePlanningContext(&ast_storage, &symbol_table, query, &vertex_counts); + auto [root, cost] = plan::MakeLogicalPlan(&planning_context, parameters, FLAGS_query_cost_planner); + return std::make_unique<SingleNodeLogicalPlan>(std::move(root), cost, std::move(ast_storage), + std::move(symbol_table)); +} + +std::shared_ptr<CachedPlan> CypherQueryToPlan(uint64_t hash, AstStorage ast_storage, CypherQuery *query, + const Parameters ¶meters, utils::SkipList<PlanCacheEntry> *plan_cache, + DbAccessor *db_accessor, + const std::vector<Identifier *> &predefined_identifiers) { + std::optional<utils::SkipList<PlanCacheEntry>::Accessor> plan_cache_access; + if (plan_cache) { + plan_cache_access.emplace(plan_cache->access()); + auto it = plan_cache_access->find(hash); + if (it != plan_cache_access->end()) { + if (it->second->IsExpired()) { + plan_cache_access->remove(hash); + } else { + return it->second; + } + } + } + + auto plan = std::make_shared<CachedPlan>( + MakeLogicalPlan(std::move(ast_storage), query, parameters, db_accessor, predefined_identifiers)); + if (plan_cache_access) { + plan_cache_access->insert({hash, plan}); + } + return plan; +} +} // namespace memgraph::query::v2 diff --git a/src/query/v2/cypher_query_interpreter.hpp b/src/query/v2/cypher_query_interpreter.hpp new file mode 100644 index 000000000..423eafdde --- /dev/null +++ b/src/query/v2/cypher_query_interpreter.hpp @@ -0,0 +1,152 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "query/v2/config.hpp" +#include "query/v2/frontend/ast/cypher_main_visitor.hpp" +#include "query/v2/frontend/opencypher/parser.hpp" +#include "query/v2/frontend/semantic/required_privileges.hpp" +#include "query/v2/frontend/semantic/symbol_generator.hpp" +#include "query/v2/frontend/stripped.hpp" +#include "query/v2/plan/planner.hpp" +#include "utils/flag_validation.hpp" +#include "utils/timer.hpp" + +// NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables) +DECLARE_bool(query_cost_planner); +// NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables) +DECLARE_int32(query_plan_cache_ttl); + +namespace memgraph::query::v2 { + +// TODO: Maybe this should move to query/plan/planner. +/// Interface for accessing the root operator of a logical plan. +class LogicalPlan { + public: + explicit LogicalPlan() = default; + + virtual ~LogicalPlan() = default; + + LogicalPlan(const LogicalPlan &) = default; + LogicalPlan &operator=(const LogicalPlan &) = default; + LogicalPlan(LogicalPlan &&) = default; + LogicalPlan &operator=(LogicalPlan &&) = default; + + virtual const plan::LogicalOperator &GetRoot() const = 0; + virtual double GetCost() const = 0; + virtual const SymbolTable &GetSymbolTable() const = 0; + virtual const AstStorage &GetAstStorage() const = 0; +}; + +class CachedPlan { + public: + explicit CachedPlan(std::unique_ptr<LogicalPlan> plan); + + const auto &plan() const { return plan_->GetRoot(); } + double cost() const { return plan_->GetCost(); } + const auto &symbol_table() const { return plan_->GetSymbolTable(); } + const auto &ast_storage() const { return plan_->GetAstStorage(); } + + bool IsExpired() const { + // NOLINTNEXTLINE (modernize-use-nullptr) + return cache_timer_.Elapsed() > std::chrono::seconds(FLAGS_query_plan_cache_ttl); + }; + + private: + std::unique_ptr<LogicalPlan> plan_; + utils::Timer cache_timer_; +}; + +struct CachedQuery { + AstStorage ast_storage; + Query *query; + std::vector<AuthQuery::Privilege> required_privileges; +}; + +struct QueryCacheEntry { + bool operator==(const QueryCacheEntry &other) const { return first == other.first; } + bool operator<(const QueryCacheEntry &other) const { return first < other.first; } + bool operator==(const uint64_t &other) const { return first == other; } + bool operator<(const uint64_t &other) const { return first < other; } + + uint64_t first; + // TODO: Maybe store the query string here and use it as a key with the hash + // so that we eliminate the risk of hash collisions. + CachedQuery second; +}; + +struct PlanCacheEntry { + bool operator==(const PlanCacheEntry &other) const { return first == other.first; } + bool operator<(const PlanCacheEntry &other) const { return first < other.first; } + bool operator==(const uint64_t &other) const { return first == other; } + bool operator<(const uint64_t &other) const { return first < other; } + + uint64_t first; + // TODO: Maybe store the query string here and use it as a key with the hash + // so that we eliminate the risk of hash collisions. + std::shared_ptr<CachedPlan> second; +}; + +/** + * A container for data related to the parsing of a query. + */ +struct ParsedQuery { + std::string query_string; + std::map<std::string, storage::v3::PropertyValue> user_parameters; + Parameters parameters; + frontend::StrippedQuery stripped_query; + AstStorage ast_storage; + Query *query; + std::vector<AuthQuery::Privilege> required_privileges; + bool is_cacheable{true}; +}; + +ParsedQuery ParseQuery(const std::string &query_string, const std::map<std::string, storage::v3::PropertyValue> ¶ms, + utils::SkipList<QueryCacheEntry> *cache, utils::SpinLock *antlr_lock, + const InterpreterConfig::Query &query_config); + +class SingleNodeLogicalPlan final : public LogicalPlan { + public: + SingleNodeLogicalPlan(std::unique_ptr<plan::LogicalOperator> root, double cost, AstStorage storage, + const SymbolTable &symbol_table) + : root_(std::move(root)), cost_(cost), storage_(std::move(storage)), symbol_table_(symbol_table) {} + + const plan::LogicalOperator &GetRoot() const override { return *root_; } + double GetCost() const override { return cost_; } + const SymbolTable &GetSymbolTable() const override { return symbol_table_; } + const AstStorage &GetAstStorage() const override { return storage_; } + + private: + std::unique_ptr<plan::LogicalOperator> root_; + double cost_; + AstStorage storage_; + SymbolTable symbol_table_; +}; + +std::unique_ptr<LogicalPlan> MakeLogicalPlan(AstStorage ast_storage, CypherQuery *query, const Parameters ¶meters, + DbAccessor *db_accessor, + const std::vector<Identifier *> &predefined_identifiers); + +/** + * Return the parsed *Cypher* query's AST cached logical plan, or create and + * cache a fresh one if it doesn't yet exist. + * @param predefined_identifiers optional identifiers you want to inject into a query. + * If an identifier is not defined in a scope, we check the predefined identifiers. + * If an identifier is contained there, we inject it at that place and remove it, + * because a predefined identifier can be used only in one scope. + */ +std::shared_ptr<CachedPlan> CypherQueryToPlan(uint64_t hash, AstStorage ast_storage, CypherQuery *query, + const Parameters ¶meters, utils::SkipList<PlanCacheEntry> *plan_cache, + DbAccessor *db_accessor, + const std::vector<Identifier *> &predefined_identifiers = {}); + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/db_accessor.hpp b/src/query/v2/db_accessor.hpp new file mode 100644 index 000000000..90ea6d431 --- /dev/null +++ b/src/query/v2/db_accessor.hpp @@ -0,0 +1,384 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <optional> + +#include <cppitertools/filter.hpp> +#include <cppitertools/imap.hpp> + +#include "query/v2/exceptions.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/result.hpp" + +/////////////////////////////////////////////////////////// +// Our communication layer and query engine don't mix +// very well on Centos because OpenSSL version avaialable +// on Centos 7 include libkrb5 which has brilliant macros +// called TRUE and FALSE. For more detailed explanation go +// to memgraph.cpp. +// +// Because of the replication storage now uses some form of +// communication so we have some unwanted macros. +// This cannot be avoided by simple include orderings so we +// simply undefine those macros as we're sure that libkrb5 +// won't and can't be used anywhere in the query engine. +#include "storage/v3/storage.hpp" + +#undef FALSE +#undef TRUE +/////////////////////////////////////////////////////////// + +#include "storage/v3/view.hpp" +#include "utils/bound.hpp" +#include "utils/exceptions.hpp" + +namespace memgraph::query::v2 { + +class VertexAccessor; + +class EdgeAccessor final { + public: + storage::v3::EdgeAccessor impl_; + + public: + explicit EdgeAccessor(storage::v3::EdgeAccessor impl) : impl_(std::move(impl)) {} + + bool IsVisible(storage::v3::View view) const { return impl_.IsVisible(view); } + + storage::v3::EdgeTypeId EdgeType() const { return impl_.EdgeType(); } + + auto Properties(storage::v3::View view) const { return impl_.Properties(view); } + + storage::v3::Result<storage::v3::PropertyValue> GetProperty(storage::v3::View view, + storage::v3::PropertyId key) const { + return impl_.GetProperty(key, view); + } + + storage::v3::Result<storage::v3::PropertyValue> SetProperty(storage::v3::PropertyId key, + const storage::v3::PropertyValue &value) { + return impl_.SetProperty(key, value); + } + + storage::v3::Result<storage::v3::PropertyValue> RemoveProperty(storage::v3::PropertyId key) { + return SetProperty(key, storage::v3::PropertyValue()); + } + + storage::v3::Result<std::map<storage::v3::PropertyId, storage::v3::PropertyValue>> ClearProperties() { + return impl_.ClearProperties(); + } + + VertexAccessor To() const; + + VertexAccessor From() const; + + bool IsCycle() const; + + int64_t CypherId() const { return impl_.Gid().AsInt(); } + + storage::v3::Gid Gid() const noexcept { return impl_.Gid(); } + + bool operator==(const EdgeAccessor &e) const noexcept { return impl_ == e.impl_; } + + bool operator!=(const EdgeAccessor &e) const noexcept { return !(*this == e); } +}; + +class VertexAccessor final { + public: + storage::v3::VertexAccessor impl_; + + static EdgeAccessor MakeEdgeAccessor(const storage::v3::EdgeAccessor impl) { return EdgeAccessor(impl); } + + public: + explicit VertexAccessor(storage::v3::VertexAccessor impl) : impl_(impl) {} + + bool IsVisible(storage::v3::View view) const { return impl_.IsVisible(view); } + + auto Labels(storage::v3::View view) const { return impl_.Labels(view); } + + storage::v3::Result<bool> AddLabel(storage::v3::LabelId label) { return impl_.AddLabel(label); } + + storage::v3::Result<bool> RemoveLabel(storage::v3::LabelId label) { return impl_.RemoveLabel(label); } + + storage::v3::Result<bool> HasLabel(storage::v3::View view, storage::v3::LabelId label) const { + return impl_.HasLabel(label, view); + } + + auto Properties(storage::v3::View view) const { return impl_.Properties(view); } + + storage::v3::Result<storage::v3::PropertyValue> GetProperty(storage::v3::View view, + storage::v3::PropertyId key) const { + return impl_.GetProperty(key, view); + } + + storage::v3::Result<storage::v3::PropertyValue> SetProperty(storage::v3::PropertyId key, + const storage::v3::PropertyValue &value) { + return impl_.SetProperty(key, value); + } + + storage::v3::Result<storage::v3::PropertyValue> RemoveProperty(storage::v3::PropertyId key) { + return SetProperty(key, storage::v3::PropertyValue()); + } + + storage::v3::Result<std::map<storage::v3::PropertyId, storage::v3::PropertyValue>> ClearProperties() { + return impl_.ClearProperties(); + } + + auto InEdges(storage::v3::View view, const std::vector<storage::v3::EdgeTypeId> &edge_types) const + -> storage::v3::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.InEdges(view)))> { + auto maybe_edges = impl_.InEdges(view, edge_types); + if (maybe_edges.HasError()) return maybe_edges.GetError(); + return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges)); + } + + auto InEdges(storage::v3::View view) const { return InEdges(view, {}); } + + auto InEdges(storage::v3::View view, const std::vector<storage::v3::EdgeTypeId> &edge_types, + const VertexAccessor &dest) const + -> storage::v3::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.InEdges(view)))> { + auto maybe_edges = impl_.InEdges(view, edge_types, &dest.impl_); + if (maybe_edges.HasError()) return maybe_edges.GetError(); + return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges)); + } + + auto OutEdges(storage::v3::View view, const std::vector<storage::v3::EdgeTypeId> &edge_types) const + -> storage::v3::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.OutEdges(view)))> { + auto maybe_edges = impl_.OutEdges(view, edge_types); + if (maybe_edges.HasError()) return maybe_edges.GetError(); + return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges)); + } + + auto OutEdges(storage::v3::View view) const { return OutEdges(view, {}); } + + auto OutEdges(storage::v3::View view, const std::vector<storage::v3::EdgeTypeId> &edge_types, + const VertexAccessor &dest) const + -> storage::v3::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.OutEdges(view)))> { + auto maybe_edges = impl_.OutEdges(view, edge_types, &dest.impl_); + if (maybe_edges.HasError()) return maybe_edges.GetError(); + return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges)); + } + + storage::v3::Result<size_t> InDegree(storage::v3::View view) const { return impl_.InDegree(view); } + + storage::v3::Result<size_t> OutDegree(storage::v3::View view) const { return impl_.OutDegree(view); } + + int64_t CypherId() const { return impl_.Gid().AsInt(); } + + storage::v3::Gid Gid() const noexcept { return impl_.Gid(); } + + bool operator==(const VertexAccessor &v) const noexcept { + static_assert(noexcept(impl_ == v.impl_)); + return impl_ == v.impl_; + } + + bool operator!=(const VertexAccessor &v) const noexcept { return !(*this == v); } +}; + +inline VertexAccessor EdgeAccessor::To() const { return VertexAccessor(impl_.ToVertex()); } + +inline VertexAccessor EdgeAccessor::From() const { return VertexAccessor(impl_.FromVertex()); } + +inline bool EdgeAccessor::IsCycle() const { return To() == From(); } + +class DbAccessor final { + storage::v3::Storage::Accessor *accessor_; + + class VerticesIterable final { + storage::v3::VerticesIterable iterable_; + + public: + class Iterator final { + storage::v3::VerticesIterable::Iterator it_; + + public: + explicit Iterator(storage::v3::VerticesIterable::Iterator it) : it_(it) {} + + VertexAccessor operator*() const { return VertexAccessor(*it_); } + + Iterator &operator++() { + ++it_; + return *this; + } + + bool operator==(const Iterator &other) const { return it_ == other.it_; } + + bool operator!=(const Iterator &other) const { return !(other == *this); } + }; + + explicit VerticesIterable(storage::v3::VerticesIterable iterable) : iterable_(std::move(iterable)) {} + + Iterator begin() { return Iterator(iterable_.begin()); } + + Iterator end() { return Iterator(iterable_.end()); } + }; + + public: + explicit DbAccessor(storage::v3::Storage::Accessor *accessor) : accessor_(accessor) {} + + std::optional<VertexAccessor> FindVertex(storage::v3::Gid gid, storage::v3::View view) { + auto maybe_vertex = accessor_->FindVertex(gid, view); + if (maybe_vertex) return VertexAccessor(*maybe_vertex); + return std::nullopt; + } + + void FinalizeTransaction() { accessor_->FinalizeTransaction(); } + + VerticesIterable Vertices(storage::v3::View view) { return VerticesIterable(accessor_->Vertices(view)); } + + VerticesIterable Vertices(storage::v3::View view, storage::v3::LabelId label) { + return VerticesIterable(accessor_->Vertices(label, view)); + } + + VerticesIterable Vertices(storage::v3::View view, storage::v3::LabelId label, storage::v3::PropertyId property) { + return VerticesIterable(accessor_->Vertices(label, property, view)); + } + + VerticesIterable Vertices(storage::v3::View view, storage::v3::LabelId label, storage::v3::PropertyId property, + const storage::v3::PropertyValue &value) { + return VerticesIterable(accessor_->Vertices(label, property, value, view)); + } + + VerticesIterable Vertices(storage::v3::View view, storage::v3::LabelId label, storage::v3::PropertyId property, + const std::optional<utils::Bound<storage::v3::PropertyValue>> &lower, + const std::optional<utils::Bound<storage::v3::PropertyValue>> &upper) { + return VerticesIterable(accessor_->Vertices(label, property, lower, upper, view)); + } + + VertexAccessor InsertVertex() { return VertexAccessor(accessor_->CreateVertex()); } + + storage::v3::Result<EdgeAccessor> InsertEdge(VertexAccessor *from, VertexAccessor *to, + const storage::v3::EdgeTypeId &edge_type) { + auto maybe_edge = accessor_->CreateEdge(&from->impl_, &to->impl_, edge_type); + if (maybe_edge.HasError()) return storage::v3::Result<EdgeAccessor>(maybe_edge.GetError()); + return EdgeAccessor(*maybe_edge); + } + + storage::v3::Result<std::optional<EdgeAccessor>> RemoveEdge(EdgeAccessor *edge) { + auto res = accessor_->DeleteEdge(&edge->impl_); + if (res.HasError()) { + return res.GetError(); + } + + const auto &value = res.GetValue(); + if (!value) { + return std::optional<EdgeAccessor>{}; + } + + return std::make_optional<EdgeAccessor>(*value); + } + + storage::v3::Result<std::optional<std::pair<VertexAccessor, std::vector<EdgeAccessor>>>> DetachRemoveVertex( + VertexAccessor *vertex_accessor) { + using ReturnType = std::pair<VertexAccessor, std::vector<EdgeAccessor>>; + + auto res = accessor_->DetachDeleteVertex(&vertex_accessor->impl_); + if (res.HasError()) { + return res.GetError(); + } + + const auto &value = res.GetValue(); + if (!value) { + return std::optional<ReturnType>{}; + } + + const auto &[vertex, edges] = *value; + + std::vector<EdgeAccessor> deleted_edges; + deleted_edges.reserve(edges.size()); + std::transform(edges.begin(), edges.end(), std::back_inserter(deleted_edges), + [](const auto &deleted_edge) { return EdgeAccessor{deleted_edge}; }); + + return std::make_optional<ReturnType>(vertex, std::move(deleted_edges)); + } + + storage::v3::Result<std::optional<VertexAccessor>> RemoveVertex(VertexAccessor *vertex_accessor) { + auto res = accessor_->DeleteVertex(&vertex_accessor->impl_); + if (res.HasError()) { + return res.GetError(); + } + + const auto &value = res.GetValue(); + if (!value) { + return std::optional<VertexAccessor>{}; + } + + return std::make_optional<VertexAccessor>(*value); + } + + storage::v3::PropertyId NameToProperty(const std::string_view name) { return accessor_->NameToProperty(name); } + + storage::v3::LabelId NameToLabel(const std::string_view name) { return accessor_->NameToLabel(name); } + + storage::v3::EdgeTypeId NameToEdgeType(const std::string_view name) { return accessor_->NameToEdgeType(name); } + + const std::string &PropertyToName(storage::v3::PropertyId prop) const { return accessor_->PropertyToName(prop); } + + const std::string &LabelToName(storage::v3::LabelId label) const { return accessor_->LabelToName(label); } + + const std::string &EdgeTypeToName(storage::v3::EdgeTypeId type) const { return accessor_->EdgeTypeToName(type); } + + void AdvanceCommand() { accessor_->AdvanceCommand(); } + + utils::BasicResult<storage::v3::ConstraintViolation, void> Commit() { return accessor_->Commit(); } + + void Abort() { accessor_->Abort(); } + + bool LabelIndexExists(storage::v3::LabelId label) const { return accessor_->LabelIndexExists(label); } + + bool LabelPropertyIndexExists(storage::v3::LabelId label, storage::v3::PropertyId prop) const { + return accessor_->LabelPropertyIndexExists(label, prop); + } + + int64_t VerticesCount() const { return accessor_->ApproximateVertexCount(); } + + int64_t VerticesCount(storage::v3::LabelId label) const { return accessor_->ApproximateVertexCount(label); } + + int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property) const { + return accessor_->ApproximateVertexCount(label, property); + } + + int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property, + const storage::v3::PropertyValue &value) const { + return accessor_->ApproximateVertexCount(label, property, value); + } + + int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property, + const std::optional<utils::Bound<storage::v3::PropertyValue>> &lower, + const std::optional<utils::Bound<storage::v3::PropertyValue>> &upper) const { + return accessor_->ApproximateVertexCount(label, property, lower, upper); + } + + storage::v3::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); } + + storage::v3::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); } +}; + +} // namespace memgraph::query::v2 + +namespace std { + +template <> +struct hash<memgraph::query::v2::VertexAccessor> { + size_t operator()(const memgraph::query::v2::VertexAccessor &v) const { + return std::hash<decltype(v.impl_)>{}(v.impl_); + } +}; + +template <> +struct hash<memgraph::query::v2::EdgeAccessor> { + size_t operator()(const memgraph::query::v2::EdgeAccessor &e) const { + return std::hash<decltype(e.impl_)>{}(e.impl_); + } +}; + +} // namespace std diff --git a/src/query/v2/discard_value_stream.hpp b/src/query/v2/discard_value_stream.hpp new file mode 100644 index 000000000..8703aa470 --- /dev/null +++ b/src/query/v2/discard_value_stream.hpp @@ -0,0 +1,24 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <vector> + +#include "query/v2/typed_value.hpp" + +namespace memgraph::query::v2 { +struct DiscardValueResultStream final { + void Result(const std::vector<query::v2::TypedValue> & /*values*/) { + // do nothing + } +}; +} // namespace memgraph::query::v2 diff --git a/src/query/v2/dump.cpp b/src/query/v2/dump.cpp new file mode 100644 index 000000000..e155c600d --- /dev/null +++ b/src/query/v2/dump.cpp @@ -0,0 +1,541 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/dump.hpp" + +#include <iomanip> +#include <limits> +#include <map> +#include <optional> +#include <ostream> +#include <utility> +#include <vector> + +#include <fmt/format.h> + +#include "query/v2/db_accessor.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/stream.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/storage.hpp" +#include "utils/algorithm.hpp" +#include "utils/logging.hpp" +#include "utils/string.hpp" +#include "utils/temporal.hpp" + +namespace memgraph::query::v2 { + +namespace { + +// Property that is used to make a difference among vertices. It is added to +// property set of vertices to match edges and removed after the entire graph +// is built. +const char *kInternalPropertyId = "__mg_id__"; + +// Label that is attached to each vertex and is used for easier creation of +// index on internal property id. +const char *kInternalVertexLabel = "__mg_vertex__"; + +/// A helper function that escapes label, edge type and property names. +std::string EscapeName(const std::string_view value) { + std::string out; + out.reserve(value.size() + 2); + out.append(1, '`'); + for (auto c : value) { + if (c == '`') { + out.append("``"); + } else { + out.append(1, c); + } + } + out.append(1, '`'); + return out; +} + +void DumpPreciseDouble(std::ostream *os, double value) { + // A temporary stream is used to keep precision of the original output + // stream unchanged. + std::ostringstream temp_oss; + temp_oss << std::setprecision(std::numeric_limits<double>::max_digits10) << value; + *os << temp_oss.str(); +} + +namespace { +void DumpDate(std::ostream &os, const storage::v3::TemporalData &value) { + utils::Date date(value.microseconds); + os << "DATE(\"" << date << "\")"; +} + +void DumpLocalTime(std::ostream &os, const storage::v3::TemporalData &value) { + utils::LocalTime lt(value.microseconds); + os << "LOCALTIME(\"" << lt << "\")"; +} + +void DumpLocalDateTime(std::ostream &os, const storage::v3::TemporalData &value) { + utils::LocalDateTime ldt(value.microseconds); + os << "LOCALDATETIME(\"" << ldt << "\")"; +} + +void DumpDuration(std::ostream &os, const storage::v3::TemporalData &value) { + utils::Duration dur(value.microseconds); + os << "DURATION(\"" << dur << "\")"; +} + +void DumpTemporalData(std::ostream &os, const storage::v3::TemporalData &value) { + switch (value.type) { + case storage::v3::TemporalType::Date: { + DumpDate(os, value); + return; + } + case storage::v3::TemporalType::LocalTime: { + DumpLocalTime(os, value); + return; + } + case storage::v3::TemporalType::LocalDateTime: { + DumpLocalDateTime(os, value); + return; + } + case storage::v3::TemporalType::Duration: { + DumpDuration(os, value); + return; + } + } +} +} // namespace + +void DumpPropertyValue(std::ostream *os, const storage::v3::PropertyValue &value) { + switch (value.type()) { + case storage::v3::PropertyValue::Type::Null: + *os << "Null"; + return; + case storage::v3::PropertyValue::Type::Bool: + *os << (value.ValueBool() ? "true" : "false"); + return; + case storage::v3::PropertyValue::Type::String: + *os << utils::Escape(value.ValueString()); + return; + case storage::v3::PropertyValue::Type::Int: + *os << value.ValueInt(); + return; + case storage::v3::PropertyValue::Type::Double: + DumpPreciseDouble(os, value.ValueDouble()); + return; + case storage::v3::PropertyValue::Type::List: { + *os << "["; + const auto &list = value.ValueList(); + utils::PrintIterable(*os, list, ", ", [](auto &os, const auto &item) { DumpPropertyValue(&os, item); }); + *os << "]"; + return; + } + case storage::v3::PropertyValue::Type::Map: { + *os << "{"; + const auto &map = value.ValueMap(); + utils::PrintIterable(*os, map, ", ", [](auto &os, const auto &kv) { + os << EscapeName(kv.first) << ": "; + DumpPropertyValue(&os, kv.second); + }); + *os << "}"; + return; + } + case storage::v3::PropertyValue::Type::TemporalData: { + DumpTemporalData(*os, value.ValueTemporalData()); + return; + } + } +} + +void DumpProperties(std::ostream *os, query::v2::DbAccessor *dba, + const std::map<storage::v3::PropertyId, storage::v3::PropertyValue> &store, + std::optional<int64_t> property_id = std::nullopt) { + *os << "{"; + if (property_id) { + *os << kInternalPropertyId << ": " << *property_id; + if (store.size() > 0) *os << ", "; + } + utils::PrintIterable(*os, store, ", ", [&dba](auto &os, const auto &kv) { + os << EscapeName(dba->PropertyToName(kv.first)) << ": "; + DumpPropertyValue(&os, kv.second); + }); + *os << "}"; +} + +void DumpVertex(std::ostream *os, query::v2::DbAccessor *dba, const query::v2::VertexAccessor &vertex) { + *os << "CREATE ("; + *os << ":" << kInternalVertexLabel; + auto maybe_labels = vertex.Labels(storage::v3::View::OLD); + if (maybe_labels.HasError()) { + switch (maybe_labels.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get labels from a deleted node."); + case storage::v3::Error::NONEXISTENT_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get labels from a node that doesn't exist."); + case storage::v3::Error::SERIALIZATION_ERROR: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + throw query::v2::QueryRuntimeException("Unexpected error when getting labels."); + } + } + for (const auto &label : *maybe_labels) { + *os << ":" << EscapeName(dba->LabelToName(label)); + } + *os << " "; + auto maybe_props = vertex.Properties(storage::v3::View::OLD); + if (maybe_props.HasError()) { + switch (maybe_props.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get properties from a deleted object."); + case storage::v3::Error::NONEXISTENT_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get properties from a node that doesn't exist."); + case storage::v3::Error::SERIALIZATION_ERROR: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + throw query::v2::QueryRuntimeException("Unexpected error when getting properties."); + } + } + DumpProperties(os, dba, *maybe_props, vertex.CypherId()); + *os << ");"; +} + +void DumpEdge(std::ostream *os, query::v2::DbAccessor *dba, const query::v2::EdgeAccessor &edge) { + *os << "MATCH "; + *os << "(u:" << kInternalVertexLabel << "), "; + *os << "(v:" << kInternalVertexLabel << ")"; + *os << " WHERE "; + *os << "u." << kInternalPropertyId << " = " << edge.From().CypherId(); + *os << " AND "; + *os << "v." << kInternalPropertyId << " = " << edge.To().CypherId() << " "; + *os << "CREATE (u)-["; + *os << ":" << EscapeName(dba->EdgeTypeToName(edge.EdgeType())); + auto maybe_props = edge.Properties(storage::v3::View::OLD); + if (maybe_props.HasError()) { + switch (maybe_props.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get properties from a deleted object."); + case storage::v3::Error::NONEXISTENT_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get properties from an edge that doesn't exist."); + case storage::v3::Error::SERIALIZATION_ERROR: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + throw query::v2::QueryRuntimeException("Unexpected error when getting properties."); + } + } + if (maybe_props->size() > 0) { + *os << " "; + DumpProperties(os, dba, *maybe_props); + } + *os << "]->(v);"; +} + +void DumpLabelIndex(std::ostream *os, query::v2::DbAccessor *dba, const storage::v3::LabelId label) { + *os << "CREATE INDEX ON :" << EscapeName(dba->LabelToName(label)) << ";"; +} + +void DumpLabelPropertyIndex(std::ostream *os, query::v2::DbAccessor *dba, storage::v3::LabelId label, + storage::v3::PropertyId property) { + *os << "CREATE INDEX ON :" << EscapeName(dba->LabelToName(label)) << "(" << EscapeName(dba->PropertyToName(property)) + << ");"; +} + +void DumpExistenceConstraint(std::ostream *os, query::v2::DbAccessor *dba, storage::v3::LabelId label, + storage::v3::PropertyId property) { + *os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label)) << ") ASSERT EXISTS (u." + << EscapeName(dba->PropertyToName(property)) << ");"; +} + +void DumpUniqueConstraint(std::ostream *os, query::v2::DbAccessor *dba, storage::v3::LabelId label, + const std::set<storage::v3::PropertyId> &properties) { + *os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label)) << ") ASSERT "; + utils::PrintIterable(*os, properties, ", ", [&dba](auto &stream, const auto &property) { + stream << "u." << EscapeName(dba->PropertyToName(property)); + }); + *os << " IS UNIQUE;"; +} + +} // namespace + +PullPlanDump::PullPlanDump(DbAccessor *dba) + : dba_(dba), + vertices_iterable_(dba->Vertices(storage::v3::View::OLD)), + pull_chunks_{// Dump all label indices + CreateLabelIndicesPullChunk(), + // Dump all label property indices + CreateLabelPropertyIndicesPullChunk(), + // Dump all existence constraints + CreateExistenceConstraintsPullChunk(), + // Dump all unique constraints + CreateUniqueConstraintsPullChunk(), + // Create internal index for faster edge creation + CreateInternalIndexPullChunk(), + // Dump all vertices + CreateVertexPullChunk(), + // Dump all edges + CreateEdgePullChunk(), + // Drop the internal index + CreateDropInternalIndexPullChunk(), + // Internal index cleanup + CreateInternalIndexCleanupPullChunk()} {} + +bool PullPlanDump::Pull(AnyStream *stream, std::optional<int> n) { + // Iterate all functions that stream some results. + // Each function should return number of results it streamed after it + // finishes. If the function did not finish streaming all the results, + // std::nullopt should be returned because n results have already been sent. + while (current_chunk_index_ < pull_chunks_.size() && (!n || *n > 0)) { + const auto maybe_streamed_count = pull_chunks_[current_chunk_index_](stream, n); + + if (!maybe_streamed_count) { + // n wasn't large enough to stream all the results from the current chunk + break; + } + + if (n) { + // chunk finished streaming its results + // subtract number of results streamed in current pull + // so we know how many results we need to stream from future + // chunks. + *n -= *maybe_streamed_count; + } + + ++current_chunk_index_; + } + return current_chunk_index_ == pull_chunks_.size(); +} + +PullPlanDump::PullChunk PullPlanDump::CreateLabelIndicesPullChunk() { + // Dump all label indices + return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> { + // Delay the construction of indices vectors + if (!indices_info_) { + indices_info_.emplace(dba_->ListAllIndices()); + } + const auto &label = indices_info_->label; + + size_t local_counter = 0; + while (global_index < label.size() && (!n || local_counter < *n)) { + std::ostringstream os; + DumpLabelIndex(&os, dba_, label[global_index]); + stream->Result({TypedValue(os.str())}); + + ++global_index; + ++local_counter; + } + + if (global_index == label.size()) { + return local_counter; + } + + return std::nullopt; + }; +} + +PullPlanDump::PullChunk PullPlanDump::CreateLabelPropertyIndicesPullChunk() { + return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> { + // Delay the construction of indices vectors + if (!indices_info_) { + indices_info_.emplace(dba_->ListAllIndices()); + } + const auto &label_property = indices_info_->label_property; + + size_t local_counter = 0; + while (global_index < label_property.size() && (!n || local_counter < *n)) { + std::ostringstream os; + const auto &label_property_index = label_property[global_index]; + DumpLabelPropertyIndex(&os, dba_, label_property_index.first, label_property_index.second); + stream->Result({TypedValue(os.str())}); + + ++global_index; + ++local_counter; + } + + if (global_index == label_property.size()) { + return local_counter; + } + + return std::nullopt; + }; +} + +PullPlanDump::PullChunk PullPlanDump::CreateExistenceConstraintsPullChunk() { + return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> { + // Delay the construction of constraint vectors + if (!constraints_info_) { + constraints_info_.emplace(dba_->ListAllConstraints()); + } + + const auto &existence = constraints_info_->existence; + size_t local_counter = 0; + while (global_index < existence.size() && (!n || local_counter < *n)) { + const auto &constraint = existence[global_index]; + std::ostringstream os; + DumpExistenceConstraint(&os, dba_, constraint.first, constraint.second); + stream->Result({TypedValue(os.str())}); + + ++global_index; + ++local_counter; + } + + if (global_index == existence.size()) { + return local_counter; + } + + return std::nullopt; + }; +} + +PullPlanDump::PullChunk PullPlanDump::CreateUniqueConstraintsPullChunk() { + return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> { + // Delay the construction of constraint vectors + if (!constraints_info_) { + constraints_info_.emplace(dba_->ListAllConstraints()); + } + + const auto &unique = constraints_info_->unique; + size_t local_counter = 0; + while (global_index < unique.size() && (!n || local_counter < *n)) { + const auto &constraint = unique[global_index]; + std::ostringstream os; + DumpUniqueConstraint(&os, dba_, constraint.first, constraint.second); + stream->Result({TypedValue(os.str())}); + + ++global_index; + ++local_counter; + } + + if (global_index == unique.size()) { + return local_counter; + } + + return std::nullopt; + }; +} + +PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexPullChunk() { + return [this](AnyStream *stream, std::optional<int>) mutable -> std::optional<size_t> { + if (vertices_iterable_.begin() != vertices_iterable_.end()) { + std::ostringstream os; + os << "CREATE INDEX ON :" << kInternalVertexLabel << "(" << kInternalPropertyId << ");"; + stream->Result({TypedValue(os.str())}); + internal_index_created_ = true; + return 1; + } + return 0; + }; +} + +PullPlanDump::PullChunk PullPlanDump::CreateVertexPullChunk() { + return [this, maybe_current_iter = std::optional<VertexAccessorIterableIterator>{}]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> { + // Delay the call of begin() function + // If multiple begins are called before an iteration, + // one iteration will make the rest of iterators be in undefined + // states. + if (!maybe_current_iter) { + maybe_current_iter.emplace(vertices_iterable_.begin()); + } + + auto ¤t_iter{*maybe_current_iter}; + + size_t local_counter = 0; + while (current_iter != vertices_iterable_.end() && (!n || local_counter < *n)) { + std::ostringstream os; + DumpVertex(&os, dba_, *current_iter); + stream->Result({TypedValue(os.str())}); + ++local_counter; + ++current_iter; + } + if (current_iter == vertices_iterable_.end()) { + return local_counter; + } + + return std::nullopt; + }; +} + +PullPlanDump::PullChunk PullPlanDump::CreateEdgePullChunk() { + return [this, maybe_current_vertex_iter = std::optional<VertexAccessorIterableIterator>{}, + // we need to save the iterable which contains list of accessor so + // our saved iterator is valid in the next run + maybe_edge_iterable = std::shared_ptr<EdgeAccessorIterable>{nullptr}, + maybe_current_edge_iter = std::optional<EdgeAccessorIterableIterator>{}]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> { + // Delay the call of begin() function + // If multiple begins are called before an iteration, + // one iteration will make the rest of iterators be in undefined + // states. + if (!maybe_current_vertex_iter) { + maybe_current_vertex_iter.emplace(vertices_iterable_.begin()); + } + + auto ¤t_vertex_iter{*maybe_current_vertex_iter}; + size_t local_counter = 0U; + for (; current_vertex_iter != vertices_iterable_.end() && (!n || local_counter < *n); ++current_vertex_iter) { + const auto &vertex = *current_vertex_iter; + // If we have a saved iterable from a previous pull + // we need to use the same iterable + if (!maybe_edge_iterable) { + maybe_edge_iterable = std::make_shared<EdgeAccessorIterable>(vertex.OutEdges(storage::v3::View::OLD)); + } + auto &maybe_edges = *maybe_edge_iterable; + MG_ASSERT(maybe_edges.HasValue(), "Invalid database state!"); + auto current_edge_iter = maybe_current_edge_iter ? *maybe_current_edge_iter : maybe_edges->begin(); + for (; current_edge_iter != maybe_edges->end() && (!n || local_counter < *n); ++current_edge_iter) { + std::ostringstream os; + DumpEdge(&os, dba_, *current_edge_iter); + stream->Result({TypedValue(os.str())}); + + ++local_counter; + } + + if (current_edge_iter != maybe_edges->end()) { + maybe_current_edge_iter.emplace(current_edge_iter); + return std::nullopt; + } + + maybe_current_edge_iter = std::nullopt; + maybe_edge_iterable = nullptr; + } + + if (current_vertex_iter == vertices_iterable_.end()) { + return local_counter; + } + + return std::nullopt; + }; +} + +PullPlanDump::PullChunk PullPlanDump::CreateDropInternalIndexPullChunk() { + return [this](AnyStream *stream, std::optional<int>) { + if (internal_index_created_) { + std::ostringstream os; + os << "DROP INDEX ON :" << kInternalVertexLabel << "(" << kInternalPropertyId << ");"; + stream->Result({TypedValue(os.str())}); + return 1; + } + return 0; + }; +} + +PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexCleanupPullChunk() { + return [this](AnyStream *stream, std::optional<int>) { + if (internal_index_created_) { + std::ostringstream os; + os << "MATCH (u) REMOVE u:" << kInternalVertexLabel << ", u." << kInternalPropertyId << ";"; + stream->Result({TypedValue(os.str())}); + return 1; + } + return 0; + }; +} + +void DumpDatabaseToCypherQueries(query::v2::DbAccessor *dba, AnyStream *stream) { PullPlanDump(dba).Pull(stream, {}); } + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/dump.hpp b/src/query/v2/dump.hpp new file mode 100644 index 000000000..de8018724 --- /dev/null +++ b/src/query/v2/dump.hpp @@ -0,0 +1,66 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <ostream> + +#include "query/v2/db_accessor.hpp" +#include "query/v2/stream.hpp" +#include "storage/v3/storage.hpp" + +namespace memgraph::query::v2 { + +void DumpDatabaseToCypherQueries(query::v2::DbAccessor *dba, AnyStream *stream); + +struct PullPlanDump { + explicit PullPlanDump(query::v2::DbAccessor *dba); + + /// Pull the dump results lazily + /// @return true if all results were returned, false otherwise + bool Pull(AnyStream *stream, std::optional<int> n); + + private: + query::v2::DbAccessor *dba_ = nullptr; + + std::optional<storage::v3::IndicesInfo> indices_info_ = std::nullopt; + std::optional<storage::v3::ConstraintsInfo> constraints_info_ = std::nullopt; + + using VertexAccessorIterable = decltype(std::declval<query::v2::DbAccessor>().Vertices(storage::v3::View::OLD)); + using VertexAccessorIterableIterator = decltype(std::declval<VertexAccessorIterable>().begin()); + + using EdgeAccessorIterable = decltype(std::declval<VertexAccessor>().OutEdges(storage::v3::View::OLD)); + using EdgeAccessorIterableIterator = decltype(std::declval<EdgeAccessorIterable>().GetValue().begin()); + + VertexAccessorIterable vertices_iterable_; + bool internal_index_created_ = false; + + size_t current_chunk_index_ = 0; + + using PullChunk = std::function<std::optional<size_t>(AnyStream *stream, std::optional<int> n)>; + // We define every part of the dump query in a self contained function. + // Each functions is responsible of keeping track of its execution status. + // If a function did finish its execution, it should return number of results + // it streamed so we know how many rows should be pulled from the next + // function, otherwise std::nullopt is returned. + std::vector<PullChunk> pull_chunks_; + + PullChunk CreateLabelIndicesPullChunk(); + PullChunk CreateLabelPropertyIndicesPullChunk(); + PullChunk CreateExistenceConstraintsPullChunk(); + PullChunk CreateUniqueConstraintsPullChunk(); + PullChunk CreateInternalIndexPullChunk(); + PullChunk CreateVertexPullChunk(); + PullChunk CreateEdgePullChunk(); + PullChunk CreateDropInternalIndexPullChunk(); + PullChunk CreateInternalIndexCleanupPullChunk(); +}; +} // namespace memgraph::query::v2 diff --git a/src/query/v2/exceptions.hpp b/src/query/v2/exceptions.hpp new file mode 100644 index 000000000..e0802a6cc --- /dev/null +++ b/src/query/v2/exceptions.hpp @@ -0,0 +1,227 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "utils/exceptions.hpp" + +#include <fmt/format.h> + +namespace memgraph::query::v2 { + +/** + * @brief Base class of all query language related exceptions. All exceptions + * derived from this one will be interpreted as ClientError-s, i. e. if client + * executes same query again without making modifications to the database data, + * query will fail again. + */ +class QueryException : public utils::BasicException { + using utils::BasicException::BasicException; +}; + +class LexingException : public QueryException { + public: + using QueryException::QueryException; + LexingException() : QueryException("") {} +}; + +class SyntaxException : public QueryException { + public: + using QueryException::QueryException; + SyntaxException() : QueryException("") {} +}; + +// TODO: Figure out what information to put in exception. +// Error reporting is tricky since we get stripped query and position of error +// in original query is not same as position of error in stripped query. Most +// correct approach would be to do semantic analysis with original query even +// for already hashed queries, but that has obvious performance issues. Other +// approach would be to report some of the semantic errors in runtime of the +// query and only report line numbers of semantic errors (not position in the +// line) if multiple line strings are not allowed by grammar. We could also +// print whole line that contains error instead of specifying line number. +class SemanticException : public QueryException { + public: + using QueryException::QueryException; + SemanticException() : QueryException("") {} +}; + +class UnboundVariableError : public SemanticException { + public: + explicit UnboundVariableError(const std::string &name) : SemanticException("Unbound variable: " + name + ".") {} +}; + +class RedeclareVariableError : public SemanticException { + public: + explicit RedeclareVariableError(const std::string &name) : SemanticException("Redeclaring variable: " + name + ".") {} +}; + +class TypeMismatchError : public SemanticException { + public: + TypeMismatchError(const std::string &name, const std::string &datum, const std::string &expected) + : SemanticException(fmt::format("Type mismatch: {} already defined as {}, expected {}.", name, datum, expected)) { + } +}; + +class UnprovidedParameterError : public QueryException { + public: + using QueryException::QueryException; +}; + +class ProfileInMulticommandTxException : public QueryException { + public: + using QueryException::QueryException; + ProfileInMulticommandTxException() : QueryException("PROFILE not allowed in multicommand transactions.") {} +}; + +class IndexInMulticommandTxException : public QueryException { + public: + using QueryException::QueryException; + IndexInMulticommandTxException() : QueryException("Index manipulation not allowed in multicommand transactions.") {} +}; + +class ConstraintInMulticommandTxException : public QueryException { + public: + using QueryException::QueryException; + ConstraintInMulticommandTxException() + : QueryException( + "Constraint manipulation not allowed in multicommand " + "transactions.") {} +}; + +class InfoInMulticommandTxException : public QueryException { + public: + using QueryException::QueryException; + InfoInMulticommandTxException() : QueryException("Info reporting not allowed in multicommand transactions.") {} +}; + +/** + * An exception for an illegal operation that can not be detected + * before the query starts executing over data. + */ +class QueryRuntimeException : public QueryException { + public: + using QueryException::QueryException; +}; + +// This one is inherited from BasicException and will be treated as +// TransientError, i. e. client will be encouraged to retry execution because it +// could succeed if executed again. +class HintedAbortError : public utils::BasicException { + public: + using utils::BasicException::BasicException; + HintedAbortError() + : utils::BasicException( + "Transaction was asked to abort, most likely because it was " + "executing longer than time specified by " + "--query-execution-timeout-sec flag.") {} +}; + +class ExplicitTransactionUsageException : public QueryRuntimeException { + public: + using QueryRuntimeException::QueryRuntimeException; +}; + +/** + * An exception for serialization error + */ +class TransactionSerializationException : public QueryException { + public: + using QueryException::QueryException; + TransactionSerializationException() + : QueryException( + "Cannot resolve conflicting transactions. You can retry this transaction when the conflicting transaction " + "is finished") {} +}; + +class ReconstructionException : public QueryException { + public: + ReconstructionException() + : QueryException( + "Record invalid after WITH clause. Most likely deleted by a " + "preceeding DELETE.") {} +}; + +class RemoveAttachedVertexException : public QueryRuntimeException { + public: + RemoveAttachedVertexException() + : QueryRuntimeException( + "Failed to remove node because of it's existing " + "connections. Consider using DETACH DELETE.") {} +}; + +class UserModificationInMulticommandTxException : public QueryException { + public: + UserModificationInMulticommandTxException() + : QueryException("Authentication clause not allowed in multicommand transactions.") {} +}; + +class InvalidArgumentsException : public QueryException { + public: + InvalidArgumentsException(const std::string &argument_name, const std::string &message) + : QueryException(fmt::format("Invalid arguments sent: {} - {}", argument_name, message)) {} +}; + +class ReplicationModificationInMulticommandTxException : public QueryException { + public: + ReplicationModificationInMulticommandTxException() + : QueryException("Replication clause not allowed in multicommand transactions.") {} +}; + +class LockPathModificationInMulticommandTxException : public QueryException { + public: + LockPathModificationInMulticommandTxException() + : QueryException("Lock path query not allowed in multicommand transactions.") {} +}; + +class FreeMemoryModificationInMulticommandTxException : public QueryException { + public: + FreeMemoryModificationInMulticommandTxException() + : QueryException("Free memory query not allowed in multicommand transactions.") {} +}; + +class TriggerModificationInMulticommandTxException : public QueryException { + public: + TriggerModificationInMulticommandTxException() + : QueryException("Trigger queries not allowed in multicommand transactions.") {} +}; + +class StreamQueryInMulticommandTxException : public QueryException { + public: + StreamQueryInMulticommandTxException() + : QueryException("Stream queries are not allowed in multicommand transactions.") {} +}; + +class IsolationLevelModificationInMulticommandTxException : public QueryException { + public: + IsolationLevelModificationInMulticommandTxException() + : QueryException("Isolation level cannot be modified in multicommand transactions.") {} +}; + +class CreateSnapshotInMulticommandTxException final : public QueryException { + public: + CreateSnapshotInMulticommandTxException() + : QueryException("Snapshot cannot be created in multicommand transactions.") {} +}; + +class SettingConfigInMulticommandTxException final : public QueryException { + public: + SettingConfigInMulticommandTxException() + : QueryException("Settings cannot be changed or fetched in multicommand transactions.") {} +}; + +class VersionInfoInMulticommandTxException : public QueryException { + public: + VersionInfoInMulticommandTxException() + : QueryException("Version info query not allowed in multicommand transactions.") {} +}; + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/ast/ast.lcp b/src/query/v2/frontend/ast/ast.lcp new file mode 100644 index 000000000..b858ab71f --- /dev/null +++ b/src/query/v2/frontend/ast/ast.lcp @@ -0,0 +1,2676 @@ +;; Copyright 2022 Memgraph Ltd. +;; +;; Use of this software is governed by the Business Source License +;; included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +;; License, and you may not use this file except in compliance with the Business Source License. +;; +;; As of the Change Date specified in that file, in accordance with +;; the Business Source License, use of this software will be governed +;; by the Apache License, Version 2.0, included in the file +;; licenses/APL.txt. + +#>cpp +#pragma once + +#include <memory> +#include <unordered_map> +#include <variant> +#include <vector> + +#include "query/v2/frontend/ast/ast_visitor.hpp" +#include "query/v2/frontend/semantic/symbol.hpp" +#include "query/v2/interpret/awesome_memgraph_functions.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/property_value.hpp" +#include "utils/typeinfo.hpp" + +cpp<# + +(lcp:namespace memgraph) +(lcp:namespace query) +(lcp:namespace v2) + +(defun slk-save-ast-pointer (member) + #>cpp + query::v2::SaveAstPointer(self.${member}, builder); + cpp<#) + +(defun slk-load-ast-pointer (type) + (lambda (member) + #>cpp + self->${member} = query::v2::LoadAstPointer<query::v2::${type}>(storage, reader); + cpp<#)) + +(defun slk-save-ast-vector (member) + #>cpp + size_t size = self.${member}.size(); + slk::Save(size, builder); + for (const auto *val : self.${member}) { + query::v2::SaveAstPointer(val, builder); + } + cpp<#) + +(defun slk-load-ast-vector (type) + (lambda (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + self->${member}.resize(size); + for (size_t i = 0; i < size; ++i) { + self->${member}[i] = query::v2::LoadAstPointer<query::v2::${type}>(storage, reader); + } + cpp<#)) + +(defun slk-save-property-map (member) + #>cpp + size_t size = self.${member}.size(); + slk::Save(size, builder); + for (const auto &entry : self.${member}) { + slk::Save(entry.first, builder); + query::v2::SaveAstPointer(entry.second, builder); + } + cpp<#) + +(defun slk-load-property-map (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + for (size_t i = 0; i < size; ++i) { + query::v2::PropertyIx key; + slk::Load(&key, reader, storage); + auto *value = query::v2::LoadAstPointer<query::v2::Expression>(storage, reader); + self->${member}.emplace(key, value); + } + cpp<#) + +(defun clone-property-map (source dest) + #>cpp + for (const auto &entry : ${source}) { + PropertyIx key = storage->GetPropertyIx(entry.first.name); + ${dest}[key] = entry.second->Clone(storage); + } + cpp<#) + +(defun slk-save-expression-map (member) + #>cpp + size_t size = self.${member}.size(); + slk::Save(size, builder); + for (const auto &entry : self.${member}) { + query::v2::SaveAstPointer(entry.first, builder); + query::v2::SaveAstPointer(entry.second, builder); + } + cpp<#) + +(defun slk-load-expression-map (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + for (size_t i = 0; i < size; ++i) { + auto *key = query::v2::LoadAstPointer<query::v2::Expression>(storage, reader); + auto *value = query::v2::LoadAstPointer<query::v2::Expression>(storage, reader); + self->${member}.emplace(key, value); + } + cpp<#) + +(defun clone-expression-map (source dest) + #>cpp + for (const auto &[key, value] : ${source}) { + ${dest}[key->Clone(storage)] = value->Clone(storage); + } + cpp<#) + +(defun slk-load-name-ix (name-type) + (lambda (member) + #>cpp + self->${member} = storage->Get${name-type}Ix(self->name).ix; + cpp<#)) + +(defun clone-name-ix-vector (name-type) + (lambda (source dest) + #>cpp + ${dest}.resize(${source}.size()); + for (auto i = 0; i < ${dest}.size(); ++i) { + ${dest}[i] = storage->Get${name-type}Ix(${source}[i].name); + } + cpp<#)) + +;; The following index structs serve as a decoupling point of AST from +;; concrete database types. All the names are collected in AstStorage, and can +;; be indexed through these instances. This means that we can create a vector +;; of concrete database types in the same order as all of the names and use the +;; same index to get the correct behaviour. Additionally, each index is +;; accompanied with the duplicated name found at the same index. The primary +;; reason for this duplication is simplifying the Clone and serialization API. +;; When an old index is being cloned or deserialized into a new AstStorage, we +;; request the new `ix` from the new AstStorage for the same `name`. If we +;; didn't do this, we would have to duplicate the old storage, which would +;; require having access to that storage. This in turn would complicate the +;; client code. +(lcp:define-struct label-ix () + ((name "std::string") + (ix :int64_t + :dont-save t + :slk-load (slk-load-name-ix "Label"))) + (:serialize (:slk :load-args '((storage "query::v2::AstStorage *"))))) + +(lcp:define-struct property-ix () + ((name "std::string") + (ix :int64_t + :dont-save t + :slk-load (slk-load-name-ix "Property"))) + (:serialize (:slk :load-args '((storage "query::v2::AstStorage *"))))) + +(lcp:define-struct edge-type-ix () + ((name "std::string") + (ix :int64_t + :dont-save t + :slk-load (slk-load-name-ix "EdgeType"))) + (:serialize (:slk :load-args '((storage "query::v2::AstStorage *"))))) + +#>cpp +inline bool operator==(const LabelIx &a, const LabelIx &b) { + return a.ix == b.ix && a.name == b.name; +} + +inline bool operator!=(const LabelIx &a, const LabelIx &b) { return !(a == b); } + +inline bool operator==(const PropertyIx &a, const PropertyIx &b) { + return a.ix == b.ix && a.name == b.name; +} + +inline bool operator!=(const PropertyIx &a, const PropertyIx &b) { + return !(a == b); +} + +inline bool operator==(const EdgeTypeIx &a, const EdgeTypeIx &b) { + return a.ix == b.ix && a.name == b.name; +} + +inline bool operator!=(const EdgeTypeIx &a, const EdgeTypeIx &b) { + return !(a == b); +} +cpp<# + +(lcp:pop-namespace) ;; namespace v2 +(lcp:pop-namespace) ;; namespace query +(lcp:pop-namespace) ;; namespace memgraph + +#>cpp +namespace std { + +template <> +struct hash<memgraph::query::v2::LabelIx> { + size_t operator()(const memgraph::query::v2::LabelIx &label) const { return label.ix; } +}; + +template <> +struct hash<memgraph::query::v2::PropertyIx> { + size_t operator()(const memgraph::query::v2::PropertyIx &prop) const { return prop.ix; } +}; + +template <> +struct hash<memgraph::query::v2::EdgeTypeIx> { + size_t operator()(const memgraph::query::v2::EdgeTypeIx &edge_type) const { + return edge_type.ix; + } +}; + +} // namespace std +cpp<# + +(lcp:namespace memgraph) +(lcp:namespace query) +(lcp:namespace v2) + +#>cpp + +class Tree; + +// It would be better to call this AstTree, but we already have a class Tree, +// which could be renamed to Node or AstTreeNode, but we also have a class +// called NodeAtom... +class AstStorage { + public: + AstStorage() = default; + AstStorage(const AstStorage &) = delete; + AstStorage &operator=(const AstStorage &) = delete; + AstStorage(AstStorage &&) = default; + AstStorage &operator=(AstStorage &&) = default; + + template <typename T, typename... Args> + T *Create(Args &&... args) { + T *ptr = new T(std::forward<Args>(args)...); + std::unique_ptr<T> tmp(ptr); + storage_.emplace_back(std::move(tmp)); + return ptr; + } + + LabelIx GetLabelIx(const std::string &name) { + return LabelIx{name, FindOrAddName(name, &labels_)}; + } + + PropertyIx GetPropertyIx(const std::string &name) { + return PropertyIx{name, FindOrAddName(name, &properties_)}; + } + + EdgeTypeIx GetEdgeTypeIx(const std::string &name) { + return EdgeTypeIx{name, FindOrAddName(name, &edge_types_)}; + } + + std::vector<std::string> labels_; + std::vector<std::string> edge_types_; + std::vector<std::string> properties_; + + // Public only for serialization access + std::vector<std::unique_ptr<Tree>> storage_; + + private: + int64_t FindOrAddName(const std::string &name, + std::vector<std::string> *names) { + for (int64_t i = 0; i < names->size(); ++i) { + if ((*names)[i] == name) { + return i; + } + } + names->push_back(name); + return names->size() - 1; + } +}; +cpp<# + +(lcp:define-class tree () + () + (:abstractp t) + (:public + #>cpp + Tree() = default; + virtual ~Tree() {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk :load-args '((storage "query::v2::AstStorage *")))) + (:clone :return-type (lambda (typename) + (format nil "~A*" typename)) + :args '((storage "AstStorage *")) + :init-object (lambda (var typename) + (format nil "~A* ~A = storage->Create<~A>();" + typename var typename)))) + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;;; Expressions +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +(lcp:define-class expression (tree "::utils::Visitable<HierarchicalTreeVisitor>" + "::utils::Visitable<ExpressionVisitor<TypedValue>>" + "::utils::Visitable<ExpressionVisitor<void>>") + () + (:abstractp t) + (:public + #>cpp + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + using utils::Visitable<ExpressionVisitor<TypedValue>>::Accept; + using utils::Visitable<ExpressionVisitor<void>>::Accept; + + Expression() = default; + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk :ignore-other-base-classes t)) + (:clone :ignore-other-base-classes t) + (:type-info :ignore-other-base-classes t)) + +(lcp:define-class where (tree "::utils::Visitable<HierarchicalTreeVisitor>") + ((expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:public + #>cpp + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + Where() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + explicit Where(Expression *expression) : expression_(expression) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk :ignore-other-base-classes t)) + (:clone :ignore-other-base-classes t) + (:type-info :ignore-other-base-classes t)) + +(lcp:define-class binary-operator (expression) + ((expression1 "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (expression2 "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:abstractp t) + (:public + #>cpp + BinaryOperator() = default; + cpp<#) + (:protected + #>cpp + BinaryOperator(Expression *expression1, Expression *expression2) + : expression1_(expression1), expression2_(expression2) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class unary-operator (expression) + ((expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:abstractp t) + (:public + #>cpp + UnaryOperator() = default; + cpp<#) + (:protected + #>cpp + explicit UnaryOperator(Expression *expression) : expression_(expression) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(macrolet ((define-binary-operators () + `(lcp:cpp-list + ,@(loop for op in + '(or-operator xor-operator and-operator addition-operator + subtraction-operator multiplication-operator division-operator + mod-operator not-equal-operator equal-operator less-operator + greater-operator less-equal-operator greater-equal-operator + in-list-operator subscript-operator) + collecting + `(lcp:define-class ,op (binary-operator) + () + (:public + #>cpp + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + using BinaryOperator::BinaryOperator; + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)))))) + (define-binary-operators)) + +(macrolet ((define-unary-operators () + `(lcp:cpp-list + ,@(loop for op in + '(not-operator unary-plus-operator + unary-minus-operator is-null-operator) + collecting + `(lcp:define-class ,op (unary-operator) + () + (:public + #>cpp + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + using UnaryOperator::UnaryOperator; + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)))))) + (define-unary-operators)) + +(lcp:define-class aggregation (binary-operator) + ((op "Op" :scope :public) + (symbol-pos :int32_t :initval -1 :scope :public + :documentation "Symbol table position of the symbol this Aggregation is mapped to.")) + (:public + (lcp:define-enum op + (count min max sum avg collect-list collect-map) + (:serialize)) + #>cpp + Aggregation() = default; + + static const constexpr char *const kCount = "COUNT"; + static const constexpr char *const kMin = "MIN"; + static const constexpr char *const kMax = "MAX"; + static const constexpr char *const kSum = "SUM"; + static const constexpr char *const kAvg = "AVG"; + static const constexpr char *const kCollect = "COLLECT"; + + static std::string OpToString(Op op) { + const char *op_strings[] = {kCount, kMin, kMax, kSum, + kAvg, kCollect, kCollect}; + return op_strings[static_cast<int>(op)]; + } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + if (expression1_) expression1_->Accept(visitor); + if (expression2_) expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + Aggregation *MapTo(const Symbol &symbol) { + symbol_pos_ = symbol.position(); + return this; + } + cpp<#) + (:protected + #>cpp + // Use only for serialization. + explicit Aggregation(Op op) : op_(op) {} + + /// Aggregation's first expression is the value being aggregated. The second + /// expression is the key used only in COLLECT_MAP. + Aggregation(Expression *expression1, Expression *expression2, Op op) + : BinaryOperator(expression1, expression2), op_(op) { + // COUNT without expression denotes COUNT(*) in cypher. + DMG_ASSERT(expression1 || op == Aggregation::Op::COUNT, + "All aggregations, except COUNT require expression"); + DMG_ASSERT((expression2 == nullptr) ^ (op == Aggregation::Op::COLLECT_MAP), + "The second expression is obligatory in COLLECT_MAP and " + "invalid otherwise"); + } + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class list-slicing-operator (expression) + ((list "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (lower-bound "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (upper-bound "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:public + #>cpp + ListSlicingOperator() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = list_->Accept(visitor); + if (cont && lower_bound_) { + cont = lower_bound_->Accept(visitor); + } + if (cont && upper_bound_) { + upper_bound_->Accept(visitor); + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + ListSlicingOperator(Expression *list, Expression *lower_bound, + Expression *upper_bound) + : list_(list), lower_bound_(lower_bound), upper_bound_(upper_bound) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class if-operator (expression) + ((condition "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "None of the expressions should be nullptr. If there is no else_expression, you should make it null PrimitiveLiteral.") + (then-expression "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (else-expression "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:public + #>cpp + IfOperator() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + condition_->Accept(visitor) && then_expression_->Accept(visitor) && + else_expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + IfOperator(Expression *condition, Expression *then_expression, + Expression *else_expression) + : condition_(condition), + then_expression_(then_expression), + else_expression_(else_expression) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class base-literal (expression) + () + (:abstractp t) + (:public + #>cpp + BaseLiteral() = default; + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class primitive-literal (base-literal) + ((value "::storage::v3::PropertyValue" :scope :public) + (token-position :int32_t :scope :public :initval -1 + :documentation "This field contains token position of literal used to create PrimitiveLiteral object. If PrimitiveLiteral object is not created from query, leave its value at -1.")) + (:public + #>cpp + PrimitiveLiteral() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + DEFVISITABLE(HierarchicalTreeVisitor); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:protected + #>cpp + template <typename T> + explicit PrimitiveLiteral(T value) : value_(value) {} + template <typename T> + PrimitiveLiteral(T value, int token_position) + : value_(value), token_position_(token_position) {} + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class list-literal (base-literal) + ((elements "std::vector<Expression *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Expression"))) + (:public + #>cpp + ListLiteral() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto expr_ptr : elements_) + if (!expr_ptr->Accept(visitor)) break; + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + explicit ListLiteral(const std::vector<Expression *> &elements) + : elements_(elements) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class map-literal (base-literal) + ((elements "std::unordered_map<PropertyIx, Expression *>" + :slk-save #'slk-save-property-map + :slk-load #'slk-load-property-map + :clone #'clone-property-map + :scope :public)) + (:public + #>cpp + MapLiteral() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto pair : elements_) + if (!pair.second->Accept(visitor)) break; + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + explicit MapLiteral( + const std::unordered_map<PropertyIx, Expression *> &elements) + : elements_(elements) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class identifier (expression) + ((name "std::string" :scope :public) + (user-declared :bool :initval "true" :scope :public) + (symbol-pos :int32_t :initval -1 :scope :public + :documentation "Symbol table position of the symbol this Identifier is mapped to.")) + (:public + #>cpp + Identifier() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + DEFVISITABLE(HierarchicalTreeVisitor); + + Identifier *MapTo(const Symbol &symbol) { + symbol_pos_ = symbol.position(); + return this; + } + + explicit Identifier(const std::string &name) : name_(name) {} + Identifier(const std::string &name, bool user_declared) + : name_(name), user_declared_(user_declared) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class property-lookup (expression) + ((expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (property "PropertyIx" :scope :public + :slk-load (lambda (member) + #>cpp + slk::Load(&self->${member}, reader, storage); + cpp<#) + :clone (lambda (source dest) + #>cpp + ${dest} = storage->GetPropertyIx(${source}.name); + cpp<#))) + (:public + #>cpp + PropertyLookup() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + PropertyLookup(Expression *expression, PropertyIx property) + : expression_(expression), property_(property) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class labels-test (expression) + ((expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (labels "std::vector<LabelIx>" :scope :public + :slk-load (lambda (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + self->${member}.resize(size); + for (size_t i = 0; i < size; ++i) { + slk::Load(&self->${member}[i], reader, storage); + } + cpp<#) + :clone (clone-name-ix-vector "Label"))) + (:public + #>cpp + LabelsTest() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + LabelsTest(Expression *expression, const std::vector<LabelIx> &labels) + : expression_(expression), labels_(labels) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class function (expression) + ((arguments "std::vector<Expression *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Expression")) + (function-name "std::string" :scope :public) + (function "std::function<TypedValue(const TypedValue *, int64_t, + const FunctionContext &)>" + :scope :public + :dont-save t + :clone :copy + :slk-load (lambda (member) + #>cpp + self->${member} = query::v2::NameToFunction(self->function_name_); + cpp<#))) + (:public + #>cpp + Function() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto *argument : arguments_) { + if (!argument->Accept(visitor)) break; + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + Function(const std::string &function_name, + const std::vector<Expression *> &arguments) + : arguments_(arguments), + function_name_(function_name), + function_(NameToFunction(function_name_)) { + if (!function_) { + throw SemanticException("Function '{}' doesn't exist.", function_name); + } + } + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class reduce (expression) + ((accumulator "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier") + :documentation "Identifier for the accumulating variable") + (initializer "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "Expression which produces the initial accumulator value.") + (identifier "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier") + :documentation "Identifier for the list element.") + (list "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "Expression which produces a list to be reduced.") + (expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "Expression which does the reduction, i.e. produces the new accumulator value.")) + (:public + #>cpp + Reduce() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + accumulator_->Accept(visitor) && initializer_->Accept(visitor) && + identifier_->Accept(visitor) && list_->Accept(visitor) && + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + Reduce(Identifier *accumulator, Expression *initializer, Identifier *identifier, + Expression *list, Expression *expression) + : accumulator_(accumulator), + initializer_(initializer), + identifier_(identifier), + list_(list), + expression_(expression) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class coalesce (expression) + ((expressions "std::vector<Expression *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Expression") + :documentation "A list of expressions to evaluate. None of the expressions should be nullptr.")) + (:public + #>cpp + Coalesce() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto *expr : expressions_) { + if (!expr->Accept(visitor)) break; + } + } + return visitor.PostVisit(*this); + } + cpp<# + ) + (:private + #>cpp + explicit Coalesce(const std::vector<Expression *> &expressions) + : expressions_(expressions) {} + + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class extract (expression) + ((identifier "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier") + :documentation "Identifier for the list element.") + (list "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "Expression which produces a list which will be extracted.") + (expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "Expression which produces the new value for list element.")) + (:public + #>cpp + Extract() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_->Accept(visitor) && + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + Extract(Identifier *identifier, Expression *list, Expression *expression) + : identifier_(identifier), list_(list), expression_(expression) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class all (expression) + ((identifier "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier")) + (list-expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (where "Where *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Where"))) + (:public + #>cpp + All() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_expression_->Accept(visitor) && + where_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + All(Identifier *identifier, Expression *list_expression, Where *where) + : identifier_(identifier), + list_expression_(list_expression), + where_(where) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +;; TODO: This is pretty much copy pasted from All. Consider merging Reduce, +;; All, Any, None and Single into something like a higher-order function call +;; which takes a list argument and a function which is applied on list elements. +(lcp:define-class single (expression) + ((identifier "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier")) + (list-expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (where "Where *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Where"))) + (:public + #>cpp + Single() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_expression_->Accept(visitor) && + where_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + Single(Identifier *identifier, Expression *list_expression, Where *where) + : identifier_(identifier), + list_expression_(list_expression), + where_(where) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +;; TODO: This is pretty much copy pasted from All. Consider merging Reduce, +;; All, Any, None and Single into something like a higher-order function call +;; which takes a list argument and a function which is applied on list elements. +(lcp:define-class any (expression) + ((identifier "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier")) + (list-expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (where "Where *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Where"))) + (:public + #>cpp + Any() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_expression_->Accept(visitor) && + where_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + Any(Identifier *identifier, Expression *list_expression, Where *where) + : identifier_(identifier), + list_expression_(list_expression), + where_(where) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +;; TODO: This is pretty much copy pasted from All. Consider merging Reduce, +;; All, Any, None and Single into something like a higher-order function call +;; which takes a list argument and a function which is applied on list elements. +(lcp:define-class none (expression) + ((identifier "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier")) + (list-expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (where "Where *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Where"))) + (:public + #>cpp + None() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_expression_->Accept(visitor) && + where_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + None(Identifier *identifier, Expression *list_expression, Where *where) + : identifier_(identifier), + list_expression_(list_expression), + where_(where) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class parameter-lookup (expression) + ((token-position :int32_t :initval -1 :scope :public + :documentation "This field contains token position of *literal* used to create ParameterLookup object. If ParameterLookup object is not created from a literal leave this value at -1.")) + (:public + #>cpp + ParameterLookup() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + DEFVISITABLE(HierarchicalTreeVisitor); + cpp<#) + (:protected + #>cpp + explicit ParameterLookup(int token_position) + : token_position_(token_position) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class regex-match (expression) + ((string-expr "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (regex "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:public + #>cpp + RegexMatch() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + string_expr_->Accept(visitor) && regex_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:private + #>cpp + friend class AstStorage; + RegexMatch(Expression *string_expr, Expression *regex) + : string_expr_(string_expr), regex_(regex) {} + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class named-expression (tree "::utils::Visitable<HierarchicalTreeVisitor>" + "::utils::Visitable<ExpressionVisitor<TypedValue>>" + "::utils::Visitable<ExpressionVisitor<void>>") + ((name "std::string" :scope :public) + (expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (token-position :int32_t :initval -1 :scope :public + :documentation "This field contains token position of first token in named expression used to create name_. If NamedExpression object is not created from query or it is aliased leave this value at -1.") + (symbol-pos :int32_t :initval -1 :scope :public + :documentation "Symbol table position of the symbol this NamedExpression is mapped to.")) + (:public + #>cpp + using utils::Visitable<ExpressionVisitor<TypedValue>>::Accept; + using utils::Visitable<ExpressionVisitor<void>>::Accept; + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + NamedExpression() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + NamedExpression *MapTo(const Symbol &symbol) { + symbol_pos_ = symbol.position(); + return this; + } + cpp<#) + (:protected + #>cpp + explicit NamedExpression(const std::string &name) : name_(name) {} + NamedExpression(const std::string &name, Expression *expression) + : name_(name), expression_(expression) {} + NamedExpression(const std::string &name, Expression *expression, + int token_position) + : name_(name), expression_(expression), token_position_(token_position) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk :ignore-other-base-classes t)) + (:clone :ignore-other-base-classes t) + (:type-info :ignore-other-base-classes t)) + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;;; END Expressions +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +(lcp:define-class pattern-atom (tree "::utils::Visitable<HierarchicalTreeVisitor>") + ((identifier "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier"))) + (:abstractp t) + (:public + #>cpp + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + PatternAtom() = default; + cpp<#) + (:protected + #>cpp + explicit PatternAtom(Identifier *identifier) : identifier_(identifier) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk :ignore-other-base-classes t)) + (:clone :ignore-other-base-classes t) + (:type-info :ignore-other-base-classes t)) + +(defun clone-variant-properties (source destination) + #>cpp + if (const auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&${source})) { + auto &new_obj_properties = std::get<std::unordered_map<PropertyIx, Expression *>>(${destination}); + for (const auto &[property, value_expression] : *properties) { + PropertyIx key = storage->GetPropertyIx(property.name); + new_obj_properties[key] = value_expression->Clone(storage); + } + } else { + ${destination} = std::get<ParameterLookup *>(${source})->Clone(storage); + } + cpp<#) + +(lcp:define-class node-atom (pattern-atom) + ((labels "std::vector<LabelIx>" :scope :public + :slk-load (lambda (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + self->${member}.resize(size); + for (size_t i = 0; i < size; ++i) { + slk::Load(&self->${member}[i], reader, storage); + } + cpp<#) + :clone (clone-name-ix-vector "Label")) + (properties "std::variant<std::unordered_map<PropertyIx, Expression *>, ParameterLookup*>" + :clone #'clone-variant-properties + :scope :public)) + (:public + #>cpp + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + if (auto* properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&properties_)) { + bool cont = identifier_->Accept(visitor); + for (auto &property : *properties) { + if (cont) { + cont = property.second->Accept(visitor); + } + } + } else { + std::get<ParameterLookup*>(properties_)->Accept(visitor); + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + using PatternAtom::PatternAtom; + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class edge-atom (pattern-atom) + ((type "Type" :initval "Type::SINGLE" :scope :public) + (direction "Direction" :initval "Direction::BOTH" :scope :public) + (edge-types "std::vector<EdgeTypeIx>" :scope :public + :slk-load (lambda (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + self->${member}.resize(size); + for (size_t i = 0; i < size; ++i) { + slk::Load(&self->${member}[i], reader, storage); + } + cpp<#) + :clone (clone-name-ix-vector "EdgeType")) + (properties "std::variant<std::unordered_map<PropertyIx, Expression *>, ParameterLookup*>" + :scope :public + :slk-save #'slk-save-property-map + :slk-load #'slk-load-property-map + :clone #'clone-variant-properties) + (lower-bound "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "Evaluates to lower bound in variable length expands.") + (upper-bound "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "Evaluated to upper bound in variable length expands.") + (filter-lambda "Lambda" :scope :public + :documentation "Filter lambda for variable length expands. Can have an empty expression, but identifiers must be valid, because an optimization pass may inline other expressions into this lambda." + :slk-load (lambda (member) + #>cpp + slk::Load(&self->${member}, reader, storage); + cpp<#)) + (weight-lambda "Lambda" :scope :public + :documentation "Used in weighted shortest path. It must have valid expressions and identifiers. In all other expand types, it is empty." + :slk-load (lambda (member) + #>cpp + slk::Load(&self->${member}, reader, storage); + cpp<#)) + (total-weight "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier") + :documentation "Variable where the total weight for weighted shortest path will be stored.")) + (:public + (lcp:define-enum type + (single depth-first breadth-first weighted-shortest-path) + (:serialize)) + (lcp:define-enum direction + (in out both) + (:serialize)) + (lcp:define-struct lambda () + ((inner-edge "Identifier *" :initval "nullptr" + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier") + :documentation "Argument identifier for the edge currently being traversed.") + (inner-node "Identifier *" :initval "nullptr" + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier") + :documentation "Argument identifier for the destination node of the edge.") + (expression "Expression *" :initval "nullptr" + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "Evaluates the result of the lambda.")) + (:documentation "Lambda for use in filtering or weight calculation during variable expand.") + (:serialize (:slk :load-args '((storage "query::v2::AstStorage *")))) + (:clone :args '((storage "AstStorage *")))) + #>cpp + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = identifier_->Accept(visitor); + if (auto *properties = std::get_if<std::unordered_map<query::v2::PropertyIx, query::v2::Expression *>>(&properties_)) { + for (auto &property : *properties) { + if (cont) { + cont = property.second->Accept(visitor); + } + } + } else { + std::get<ParameterLookup *>(properties_)->Accept(visitor); + } + if (cont && lower_bound_) { + cont = lower_bound_->Accept(visitor); + } + if (cont && upper_bound_) { + cont = upper_bound_->Accept(visitor); + } + if (cont && total_weight_) { + total_weight_->Accept(visitor); + } + } + return visitor.PostVisit(*this); + } + + bool IsVariable() const { + switch (type_) { + case Type::DEPTH_FIRST: + case Type::BREADTH_FIRST: + case Type::WEIGHTED_SHORTEST_PATH: + return true; + case Type::SINGLE: + return false; + } + } + cpp<#) + (:protected + #>cpp + using PatternAtom::PatternAtom; + EdgeAtom(Identifier *identifier, Type type, Direction direction) + : PatternAtom(identifier), type_(type), direction_(direction) {} + + // Creates an edge atom for a SINGLE expansion with the given . + EdgeAtom(Identifier *identifier, Type type, Direction direction, + const std::vector<EdgeTypeIx> &edge_types) + : PatternAtom(identifier), + type_(type), + direction_(direction), + edge_types_(edge_types) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class pattern (tree "::utils::Visitable<HierarchicalTreeVisitor>") + ((identifier "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier")) + (atoms "std::vector<PatternAtom *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "PatternAtom"))) + (:public + #>cpp + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + Pattern() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = identifier_->Accept(visitor); + for (auto &part : atoms_) { + if (cont) { + cont = part->Accept(visitor); + } + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk :ignore-other-base-classes t)) + (:clone :ignore-other-base-classes t) + (:type-info :ignore-other-base-classes t)) + +(lcp:define-class clause (tree "::utils::Visitable<HierarchicalTreeVisitor>") + () + (:abstractp t) + (:public + #>cpp + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + Clause() = default; + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk :ignore-other-base-classes t)) + (:clone :ignore-other-base-classes t) + (:type-info :ignore-other-base-classes t)) + +(lcp:define-class single-query (tree "::utils::Visitable<HierarchicalTreeVisitor>") + ((clauses "std::vector<Clause *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Clause"))) + (:public + #>cpp + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + SingleQuery() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto &clause : clauses_) { + if (!clause->Accept(visitor)) break; + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk :ignore-other-base-classes t)) + (:clone :ignore-other-base-classes t) + (:type-info :ignore-other-base-classes t)) + +(lcp:define-class cypher-union (tree "::utils::Visitable<HierarchicalTreeVisitor>") + ((single-query "SingleQuery *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "SingleQuery")) + (distinct :bool :initval "false" :scope :public) + (union-symbols "std::vector<Symbol>" :scope :public + :documentation "Holds symbols that are created during symbol generation phase. These symbols are used when UNION/UNION ALL combines single query results.")) + (:public + #>cpp + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + CypherUnion() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + single_query_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + explicit CypherUnion(bool distinct) : distinct_(distinct) {} + CypherUnion(bool distinct, SingleQuery *single_query, + std::vector<Symbol> union_symbols) + : single_query_(single_query), + distinct_(distinct), + union_symbols_(union_symbols) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk :ignore-other-base-classes t)) + (:clone :ignore-other-base-classes t) + (:type-info :ignore-other-base-classes t)) + +(lcp:define-class query (tree "::utils::Visitable<QueryVisitor<void>>") + () + (:abstractp t) + (:public + #>cpp + using utils::Visitable<QueryVisitor<void>>::Accept; + + Query() = default; + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk :ignore-other-base-classes t)) + (:clone :ignore-other-base-classes t) + (:type-info :ignore-other-base-classes t)) + +(lcp:define-class cypher-query (query) + ((single-query "SingleQuery *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "SingleQuery") + :documentation "First and potentially only query.") + (cypher-unions "std::vector<CypherUnion *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "CypherUnion") + :documentation "Contains remaining queries that should form and union with `single_query_`.") + (memory-limit "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (memory-scale "size_t" :initval "1024U" :scope :public)) + (:public + #>cpp + CypherQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class explain-query (query) + ((cypher-query "CypherQuery *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "CypherQuery") + :documentation "The CypherQuery to explain.")) + (:public + #>cpp + ExplainQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class profile-query (query) + ((cypher-query "CypherQuery *" + :initval "nullptr" + :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "CypherQuery") + :documentation "The CypherQuery to profile.")) + (:public + #>cpp + ProfileQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class index-query (query) + ((action "Action" :scope :public) + (label "LabelIx" :scope :public + :slk-load (lambda (member) + #>cpp + slk::Load(&self->${member}, reader, storage); + cpp<#) + :clone (lambda (source dest) + #>cpp + ${dest} = storage->GetLabelIx(${source}.name); + cpp<#)) + (properties "std::vector<PropertyIx>" :scope :public + :slk-load (lambda (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + self->${member}.resize(size); + for (size_t i = 0; i < size; ++i) { + slk::Load(&self->${member}[i], reader, storage); + } + cpp<#) + :clone (clone-name-ix-vector "Property"))) + (:public + (lcp:define-enum action + (create drop) + (:serialize)) + + #>cpp + IndexQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:protected + #>cpp + IndexQuery(Action action, LabelIx label, std::vector<PropertyIx> properties) + : action_(action), label_(label), properties_(properties) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class create (clause) + ((patterns "std::vector<Pattern *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Pattern"))) + (:public + #>cpp + Create() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto &pattern : patterns_) { + if (!pattern->Accept(visitor)) break; + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + explicit Create(std::vector<Pattern *> patterns) : patterns_(patterns) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class call-procedure (clause) + ((procedure-name "std::string" :scope :public) + (arguments "std::vector<Expression *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Expression")) + (result-fields "std::vector<std::string>" :scope :public) + (result-identifiers "std::vector<Identifier *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Identifier")) + (memory-limit "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (memory-scale "size_t" :initval "1024U" :scope :public) + (is_write :bool :scope :public)) + (:public + #>cpp + CallProcedure() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = true; + for (auto &arg : arguments_) { + if (!arg->Accept(visitor)) { + cont = false; + break; + } + } + if (cont) { + for (auto &ident : result_identifiers_) { + if (!ident->Accept(visitor)) { + cont = false; + break; + } + } + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class match (clause) + ((patterns "std::vector<Pattern *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Pattern")) + (where "Where *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Where")) + (optional :bool :initval "false" :scope :public)) + (:public + #>cpp + Match() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = true; + for (auto &pattern : patterns_) { + if (!pattern->Accept(visitor)) { + cont = false; + break; + } + } + if (cont && where_) { + where_->Accept(visitor); + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + explicit Match(bool optional) : optional_(optional) {} + Match(bool optional, Where *where, std::vector<Pattern *> patterns) + : patterns_(patterns), where_(where), optional_(optional) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-enum ordering + (asc desc) + (:documentation "Defines the order for sorting values (ascending or descending).") + (:serialize)) + +(lcp:define-struct sort-item () + ((ordering "Ordering" :scope :public) + (expression "Expression *" + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:serialize (:slk :load-args '((storage "query::v2::AstStorage *")))) + (:clone :args '((storage "AstStorage *")))) + +(lcp:define-struct return-body () + ((distinct :bool :initval "false" + :documentation "True if distinct results should be produced.") + (all-identifiers :bool :initval "false" + :documentation "True if asterisk was found in the return body.") + (named-expressions "std::vector<NamedExpression *>" + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "NamedExpression") + :documentation "Expressions which are used to produce results.") + (order-by "std::vector<SortItem>" + :slk-load (lambda (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + self->${member}.resize(size); + for (size_t i = 0; i < size; ++i) { + slk::Load(&self->${member}[i], reader, storage); + } + cpp<#) + :documentation "Expressions used for ordering the results.") + (skip "Expression *" :initval "nullptr" + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "Optional expression on how many results to skip.") + (limit "Expression *" :initval "nullptr" + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "Optional expression on how many results to produce.")) + (:documentation "Contents common to @c Return and @c With clauses.") + (:serialize (:slk :load-args '((storage "query::v2::AstStorage *")))) + (:clone :args '((storage "AstStorage *")))) + +(lcp:define-class return (clause) + ((body "ReturnBody" :scope :public + :slk-load (lambda (member) + #>cpp + slk::Load(&self->${member}, reader, storage); + cpp<#))) + (:public + #>cpp + Return() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = true; + for (auto &expr : body_.named_expressions) { + if (!expr->Accept(visitor)) { + cont = false; + break; + } + } + if (cont) { + for (auto &order_by : body_.order_by) { + if (!order_by.expression->Accept(visitor)) { + cont = false; + break; + } + } + } + if (cont && body_.skip) cont = body_.skip->Accept(visitor); + if (cont && body_.limit) cont = body_.limit->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + explicit Return(ReturnBody &body) : body_(body) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class with (clause) + ((body "ReturnBody" :scope :public + :slk-load (lambda (member) + #>cpp + slk::Load(&self->${member}, reader, storage); + cpp<#)) + (where "Where *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Where"))) + (:public + #>cpp + With() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = true; + for (auto &expr : body_.named_expressions) { + if (!expr->Accept(visitor)) { + cont = false; + break; + } + } + if (cont) { + for (auto &order_by : body_.order_by) { + if (!order_by.expression->Accept(visitor)) { + cont = false; + break; + } + } + } + if (cont && where_) cont = where_->Accept(visitor); + if (cont && body_.skip) cont = body_.skip->Accept(visitor); + if (cont && body_.limit) cont = body_.limit->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + With(ReturnBody &body, Where *where) : body_(body), where_(where) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class delete (clause) + ((expressions "std::vector<Expression *>" + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Expression") + :scope :public) + (detach :bool :initval "false" :scope :public)) + (:public + #>cpp + Delete() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto &expr : expressions_) { + if (!expr->Accept(visitor)) break; + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + Delete(bool detach, std::vector<Expression *> expressions) + : expressions_(expressions), detach_(detach) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class set-property (clause) + ((property-lookup "PropertyLookup *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "PropertyLookup")) + (expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:public + #>cpp + SetProperty() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + property_lookup_->Accept(visitor) && expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + SetProperty(PropertyLookup *property_lookup, Expression *expression) + : property_lookup_(property_lookup), expression_(expression) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class set-properties (clause) + ((identifier "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier")) + (expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (update :bool :initval "false" :scope :public)) + (:public + #>cpp + SetProperties() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + SetProperties(Identifier *identifier, Expression *expression, + bool update = false) + : identifier_(identifier), expression_(expression), update_(update) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class set-labels (clause) + ((identifier "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier")) + (labels "std::vector<LabelIx>" :scope :public + :slk-load (lambda (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + self->${member}.resize(size); + for (size_t i = 0; i < size; ++i) { + slk::Load(&self->${member}[i], reader, storage); + } + cpp<#) + :clone (clone-name-ix-vector "Label"))) + (:public + #>cpp + SetLabels() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + SetLabels(Identifier *identifier, const std::vector<LabelIx> &labels) + : identifier_(identifier), labels_(labels) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class remove-property (clause) + ((property-lookup "PropertyLookup *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "PropertyLookup"))) + (:public + #>cpp + RemoveProperty() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + property_lookup_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + explicit RemoveProperty(PropertyLookup *property_lookup) + : property_lookup_(property_lookup) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class remove-labels (clause) + ((identifier "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier")) + (labels "std::vector<LabelIx>" :scope :public + :slk-load (lambda (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + self->${member}.resize(size); + for (size_t i = 0; i < size; ++i) { + slk::Load(&self->${member}[i], reader, storage); + } + cpp<#) + :clone (clone-name-ix-vector "Label"))) + (:public + #>cpp + RemoveLabels() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + RemoveLabels(Identifier *identifier, const std::vector<LabelIx> &labels) + : identifier_(identifier), labels_(labels) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class merge (clause) + ((pattern "Pattern *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Pattern")) + (on-match "std::vector<Clause *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Clause")) + (on-create "std::vector<Clause *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Clause"))) + (:public + #>cpp + Merge() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = pattern_->Accept(visitor); + if (cont) { + for (auto &set : on_match_) { + if (!set->Accept(visitor)) { + cont = false; + break; + } + } + } + if (cont) { + for (auto &set : on_create_) { + if (!set->Accept(visitor)) { + cont = false; + break; + } + } + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + Merge(Pattern *pattern, std::vector<Clause *> on_match, + std::vector<Clause *> on_create) + : pattern_(pattern), on_match_(on_match), on_create_(on_create) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class unwind (clause) + ((named-expression "NamedExpression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "NamedExpression"))) + (:public + #>cpp + Unwind() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + named_expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + explicit Unwind(NamedExpression *named_expression) + : named_expression_(named_expression) { + DMG_ASSERT(named_expression, "Unwind cannot take nullptr for named_expression"); + } + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class auth-query (query) + ((action "Action" :scope :public) + (user "std::string" :scope :public) + (role "std::string" :scope :public) + (user-or-role "std::string" :scope :public) + (password "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (privileges "std::vector<Privilege>" :scope :public)) + (:public + (lcp:define-enum action + (create-role drop-role show-roles create-user set-password drop-user + show-users set-role clear-role grant-privilege deny-privilege + revoke-privilege show-privileges show-role-for-user + show-users-for-role) + (:serialize)) + (lcp:define-enum privilege + (create delete match merge set remove index stats auth constraint + dump replication durability read_file free_memory trigger config stream module_read module_write + websocket) + (:serialize)) + #>cpp + AuthQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:protected + #>cpp + AuthQuery(Action action, std::string user, std::string role, + std::string user_or_role, Expression *password, + std::vector<Privilege> privileges) + : action_(action), + user_(user), + role_(role), + user_or_role_(user_or_role), + password_(password), + privileges_(privileges) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +;; TODO: Generate this via LCP +#>cpp +/// Constant that holds all available privileges. +const std::vector<AuthQuery::Privilege> kPrivilegesAll = { + AuthQuery::Privilege::CREATE, AuthQuery::Privilege::DELETE, + AuthQuery::Privilege::MATCH, AuthQuery::Privilege::MERGE, + AuthQuery::Privilege::SET, AuthQuery::Privilege::REMOVE, + AuthQuery::Privilege::INDEX, AuthQuery::Privilege::STATS, + AuthQuery::Privilege::AUTH, + AuthQuery::Privilege::CONSTRAINT, AuthQuery::Privilege::DUMP, + AuthQuery::Privilege::REPLICATION, + AuthQuery::Privilege::READ_FILE, + AuthQuery::Privilege::DURABILITY, + AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER, + AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, + AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, + AuthQuery::Privilege::WEBSOCKET}; +cpp<# + +(lcp:define-class info-query (query) + ((info-type "InfoType" :scope :public)) + (:public + (lcp:define-enum info-type + (storage index constraint) + (:serialize)) + + #>cpp + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-struct constraint () + ((type "Type") + (label "LabelIx" :scope :public + :slk-load (lambda (member) + #>cpp + slk::Load(&self->${member}, reader, storage); + cpp<#) + :clone (lambda (source dest) + #>cpp + ${dest} = storage->GetLabelIx(${source}.name); + cpp<#)) + (properties "std::vector<PropertyIx>" :scope :public + :slk-load (lambda (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + self->${member}.resize(size); + for (size_t i = 0; i < size; ++i) { + slk::Load(&self->${member}[i], reader, storage); + } + cpp<#) + :clone (clone-name-ix-vector "Property"))) + (:public + (lcp:define-enum type (exists unique node-key) + (:serialize (:lcp)))) + (:serialize (:slk :load-args '((storage "query::v2::AstStorage *")))) + (:clone :args '((storage "AstStorage *")))) + +(lcp:define-class constraint-query (query) + ((action-type "ActionType" :scope :public) + (constraint "Constraint" :scope :public + :slk-load (lambda (member) + #>cpp + slk::Load(&self->${member}, reader, storage); + cpp<#))) + (:public + (lcp:define-enum action-type + (create drop) + (:serialize)) + + #>cpp + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class dump-query (query) () + (:public + #>cpp + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class replication-query (query) + ((action "Action" :scope :public) + (role "ReplicationRole" :scope :public) + (replica_name "std::string" :scope :public) + (socket_address "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (port "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (sync_mode "SyncMode" :scope :public) + (timeout "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + + (:public + (lcp:define-enum action + (set-replication-role show-replication-role register-replica + drop-replica show-replicas) + (:serialize)) + (lcp:define-enum replication-role + (main replica) + (:serialize)) + (lcp:define-enum sync-mode + (sync async) + (:serialize)) + (lcp:define-enum replica-state + (ready replicating recovery invalid) + (:serialize)) + #>cpp + ReplicationQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class lock-path-query (query) + ((action "Action" :scope :public)) + + (:public + (lcp:define-enum action + (lock-path unlock-path) + (:serialize)) + #>cpp + LockPathQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class load-csv (clause) + ((file "Expression *" :scope :public) + (with_header "bool" :scope :public) + (ignore_bad "bool" :scope :public) + (delimiter "Expression *" :initval "nullptr" :scope :public) + (quote "Expression *" :initval "nullptr" :scope :public) + (row_var "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier"))) + + (:public + #>cpp + LoadCsv() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + row_var_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + explicit LoadCsv(Expression *file, bool with_header, bool ignore_bad, Expression *delimiter, + Expression* quote, Identifier* row_var) + : file_(file), + with_header_(with_header), + ignore_bad_(ignore_bad), + delimiter_(delimiter), + quote_(quote), + row_var_(row_var) { + DMG_ASSERT(row_var, "LoadCsv cannot take nullptr for identifier"); + } + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class free-memory-query (query) () + (:public + #>cpp + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class trigger-query (query) + ((action "Action" :scope :public) + (event_type "EventType" :scope :public) + (trigger_name "std::string" :scope :public) + (before_commit "bool" :scope :public) + (statement "std::string" :scope :public)) + + (:public + (lcp:define-enum action + (create-trigger drop-trigger show-triggers) + (:serialize)) + (lcp:define-enum event-type + (any vertex_create edge_create create vertex_delete edge_delete delete vertex_update edge_update update) + (:serialize)) + #>cpp + TriggerQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class isolation-level-query (query) + ((isolation_level "IsolationLevel" :scope :public) + (isolation_level_scope "IsolationLevelScope" :scope :public)) + + (:public + (lcp:define-enum isolation-level + (snapshot-isolation read-committed read-uncommitted) + (:serialize)) + (lcp:define-enum isolation-level-scope + (next session global) + (:serialize)) + #>cpp + IsolationLevelQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class create-snapshot-query (query) () + (:public + #>cpp + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:serialize (:slk)) + (:clone)) + +(defun clone-variant-topic-names (source destination) + #>cpp + if (auto *topic_expression = std::get_if<Expression*>(&${source})) { + if (*topic_expression == nullptr) { + ${destination} = nullptr; + } else { + ${destination} = (*topic_expression)->Clone(storage); + } + } else { + ${destination} = std::get<std::vector<std::string>>(${source}); + } + cpp<#) + +(lcp:define-class stream-query (query) + ((action "Action" :scope :public) + (type "Type" :scope :public) + (stream_name "std::string" :scope :public) + + (batch_limit "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (timeout "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + + (transform_name "std::string" :scope :public) + (batch_interval "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (batch_size "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + + (topic_names "std::variant<Expression*, std::vector<std::string>>" :initval "nullptr" + :clone #'clone-variant-topic-names + :scope :public) + (consumer_group "std::string" :scope :public) + (bootstrap_servers "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + + (service_url "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + + (configs "std::unordered_map<Expression *, Expression *>" :scope :public + :slk-save #'slk-save-expression-map + :slk-load #'slk-load-expression-map + :clone #'clone-expression-map) + + (credentials "std::unordered_map<Expression *, Expression *>" :scope :public + :slk-save #'slk-save-expression-map + :slk-load #'slk-load-expression-map + :clone #'clone-expression-map)) + + (:public + (lcp:define-enum action + (create-stream drop-stream start-stream stop-stream start-all-streams stop-all-streams show-streams check-stream) + (:serialize)) + (lcp:define-enum type + (kafka pulsar) + (:serialize)) + #>cpp + StreamQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class setting-query (query) + ((action "Action" :scope :public) + (setting_name "Expression *" :initval "nullptr" :scope :public) + (setting_value "Expression *" :initval "nullptr" :scope :public)) + + (:public + (lcp:define-enum action + (show-setting show-all-settings set-setting) + (:serialize)) + #>cpp + SettingQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class version-query (query) () + (:public + #>cpp + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class foreach (clause) + ((named_expression "NamedExpression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (clauses "std::vector<Clause *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Clause"))) + (:public + #>cpp + Foreach() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + named_expression_->Accept(visitor); + for (auto &clause : clauses_) { + clause->Accept(visitor); + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + Foreach(NamedExpression *expression, std::vector<Clause *> clauses) + : named_expression_(expression), clauses_(clauses) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:pop-namespace) ;; namespace v2 +(lcp:pop-namespace) ;; namespace query +(lcp:pop-namespace) ;; namespace memgraph diff --git a/src/query/v2/frontend/ast/ast_visitor.hpp b/src/query/v2/frontend/ast/ast_visitor.hpp new file mode 100644 index 000000000..3cd7f9074 --- /dev/null +++ b/src/query/v2/frontend/ast/ast_visitor.hpp @@ -0,0 +1,133 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "utils/visitor.hpp" + +namespace memgraph::query::v2 { + +// Forward declares for Tree visitors. +class CypherQuery; +class SingleQuery; +class CypherUnion; +class NamedExpression; +class Identifier; +class PropertyLookup; +class LabelsTest; +class Aggregation; +class Function; +class Reduce; +class Coalesce; +class Extract; +class All; +class Single; +class Any; +class None; +class ParameterLookup; +class CallProcedure; +class Create; +class Match; +class Return; +class With; +class Pattern; +class NodeAtom; +class EdgeAtom; +class PrimitiveLiteral; +class ListLiteral; +class MapLiteral; +class OrOperator; +class XorOperator; +class AndOperator; +class NotOperator; +class AdditionOperator; +class SubtractionOperator; +class MultiplicationOperator; +class DivisionOperator; +class ModOperator; +class UnaryPlusOperator; +class UnaryMinusOperator; +class IsNullOperator; +class NotEqualOperator; +class EqualOperator; +class LessOperator; +class GreaterOperator; +class LessEqualOperator; +class GreaterEqualOperator; +class InListOperator; +class SubscriptOperator; +class ListSlicingOperator; +class IfOperator; +class Delete; +class Where; +class SetProperty; +class SetProperties; +class SetLabels; +class RemoveProperty; +class RemoveLabels; +class Merge; +class Unwind; +class AuthQuery; +class ExplainQuery; +class ProfileQuery; +class IndexQuery; +class InfoQuery; +class ConstraintQuery; +class RegexMatch; +class DumpQuery; +class ReplicationQuery; +class LockPathQuery; +class LoadCsv; +class FreeMemoryQuery; +class TriggerQuery; +class IsolationLevelQuery; +class CreateSnapshotQuery; +class StreamQuery; +class SettingQuery; +class VersionQuery; +class Foreach; + +using TreeCompositeVisitor = utils::CompositeVisitor< + SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator, + SubtractionOperator, MultiplicationOperator, DivisionOperator, ModOperator, NotEqualOperator, EqualOperator, + LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator, InListOperator, SubscriptOperator, + ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, + PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None, CallProcedure, + Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels, + RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch, LoadCsv, Foreach>; + +using TreeLeafVisitor = utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup>; + +class HierarchicalTreeVisitor : public TreeCompositeVisitor, public TreeLeafVisitor { + public: + using TreeCompositeVisitor::PostVisit; + using TreeCompositeVisitor::PreVisit; + using TreeLeafVisitor::Visit; + using typename TreeLeafVisitor::ReturnType; +}; + +template <class TResult> +class ExpressionVisitor + : public utils::Visitor< + TResult, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator, + SubtractionOperator, MultiplicationOperator, DivisionOperator, ModOperator, NotEqualOperator, EqualOperator, + LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator, InListOperator, SubscriptOperator, + ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, + MapLiteral, PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, + None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch> {}; + +template <class TResult> +class QueryVisitor + : public utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, InfoQuery, + ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, FreeMemoryQuery, TriggerQuery, + IsolationLevelQuery, CreateSnapshotQuery, StreamQuery, SettingQuery, VersionQuery> {}; + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/ast/cypher_main_visitor.cpp b/src/query/v2/frontend/ast/cypher_main_visitor.cpp new file mode 100644 index 000000000..8a74fbbeb --- /dev/null +++ b/src/query/v2/frontend/ast/cypher_main_visitor.cpp @@ -0,0 +1,2362 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/frontend/ast/cypher_main_visitor.hpp" + +#include <algorithm> +#include <climits> +#include <codecvt> +#include <cstring> +#include <iterator> +#include <limits> +#include <string> +#include <tuple> +#include <type_traits> +#include <unordered_map> +#include <utility> +#include <variant> +#include <vector> + +#include <boost/preprocessor/cat.hpp> + +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/ast_visitor.hpp" +#include "query/v2/frontend/parsing.hpp" +#include "query/v2/interpret/awesome_memgraph_functions.hpp" +#include "query/v2/procedure/module.hpp" +#include "query/v2/stream/common.hpp" +#include "utils/exceptions.hpp" +#include "utils/logging.hpp" +#include "utils/string.hpp" +#include "utils/typeinfo.hpp" + +namespace memgraph::query::v2::frontend { + +const std::string CypherMainVisitor::kAnonPrefix = "anon"; + +namespace { +template <typename TVisitor> +std::optional<std::pair<memgraph::query::v2::Expression *, size_t>> VisitMemoryLimit( + MemgraphCypher::MemoryLimitContext *memory_limit_ctx, TVisitor *visitor) { + MG_ASSERT(memory_limit_ctx); + if (memory_limit_ctx->UNLIMITED()) { + return std::nullopt; + } + + auto memory_limit = memory_limit_ctx->literal()->accept(visitor); + size_t memory_scale = 1024U; + if (memory_limit_ctx->MB()) { + memory_scale = 1024U * 1024U; + } else { + MG_ASSERT(memory_limit_ctx->KB()); + memory_scale = 1024U; + } + + return std::make_pair(memory_limit, memory_scale); +} + +std::string JoinTokens(const auto &tokens, const auto &string_projection, const auto &separator) { + std::vector<std::string> tokens_string; + tokens_string.reserve(tokens.size()); + for (auto *token : tokens) { + tokens_string.emplace_back(string_projection(token)); + } + return utils::Join(tokens_string, separator); +} + +std::string JoinSymbolicNames(antlr4::tree::ParseTreeVisitor *visitor, + const std::vector<MemgraphCypher::SymbolicNameContext *> symbolicNames, + const std::string &separator = ".") { + return JoinTokens( + symbolicNames, [&](auto *token) { return token->accept(visitor).template as<std::string>(); }, separator); +} + +std::string JoinSymbolicNamesWithDotsAndMinus(antlr4::tree::ParseTreeVisitor &visitor, + MemgraphCypher::SymbolicNameWithDotsAndMinusContext &ctx) { + return JoinTokens( + ctx.symbolicNameWithMinus(), [&](auto *token) { return JoinSymbolicNames(&visitor, token->symbolicName(), "-"); }, + "."); +} +} // namespace + +antlrcpp::Any CypherMainVisitor::visitExplainQuery(MemgraphCypher::ExplainQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 2, "ExplainQuery should have exactly two children!"); + auto *cypher_query = ctx->children[1]->accept(this).as<CypherQuery *>(); + auto *explain_query = storage_->Create<ExplainQuery>(); + explain_query->cypher_query_ = cypher_query; + query_ = explain_query; + return explain_query; +} + +antlrcpp::Any CypherMainVisitor::visitProfileQuery(MemgraphCypher::ProfileQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 2, "ProfileQuery should have exactly two children!"); + auto *cypher_query = ctx->children[1]->accept(this).as<CypherQuery *>(); + auto *profile_query = storage_->Create<ProfileQuery>(); + profile_query->cypher_query_ = cypher_query; + query_ = profile_query; + return profile_query; +} + +antlrcpp::Any CypherMainVisitor::visitInfoQuery(MemgraphCypher::InfoQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 2, "InfoQuery should have exactly two children!"); + auto *info_query = storage_->Create<InfoQuery>(); + query_ = info_query; + if (ctx->storageInfo()) { + info_query->info_type_ = InfoQuery::InfoType::STORAGE; + return info_query; + } else if (ctx->indexInfo()) { + info_query->info_type_ = InfoQuery::InfoType::INDEX; + return info_query; + } else if (ctx->constraintInfo()) { + info_query->info_type_ = InfoQuery::InfoType::CONSTRAINT; + return info_query; + } else { + throw utils::NotYetImplemented("Info query: '{}'", ctx->getText()); + } +} + +antlrcpp::Any CypherMainVisitor::visitConstraintQuery(MemgraphCypher::ConstraintQueryContext *ctx) { + auto *constraint_query = storage_->Create<ConstraintQuery>(); + MG_ASSERT(ctx->CREATE() || ctx->DROP()); + if (ctx->CREATE()) { + constraint_query->action_type_ = ConstraintQuery::ActionType::CREATE; + } else if (ctx->DROP()) { + constraint_query->action_type_ = ConstraintQuery::ActionType::DROP; + } + constraint_query->constraint_ = ctx->constraint()->accept(this).as<Constraint>(); + query_ = constraint_query; + return query_; +} + +antlrcpp::Any CypherMainVisitor::visitConstraint(MemgraphCypher::ConstraintContext *ctx) { + Constraint constraint; + MG_ASSERT(ctx->EXISTS() || ctx->UNIQUE() || (ctx->NODE() && ctx->KEY())); + if (ctx->EXISTS()) { + constraint.type = Constraint::Type::EXISTS; + } else if (ctx->UNIQUE()) { + constraint.type = Constraint::Type::UNIQUE; + } else if (ctx->NODE() && ctx->KEY()) { + constraint.type = Constraint::Type::NODE_KEY; + } + constraint.label = AddLabel(ctx->labelName()->accept(this)); + std::string node_name = ctx->nodeName->symbolicName()->accept(this); + for (const auto &var_ctx : ctx->constraintPropertyList()->variable()) { + std::string var_name = var_ctx->symbolicName()->accept(this); + if (var_name != node_name) { + throw SemanticException("All constraint variable should reference node '{}'", node_name); + } + } + for (const auto &prop_lookup : ctx->constraintPropertyList()->propertyLookup()) { + constraint.properties.push_back(prop_lookup->propertyKeyName()->accept(this)); + } + + return constraint; +} + +antlrcpp::Any CypherMainVisitor::visitCypherQuery(MemgraphCypher::CypherQueryContext *ctx) { + auto *cypher_query = storage_->Create<CypherQuery>(); + MG_ASSERT(ctx->singleQuery(), "Expected single query."); + cypher_query->single_query_ = ctx->singleQuery()->accept(this).as<SingleQuery *>(); + + // Check that union and union all dont mix + bool has_union = false; + bool has_union_all = false; + for (auto *child : ctx->cypherUnion()) { + if (child->ALL()) { + has_union_all = true; + } else { + has_union = true; + } + if (has_union && has_union_all) { + throw SemanticException("Invalid combination of UNION and UNION ALL."); + } + cypher_query->cypher_unions_.push_back(child->accept(this).as<CypherUnion *>()); + } + + if (auto *memory_limit_ctx = ctx->queryMemoryLimit()) { + const auto memory_limit_info = VisitMemoryLimit(memory_limit_ctx->memoryLimit(), this); + if (memory_limit_info) { + cypher_query->memory_limit_ = memory_limit_info->first; + cypher_query->memory_scale_ = memory_limit_info->second; + } + } + + query_ = cypher_query; + return cypher_query; +} + +antlrcpp::Any CypherMainVisitor::visitIndexQuery(MemgraphCypher::IndexQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 1, "IndexQuery should have exactly one child!"); + auto *index_query = ctx->children[0]->accept(this).as<IndexQuery *>(); + query_ = index_query; + return index_query; +} + +antlrcpp::Any CypherMainVisitor::visitCreateIndex(MemgraphCypher::CreateIndexContext *ctx) { + auto *index_query = storage_->Create<IndexQuery>(); + index_query->action_ = IndexQuery::Action::CREATE; + index_query->label_ = AddLabel(ctx->labelName()->accept(this)); + if (ctx->propertyKeyName()) { + PropertyIx name_key = ctx->propertyKeyName()->accept(this); + index_query->properties_ = {name_key}; + } + return index_query; +} + +antlrcpp::Any CypherMainVisitor::visitDropIndex(MemgraphCypher::DropIndexContext *ctx) { + auto *index_query = storage_->Create<IndexQuery>(); + index_query->action_ = IndexQuery::Action::DROP; + if (ctx->propertyKeyName()) { + PropertyIx key = ctx->propertyKeyName()->accept(this); + index_query->properties_ = {key}; + } + index_query->label_ = AddLabel(ctx->labelName()->accept(this)); + return index_query; +} + +antlrcpp::Any CypherMainVisitor::visitAuthQuery(MemgraphCypher::AuthQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 1, "AuthQuery should have exactly one child!"); + auto *auth_query = ctx->children[0]->accept(this).as<AuthQuery *>(); + query_ = auth_query; + return auth_query; +} + +antlrcpp::Any CypherMainVisitor::visitDumpQuery(MemgraphCypher::DumpQueryContext *ctx) { + auto *dump_query = storage_->Create<DumpQuery>(); + query_ = dump_query; + return dump_query; +} + +antlrcpp::Any CypherMainVisitor::visitReplicationQuery(MemgraphCypher::ReplicationQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 1, "ReplicationQuery should have exactly one child!"); + auto *replication_query = ctx->children[0]->accept(this).as<ReplicationQuery *>(); + query_ = replication_query; + return replication_query; +} + +antlrcpp::Any CypherMainVisitor::visitSetReplicationRole(MemgraphCypher::SetReplicationRoleContext *ctx) { + auto *replication_query = storage_->Create<ReplicationQuery>(); + replication_query->action_ = ReplicationQuery::Action::SET_REPLICATION_ROLE; + if (ctx->MAIN()) { + if (ctx->WITH() || ctx->PORT()) { + throw SemanticException("Main can't set a port!"); + } + replication_query->role_ = ReplicationQuery::ReplicationRole::MAIN; + } else if (ctx->REPLICA()) { + replication_query->role_ = ReplicationQuery::ReplicationRole::REPLICA; + if (ctx->WITH() && ctx->PORT()) { + if (ctx->port->numberLiteral() && ctx->port->numberLiteral()->integerLiteral()) { + replication_query->port_ = ctx->port->accept(this); + } else { + throw SyntaxException("Port must be an integer literal!"); + } + } + } + return replication_query; +} +antlrcpp::Any CypherMainVisitor::visitShowReplicationRole(MemgraphCypher::ShowReplicationRoleContext *ctx) { + auto *replication_query = storage_->Create<ReplicationQuery>(); + replication_query->action_ = ReplicationQuery::Action::SHOW_REPLICATION_ROLE; + return replication_query; +} + +antlrcpp::Any CypherMainVisitor::visitRegisterReplica(MemgraphCypher::RegisterReplicaContext *ctx) { + auto *replication_query = storage_->Create<ReplicationQuery>(); + replication_query->action_ = ReplicationQuery::Action::REGISTER_REPLICA; + replication_query->replica_name_ = ctx->replicaName()->symbolicName()->accept(this).as<std::string>(); + if (ctx->SYNC()) { + replication_query->sync_mode_ = memgraph::query::v2::ReplicationQuery::SyncMode::SYNC; + if (ctx->WITH() && ctx->TIMEOUT()) { + if (ctx->timeout->numberLiteral()) { + // we accept both double and integer literals + replication_query->timeout_ = ctx->timeout->accept(this); + } else { + throw SemanticException("Timeout should be a integer or double literal!"); + } + } + } else if (ctx->ASYNC()) { + if (ctx->WITH() && ctx->TIMEOUT()) { + throw SyntaxException("Timeout can be set only for the SYNC replication mode!"); + } + replication_query->sync_mode_ = memgraph::query::v2::ReplicationQuery::SyncMode::ASYNC; + } + + if (!ctx->socketAddress()->literal()->StringLiteral()) { + throw SemanticException("Socket address should be a string literal!"); + } else { + replication_query->socket_address_ = ctx->socketAddress()->accept(this); + } + + return replication_query; +} + +antlrcpp::Any CypherMainVisitor::visitDropReplica(MemgraphCypher::DropReplicaContext *ctx) { + auto *replication_query = storage_->Create<ReplicationQuery>(); + replication_query->action_ = ReplicationQuery::Action::DROP_REPLICA; + replication_query->replica_name_ = ctx->replicaName()->symbolicName()->accept(this).as<std::string>(); + return replication_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowReplicas(MemgraphCypher::ShowReplicasContext *ctx) { + auto *replication_query = storage_->Create<ReplicationQuery>(); + replication_query->action_ = ReplicationQuery::Action::SHOW_REPLICAS; + return replication_query; +} + +antlrcpp::Any CypherMainVisitor::visitLockPathQuery(MemgraphCypher::LockPathQueryContext *ctx) { + auto *lock_query = storage_->Create<LockPathQuery>(); + if (ctx->LOCK()) { + lock_query->action_ = LockPathQuery::Action::LOCK_PATH; + } else if (ctx->UNLOCK()) { + lock_query->action_ = LockPathQuery::Action::UNLOCK_PATH; + } else { + throw SyntaxException("Expected LOCK or UNLOCK"); + } + + query_ = lock_query; + return lock_query; +} + +antlrcpp::Any CypherMainVisitor::visitLoadCsv(MemgraphCypher::LoadCsvContext *ctx) { + query_info_.has_load_csv = true; + + auto *load_csv = storage_->Create<LoadCsv>(); + // handle file name + if (ctx->csvFile()->literal()->StringLiteral()) { + load_csv->file_ = ctx->csvFile()->accept(this); + } else { + throw SemanticException("CSV file path should be a string literal"); + } + + // handle header options + // Don't have to check for ctx->HEADER(), as it's a mandatory token. + // Just need to check if ctx->WITH() is not nullptr - otherwise, we have a + // ctx->NO() and ctx->HEADER() present. + load_csv->with_header_ = ctx->WITH() != nullptr; + + // handle skip bad row option + load_csv->ignore_bad_ = ctx->IGNORE() && ctx->BAD(); + + // handle delimiter + if (ctx->DELIMITER()) { + if (ctx->delimiter()->literal()->StringLiteral()) { + load_csv->delimiter_ = ctx->delimiter()->accept(this); + } else { + throw SemanticException("Delimiter should be a string literal"); + } + } + + // handle quote + if (ctx->QUOTE()) { + if (ctx->quote()->literal()->StringLiteral()) { + load_csv->quote_ = ctx->quote()->accept(this); + } else { + throw SemanticException("Quote should be a string literal"); + } + } + + // handle row variable + load_csv->row_var_ = storage_->Create<Identifier>(ctx->rowVar()->variable()->accept(this).as<std::string>()); + + return load_csv; +} + +antlrcpp::Any CypherMainVisitor::visitFreeMemoryQuery(MemgraphCypher::FreeMemoryQueryContext *ctx) { + auto *free_memory_query = storage_->Create<FreeMemoryQuery>(); + query_ = free_memory_query; + return free_memory_query; +} + +antlrcpp::Any CypherMainVisitor::visitTriggerQuery(MemgraphCypher::TriggerQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 1, "TriggerQuery should have exactly one child!"); + auto *trigger_query = ctx->children[0]->accept(this).as<TriggerQuery *>(); + query_ = trigger_query; + return trigger_query; +} + +antlrcpp::Any CypherMainVisitor::visitCreateTrigger(MemgraphCypher::CreateTriggerContext *ctx) { + auto *trigger_query = storage_->Create<TriggerQuery>(); + trigger_query->action_ = TriggerQuery::Action::CREATE_TRIGGER; + trigger_query->trigger_name_ = ctx->triggerName()->symbolicName()->accept(this).as<std::string>(); + + auto *statement = ctx->triggerStatement(); + antlr4::misc::Interval interval{statement->start->getStartIndex(), statement->stop->getStopIndex()}; + trigger_query->statement_ = ctx->start->getInputStream()->getText(interval); + + trigger_query->event_type_ = [ctx] { + if (!ctx->ON()) { + return TriggerQuery::EventType::ANY; + } + + if (ctx->CREATE(1)) { + if (ctx->emptyVertex()) { + return TriggerQuery::EventType::VERTEX_CREATE; + } + if (ctx->emptyEdge()) { + return TriggerQuery::EventType::EDGE_CREATE; + } + return TriggerQuery::EventType::CREATE; + } + + if (ctx->DELETE()) { + if (ctx->emptyVertex()) { + return TriggerQuery::EventType::VERTEX_DELETE; + } + if (ctx->emptyEdge()) { + return TriggerQuery::EventType::EDGE_DELETE; + } + return TriggerQuery::EventType::DELETE; + } + + if (ctx->UPDATE()) { + if (ctx->emptyVertex()) { + return TriggerQuery::EventType::VERTEX_UPDATE; + } + if (ctx->emptyEdge()) { + return TriggerQuery::EventType::EDGE_UPDATE; + } + return TriggerQuery::EventType::UPDATE; + } + + LOG_FATAL("Invalid token allowed for the query"); + }(); + + trigger_query->before_commit_ = ctx->BEFORE(); + + return trigger_query; +} + +antlrcpp::Any CypherMainVisitor::visitDropTrigger(MemgraphCypher::DropTriggerContext *ctx) { + auto *trigger_query = storage_->Create<TriggerQuery>(); + trigger_query->action_ = TriggerQuery::Action::DROP_TRIGGER; + trigger_query->trigger_name_ = ctx->triggerName()->symbolicName()->accept(this).as<std::string>(); + return trigger_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowTriggers(MemgraphCypher::ShowTriggersContext *ctx) { + auto *trigger_query = storage_->Create<TriggerQuery>(); + trigger_query->action_ = TriggerQuery::Action::SHOW_TRIGGERS; + return trigger_query; +} + +antlrcpp::Any CypherMainVisitor::visitIsolationLevelQuery(MemgraphCypher::IsolationLevelQueryContext *ctx) { + auto *isolation_level_query = storage_->Create<IsolationLevelQuery>(); + + isolation_level_query->isolation_level_scope_ = [scope = ctx->isolationLevelScope()]() { + if (scope->GLOBAL()) { + return IsolationLevelQuery::IsolationLevelScope::GLOBAL; + } + if (scope->SESSION()) { + return IsolationLevelQuery::IsolationLevelScope::SESSION; + } + return IsolationLevelQuery::IsolationLevelScope::NEXT; + }(); + + isolation_level_query->isolation_level_ = [level = ctx->isolationLevel()]() { + if (level->SNAPSHOT()) { + return IsolationLevelQuery::IsolationLevel::SNAPSHOT_ISOLATION; + } + if (level->COMMITTED()) { + return IsolationLevelQuery::IsolationLevel::READ_COMMITTED; + } + return IsolationLevelQuery::IsolationLevel::READ_UNCOMMITTED; + }(); + + query_ = isolation_level_query; + return isolation_level_query; +} + +antlrcpp::Any CypherMainVisitor::visitCreateSnapshotQuery(MemgraphCypher::CreateSnapshotQueryContext *ctx) { + query_ = storage_->Create<CreateSnapshotQuery>(); + return query_; +} + +antlrcpp::Any CypherMainVisitor::visitStreamQuery(MemgraphCypher::StreamQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 1, "StreamQuery should have exactly one child!"); + auto *stream_query = ctx->children[0]->accept(this).as<StreamQuery *>(); + query_ = stream_query; + return stream_query; +} + +antlrcpp::Any CypherMainVisitor::visitCreateStream(MemgraphCypher::CreateStreamContext *ctx) { + MG_ASSERT(ctx->children.size() == 1, "CreateStreamQuery should have exactly one child!"); + auto *stream_query = ctx->children[0]->accept(this).as<StreamQuery *>(); + query_ = stream_query; + return stream_query; +} + +namespace { +std::vector<std::string> TopicNamesFromSymbols( + antlr4::tree::ParseTreeVisitor &visitor, + const std::vector<MemgraphCypher::SymbolicNameWithDotsAndMinusContext *> &topic_name_symbols) { + MG_ASSERT(!topic_name_symbols.empty()); + std::vector<std::string> topic_names; + topic_names.reserve(topic_name_symbols.size()); + std::transform(topic_name_symbols.begin(), topic_name_symbols.end(), std::back_inserter(topic_names), + [&visitor](auto *topic_name) { return JoinSymbolicNamesWithDotsAndMinus(visitor, *topic_name); }); + return topic_names; +} + +template <typename T> +concept EnumUint8 = std::is_enum_v<T> && std::same_as<uint8_t, std::underlying_type_t<T>>; + +template <bool required, typename... ValueTypes> +void MapConfig(auto &memory, const EnumUint8 auto &enum_key, auto &destination) { + const auto key = static_cast<uint8_t>(enum_key); + if (!memory.contains(key)) { + if constexpr (required) { + throw SemanticException("Config {} is required.", ToString(enum_key)); + } else { + return; + } + } + + std::visit( + [&]<typename T>(T &&value) { + using ValueType = std::decay_t<T>; + if constexpr (utils::SameAsAnyOf<ValueType, ValueTypes...>) { + destination = std::forward<T>(value); + } else { + LOG_FATAL("Invalid type mapped"); + } + }, + std::move(memory[key])); + memory.erase(key); +} + +enum class CommonStreamConfigKey : uint8_t { TRANSFORM, BATCH_INTERVAL, BATCH_SIZE, END }; + +std::string_view ToString(const CommonStreamConfigKey key) { + switch (key) { + case CommonStreamConfigKey::TRANSFORM: + return "TRANSFORM"; + case CommonStreamConfigKey::BATCH_INTERVAL: + return "BATCH_INTERVAL"; + case CommonStreamConfigKey::BATCH_SIZE: + return "BATCH_SIZE"; + case CommonStreamConfigKey::END: + LOG_FATAL("Invalid config key used"); + } +} + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define GENERATE_STREAM_CONFIG_KEY_ENUM(stream, first_config, ...) \ + enum class BOOST_PP_CAT(stream, ConfigKey) : uint8_t { \ + first_config = static_cast<uint8_t>(CommonStreamConfigKey::END), \ + __VA_ARGS__ \ + }; + +GENERATE_STREAM_CONFIG_KEY_ENUM(Kafka, TOPICS, CONSUMER_GROUP, BOOTSTRAP_SERVERS, CONFIGS, CREDENTIALS); + +std::string_view ToString(const KafkaConfigKey key) { + switch (key) { + case KafkaConfigKey::TOPICS: + return "TOPICS"; + case KafkaConfigKey::CONSUMER_GROUP: + return "CONSUMER_GROUP"; + case KafkaConfigKey::BOOTSTRAP_SERVERS: + return "BOOTSTRAP_SERVERS"; + case KafkaConfigKey::CONFIGS: + return "CONFIGS"; + case KafkaConfigKey::CREDENTIALS: + return "CREDENTIALS"; + } +} + +void MapCommonStreamConfigs(auto &memory, StreamQuery &stream_query) { + MapConfig<true, std::string>(memory, CommonStreamConfigKey::TRANSFORM, stream_query.transform_name_); + MapConfig<false, Expression *>(memory, CommonStreamConfigKey::BATCH_INTERVAL, stream_query.batch_interval_); + MapConfig<false, Expression *>(memory, CommonStreamConfigKey::BATCH_SIZE, stream_query.batch_size_); +} +} // namespace + +antlrcpp::Any CypherMainVisitor::visitConfigKeyValuePair(MemgraphCypher::ConfigKeyValuePairContext *ctx) { + MG_ASSERT(ctx->literal().size() == 2); + return std::pair{ctx->literal(0)->accept(this).as<Expression *>(), ctx->literal(1)->accept(this).as<Expression *>()}; +} + +antlrcpp::Any CypherMainVisitor::visitConfigMap(MemgraphCypher::ConfigMapContext *ctx) { + std::unordered_map<Expression *, Expression *> map; + for (auto *key_value_pair : ctx->configKeyValuePair()) { + // If the queries are cached, then only the stripped query is parsed, so the actual keys cannot be determined + // here. That means duplicates cannot be checked. + map.insert(key_value_pair->accept(this).as<std::pair<Expression *, Expression *>>()); + } + return map; +} + +antlrcpp::Any CypherMainVisitor::visitKafkaCreateStream(MemgraphCypher::KafkaCreateStreamContext *ctx) { + auto *stream_query = storage_->Create<StreamQuery>(); + stream_query->action_ = StreamQuery::Action::CREATE_STREAM; + stream_query->type_ = StreamQuery::Type::KAFKA; + stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as<std::string>(); + + for (auto *create_config_ctx : ctx->kafkaCreateStreamConfig()) { + create_config_ctx->accept(this); + } + + MapConfig<true, std::vector<std::string>, Expression *>(memory_, KafkaConfigKey::TOPICS, stream_query->topic_names_); + MapConfig<false, std::string>(memory_, KafkaConfigKey::CONSUMER_GROUP, stream_query->consumer_group_); + MapConfig<false, Expression *>(memory_, KafkaConfigKey::BOOTSTRAP_SERVERS, stream_query->bootstrap_servers_); + MapConfig<false, std::unordered_map<Expression *, Expression *>>(memory_, KafkaConfigKey::CONFIGS, + stream_query->configs_); + MapConfig<false, std::unordered_map<Expression *, Expression *>>(memory_, KafkaConfigKey::CREDENTIALS, + stream_query->credentials_); + + MapCommonStreamConfigs(memory_, *stream_query); + + return stream_query; +} + +namespace { +void ThrowIfExists(const auto &map, const EnumUint8 auto &enum_key) { + const auto key = static_cast<uint8_t>(enum_key); + if (map.contains(key)) { + throw SemanticException("{} defined multiple times in the query", ToString(enum_key)); + } +} + +void GetTopicNames(auto &destination, MemgraphCypher::TopicNamesContext *topic_names_ctx, + antlr4::tree::ParseTreeVisitor &visitor) { + MG_ASSERT(topic_names_ctx != nullptr); + if (auto *symbolic_topic_names_ctx = topic_names_ctx->symbolicTopicNames()) { + destination = TopicNamesFromSymbols(visitor, symbolic_topic_names_ctx->symbolicNameWithDotsAndMinus()); + } else { + if (!topic_names_ctx->literal()->StringLiteral()) { + throw SemanticException("Topic names should be defined as a string literal or as symbolic names"); + } + destination = topic_names_ctx->accept(&visitor).as<Expression *>(); + } +} +} // namespace + +antlrcpp::Any CypherMainVisitor::visitKafkaCreateStreamConfig(MemgraphCypher::KafkaCreateStreamConfigContext *ctx) { + if (ctx->commonCreateStreamConfig()) { + return ctx->commonCreateStreamConfig()->accept(this); + } + + if (ctx->TOPICS()) { + ThrowIfExists(memory_, KafkaConfigKey::TOPICS); + static constexpr auto topics_key = static_cast<uint8_t>(KafkaConfigKey::TOPICS); + GetTopicNames(memory_[topics_key], ctx->topicNames(), *this); + return {}; + } + + if (ctx->CONSUMER_GROUP()) { + ThrowIfExists(memory_, KafkaConfigKey::CONSUMER_GROUP); + static constexpr auto consumer_group_key = static_cast<uint8_t>(KafkaConfigKey::CONSUMER_GROUP); + memory_[consumer_group_key] = JoinSymbolicNamesWithDotsAndMinus(*this, *ctx->consumerGroup); + return {}; + } + + if (ctx->CONFIGS()) { + ThrowIfExists(memory_, KafkaConfigKey::CONFIGS); + static constexpr auto configs_key = static_cast<uint8_t>(KafkaConfigKey::CONFIGS); + memory_.emplace(configs_key, ctx->configsMap->accept(this).as<std::unordered_map<Expression *, Expression *>>()); + return {}; + } + + if (ctx->CREDENTIALS()) { + ThrowIfExists(memory_, KafkaConfigKey::CREDENTIALS); + static constexpr auto credentials_key = static_cast<uint8_t>(KafkaConfigKey::CREDENTIALS); + memory_.emplace(credentials_key, + ctx->credentialsMap->accept(this).as<std::unordered_map<Expression *, Expression *>>()); + return {}; + } + + MG_ASSERT(ctx->BOOTSTRAP_SERVERS()); + ThrowIfExists(memory_, KafkaConfigKey::BOOTSTRAP_SERVERS); + if (!ctx->bootstrapServers->StringLiteral()) { + throw SemanticException("Bootstrap servers should be a string!"); + } + + const auto bootstrap_servers_key = static_cast<uint8_t>(KafkaConfigKey::BOOTSTRAP_SERVERS); + memory_[bootstrap_servers_key] = ctx->bootstrapServers->accept(this).as<Expression *>(); + return {}; +} + +namespace { +GENERATE_STREAM_CONFIG_KEY_ENUM(Pulsar, TOPICS, SERVICE_URL); + +std::string_view ToString(const PulsarConfigKey key) { + switch (key) { + case PulsarConfigKey::TOPICS: + return "TOPICS"; + case PulsarConfigKey::SERVICE_URL: + return "SERVICE_URL"; + } +} +} // namespace + +antlrcpp::Any CypherMainVisitor::visitPulsarCreateStream(MemgraphCypher::PulsarCreateStreamContext *ctx) { + auto *stream_query = storage_->Create<StreamQuery>(); + stream_query->action_ = StreamQuery::Action::CREATE_STREAM; + stream_query->type_ = StreamQuery::Type::PULSAR; + stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as<std::string>(); + + for (auto *create_config_ctx : ctx->pulsarCreateStreamConfig()) { + create_config_ctx->accept(this); + } + + MapConfig<true, std::vector<std::string>, Expression *>(memory_, PulsarConfigKey::TOPICS, stream_query->topic_names_); + MapConfig<false, Expression *>(memory_, PulsarConfigKey::SERVICE_URL, stream_query->service_url_); + + MapCommonStreamConfigs(memory_, *stream_query); + + return stream_query; +} + +antlrcpp::Any CypherMainVisitor::visitPulsarCreateStreamConfig(MemgraphCypher::PulsarCreateStreamConfigContext *ctx) { + if (ctx->commonCreateStreamConfig()) { + return ctx->commonCreateStreamConfig()->accept(this); + } + + if (ctx->TOPICS()) { + ThrowIfExists(memory_, PulsarConfigKey::TOPICS); + const auto topics_key = static_cast<uint8_t>(PulsarConfigKey::TOPICS); + GetTopicNames(memory_[topics_key], ctx->topicNames(), *this); + return {}; + } + + MG_ASSERT(ctx->SERVICE_URL()); + ThrowIfExists(memory_, PulsarConfigKey::SERVICE_URL); + if (!ctx->serviceUrl->StringLiteral()) { + throw SemanticException("Service URL must be a string!"); + } + const auto service_url_key = static_cast<uint8_t>(PulsarConfigKey::SERVICE_URL); + memory_[service_url_key] = ctx->serviceUrl->accept(this).as<Expression *>(); + return {}; +} + +antlrcpp::Any CypherMainVisitor::visitCommonCreateStreamConfig(MemgraphCypher::CommonCreateStreamConfigContext *ctx) { + if (ctx->TRANSFORM()) { + ThrowIfExists(memory_, CommonStreamConfigKey::TRANSFORM); + const auto transform_key = static_cast<uint8_t>(CommonStreamConfigKey::TRANSFORM); + memory_[transform_key] = JoinSymbolicNames(this, ctx->transformationName->symbolicName()); + return {}; + } + + if (ctx->BATCH_INTERVAL()) { + ThrowIfExists(memory_, CommonStreamConfigKey::BATCH_INTERVAL); + if (!ctx->batchInterval->numberLiteral() || !ctx->batchInterval->numberLiteral()->integerLiteral()) { + throw SemanticException("Batch interval must be an integer literal!"); + } + const auto batch_interval_key = static_cast<uint8_t>(CommonStreamConfigKey::BATCH_INTERVAL); + memory_[batch_interval_key] = ctx->batchInterval->accept(this).as<Expression *>(); + return {}; + } + + MG_ASSERT(ctx->BATCH_SIZE()); + ThrowIfExists(memory_, CommonStreamConfigKey::BATCH_SIZE); + if (!ctx->batchSize->numberLiteral() || !ctx->batchSize->numberLiteral()->integerLiteral()) { + throw SemanticException("Batch size must be an integer literal!"); + } + const auto batch_size_key = static_cast<uint8_t>(CommonStreamConfigKey::BATCH_SIZE); + memory_[batch_size_key] = ctx->batchSize->accept(this).as<Expression *>(); + return {}; +} + +antlrcpp::Any CypherMainVisitor::visitDropStream(MemgraphCypher::DropStreamContext *ctx) { + auto *stream_query = storage_->Create<StreamQuery>(); + stream_query->action_ = StreamQuery::Action::DROP_STREAM; + stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as<std::string>(); + return stream_query; +} + +antlrcpp::Any CypherMainVisitor::visitStartStream(MemgraphCypher::StartStreamContext *ctx) { + auto *stream_query = storage_->Create<StreamQuery>(); + stream_query->action_ = StreamQuery::Action::START_STREAM; + + if (ctx->BATCH_LIMIT()) { + if (!ctx->batchLimit->numberLiteral() || !ctx->batchLimit->numberLiteral()->integerLiteral()) { + throw SemanticException("Batch limit should be an integer literal!"); + } + stream_query->batch_limit_ = ctx->batchLimit->accept(this); + } + if (ctx->TIMEOUT()) { + if (!ctx->timeout->numberLiteral() || !ctx->timeout->numberLiteral()->integerLiteral()) { + throw SemanticException("Timeout should be an integer literal!"); + } + if (!ctx->BATCH_LIMIT()) { + throw SemanticException("Parameter TIMEOUT can only be defined if BATCH_LIMIT is defined"); + } + stream_query->timeout_ = ctx->timeout->accept(this); + } + + stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as<std::string>(); + return stream_query; +} + +antlrcpp::Any CypherMainVisitor::visitStartAllStreams(MemgraphCypher::StartAllStreamsContext *ctx) { + auto *stream_query = storage_->Create<StreamQuery>(); + stream_query->action_ = StreamQuery::Action::START_ALL_STREAMS; + return stream_query; +} + +antlrcpp::Any CypherMainVisitor::visitStopStream(MemgraphCypher::StopStreamContext *ctx) { + auto *stream_query = storage_->Create<StreamQuery>(); + stream_query->action_ = StreamQuery::Action::STOP_STREAM; + stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as<std::string>(); + return stream_query; +} + +antlrcpp::Any CypherMainVisitor::visitStopAllStreams(MemgraphCypher::StopAllStreamsContext *ctx) { + auto *stream_query = storage_->Create<StreamQuery>(); + stream_query->action_ = StreamQuery::Action::STOP_ALL_STREAMS; + return stream_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowStreams(MemgraphCypher::ShowStreamsContext *ctx) { + auto *stream_query = storage_->Create<StreamQuery>(); + stream_query->action_ = StreamQuery::Action::SHOW_STREAMS; + return stream_query; +} + +antlrcpp::Any CypherMainVisitor::visitCheckStream(MemgraphCypher::CheckStreamContext *ctx) { + auto *stream_query = storage_->Create<StreamQuery>(); + stream_query->action_ = StreamQuery::Action::CHECK_STREAM; + stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as<std::string>(); + + if (ctx->BATCH_LIMIT()) { + if (!ctx->batchLimit->numberLiteral() || !ctx->batchLimit->numberLiteral()->integerLiteral()) { + throw SemanticException("Batch limit should be an integer literal!"); + } + stream_query->batch_limit_ = ctx->batchLimit->accept(this); + } + if (ctx->TIMEOUT()) { + if (!ctx->timeout->numberLiteral() || !ctx->timeout->numberLiteral()->integerLiteral()) { + throw SemanticException("Timeout should be an integer literal!"); + } + stream_query->timeout_ = ctx->timeout->accept(this); + } + return stream_query; +} + +antlrcpp::Any CypherMainVisitor::visitSettingQuery(MemgraphCypher::SettingQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 1, "SettingQuery should have exactly one child!"); + auto *setting_query = ctx->children[0]->accept(this).as<SettingQuery *>(); + query_ = setting_query; + return setting_query; +} + +antlrcpp::Any CypherMainVisitor::visitSetSetting(MemgraphCypher::SetSettingContext *ctx) { + auto *setting_query = storage_->Create<SettingQuery>(); + setting_query->action_ = SettingQuery::Action::SET_SETTING; + + if (!ctx->settingName()->literal()->StringLiteral()) { + throw SemanticException("Setting name should be a string literal"); + } + + if (!ctx->settingValue()->literal()->StringLiteral()) { + throw SemanticException("Setting value should be a string literal"); + } + + setting_query->setting_name_ = ctx->settingName()->accept(this); + MG_ASSERT(setting_query->setting_name_); + + setting_query->setting_value_ = ctx->settingValue()->accept(this); + MG_ASSERT(setting_query->setting_value_); + return setting_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowSetting(MemgraphCypher::ShowSettingContext *ctx) { + auto *setting_query = storage_->Create<SettingQuery>(); + setting_query->action_ = SettingQuery::Action::SHOW_SETTING; + + if (!ctx->settingName()->literal()->StringLiteral()) { + throw SemanticException("Setting name should be a string literal"); + } + + setting_query->setting_name_ = ctx->settingName()->accept(this); + MG_ASSERT(setting_query->setting_name_); + + return setting_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowSettings(MemgraphCypher::ShowSettingsContext * /*ctx*/) { + auto *setting_query = storage_->Create<SettingQuery>(); + setting_query->action_ = SettingQuery::Action::SHOW_ALL_SETTINGS; + return setting_query; +} + +antlrcpp::Any CypherMainVisitor::visitVersionQuery(MemgraphCypher::VersionQueryContext * /*ctx*/) { + auto *version_query = storage_->Create<VersionQuery>(); + query_ = version_query; + return version_query; +} + +antlrcpp::Any CypherMainVisitor::visitCypherUnion(MemgraphCypher::CypherUnionContext *ctx) { + bool distinct = !ctx->ALL(); + auto *cypher_union = storage_->Create<CypherUnion>(distinct); + DMG_ASSERT(ctx->singleQuery(), "Expected single query."); + cypher_union->single_query_ = ctx->singleQuery()->accept(this).as<SingleQuery *>(); + return cypher_union; +} + +antlrcpp::Any CypherMainVisitor::visitSingleQuery(MemgraphCypher::SingleQueryContext *ctx) { + auto *single_query = storage_->Create<SingleQuery>(); + for (auto *child : ctx->clause()) { + antlrcpp::Any got = child->accept(this); + if (got.is<Clause *>()) { + single_query->clauses_.push_back(got.as<Clause *>()); + } else { + auto child_clauses = got.as<std::vector<Clause *>>(); + single_query->clauses_.insert(single_query->clauses_.end(), child_clauses.begin(), child_clauses.end()); + } + } + + // Check if ordering of clauses makes sense. + // + // TODO: should we forbid multiple consecutive set clauses? That case is + // little bit problematic because multiple barriers are needed. Multiple + // consecutive SET clauses are undefined behaviour in neo4j. + bool has_update = false; + bool has_return = false; + bool has_optional_match = false; + bool has_call_procedure = false; + bool calls_write_procedure = false; + bool has_any_update = false; + bool has_load_csv = false; + + auto check_write_procedure = [&calls_write_procedure](const std::string_view clause) { + if (calls_write_procedure) { + throw SemanticException( + "{} can't be put after calling a writeable procedure, only RETURN clause can be put after.", clause); + } + }; + + for (Clause *clause : single_query->clauses_) { + const auto &clause_type = clause->GetTypeInfo(); + if (const auto *call_procedure = utils::Downcast<CallProcedure>(clause); call_procedure != nullptr) { + if (has_return) { + throw SemanticException("CALL can't be put after RETURN clause."); + } + check_write_procedure("CALL"); + has_call_procedure = true; + if (call_procedure->is_write_) { + calls_write_procedure = true; + has_update = true; + } + } else if (utils::IsSubtype(clause_type, Unwind::kType)) { + check_write_procedure("UNWIND"); + if (has_update || has_return) { + throw SemanticException("UNWIND can't be put after RETURN clause or after an update."); + } + } else if (utils::IsSubtype(clause_type, LoadCsv::kType)) { + if (has_load_csv) { + throw SemanticException("Can't have multiple LOAD CSV clauses in a single query."); + } + check_write_procedure("LOAD CSV"); + if (has_return) { + throw SemanticException("LOAD CSV can't be put after RETURN clause."); + } + has_load_csv = true; + } else if (auto *match = utils::Downcast<Match>(clause)) { + if (has_update || has_return) { + throw SemanticException("MATCH can't be put after RETURN clause or after an update."); + } + if (match->optional_) { + has_optional_match = true; + } else if (has_optional_match) { + throw SemanticException("MATCH can't be put after OPTIONAL MATCH."); + } + check_write_procedure("MATCH"); + } else if (utils::IsSubtype(clause_type, Create::kType) || utils::IsSubtype(clause_type, Delete::kType) || + utils::IsSubtype(clause_type, SetProperty::kType) || + utils::IsSubtype(clause_type, SetProperties::kType) || utils::IsSubtype(clause_type, SetLabels::kType) || + utils::IsSubtype(clause_type, RemoveProperty::kType) || + utils::IsSubtype(clause_type, RemoveLabels::kType) || utils::IsSubtype(clause_type, Merge::kType) || + utils::IsSubtype(clause_type, Foreach::kType)) { + if (has_return) { + throw SemanticException("Update clause can't be used after RETURN."); + } + check_write_procedure("Update clause"); + has_update = true; + has_any_update = true; + } else if (utils::IsSubtype(clause_type, Return::kType)) { + if (has_return) { + throw SemanticException("There can only be one RETURN in a clause."); + } + has_return = true; + } else if (utils::IsSubtype(clause_type, With::kType)) { + if (has_return) { + throw SemanticException("RETURN can't be put before WITH."); + } + check_write_procedure("WITH"); + has_update = has_return = has_optional_match = false; + } else { + DLOG_FATAL("Can't happen"); + } + } + bool is_standalone_call_procedure = has_call_procedure && single_query->clauses_.size() == 1U; + if (!has_update && !has_return && !is_standalone_call_procedure) { + throw SemanticException("Query should either create or update something, or return results!"); + } + + if (has_any_update && calls_write_procedure) { + throw SemanticException("Write procedures cannot be used in queries that contains any update clauses!"); + } + // Construct unique names for anonymous identifiers; + int id = 1; + for (auto **identifier : anonymous_identifiers) { + while (true) { + std::string id_name = kAnonPrefix + std::to_string(id++); + if (users_identifiers.find(id_name) == users_identifiers.end()) { + *identifier = storage_->Create<Identifier>(id_name, false); + break; + } + } + } + return single_query; +} + +antlrcpp::Any CypherMainVisitor::visitClause(MemgraphCypher::ClauseContext *ctx) { + if (ctx->cypherReturn()) { + return static_cast<Clause *>(ctx->cypherReturn()->accept(this).as<Return *>()); + } + if (ctx->cypherMatch()) { + return static_cast<Clause *>(ctx->cypherMatch()->accept(this).as<Match *>()); + } + if (ctx->create()) { + return static_cast<Clause *>(ctx->create()->accept(this).as<Create *>()); + } + if (ctx->cypherDelete()) { + return static_cast<Clause *>(ctx->cypherDelete()->accept(this).as<Delete *>()); + } + if (ctx->set()) { + // Different return type!!! + return ctx->set()->accept(this).as<std::vector<Clause *>>(); + } + if (ctx->remove()) { + // Different return type!!! + return ctx->remove()->accept(this).as<std::vector<Clause *>>(); + } + if (ctx->with()) { + return static_cast<Clause *>(ctx->with()->accept(this).as<With *>()); + } + if (ctx->merge()) { + return static_cast<Clause *>(ctx->merge()->accept(this).as<Merge *>()); + } + if (ctx->unwind()) { + return static_cast<Clause *>(ctx->unwind()->accept(this).as<Unwind *>()); + } + if (ctx->callProcedure()) { + return static_cast<Clause *>(ctx->callProcedure()->accept(this).as<CallProcedure *>()); + } + if (ctx->loadCsv()) { + return static_cast<Clause *>(ctx->loadCsv()->accept(this).as<LoadCsv *>()); + } + if (ctx->foreach ()) { + return static_cast<Clause *>(ctx->foreach ()->accept(this).as<Foreach *>()); + } + // TODO: implement other clauses. + throw utils::NotYetImplemented("clause '{}'", ctx->getText()); + return 0; +} + +antlrcpp::Any CypherMainVisitor::visitCypherMatch(MemgraphCypher::CypherMatchContext *ctx) { + auto *match = storage_->Create<Match>(); + match->optional_ = !!ctx->OPTIONAL(); + if (ctx->where()) { + match->where_ = ctx->where()->accept(this); + } + match->patterns_ = ctx->pattern()->accept(this).as<std::vector<Pattern *>>(); + return match; +} + +antlrcpp::Any CypherMainVisitor::visitCreate(MemgraphCypher::CreateContext *ctx) { + auto *create = storage_->Create<Create>(); + create->patterns_ = ctx->pattern()->accept(this).as<std::vector<Pattern *>>(); + return create; +} + +antlrcpp::Any CypherMainVisitor::visitCallProcedure(MemgraphCypher::CallProcedureContext *ctx) { + // Don't cache queries which call procedures because the + // procedure definition can affect the behaviour of the visitor and + // the execution of the query. + // If a user recompiles and reloads the procedure with different result + // names, because of the cache, old result names will be expected while the + // procedure will return results mapped to new names. + query_info_.is_cacheable = false; + + auto *call_proc = storage_->Create<CallProcedure>(); + MG_ASSERT(!ctx->procedureName()->symbolicName().empty()); + call_proc->procedure_name_ = JoinSymbolicNames(this, ctx->procedureName()->symbolicName()); + call_proc->arguments_.reserve(ctx->expression().size()); + for (auto *expr : ctx->expression()) { + call_proc->arguments_.push_back(expr->accept(this)); + } + + if (auto *memory_limit_ctx = ctx->procedureMemoryLimit()) { + const auto memory_limit_info = VisitMemoryLimit(memory_limit_ctx->memoryLimit(), this); + if (memory_limit_info) { + call_proc->memory_limit_ = memory_limit_info->first; + call_proc->memory_scale_ = memory_limit_info->second; + } + } else { + // Default to 100 MB + call_proc->memory_limit_ = storage_->Create<PrimitiveLiteral>(TypedValue(100)); + call_proc->memory_scale_ = 1024U * 1024U; + } + + const auto &maybe_found = + procedure::FindProcedure(procedure::gModuleRegistry, call_proc->procedure_name_, utils::NewDeleteResource()); + if (!maybe_found) { + throw SemanticException("There is no procedure named '{}'.", call_proc->procedure_name_); + } + call_proc->is_write_ = maybe_found->second->info.is_write; + + auto *yield_ctx = ctx->yieldProcedureResults(); + if (!yield_ctx) { + if (!maybe_found->second->results.empty()) { + throw SemanticException( + "CALL without YIELD may only be used on procedures which do not " + "return any result fields."); + } + // When we return, we will release the lock on modules. This means that + // someone may reload the procedure and change the result signature. But to + // keep the implementation simple, we ignore the case as the rest of the + // code doesn't really care whether we yield or not, so it should not break. + return call_proc; + } + if (yield_ctx->getTokens(MemgraphCypher::ASTERISK).empty()) { + call_proc->result_fields_.reserve(yield_ctx->procedureResult().size()); + call_proc->result_identifiers_.reserve(yield_ctx->procedureResult().size()); + for (auto *result : yield_ctx->procedureResult()) { + MG_ASSERT(result->variable().size() == 1 || result->variable().size() == 2); + call_proc->result_fields_.push_back(result->variable()[0]->accept(this).as<std::string>()); + std::string result_alias; + if (result->variable().size() == 2) { + result_alias = result->variable()[1]->accept(this).as<std::string>(); + } else { + result_alias = result->variable()[0]->accept(this).as<std::string>(); + } + call_proc->result_identifiers_.push_back(storage_->Create<Identifier>(result_alias)); + } + } else { + const auto &maybe_found = + procedure::FindProcedure(procedure::gModuleRegistry, call_proc->procedure_name_, utils::NewDeleteResource()); + if (!maybe_found) { + throw SemanticException("There is no procedure named '{}'.", call_proc->procedure_name_); + } + const auto &[module, proc] = *maybe_found; + call_proc->result_fields_.reserve(proc->results.size()); + call_proc->result_identifiers_.reserve(proc->results.size()); + for (const auto &[result_name, desc] : proc->results) { + bool is_deprecated = desc.second; + if (is_deprecated) continue; + call_proc->result_fields_.emplace_back(result_name); + call_proc->result_identifiers_.push_back(storage_->Create<Identifier>(std::string(result_name))); + } + // When we leave the scope, we will release the lock on modules. This means + // that someone may reload the procedure and change its result signature. We + // are fine with this, because if new result fields were added then we yield + // the subset of those and that will appear to a user as if they used the + // procedure before reload. Any subsequent `CALL ... YIELD *` will fetch the + // new fields as well. In case the result signature has had some result + // fields removed, then the query execution will report an error that we are + // yielding missing fields. The user can then just retry the query. + } + + return call_proc; +} + +/** + * @return std::string + */ +antlrcpp::Any CypherMainVisitor::visitUserOrRoleName(MemgraphCypher::UserOrRoleNameContext *ctx) { + return ctx->symbolicName()->accept(this).as<std::string>(); +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitCreateRole(MemgraphCypher::CreateRoleContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::CREATE_ROLE; + auth->role_ = ctx->role->accept(this).as<std::string>(); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitDropRole(MemgraphCypher::DropRoleContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::DROP_ROLE; + auth->role_ = ctx->role->accept(this).as<std::string>(); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitShowRoles(MemgraphCypher::ShowRolesContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::SHOW_ROLES; + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitCreateUser(MemgraphCypher::CreateUserContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::CREATE_USER; + auth->user_ = ctx->user->accept(this).as<std::string>(); + if (ctx->password) { + if (!ctx->password->StringLiteral() && !ctx->literal()->CYPHERNULL()) { + throw SyntaxException("Password should be a string literal or null."); + } + auth->password_ = ctx->password->accept(this); + } + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitSetPassword(MemgraphCypher::SetPasswordContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::SET_PASSWORD; + auth->user_ = ctx->user->accept(this).as<std::string>(); + if (!ctx->password->StringLiteral() && !ctx->literal()->CYPHERNULL()) { + throw SyntaxException("Password should be a string literal or null."); + } + auth->password_ = ctx->password->accept(this); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitDropUser(MemgraphCypher::DropUserContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::DROP_USER; + auth->user_ = ctx->user->accept(this).as<std::string>(); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitShowUsers(MemgraphCypher::ShowUsersContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::SHOW_USERS; + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitSetRole(MemgraphCypher::SetRoleContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::SET_ROLE; + auth->user_ = ctx->user->accept(this).as<std::string>(); + auth->role_ = ctx->role->accept(this).as<std::string>(); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitClearRole(MemgraphCypher::ClearRoleContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::CLEAR_ROLE; + auth->user_ = ctx->user->accept(this).as<std::string>(); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitGrantPrivilege(MemgraphCypher::GrantPrivilegeContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::GRANT_PRIVILEGE; + auth->user_or_role_ = ctx->userOrRole->accept(this).as<std::string>(); + if (ctx->privilegeList()) { + for (auto *privilege : ctx->privilegeList()->privilege()) { + auth->privileges_.push_back(privilege->accept(this)); + } + } else { + /* grant all privileges */ + auth->privileges_ = kPrivilegesAll; + } + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitDenyPrivilege(MemgraphCypher::DenyPrivilegeContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::DENY_PRIVILEGE; + auth->user_or_role_ = ctx->userOrRole->accept(this).as<std::string>(); + if (ctx->privilegeList()) { + for (auto *privilege : ctx->privilegeList()->privilege()) { + auth->privileges_.push_back(privilege->accept(this)); + } + } else { + /* deny all privileges */ + auth->privileges_ = kPrivilegesAll; + } + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitRevokePrivilege(MemgraphCypher::RevokePrivilegeContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::REVOKE_PRIVILEGE; + auth->user_or_role_ = ctx->userOrRole->accept(this).as<std::string>(); + if (ctx->privilegeList()) { + for (auto *privilege : ctx->privilegeList()->privilege()) { + auth->privileges_.push_back(privilege->accept(this)); + } + } else { + /* revoke all privileges */ + auth->privileges_ = kPrivilegesAll; + } + return auth; +} + +/** + * @return AuthQuery::Privilege + */ +antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext *ctx) { + if (ctx->CREATE()) return AuthQuery::Privilege::CREATE; + if (ctx->DELETE()) return AuthQuery::Privilege::DELETE; + if (ctx->MATCH()) return AuthQuery::Privilege::MATCH; + if (ctx->MERGE()) return AuthQuery::Privilege::MERGE; + if (ctx->SET()) return AuthQuery::Privilege::SET; + if (ctx->REMOVE()) return AuthQuery::Privilege::REMOVE; + if (ctx->INDEX()) return AuthQuery::Privilege::INDEX; + if (ctx->STATS()) return AuthQuery::Privilege::STATS; + if (ctx->AUTH()) return AuthQuery::Privilege::AUTH; + if (ctx->CONSTRAINT()) return AuthQuery::Privilege::CONSTRAINT; + if (ctx->DUMP()) return AuthQuery::Privilege::DUMP; + if (ctx->REPLICATION()) return AuthQuery::Privilege::REPLICATION; + if (ctx->READ_FILE()) return AuthQuery::Privilege::READ_FILE; + if (ctx->FREE_MEMORY()) return AuthQuery::Privilege::FREE_MEMORY; + if (ctx->TRIGGER()) return AuthQuery::Privilege::TRIGGER; + if (ctx->CONFIG()) return AuthQuery::Privilege::CONFIG; + if (ctx->DURABILITY()) return AuthQuery::Privilege::DURABILITY; + if (ctx->STREAM()) return AuthQuery::Privilege::STREAM; + if (ctx->MODULE_READ()) return AuthQuery::Privilege::MODULE_READ; + if (ctx->MODULE_WRITE()) return AuthQuery::Privilege::MODULE_WRITE; + if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET; + LOG_FATAL("Should not get here - unknown privilege!"); +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitShowPrivileges(MemgraphCypher::ShowPrivilegesContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::SHOW_PRIVILEGES; + auth->user_or_role_ = ctx->userOrRole->accept(this).as<std::string>(); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitShowRoleForUser(MemgraphCypher::ShowRoleForUserContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::SHOW_ROLE_FOR_USER; + auth->user_ = ctx->user->accept(this).as<std::string>(); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitShowUsersForRole(MemgraphCypher::ShowUsersForRoleContext *ctx) { + AuthQuery *auth = storage_->Create<AuthQuery>(); + auth->action_ = AuthQuery::Action::SHOW_USERS_FOR_ROLE; + auth->role_ = ctx->role->accept(this).as<std::string>(); + return auth; +} + +antlrcpp::Any CypherMainVisitor::visitCypherReturn(MemgraphCypher::CypherReturnContext *ctx) { + auto *return_clause = storage_->Create<Return>(); + return_clause->body_ = ctx->returnBody()->accept(this); + if (ctx->DISTINCT()) { + return_clause->body_.distinct = true; + } + return return_clause; +} + +antlrcpp::Any CypherMainVisitor::visitReturnBody(MemgraphCypher::ReturnBodyContext *ctx) { + ReturnBody body; + if (ctx->order()) { + body.order_by = ctx->order()->accept(this).as<std::vector<SortItem>>(); + } + if (ctx->skip()) { + body.skip = static_cast<Expression *>(ctx->skip()->accept(this)); + } + if (ctx->limit()) { + body.limit = static_cast<Expression *>(ctx->limit()->accept(this)); + } + std::tie(body.all_identifiers, body.named_expressions) = + ctx->returnItems()->accept(this).as<std::pair<bool, std::vector<NamedExpression *>>>(); + return body; +} + +antlrcpp::Any CypherMainVisitor::visitReturnItems(MemgraphCypher::ReturnItemsContext *ctx) { + std::vector<NamedExpression *> named_expressions; + for (auto *item : ctx->returnItem()) { + named_expressions.push_back(item->accept(this)); + } + return std::pair<bool, std::vector<NamedExpression *>>(ctx->getTokens(MemgraphCypher::ASTERISK).size(), + named_expressions); +} + +antlrcpp::Any CypherMainVisitor::visitReturnItem(MemgraphCypher::ReturnItemContext *ctx) { + auto *named_expr = storage_->Create<NamedExpression>(); + named_expr->expression_ = ctx->expression()->accept(this); + MG_ASSERT(named_expr->expression_); + if (ctx->variable()) { + named_expr->name_ = std::string(ctx->variable()->accept(this).as<std::string>()); + users_identifiers.insert(named_expr->name_); + } else { + if (in_with_ && !utils::IsSubtype(*named_expr->expression_, Identifier::kType)) { + throw SemanticException("Only variables can be non-aliased in WITH."); + } + named_expr->name_ = std::string(ctx->getText()); + named_expr->token_position_ = ctx->expression()->getStart()->getTokenIndex(); + } + return named_expr; +} + +antlrcpp::Any CypherMainVisitor::visitOrder(MemgraphCypher::OrderContext *ctx) { + std::vector<SortItem> order_by; + for (auto *sort_item : ctx->sortItem()) { + order_by.push_back(sort_item->accept(this)); + } + return order_by; +} + +antlrcpp::Any CypherMainVisitor::visitSortItem(MemgraphCypher::SortItemContext *ctx) { + return SortItem{ctx->DESC() || ctx->DESCENDING() ? Ordering::DESC : Ordering::ASC, ctx->expression()->accept(this)}; +} + +antlrcpp::Any CypherMainVisitor::visitNodePattern(MemgraphCypher::NodePatternContext *ctx) { + auto *node = storage_->Create<NodeAtom>(); + if (ctx->variable()) { + std::string variable = ctx->variable()->accept(this); + node->identifier_ = storage_->Create<Identifier>(variable); + users_identifiers.insert(variable); + } else { + anonymous_identifiers.push_back(&node->identifier_); + } + if (ctx->nodeLabels()) { + node->labels_ = ctx->nodeLabels()->accept(this).as<std::vector<LabelIx>>(); + } + if (ctx->properties()) { + // This can return either properties or parameters + if (ctx->properties()->mapLiteral()) { + node->properties_ = ctx->properties()->accept(this).as<std::unordered_map<PropertyIx, Expression *>>(); + } else { + node->properties_ = ctx->properties()->accept(this).as<ParameterLookup *>(); + } + } + return node; +} + +antlrcpp::Any CypherMainVisitor::visitNodeLabels(MemgraphCypher::NodeLabelsContext *ctx) { + std::vector<LabelIx> labels; + for (auto *node_label : ctx->nodeLabel()) { + labels.push_back(AddLabel(node_label->accept(this))); + } + return labels; +} + +antlrcpp::Any CypherMainVisitor::visitProperties(MemgraphCypher::PropertiesContext *ctx) { + if (ctx->mapLiteral()) { + return ctx->mapLiteral()->accept(this); + } + // If child is not mapLiteral that means child is params. + MG_ASSERT(ctx->parameter()); + return ctx->parameter()->accept(this); +} + +antlrcpp::Any CypherMainVisitor::visitMapLiteral(MemgraphCypher::MapLiteralContext *ctx) { + std::unordered_map<PropertyIx, Expression *> map; + for (int i = 0; i < static_cast<int>(ctx->propertyKeyName().size()); ++i) { + PropertyIx key = ctx->propertyKeyName()[i]->accept(this); + Expression *value = ctx->expression()[i]->accept(this); + if (!map.insert({key, value}).second) { + throw SemanticException("Same key can't appear twice in a map literal."); + } + } + return map; +} + +antlrcpp::Any CypherMainVisitor::visitListLiteral(MemgraphCypher::ListLiteralContext *ctx) { + std::vector<Expression *> expressions; + for (auto expr_ctx_ptr : ctx->expression()) expressions.push_back(expr_ctx_ptr->accept(this)); + return expressions; +} + +antlrcpp::Any CypherMainVisitor::visitPropertyKeyName(MemgraphCypher::PropertyKeyNameContext *ctx) { + return AddProperty(visitChildren(ctx)); +} + +antlrcpp::Any CypherMainVisitor::visitSymbolicName(MemgraphCypher::SymbolicNameContext *ctx) { + if (ctx->EscapedSymbolicName()) { + auto quoted_name = ctx->getText(); + DMG_ASSERT(quoted_name.size() >= 2U && quoted_name[0] == '`' && quoted_name.back() == '`', + "Can't happen. Grammar ensures this"); + // Remove enclosing backticks. + std::string escaped_name = quoted_name.substr(1, static_cast<int>(quoted_name.size()) - 2); + // Unescape remaining backticks. + std::string name; + bool escaped = false; + for (auto c : escaped_name) { + if (escaped) { + if (c == '`') { + name.push_back('`'); + escaped = false; + } else { + DLOG_FATAL("Can't happen. Grammar ensures that."); + } + } else if (c == '`') { + escaped = true; + } else { + name.push_back(c); + } + } + return name; + } + if (ctx->UnescapedSymbolicName()) { + return std::string(ctx->getText()); + } + return ctx->getText(); +} + +antlrcpp::Any CypherMainVisitor::visitPattern(MemgraphCypher::PatternContext *ctx) { + std::vector<Pattern *> patterns; + for (auto *pattern_part : ctx->patternPart()) { + patterns.push_back(pattern_part->accept(this)); + } + return patterns; +} + +antlrcpp::Any CypherMainVisitor::visitPatternPart(MemgraphCypher::PatternPartContext *ctx) { + Pattern *pattern = ctx->anonymousPatternPart()->accept(this); + if (ctx->variable()) { + std::string variable = ctx->variable()->accept(this); + pattern->identifier_ = storage_->Create<Identifier>(variable); + users_identifiers.insert(variable); + } else { + anonymous_identifiers.push_back(&pattern->identifier_); + } + return pattern; +} + +antlrcpp::Any CypherMainVisitor::visitPatternElement(MemgraphCypher::PatternElementContext *ctx) { + if (ctx->patternElement()) { + return ctx->patternElement()->accept(this); + } + auto pattern = storage_->Create<Pattern>(); + pattern->atoms_.push_back(ctx->nodePattern()->accept(this).as<NodeAtom *>()); + for (auto *pattern_element_chain : ctx->patternElementChain()) { + std::pair<PatternAtom *, PatternAtom *> element = pattern_element_chain->accept(this); + pattern->atoms_.push_back(element.first); + pattern->atoms_.push_back(element.second); + } + return pattern; +} + +antlrcpp::Any CypherMainVisitor::visitPatternElementChain(MemgraphCypher::PatternElementChainContext *ctx) { + return std::pair<PatternAtom *, PatternAtom *>(ctx->relationshipPattern()->accept(this).as<EdgeAtom *>(), + ctx->nodePattern()->accept(this).as<NodeAtom *>()); +} + +antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(MemgraphCypher::RelationshipPatternContext *ctx) { + auto *edge = storage_->Create<EdgeAtom>(); + + auto relationshipDetail = ctx->relationshipDetail(); + auto *variableExpansion = relationshipDetail ? relationshipDetail->variableExpansion() : nullptr; + edge->type_ = EdgeAtom::Type::SINGLE; + if (variableExpansion) + std::tie(edge->type_, edge->lower_bound_, edge->upper_bound_) = + variableExpansion->accept(this).as<std::tuple<EdgeAtom::Type, Expression *, Expression *>>(); + + if (ctx->leftArrowHead() && !ctx->rightArrowHead()) { + edge->direction_ = EdgeAtom::Direction::IN; + } else if (!ctx->leftArrowHead() && ctx->rightArrowHead()) { + edge->direction_ = EdgeAtom::Direction::OUT; + } else { + // <-[]-> and -[]- is the same thing as far as we understand openCypher + // grammar. + edge->direction_ = EdgeAtom::Direction::BOTH; + } + + if (!relationshipDetail) { + anonymous_identifiers.push_back(&edge->identifier_); + return edge; + } + + if (relationshipDetail->name) { + std::string variable = relationshipDetail->name->accept(this); + edge->identifier_ = storage_->Create<Identifier>(variable); + users_identifiers.insert(variable); + } else { + anonymous_identifiers.push_back(&edge->identifier_); + } + + if (relationshipDetail->relationshipTypes()) { + edge->edge_types_ = ctx->relationshipDetail()->relationshipTypes()->accept(this).as<std::vector<EdgeTypeIx>>(); + } + + auto relationshipLambdas = relationshipDetail->relationshipLambda(); + if (variableExpansion) { + if (relationshipDetail->total_weight && edge->type_ != EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) + throw SemanticException( + "Variable for total weight is allowed only with weighted shortest " + "path expansion."); + auto visit_lambda = [this](auto *lambda) { + EdgeAtom::Lambda edge_lambda; + std::string traversed_edge_variable = lambda->traversed_edge->accept(this); + edge_lambda.inner_edge = storage_->Create<Identifier>(traversed_edge_variable); + std::string traversed_node_variable = lambda->traversed_node->accept(this); + edge_lambda.inner_node = storage_->Create<Identifier>(traversed_node_variable); + edge_lambda.expression = lambda->expression()->accept(this); + return edge_lambda; + }; + auto visit_total_weight = [&]() { + if (relationshipDetail->total_weight) { + std::string total_weight_name = relationshipDetail->total_weight->accept(this); + edge->total_weight_ = storage_->Create<Identifier>(total_weight_name); + } else { + anonymous_identifiers.push_back(&edge->total_weight_); + } + }; + switch (relationshipLambdas.size()) { + case 0: + if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) + throw SemanticException( + "Lambda for calculating weights is mandatory with weighted " + "shortest path expansion."); + // In variable expansion inner variables are mandatory. + anonymous_identifiers.push_back(&edge->filter_lambda_.inner_edge); + anonymous_identifiers.push_back(&edge->filter_lambda_.inner_node); + break; + case 1: + if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) { + // For wShortest, the first (and required) lambda is used for weight + // calculation. + edge->weight_lambda_ = visit_lambda(relationshipLambdas[0]); + visit_total_weight(); + // Add mandatory inner variables for filter lambda. + anonymous_identifiers.push_back(&edge->filter_lambda_.inner_edge); + anonymous_identifiers.push_back(&edge->filter_lambda_.inner_node); + } else { + // Other variable expands only have the filter lambda. + edge->filter_lambda_ = visit_lambda(relationshipLambdas[0]); + } + break; + case 2: + if (edge->type_ != EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) + throw SemanticException("Only one filter lambda can be supplied."); + edge->weight_lambda_ = visit_lambda(relationshipLambdas[0]); + visit_total_weight(); + edge->filter_lambda_ = visit_lambda(relationshipLambdas[1]); + break; + default: + throw SemanticException("Only one filter lambda can be supplied."); + } + } else if (!relationshipLambdas.empty()) { + throw SemanticException("Filter lambda is only allowed in variable length expansion."); + } + + auto properties = relationshipDetail->properties(); + switch (properties.size()) { + case 0: + break; + case 1: { + if (properties[0]->mapLiteral()) { + edge->properties_ = properties[0]->accept(this).as<std::unordered_map<PropertyIx, Expression *>>(); + break; + } + MG_ASSERT(properties[0]->parameter()); + edge->properties_ = properties[0]->accept(this).as<ParameterLookup *>(); + break; + } + default: + throw SemanticException("Only one property map can be supplied for edge."); + } + + return edge; +} + +antlrcpp::Any CypherMainVisitor::visitRelationshipDetail(MemgraphCypher::RelationshipDetailContext *) { + DLOG_FATAL("Should never be called. See documentation in hpp."); + return 0; +} + +antlrcpp::Any CypherMainVisitor::visitRelationshipLambda(MemgraphCypher::RelationshipLambdaContext *) { + DLOG_FATAL("Should never be called. See documentation in hpp."); + return 0; +} + +antlrcpp::Any CypherMainVisitor::visitRelationshipTypes(MemgraphCypher::RelationshipTypesContext *ctx) { + std::vector<EdgeTypeIx> types; + for (auto *edge_type : ctx->relTypeName()) { + types.push_back(AddEdgeType(edge_type->accept(this))); + } + return types; +} + +antlrcpp::Any CypherMainVisitor::visitVariableExpansion(MemgraphCypher::VariableExpansionContext *ctx) { + DMG_ASSERT(ctx->expression().size() <= 2U, "Expected 0, 1 or 2 bounds in range literal."); + + EdgeAtom::Type edge_type = EdgeAtom::Type::DEPTH_FIRST; + if (!ctx->getTokens(MemgraphCypher::BFS).empty()) + edge_type = EdgeAtom::Type::BREADTH_FIRST; + else if (!ctx->getTokens(MemgraphCypher::WSHORTEST).empty()) + edge_type = EdgeAtom::Type::WEIGHTED_SHORTEST_PATH; + Expression *lower = nullptr; + Expression *upper = nullptr; + + if (ctx->expression().size() == 0U) { + // Case -[*]- + } else if (ctx->expression().size() == 1U) { + auto dots_tokens = ctx->getTokens(MemgraphCypher::DOTS); + Expression *bound = ctx->expression()[0]->accept(this); + if (!dots_tokens.size()) { + // Case -[*bound]- + if (edge_type != EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) lower = bound; + upper = bound; + } else if (dots_tokens[0]->getSourceInterval().startsAfter(ctx->expression()[0]->getSourceInterval())) { + // Case -[*bound..]- + lower = bound; + } else { + // Case -[*..bound]- + upper = bound; + } + } else { + // Case -[*lbound..rbound]- + lower = ctx->expression()[0]->accept(this); + upper = ctx->expression()[1]->accept(this); + } + if (lower && edge_type == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) + throw SemanticException("Lower bound is not allowed in weighted shortest path expansion."); + + return std::make_tuple(edge_type, lower, upper); +} + +antlrcpp::Any CypherMainVisitor::visitExpression(MemgraphCypher::ExpressionContext *ctx) { + return static_cast<Expression *>(ctx->expression12()->accept(this)); +} + +// OR. +antlrcpp::Any CypherMainVisitor::visitExpression12(MemgraphCypher::Expression12Context *ctx) { + return LeftAssociativeOperatorExpression(ctx->expression11(), ctx->children, {MemgraphCypher::OR}); +} + +// XOR. +antlrcpp::Any CypherMainVisitor::visitExpression11(MemgraphCypher::Expression11Context *ctx) { + return LeftAssociativeOperatorExpression(ctx->expression10(), ctx->children, {MemgraphCypher::XOR}); +} + +// AND. +antlrcpp::Any CypherMainVisitor::visitExpression10(MemgraphCypher::Expression10Context *ctx) { + return LeftAssociativeOperatorExpression(ctx->expression9(), ctx->children, {MemgraphCypher::AND}); +} + +// NOT. +antlrcpp::Any CypherMainVisitor::visitExpression9(MemgraphCypher::Expression9Context *ctx) { + return PrefixUnaryOperator(ctx->expression8(), ctx->children, {MemgraphCypher::NOT}); +} + +// Comparisons. +// Expresion 1 < 2 < 3 is converted to 1 < 2 && 2 < 3 and then binary operator +// ast node is constructed for each operator. +antlrcpp::Any CypherMainVisitor::visitExpression8(MemgraphCypher::Expression8Context *ctx) { + if (!ctx->partialComparisonExpression().size()) { + // There is no comparison operators. We generate expression7. + return ctx->expression7()->accept(this); + } + + // There is at least one comparison. We need to generate code for each of + // them. We don't call visitPartialComparisonExpression but do everything in + // this function and call expression7-s directly. Since every expression7 + // can be generated twice (because it can appear in two comparisons) code + // generated by whole subtree of expression7 must not have any sideeffects. + // We handle chained comparisons as defined by mathematics, neo4j handles + // them in a very interesting, illogical and incomprehensible way. For + // example in neo4j: + // 1 < 2 < 3 -> true, + // 1 < 2 < 3 < 4 -> false, + // 5 > 3 < 5 > 3 -> true, + // 4 <= 5 < 7 > 6 -> false + // All of those comparisons evaluate to true in memgraph. + std::vector<Expression *> children; + children.push_back(ctx->expression7()->accept(this)); + std::vector<size_t> operators; + auto partial_comparison_expressions = ctx->partialComparisonExpression(); + for (auto *child : partial_comparison_expressions) { + children.push_back(child->expression7()->accept(this)); + } + // First production is comparison operator. + for (auto *child : partial_comparison_expressions) { + operators.push_back(static_cast<antlr4::tree::TerminalNode *>(child->children[0])->getSymbol()->getType()); + } + + // Make all comparisons. + Expression *first_operand = children[0]; + std::vector<Expression *> comparisons; + for (int i = 0; i < (int)operators.size(); ++i) { + auto *expr = children[i + 1]; + // TODO: first_operand should only do lookup if it is only calculated and + // not recalculated whole subexpression once again. SymbolGenerator should + // generate symbol for every expresion and then lookup would be possible. + comparisons.push_back(CreateBinaryOperatorByToken(operators[i], first_operand, expr)); + first_operand = expr; + } + + first_operand = comparisons[0]; + // Calculate logical and of results of comparisons. + for (int i = 1; i < (int)comparisons.size(); ++i) { + first_operand = storage_->Create<AndOperator>(first_operand, comparisons[i]); + } + return first_operand; +} + +antlrcpp::Any CypherMainVisitor::visitPartialComparisonExpression( + MemgraphCypher::PartialComparisonExpressionContext *) { + DLOG_FATAL("Should never be called. See documentation in hpp."); + return 0; +} + +// Addition and subtraction. +antlrcpp::Any CypherMainVisitor::visitExpression7(MemgraphCypher::Expression7Context *ctx) { + return LeftAssociativeOperatorExpression(ctx->expression6(), ctx->children, + {MemgraphCypher::PLUS, MemgraphCypher::MINUS}); +} + +// Multiplication, division, modding. +antlrcpp::Any CypherMainVisitor::visitExpression6(MemgraphCypher::Expression6Context *ctx) { + return LeftAssociativeOperatorExpression(ctx->expression5(), ctx->children, + {MemgraphCypher::ASTERISK, MemgraphCypher::SLASH, MemgraphCypher::PERCENT}); +} + +// Power. +antlrcpp::Any CypherMainVisitor::visitExpression5(MemgraphCypher::Expression5Context *ctx) { + if (ctx->expression4().size() > 1U) { + // TODO: implement power operator. In neo4j power is left associative and + // int^int -> float. + throw utils::NotYetImplemented("power (^) operator"); + } + return visitChildren(ctx); +} + +// Unary minus and plus. +antlrcpp::Any CypherMainVisitor::visitExpression4(MemgraphCypher::Expression4Context *ctx) { + return PrefixUnaryOperator(ctx->expression3a(), ctx->children, {MemgraphCypher::PLUS, MemgraphCypher::MINUS}); +} + +// IS NULL, IS NOT NULL, STARTS WITH, .. +antlrcpp::Any CypherMainVisitor::visitExpression3a(MemgraphCypher::Expression3aContext *ctx) { + Expression *expression = ctx->expression3b()->accept(this); + + for (auto *op : ctx->stringAndNullOperators()) { + if (op->IS() && op->NOT() && op->CYPHERNULL()) { + expression = + static_cast<Expression *>(storage_->Create<NotOperator>(storage_->Create<IsNullOperator>(expression))); + } else if (op->IS() && op->CYPHERNULL()) { + expression = static_cast<Expression *>(storage_->Create<IsNullOperator>(expression)); + } else if (op->IN()) { + expression = + static_cast<Expression *>(storage_->Create<InListOperator>(expression, op->expression3b()->accept(this))); + } else if (utils::StartsWith(op->getText(), "=~")) { + auto *regex_match = storage_->Create<RegexMatch>(); + regex_match->string_expr_ = expression; + regex_match->regex_ = op->expression3b()->accept(this); + expression = regex_match; + } else { + std::string function_name; + if (op->STARTS() && op->WITH()) { + function_name = kStartsWith; + } else if (op->ENDS() && op->WITH()) { + function_name = kEndsWith; + } else if (op->CONTAINS()) { + function_name = kContains; + } else { + throw utils::NotYetImplemented("function '{}'", op->getText()); + } + auto expression2 = op->expression3b()->accept(this); + std::vector<Expression *> args = {expression, expression2}; + expression = static_cast<Expression *>(storage_->Create<Function>(function_name, args)); + } + } + return expression; +} +antlrcpp::Any CypherMainVisitor::visitStringAndNullOperators(MemgraphCypher::StringAndNullOperatorsContext *) { + DLOG_FATAL("Should never be called. See documentation in hpp."); + return 0; +} + +antlrcpp::Any CypherMainVisitor::visitExpression3b(MemgraphCypher::Expression3bContext *ctx) { + Expression *expression = ctx->expression2a()->accept(this); + for (auto *list_op : ctx->listIndexingOrSlicing()) { + if (list_op->getTokens(MemgraphCypher::DOTS).size() == 0U) { + // If there is no '..' then we need to create list indexing operator. + expression = storage_->Create<SubscriptOperator>(expression, list_op->expression()[0]->accept(this)); + } else if (!list_op->lower_bound && !list_op->upper_bound) { + throw SemanticException("List slicing operator requires at least one bound."); + } else { + Expression *lower_bound_ast = + list_op->lower_bound ? static_cast<Expression *>(list_op->lower_bound->accept(this)) : nullptr; + Expression *upper_bound_ast = + list_op->upper_bound ? static_cast<Expression *>(list_op->upper_bound->accept(this)) : nullptr; + expression = storage_->Create<ListSlicingOperator>(expression, lower_bound_ast, upper_bound_ast); + } + } + return expression; +} + +antlrcpp::Any CypherMainVisitor::visitListIndexingOrSlicing(MemgraphCypher::ListIndexingOrSlicingContext *) { + DLOG_FATAL("Should never be called. See documentation in hpp."); + return 0; +} + +antlrcpp::Any CypherMainVisitor::visitExpression2a(MemgraphCypher::Expression2aContext *ctx) { + Expression *expression = ctx->expression2b()->accept(this); + if (ctx->nodeLabels()) { + auto labels = ctx->nodeLabels()->accept(this).as<std::vector<LabelIx>>(); + expression = storage_->Create<LabelsTest>(expression, labels); + } + return expression; +} + +antlrcpp::Any CypherMainVisitor::visitExpression2b(MemgraphCypher::Expression2bContext *ctx) { + Expression *expression = ctx->atom()->accept(this); + for (auto *lookup : ctx->propertyLookup()) { + PropertyIx key = lookup->accept(this); + auto property_lookup = storage_->Create<PropertyLookup>(expression, key); + expression = property_lookup; + } + return expression; +} + +antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) { + if (ctx->literal()) { + return ctx->literal()->accept(this); + } else if (ctx->parameter()) { + return static_cast<Expression *>(ctx->parameter()->accept(this).as<ParameterLookup *>()); + } else if (ctx->parenthesizedExpression()) { + return static_cast<Expression *>(ctx->parenthesizedExpression()->accept(this)); + } else if (ctx->variable()) { + std::string variable = ctx->variable()->accept(this); + users_identifiers.insert(variable); + return static_cast<Expression *>(storage_->Create<Identifier>(variable)); + } else if (ctx->functionInvocation()) { + return static_cast<Expression *>(ctx->functionInvocation()->accept(this)); + } else if (ctx->COALESCE()) { + std::vector<Expression *> exprs; + for (auto *expr_context : ctx->expression()) { + exprs.emplace_back(expr_context->accept(this).as<Expression *>()); + } + return static_cast<Expression *>(storage_->Create<Coalesce>(std::move(exprs))); + } else if (ctx->COUNT()) { + // Here we handle COUNT(*). COUNT(expression) is handled in + // visitFunctionInvocation with other aggregations. This is visible in + // functionInvocation and atom producions in opencypher grammar. + return static_cast<Expression *>(storage_->Create<Aggregation>(nullptr, nullptr, Aggregation::Op::COUNT)); + } else if (ctx->ALL()) { + auto *ident = + storage_->Create<Identifier>(ctx->filterExpression()->idInColl()->variable()->accept(this).as<std::string>()); + Expression *list_expr = ctx->filterExpression()->idInColl()->expression()->accept(this); + if (!ctx->filterExpression()->where()) { + throw SyntaxException("ALL(...) requires a WHERE predicate."); + } + Where *where = ctx->filterExpression()->where()->accept(this); + return static_cast<Expression *>(storage_->Create<All>(ident, list_expr, where)); + } else if (ctx->SINGLE()) { + auto *ident = + storage_->Create<Identifier>(ctx->filterExpression()->idInColl()->variable()->accept(this).as<std::string>()); + Expression *list_expr = ctx->filterExpression()->idInColl()->expression()->accept(this); + if (!ctx->filterExpression()->where()) { + throw SyntaxException("SINGLE(...) requires a WHERE predicate."); + } + Where *where = ctx->filterExpression()->where()->accept(this); + return static_cast<Expression *>(storage_->Create<Single>(ident, list_expr, where)); + } else if (ctx->ANY()) { + auto *ident = + storage_->Create<Identifier>(ctx->filterExpression()->idInColl()->variable()->accept(this).as<std::string>()); + Expression *list_expr = ctx->filterExpression()->idInColl()->expression()->accept(this); + if (!ctx->filterExpression()->where()) { + throw SyntaxException("ANY(...) requires a WHERE predicate."); + } + Where *where = ctx->filterExpression()->where()->accept(this); + return static_cast<Expression *>(storage_->Create<Any>(ident, list_expr, where)); + } else if (ctx->NONE()) { + auto *ident = + storage_->Create<Identifier>(ctx->filterExpression()->idInColl()->variable()->accept(this).as<std::string>()); + Expression *list_expr = ctx->filterExpression()->idInColl()->expression()->accept(this); + if (!ctx->filterExpression()->where()) { + throw SyntaxException("NONE(...) requires a WHERE predicate."); + } + Where *where = ctx->filterExpression()->where()->accept(this); + return static_cast<Expression *>(storage_->Create<None>(ident, list_expr, where)); + } else if (ctx->REDUCE()) { + auto *accumulator = + storage_->Create<Identifier>(ctx->reduceExpression()->accumulator->accept(this).as<std::string>()); + Expression *initializer = ctx->reduceExpression()->initial->accept(this); + auto *ident = + storage_->Create<Identifier>(ctx->reduceExpression()->idInColl()->variable()->accept(this).as<std::string>()); + Expression *list = ctx->reduceExpression()->idInColl()->expression()->accept(this); + Expression *expr = ctx->reduceExpression()->expression().back()->accept(this); + return static_cast<Expression *>(storage_->Create<Reduce>(accumulator, initializer, ident, list, expr)); + } else if (ctx->caseExpression()) { + return static_cast<Expression *>(ctx->caseExpression()->accept(this)); + } else if (ctx->extractExpression()) { + auto *ident = + storage_->Create<Identifier>(ctx->extractExpression()->idInColl()->variable()->accept(this).as<std::string>()); + Expression *list = ctx->extractExpression()->idInColl()->expression()->accept(this); + Expression *expr = ctx->extractExpression()->expression()->accept(this); + return static_cast<Expression *>(storage_->Create<Extract>(ident, list, expr)); + } + // TODO: Implement this. We don't support comprehensions, filtering... at + // the moment. + throw utils::NotYetImplemented("atom expression '{}'", ctx->getText()); +} + +antlrcpp::Any CypherMainVisitor::visitParameter(MemgraphCypher::ParameterContext *ctx) { + return storage_->Create<ParameterLookup>(ctx->getStart()->getTokenIndex()); +} + +antlrcpp::Any CypherMainVisitor::visitLiteral(MemgraphCypher::LiteralContext *ctx) { + if (ctx->CYPHERNULL() || ctx->StringLiteral() || ctx->booleanLiteral() || ctx->numberLiteral()) { + int token_position = ctx->getStart()->getTokenIndex(); + if (ctx->CYPHERNULL()) { + return static_cast<Expression *>(storage_->Create<PrimitiveLiteral>(TypedValue(), token_position)); + } else if (context_.is_query_cached) { + // Instead of generating PrimitiveLiteral, we generate a + // ParameterLookup, so that the AST can be cached. This allows for + // varying literals, which are then looked up in the parameters table + // (even though they are not user provided). Note, that NULL always + // generates a PrimitiveLiteral. + return static_cast<Expression *>(storage_->Create<ParameterLookup>(token_position)); + } else if (ctx->StringLiteral()) { + return static_cast<Expression *>(storage_->Create<PrimitiveLiteral>( + visitStringLiteral(ctx->StringLiteral()->getText()).as<std::string>(), token_position)); + } else if (ctx->booleanLiteral()) { + return static_cast<Expression *>( + storage_->Create<PrimitiveLiteral>(ctx->booleanLiteral()->accept(this).as<bool>(), token_position)); + } else if (ctx->numberLiteral()) { + return static_cast<Expression *>( + storage_->Create<PrimitiveLiteral>(ctx->numberLiteral()->accept(this).as<TypedValue>(), token_position)); + } + LOG_FATAL("Expected to handle all cases above"); + } else if (ctx->listLiteral()) { + return static_cast<Expression *>( + storage_->Create<ListLiteral>(ctx->listLiteral()->accept(this).as<std::vector<Expression *>>())); + } else { + return static_cast<Expression *>(storage_->Create<MapLiteral>( + ctx->mapLiteral()->accept(this).as<std::unordered_map<PropertyIx, Expression *>>())); + } + return visitChildren(ctx); +} + +antlrcpp::Any CypherMainVisitor::visitParenthesizedExpression(MemgraphCypher::ParenthesizedExpressionContext *ctx) { + return static_cast<Expression *>(ctx->expression()->accept(this)); +} + +antlrcpp::Any CypherMainVisitor::visitNumberLiteral(MemgraphCypher::NumberLiteralContext *ctx) { + if (ctx->integerLiteral()) { + return TypedValue(ctx->integerLiteral()->accept(this).as<int64_t>()); + } else if (ctx->doubleLiteral()) { + return TypedValue(ctx->doubleLiteral()->accept(this).as<double>()); + } else { + // This should never happen, except grammar changes and we don't notice + // change in this production. + DLOG_FATAL("can't happen"); + throw std::exception(); + } +} + +antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(MemgraphCypher::FunctionInvocationContext *ctx) { + if (ctx->DISTINCT()) { + throw utils::NotYetImplemented("DISTINCT function call"); + } + std::string function_name = ctx->functionName()->accept(this); + std::vector<Expression *> expressions; + for (auto *expression : ctx->expression()) { + expressions.push_back(expression->accept(this)); + } + if (expressions.size() == 1U) { + if (function_name == Aggregation::kCount) { + return static_cast<Expression *>(storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::COUNT)); + } + if (function_name == Aggregation::kMin) { + return static_cast<Expression *>(storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::MIN)); + } + if (function_name == Aggregation::kMax) { + return static_cast<Expression *>(storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::MAX)); + } + if (function_name == Aggregation::kSum) { + return static_cast<Expression *>(storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::SUM)); + } + if (function_name == Aggregation::kAvg) { + return static_cast<Expression *>(storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::AVG)); + } + if (function_name == Aggregation::kCollect) { + return static_cast<Expression *>( + storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::COLLECT_LIST)); + } + } + + if (expressions.size() == 2U && function_name == Aggregation::kCollect) { + return static_cast<Expression *>( + storage_->Create<Aggregation>(expressions[1], expressions[0], Aggregation::Op::COLLECT_MAP)); + } + + auto is_user_defined_function = [](const std::string &function_name) { + // Dots are present only in user-defined functions, since modules are case-sensitive, so must be user-defined + // functions. Builtin functions should be case insensitive. + return function_name.find('.') != std::string::npos; + }; + + // Don't cache queries which call user-defined functions. User-defined function's return + // types can vary depending on whether the module is reloaded, therefore the cache would + // be invalid. + if (is_user_defined_function(function_name)) { + query_info_.is_cacheable = false; + } + + return static_cast<Expression *>(storage_->Create<Function>(function_name, expressions)); +} + +antlrcpp::Any CypherMainVisitor::visitFunctionName(MemgraphCypher::FunctionNameContext *ctx) { + auto function_name = ctx->getText(); + // Dots are present only in user-defined functions, since modules are case-sensitive, so must be user-defined + // functions. Builtin functions should be case insensitive. + if (function_name.find('.') != std::string::npos) { + return function_name; + } + return utils::ToUpperCase(function_name); +} + +antlrcpp::Any CypherMainVisitor::visitDoubleLiteral(MemgraphCypher::DoubleLiteralContext *ctx) { + return ParseDoubleLiteral(ctx->getText()); +} + +antlrcpp::Any CypherMainVisitor::visitIntegerLiteral(MemgraphCypher::IntegerLiteralContext *ctx) { + return ParseIntegerLiteral(ctx->getText()); +} + +antlrcpp::Any CypherMainVisitor::visitStringLiteral(const std::string &escaped) { return ParseStringLiteral(escaped); } + +antlrcpp::Any CypherMainVisitor::visitBooleanLiteral(MemgraphCypher::BooleanLiteralContext *ctx) { + if (ctx->getTokens(MemgraphCypher::TRUE).size()) { + return true; + } + if (ctx->getTokens(MemgraphCypher::FALSE).size()) { + return false; + } + DLOG_FATAL("Shouldn't happend"); + throw std::exception(); +} + +antlrcpp::Any CypherMainVisitor::visitCypherDelete(MemgraphCypher::CypherDeleteContext *ctx) { + auto *del = storage_->Create<Delete>(); + if (ctx->DETACH()) { + del->detach_ = true; + } + for (auto *expression : ctx->expression()) { + del->expressions_.push_back(expression->accept(this)); + } + return del; +} + +antlrcpp::Any CypherMainVisitor::visitWhere(MemgraphCypher::WhereContext *ctx) { + auto *where = storage_->Create<Where>(); + where->expression_ = ctx->expression()->accept(this); + return where; +} + +antlrcpp::Any CypherMainVisitor::visitSet(MemgraphCypher::SetContext *ctx) { + std::vector<Clause *> set_items; + for (auto *set_item : ctx->setItem()) { + set_items.push_back(set_item->accept(this)); + } + return set_items; +} + +antlrcpp::Any CypherMainVisitor::visitSetItem(MemgraphCypher::SetItemContext *ctx) { + // SetProperty + if (ctx->propertyExpression()) { + auto *set_property = storage_->Create<SetProperty>(); + set_property->property_lookup_ = ctx->propertyExpression()->accept(this); + set_property->expression_ = ctx->expression()->accept(this); + return static_cast<Clause *>(set_property); + } + + // SetProperties either assignment or update + if (ctx->getTokens(MemgraphCypher::EQ).size() || ctx->getTokens(MemgraphCypher::PLUS_EQ).size()) { + auto *set_properties = storage_->Create<SetProperties>(); + set_properties->identifier_ = storage_->Create<Identifier>(ctx->variable()->accept(this).as<std::string>()); + set_properties->expression_ = ctx->expression()->accept(this); + if (ctx->getTokens(MemgraphCypher::PLUS_EQ).size()) { + set_properties->update_ = true; + } + return static_cast<Clause *>(set_properties); + } + + // SetLabels + auto *set_labels = storage_->Create<SetLabels>(); + set_labels->identifier_ = storage_->Create<Identifier>(ctx->variable()->accept(this).as<std::string>()); + set_labels->labels_ = ctx->nodeLabels()->accept(this).as<std::vector<LabelIx>>(); + return static_cast<Clause *>(set_labels); +} + +antlrcpp::Any CypherMainVisitor::visitRemove(MemgraphCypher::RemoveContext *ctx) { + std::vector<Clause *> remove_items; + for (auto *remove_item : ctx->removeItem()) { + remove_items.push_back(remove_item->accept(this)); + } + return remove_items; +} + +antlrcpp::Any CypherMainVisitor::visitRemoveItem(MemgraphCypher::RemoveItemContext *ctx) { + // RemoveProperty + if (ctx->propertyExpression()) { + auto *remove_property = storage_->Create<RemoveProperty>(); + remove_property->property_lookup_ = ctx->propertyExpression()->accept(this); + return static_cast<Clause *>(remove_property); + } + + // RemoveLabels + auto *remove_labels = storage_->Create<RemoveLabels>(); + remove_labels->identifier_ = storage_->Create<Identifier>(ctx->variable()->accept(this).as<std::string>()); + remove_labels->labels_ = ctx->nodeLabels()->accept(this).as<std::vector<LabelIx>>(); + return static_cast<Clause *>(remove_labels); +} + +antlrcpp::Any CypherMainVisitor::visitPropertyExpression(MemgraphCypher::PropertyExpressionContext *ctx) { + Expression *expression = ctx->atom()->accept(this); + for (auto *lookup : ctx->propertyLookup()) { + PropertyIx key = lookup->accept(this); + auto property_lookup = storage_->Create<PropertyLookup>(expression, key); + expression = property_lookup; + } + // It is guaranteed by grammar that there is at least one propertyLookup. + return static_cast<PropertyLookup *>(expression); +} + +antlrcpp::Any CypherMainVisitor::visitCaseExpression(MemgraphCypher::CaseExpressionContext *ctx) { + Expression *test_expression = ctx->test ? ctx->test->accept(this).as<Expression *>() : nullptr; + auto alternatives = ctx->caseAlternatives(); + // Reverse alternatives so that tree of IfOperators can be built bottom-up. + std::reverse(alternatives.begin(), alternatives.end()); + Expression *else_expression = ctx->else_expression ? ctx->else_expression->accept(this).as<Expression *>() + : storage_->Create<PrimitiveLiteral>(TypedValue()); + for (auto *alternative : alternatives) { + Expression *condition = + test_expression ? storage_->Create<EqualOperator>(test_expression, alternative->when_expression->accept(this)) + : alternative->when_expression->accept(this).as<Expression *>(); + Expression *then_expression = alternative->then_expression->accept(this); + else_expression = storage_->Create<IfOperator>(condition, then_expression, else_expression); + } + return else_expression; +} + +antlrcpp::Any CypherMainVisitor::visitCaseAlternatives(MemgraphCypher::CaseAlternativesContext *) { + DLOG_FATAL("Should never be called. See documentation in hpp."); + return 0; +} + +antlrcpp::Any CypherMainVisitor::visitWith(MemgraphCypher::WithContext *ctx) { + auto *with = storage_->Create<With>(); + in_with_ = true; + with->body_ = ctx->returnBody()->accept(this); + in_with_ = false; + if (ctx->DISTINCT()) { + with->body_.distinct = true; + } + if (ctx->where()) { + with->where_ = ctx->where()->accept(this); + } + return with; +} + +antlrcpp::Any CypherMainVisitor::visitMerge(MemgraphCypher::MergeContext *ctx) { + auto *merge = storage_->Create<Merge>(); + merge->pattern_ = ctx->patternPart()->accept(this); + for (auto &merge_action : ctx->mergeAction()) { + auto set = merge_action->set()->accept(this).as<std::vector<Clause *>>(); + if (merge_action->MATCH()) { + merge->on_match_.insert(merge->on_match_.end(), set.begin(), set.end()); + } else { + DMG_ASSERT(merge_action->CREATE(), "Expected ON MATCH or ON CREATE"); + merge->on_create_.insert(merge->on_create_.end(), set.begin(), set.end()); + } + } + return merge; +} + +antlrcpp::Any CypherMainVisitor::visitUnwind(MemgraphCypher::UnwindContext *ctx) { + auto *named_expr = storage_->Create<NamedExpression>(); + named_expr->expression_ = ctx->expression()->accept(this); + named_expr->name_ = std::string(ctx->variable()->accept(this).as<std::string>()); + return storage_->Create<Unwind>(named_expr); +} + +antlrcpp::Any CypherMainVisitor::visitFilterExpression(MemgraphCypher::FilterExpressionContext *) { + LOG_FATAL("Should never be called. See documentation in hpp."); + return 0; +} + +antlrcpp::Any CypherMainVisitor::visitForeach(MemgraphCypher::ForeachContext *ctx) { + auto *for_each = storage_->Create<Foreach>(); + + auto *named_expr = storage_->Create<NamedExpression>(); + named_expr->expression_ = ctx->expression()->accept(this); + named_expr->name_ = std::string(ctx->variable()->accept(this).as<std::string>()); + for_each->named_expression_ = named_expr; + + for (auto *update_clause_ctx : ctx->updateClause()) { + if (auto *set = update_clause_ctx->set(); set) { + auto set_items = visitSet(set).as<std::vector<Clause *>>(); + std::copy(set_items.begin(), set_items.end(), std::back_inserter(for_each->clauses_)); + } else if (auto *remove = update_clause_ctx->remove(); remove) { + auto remove_items = visitRemove(remove).as<std::vector<Clause *>>(); + std::copy(remove_items.begin(), remove_items.end(), std::back_inserter(for_each->clauses_)); + } else if (auto *merge = update_clause_ctx->merge(); merge) { + for_each->clauses_.push_back(visitMerge(merge).as<Merge *>()); + } else if (auto *create = update_clause_ctx->create(); create) { + for_each->clauses_.push_back(visitCreate(create).as<Create *>()); + } else if (auto *cypher_delete = update_clause_ctx->cypherDelete(); cypher_delete) { + for_each->clauses_.push_back(visitCypherDelete(cypher_delete).as<Delete *>()); + } else { + auto *nested_for_each = update_clause_ctx->foreach (); + MG_ASSERT(nested_for_each != nullptr, "Unexpected clause in FOREACH"); + for_each->clauses_.push_back(visitForeach(nested_for_each).as<Foreach *>()); + } + } + + return for_each; +} + +LabelIx CypherMainVisitor::AddLabel(const std::string &name) { return storage_->GetLabelIx(name); } + +PropertyIx CypherMainVisitor::AddProperty(const std::string &name) { return storage_->GetPropertyIx(name); } + +EdgeTypeIx CypherMainVisitor::AddEdgeType(const std::string &name) { return storage_->GetEdgeTypeIx(name); } + +} // namespace memgraph::query::v2::frontend diff --git a/src/query/v2/frontend/ast/cypher_main_visitor.hpp b/src/query/v2/frontend/ast/cypher_main_visitor.hpp new file mode 100644 index 000000000..f0d5ba78b --- /dev/null +++ b/src/query/v2/frontend/ast/cypher_main_visitor.hpp @@ -0,0 +1,886 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <string> +#include <unordered_set> +#include <utility> + +#include <antlr4-runtime.h> + +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/opencypher/generated/MemgraphCypherBaseVisitor.h" +#include "utils/exceptions.hpp" +#include "utils/logging.hpp" + +namespace memgraph::query::v2::frontend { + +using antlropencypher::MemgraphCypher; + +struct ParsingContext { + bool is_query_cached = false; +}; + +class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { + public: + explicit CypherMainVisitor(ParsingContext context, AstStorage *storage) : context_(context), storage_(storage) {} + + private: + Expression *CreateBinaryOperatorByToken(size_t token, Expression *e1, Expression *e2) { + switch (token) { + case MemgraphCypher::OR: + return storage_->Create<OrOperator>(e1, e2); + case MemgraphCypher::XOR: + return storage_->Create<XorOperator>(e1, e2); + case MemgraphCypher::AND: + return storage_->Create<AndOperator>(e1, e2); + case MemgraphCypher::PLUS: + return storage_->Create<AdditionOperator>(e1, e2); + case MemgraphCypher::MINUS: + return storage_->Create<SubtractionOperator>(e1, e2); + case MemgraphCypher::ASTERISK: + return storage_->Create<MultiplicationOperator>(e1, e2); + case MemgraphCypher::SLASH: + return storage_->Create<DivisionOperator>(e1, e2); + case MemgraphCypher::PERCENT: + return storage_->Create<ModOperator>(e1, e2); + case MemgraphCypher::EQ: + return storage_->Create<EqualOperator>(e1, e2); + case MemgraphCypher::NEQ1: + case MemgraphCypher::NEQ2: + return storage_->Create<NotEqualOperator>(e1, e2); + case MemgraphCypher::LT: + return storage_->Create<LessOperator>(e1, e2); + case MemgraphCypher::GT: + return storage_->Create<GreaterOperator>(e1, e2); + case MemgraphCypher::LTE: + return storage_->Create<LessEqualOperator>(e1, e2); + case MemgraphCypher::GTE: + return storage_->Create<GreaterEqualOperator>(e1, e2); + default: + throw utils::NotYetImplemented("binary operator"); + } + } + + Expression *CreateUnaryOperatorByToken(size_t token, Expression *e) { + switch (token) { + case MemgraphCypher::NOT: + return storage_->Create<NotOperator>(e); + case MemgraphCypher::PLUS: + return storage_->Create<UnaryPlusOperator>(e); + case MemgraphCypher::MINUS: + return storage_->Create<UnaryMinusOperator>(e); + default: + throw utils::NotYetImplemented("unary operator"); + } + } + + auto ExtractOperators(std::vector<antlr4::tree::ParseTree *> &all_children, + const std::vector<size_t> &allowed_operators) { + std::vector<size_t> operators; + for (auto *child : all_children) { + antlr4::tree::TerminalNode *operator_node = nullptr; + if ((operator_node = dynamic_cast<antlr4::tree::TerminalNode *>(child))) { + if (std::find(allowed_operators.begin(), allowed_operators.end(), operator_node->getSymbol()->getType()) != + allowed_operators.end()) { + operators.push_back(operator_node->getSymbol()->getType()); + } + } + } + return operators; + } + + /** + * Convert opencypher's n-ary production to ast binary operators. + * + * @param _expressions Subexpressions of child for which we construct ast + * operators, for example expression6 if we want to create ast nodes for + * expression7. + */ + template <typename TExpression> + Expression *LeftAssociativeOperatorExpression(std::vector<TExpression *> _expressions, + std::vector<antlr4::tree::ParseTree *> all_children, + const std::vector<size_t> &allowed_operators) { + DMG_ASSERT(_expressions.size(), "can't happen"); + std::vector<Expression *> expressions; + auto operators = ExtractOperators(all_children, allowed_operators); + + for (auto *expression : _expressions) { + expressions.push_back(expression->accept(this)); + } + + Expression *first_operand = expressions[0]; + for (int i = 1; i < (int)expressions.size(); ++i) { + first_operand = CreateBinaryOperatorByToken(operators[i - 1], first_operand, expressions[i]); + } + return first_operand; + } + + template <typename TExpression> + Expression *PrefixUnaryOperator(TExpression *_expression, std::vector<antlr4::tree::ParseTree *> all_children, + const std::vector<size_t> &allowed_operators) { + DMG_ASSERT(_expression, "can't happen"); + auto operators = ExtractOperators(all_children, allowed_operators); + + Expression *expression = _expression->accept(this); + for (int i = (int)operators.size() - 1; i >= 0; --i) { + expression = CreateUnaryOperatorByToken(operators[i], expression); + } + return expression; + } + + /** + * @return CypherQuery* + */ + antlrcpp::Any visitCypherQuery(MemgraphCypher::CypherQueryContext *ctx) override; + + /** + * @return IndexQuery* + */ + antlrcpp::Any visitIndexQuery(MemgraphCypher::IndexQueryContext *ctx) override; + + /** + * @return ExplainQuery* + */ + antlrcpp::Any visitExplainQuery(MemgraphCypher::ExplainQueryContext *ctx) override; + + /** + * @return ProfileQuery* + */ + antlrcpp::Any visitProfileQuery(MemgraphCypher::ProfileQueryContext *ctx) override; + + /** + * @return InfoQuery* + */ + antlrcpp::Any visitInfoQuery(MemgraphCypher::InfoQueryContext *ctx) override; + + /** + * @return Constraint + */ + antlrcpp::Any visitConstraint(MemgraphCypher::ConstraintContext *ctx) override; + + /** + * @return ConstraintQuery* + */ + antlrcpp::Any visitConstraintQuery(MemgraphCypher::ConstraintQueryContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitAuthQuery(MemgraphCypher::AuthQueryContext *ctx) override; + + /** + * @return DumpQuery* + */ + antlrcpp::Any visitDumpQuery(MemgraphCypher::DumpQueryContext *ctx) override; + + /** + * @return ReplicationQuery* + */ + antlrcpp::Any visitReplicationQuery(MemgraphCypher::ReplicationQueryContext *ctx) override; + + /** + * @return ReplicationQuery* + */ + antlrcpp::Any visitSetReplicationRole(MemgraphCypher::SetReplicationRoleContext *ctx) override; + + /** + * @return ReplicationQuery* + */ + antlrcpp::Any visitShowReplicationRole(MemgraphCypher::ShowReplicationRoleContext *ctx) override; + + /** + * @return ReplicationQuery* + */ + antlrcpp::Any visitRegisterReplica(MemgraphCypher::RegisterReplicaContext *ctx) override; + + /** + * @return ReplicationQuery* + */ + antlrcpp::Any visitDropReplica(MemgraphCypher::DropReplicaContext *ctx) override; + + /** + * @return ReplicationQuery* + */ + antlrcpp::Any visitShowReplicas(MemgraphCypher::ShowReplicasContext *ctx) override; + + /** + * @return LockPathQuery* + */ + antlrcpp::Any visitLockPathQuery(MemgraphCypher::LockPathQueryContext *ctx) override; + + /** + * @return LoadCsvQuery* + */ + antlrcpp::Any visitLoadCsv(MemgraphCypher::LoadCsvContext *ctx) override; + + /** + * @return FreeMemoryQuery* + */ + antlrcpp::Any visitFreeMemoryQuery(MemgraphCypher::FreeMemoryQueryContext *ctx) override; + + /** + * @return TriggerQuery* + */ + antlrcpp::Any visitTriggerQuery(MemgraphCypher::TriggerQueryContext *ctx) override; + + /** + * @return CreateTrigger* + */ + antlrcpp::Any visitCreateTrigger(MemgraphCypher::CreateTriggerContext *ctx) override; + + /** + * @return DropTrigger* + */ + antlrcpp::Any visitDropTrigger(MemgraphCypher::DropTriggerContext *ctx) override; + + /** + * @return ShowTriggers* + */ + antlrcpp::Any visitShowTriggers(MemgraphCypher::ShowTriggersContext *ctx) override; + + /** + * @return IsolationLevelQuery* + */ + antlrcpp::Any visitIsolationLevelQuery(MemgraphCypher::IsolationLevelQueryContext *ctx) override; + + /** + * @return CreateSnapshotQuery* + */ + antlrcpp::Any visitCreateSnapshotQuery(MemgraphCypher::CreateSnapshotQueryContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitStreamQuery(MemgraphCypher::StreamQueryContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitCreateStream(MemgraphCypher::CreateStreamContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitConfigKeyValuePair(MemgraphCypher::ConfigKeyValuePairContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitConfigMap(MemgraphCypher::ConfigMapContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitKafkaCreateStream(MemgraphCypher::KafkaCreateStreamContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitKafkaCreateStreamConfig(MemgraphCypher::KafkaCreateStreamConfigContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitPulsarCreateStreamConfig(MemgraphCypher::PulsarCreateStreamConfigContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitPulsarCreateStream(MemgraphCypher::PulsarCreateStreamContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitCommonCreateStreamConfig(MemgraphCypher::CommonCreateStreamConfigContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitDropStream(MemgraphCypher::DropStreamContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitStartStream(MemgraphCypher::StartStreamContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitStartAllStreams(MemgraphCypher::StartAllStreamsContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitStopStream(MemgraphCypher::StopStreamContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitStopAllStreams(MemgraphCypher::StopAllStreamsContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitShowStreams(MemgraphCypher::ShowStreamsContext *ctx) override; + + /** + * @return StreamQuery* + */ + antlrcpp::Any visitCheckStream(MemgraphCypher::CheckStreamContext *ctx) override; + + /** + * @return SettingQuery* + */ + antlrcpp::Any visitSettingQuery(MemgraphCypher::SettingQueryContext *ctx) override; + + /** + * @return SetSetting* + */ + antlrcpp::Any visitSetSetting(MemgraphCypher::SetSettingContext *ctx) override; + + /** + * @return ShowSetting* + */ + antlrcpp::Any visitShowSetting(MemgraphCypher::ShowSettingContext *ctx) override; + + /** + * @return ShowSettings* + */ + antlrcpp::Any visitShowSettings(MemgraphCypher::ShowSettingsContext *ctx) override; + + /** + * @return VersionQuery* + */ + antlrcpp::Any visitVersionQuery(MemgraphCypher::VersionQueryContext *ctx) override; + + /** + * @return CypherUnion* + */ + antlrcpp::Any visitCypherUnion(MemgraphCypher::CypherUnionContext *ctx) override; + + /** + * @return SingleQuery* + */ + antlrcpp::Any visitSingleQuery(MemgraphCypher::SingleQueryContext *ctx) override; + + /** + * @return Clause* or vector<Clause*>!!! + */ + antlrcpp::Any visitClause(MemgraphCypher::ClauseContext *ctx) override; + + /** + * @return Match* + */ + antlrcpp::Any visitCypherMatch(MemgraphCypher::CypherMatchContext *ctx) override; + + /** + * @return Create* + */ + antlrcpp::Any visitCreate(MemgraphCypher::CreateContext *ctx) override; + + /** + * @return CallProcedure* + */ + antlrcpp::Any visitCallProcedure(MemgraphCypher::CallProcedureContext *ctx) override; + + /** + * @return std::string + */ + antlrcpp::Any visitUserOrRoleName(MemgraphCypher::UserOrRoleNameContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitCreateRole(MemgraphCypher::CreateRoleContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitDropRole(MemgraphCypher::DropRoleContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitShowRoles(MemgraphCypher::ShowRolesContext *ctx) override; + + /** + * @return IndexQuery* + */ + antlrcpp::Any visitCreateIndex(MemgraphCypher::CreateIndexContext *ctx) override; + + /** + * @return DropIndex* + */ + antlrcpp::Any visitDropIndex(MemgraphCypher::DropIndexContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitCreateUser(MemgraphCypher::CreateUserContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitSetPassword(MemgraphCypher::SetPasswordContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitDropUser(MemgraphCypher::DropUserContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitShowUsers(MemgraphCypher::ShowUsersContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitSetRole(MemgraphCypher::SetRoleContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitClearRole(MemgraphCypher::ClearRoleContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitGrantPrivilege(MemgraphCypher::GrantPrivilegeContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitDenyPrivilege(MemgraphCypher::DenyPrivilegeContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitRevokePrivilege(MemgraphCypher::RevokePrivilegeContext *ctx) override; + + /** + * @return AuthQuery::Privilege + */ + antlrcpp::Any visitPrivilege(MemgraphCypher::PrivilegeContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitShowPrivileges(MemgraphCypher::ShowPrivilegesContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitShowRoleForUser(MemgraphCypher::ShowRoleForUserContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitShowUsersForRole(MemgraphCypher::ShowUsersForRoleContext *ctx) override; + + /** + * @return Return* + */ + antlrcpp::Any visitCypherReturn(MemgraphCypher::CypherReturnContext *ctx) override; + + /** + * @return Return* + */ + antlrcpp::Any visitReturnBody(MemgraphCypher::ReturnBodyContext *ctx) override; + + /** + * @return pair<bool, vector<NamedExpression*>> first member is true if + * asterisk was found in return + * expressions. + */ + antlrcpp::Any visitReturnItems(MemgraphCypher::ReturnItemsContext *ctx) override; + + /** + * @return vector<NamedExpression*> + */ + antlrcpp::Any visitReturnItem(MemgraphCypher::ReturnItemContext *ctx) override; + + /** + * @return vector<SortItem> + */ + antlrcpp::Any visitOrder(MemgraphCypher::OrderContext *ctx) override; + + /** + * @return SortItem + */ + antlrcpp::Any visitSortItem(MemgraphCypher::SortItemContext *ctx) override; + + /** + * @return NodeAtom* + */ + antlrcpp::Any visitNodePattern(MemgraphCypher::NodePatternContext *ctx) override; + + /** + * @return vector<LabelIx> + */ + antlrcpp::Any visitNodeLabels(MemgraphCypher::NodeLabelsContext *ctx) override; + + /** + * @return unordered_map<PropertyIx, Expression*> + */ + antlrcpp::Any visitProperties(MemgraphCypher::PropertiesContext *ctx) override; + + /** + * @return map<std::string, Expression*> + */ + antlrcpp::Any visitMapLiteral(MemgraphCypher::MapLiteralContext *ctx) override; + + /** + * @return vector<Expression*> + */ + antlrcpp::Any visitListLiteral(MemgraphCypher::ListLiteralContext *ctx) override; + + /** + * @return PropertyIx + */ + antlrcpp::Any visitPropertyKeyName(MemgraphCypher::PropertyKeyNameContext *ctx) override; + + /** + * @return string + */ + antlrcpp::Any visitSymbolicName(MemgraphCypher::SymbolicNameContext *ctx) override; + + /** + * @return vector<Pattern*> + */ + antlrcpp::Any visitPattern(MemgraphCypher::PatternContext *ctx) override; + + /** + * @return Pattern* + */ + antlrcpp::Any visitPatternPart(MemgraphCypher::PatternPartContext *ctx) override; + + /** + * @return Pattern* + */ + antlrcpp::Any visitPatternElement(MemgraphCypher::PatternElementContext *ctx) override; + + /** + * @return vector<pair<EdgeAtom*, NodeAtom*>> + */ + antlrcpp::Any visitPatternElementChain(MemgraphCypher::PatternElementChainContext *ctx) override; + + /** + *@return EdgeAtom* + */ + antlrcpp::Any visitRelationshipPattern(MemgraphCypher::RelationshipPatternContext *ctx) override; + + /** + * This should never be called. Everything is done directly in + * visitRelationshipPattern. + */ + antlrcpp::Any visitRelationshipDetail(MemgraphCypher::RelationshipDetailContext *ctx) override; + + /** + * This should never be called. Everything is done directly in + * visitRelationshipPattern. + */ + antlrcpp::Any visitRelationshipLambda(MemgraphCypher::RelationshipLambdaContext *ctx) override; + + /** + * @return vector<EdgeTypeIx> + */ + antlrcpp::Any visitRelationshipTypes(MemgraphCypher::RelationshipTypesContext *ctx) override; + + /** + * @return std::tuple<EdgeAtom::Type, int64_t, int64_t>. + */ + antlrcpp::Any visitVariableExpansion(MemgraphCypher::VariableExpansionContext *ctx) override; + + /** + * Top level expression, does nothing. + * + * @return Expression* + */ + antlrcpp::Any visitExpression(MemgraphCypher::ExpressionContext *ctx) override; + + /** + * OR. + * + * @return Expression* + */ + antlrcpp::Any visitExpression12(MemgraphCypher::Expression12Context *ctx) override; + + /** + * XOR. + * + * @return Expression* + */ + antlrcpp::Any visitExpression11(MemgraphCypher::Expression11Context *ctx) override; + + /** + * AND. + * + * @return Expression* + */ + antlrcpp::Any visitExpression10(MemgraphCypher::Expression10Context *ctx) override; + + /** + * NOT. + * + * @return Expression* + */ + antlrcpp::Any visitExpression9(MemgraphCypher::Expression9Context *ctx) override; + + /** + * Comparisons. + * + * @return Expression* + */ + antlrcpp::Any visitExpression8(MemgraphCypher::Expression8Context *ctx) override; + + /** + * Never call this. Everything related to generating code for comparison + * operators should be done in visitExpression8. + */ + antlrcpp::Any visitPartialComparisonExpression(MemgraphCypher::PartialComparisonExpressionContext *ctx) override; + + /** + * Addition and subtraction. + * + * @return Expression* + */ + antlrcpp::Any visitExpression7(MemgraphCypher::Expression7Context *ctx) override; + + /** + * Multiplication, division, modding. + * + * @return Expression* + */ + antlrcpp::Any visitExpression6(MemgraphCypher::Expression6Context *ctx) override; + + /** + * Power. + * + * @return Expression* + */ + antlrcpp::Any visitExpression5(MemgraphCypher::Expression5Context *ctx) override; + + /** + * Unary minus and plus. + * + * @return Expression* + */ + antlrcpp::Any visitExpression4(MemgraphCypher::Expression4Context *ctx) override; + + /** + * IS NULL, IS NOT NULL, STARTS WITH, END WITH, =~, ... + * + * @return Expression* + */ + antlrcpp::Any visitExpression3a(MemgraphCypher::Expression3aContext *ctx) override; + + /** + * Does nothing, everything is done in visitExpression3a. + * + * @return Expression* + */ + antlrcpp::Any visitStringAndNullOperators(MemgraphCypher::StringAndNullOperatorsContext *ctx) override; + + /** + * List indexing and slicing. + * + * @return Expression* + */ + antlrcpp::Any visitExpression3b(MemgraphCypher::Expression3bContext *ctx) override; + + /** + * Does nothing, everything is done in visitExpression3b. + */ + antlrcpp::Any visitListIndexingOrSlicing(MemgraphCypher::ListIndexingOrSlicingContext *ctx) override; + + /** + * Node labels test. + * + * @return Expression* + */ + antlrcpp::Any visitExpression2a(MemgraphCypher::Expression2aContext *ctx) override; + + /** + * Property lookup. + * + * @return Expression* + */ + antlrcpp::Any visitExpression2b(MemgraphCypher::Expression2bContext *ctx) override; + + /** + * Literals, params, list comprehension... + * + * @return Expression* + */ + antlrcpp::Any visitAtom(MemgraphCypher::AtomContext *ctx) override; + + /** + * @return ParameterLookup* + */ + antlrcpp::Any visitParameter(MemgraphCypher::ParameterContext *ctx) override; + + /** + * @return Expression* + */ + antlrcpp::Any visitParenthesizedExpression(MemgraphCypher::ParenthesizedExpressionContext *ctx) override; + + /** + * @return Expression* + */ + antlrcpp::Any visitFunctionInvocation(MemgraphCypher::FunctionInvocationContext *ctx) override; + + /** + * @return string - uppercased + */ + antlrcpp::Any visitFunctionName(MemgraphCypher::FunctionNameContext *ctx) override; + + /** + * @return Expression* + */ + antlrcpp::Any visitLiteral(MemgraphCypher::LiteralContext *ctx) override; + + /** + * Convert escaped string from a query to unescaped utf8 string. + * + * @return string + */ + antlrcpp::Any visitStringLiteral(const std::string &escaped); + + /** + * @return bool + */ + antlrcpp::Any visitBooleanLiteral(MemgraphCypher::BooleanLiteralContext *ctx) override; + + /** + * @return TypedValue with either double or int + */ + antlrcpp::Any visitNumberLiteral(MemgraphCypher::NumberLiteralContext *ctx) override; + + /** + * @return int64_t + */ + antlrcpp::Any visitIntegerLiteral(MemgraphCypher::IntegerLiteralContext *ctx) override; + + /** + * @return double + */ + antlrcpp::Any visitDoubleLiteral(MemgraphCypher::DoubleLiteralContext *ctx) override; + + /** + * @return Delete* + */ + antlrcpp::Any visitCypherDelete(MemgraphCypher::CypherDeleteContext *ctx) override; + + /** + * @return Where* + */ + antlrcpp::Any visitWhere(MemgraphCypher::WhereContext *ctx) override; + + /** + * return vector<Clause*> + */ + antlrcpp::Any visitSet(MemgraphCypher::SetContext *ctx) override; + + /** + * @return Clause* + */ + antlrcpp::Any visitSetItem(MemgraphCypher::SetItemContext *ctx) override; + + /** + * return vector<Clause*> + */ + antlrcpp::Any visitRemove(MemgraphCypher::RemoveContext *ctx) override; + + /** + * @return Clause* + */ + antlrcpp::Any visitRemoveItem(MemgraphCypher::RemoveItemContext *ctx) override; + + /** + * @return PropertyLookup* + */ + antlrcpp::Any visitPropertyExpression(MemgraphCypher::PropertyExpressionContext *ctx) override; + + /** + * @return IfOperator* + */ + antlrcpp::Any visitCaseExpression(MemgraphCypher::CaseExpressionContext *ctx) override; + + /** + * Never call this. Ast generation for this production is done in + * @c visitCaseExpression. + */ + antlrcpp::Any visitCaseAlternatives(MemgraphCypher::CaseAlternativesContext *ctx) override; + + /** + * @return With* + */ + antlrcpp::Any visitWith(MemgraphCypher::WithContext *ctx) override; + + /** + * @return Merge* + */ + antlrcpp::Any visitMerge(MemgraphCypher::MergeContext *ctx) override; + + /** + * @return Unwind* + */ + antlrcpp::Any visitUnwind(MemgraphCypher::UnwindContext *ctx) override; + + /** + * Never call this. Ast generation for these expressions should be done by + * explicitly visiting the members of @c FilterExpressionContext. + */ + antlrcpp::Any visitFilterExpression(MemgraphCypher::FilterExpressionContext *) override; + + /** + * @return Foreach* + */ + antlrcpp::Any visitForeach(MemgraphCypher::ForeachContext *ctx) override; + + public: + Query *query() { return query_; } + const static std::string kAnonPrefix; + + struct QueryInfo { + bool is_cacheable{true}; + bool has_load_csv{false}; + }; + + const auto &GetQueryInfo() const { return query_info_; } + + private: + LabelIx AddLabel(const std::string &name); + PropertyIx AddProperty(const std::string &name); + EdgeTypeIx AddEdgeType(const std::string &name); + + ParsingContext context_; + AstStorage *storage_; + + std::unordered_map<uint8_t, std::variant<Expression *, std::string, std::vector<std::string>, + std::unordered_map<Expression *, Expression *>>> + memory_; + // Set of identifiers from queries. + std::unordered_set<std::string> users_identifiers; + // Identifiers that user didn't name. + std::vector<Identifier **> anonymous_identifiers; + Query *query_ = nullptr; + // All return items which are not variables must be aliased in with. + // We use this variable in visitReturnItem to check if we are in with or + // return. + bool in_with_ = false; + + QueryInfo query_info_; +}; +} // namespace memgraph::query::v2::frontend diff --git a/src/query/v2/frontend/ast/pretty_print.cpp b/src/query/v2/frontend/ast/pretty_print.cpp new file mode 100644 index 000000000..7aaa6ccb1 --- /dev/null +++ b/src/query/v2/frontend/ast/pretty_print.cpp @@ -0,0 +1,311 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/frontend/ast/pretty_print.hpp" + +#include <type_traits> + +#include "query/v2/frontend/ast/ast.hpp" +#include "utils/algorithm.hpp" +#include "utils/string.hpp" + +namespace memgraph::query::v2 { + +namespace { + +class ExpressionPrettyPrinter : public ExpressionVisitor<void> { + public: + explicit ExpressionPrettyPrinter(std::ostream *out); + + // Unary operators + void Visit(NotOperator &op) override; + void Visit(UnaryPlusOperator &op) override; + void Visit(UnaryMinusOperator &op) override; + void Visit(IsNullOperator &op) override; + + // Binary operators + void Visit(OrOperator &op) override; + void Visit(XorOperator &op) override; + void Visit(AndOperator &op) override; + void Visit(AdditionOperator &op) override; + void Visit(SubtractionOperator &op) override; + void Visit(MultiplicationOperator &op) override; + void Visit(DivisionOperator &op) override; + void Visit(ModOperator &op) override; + void Visit(NotEqualOperator &op) override; + void Visit(EqualOperator &op) override; + void Visit(LessOperator &op) override; + void Visit(GreaterOperator &op) override; + void Visit(LessEqualOperator &op) override; + void Visit(GreaterEqualOperator &op) override; + void Visit(InListOperator &op) override; + void Visit(SubscriptOperator &op) override; + + // Other + void Visit(ListSlicingOperator &op) override; + void Visit(IfOperator &op) override; + void Visit(ListLiteral &op) override; + void Visit(MapLiteral &op) override; + void Visit(LabelsTest &op) override; + void Visit(Aggregation &op) override; + void Visit(Function &op) override; + void Visit(Reduce &op) override; + void Visit(Coalesce &op) override; + void Visit(Extract &op) override; + void Visit(All &op) override; + void Visit(Single &op) override; + void Visit(Any &op) override; + void Visit(None &op) override; + void Visit(Identifier &op) override; + void Visit(PrimitiveLiteral &op) override; + void Visit(PropertyLookup &op) override; + void Visit(ParameterLookup &op) override; + void Visit(NamedExpression &op) override; + void Visit(RegexMatch &op) override; + + private: + std::ostream *out_; +}; + +// Declare all of the different `PrintObject` overloads upfront since they're +// mutually recursive. Without this, overload resolution depends on the ordering +// of the overloads within the source, which is quite fragile. + +template <typename T> +void PrintObject(std::ostream *out, const T &arg); + +void PrintObject(std::ostream *out, const std::string &str); + +void PrintObject(std::ostream *out, Aggregation::Op op); + +void PrintObject(std::ostream *out, Expression *expr); + +void PrintObject(std::ostream *out, Identifier *expr); + +void PrintObject(std::ostream *out, const storage::v3::PropertyValue &value); + +template <typename T> +void PrintObject(std::ostream *out, const std::vector<T> &vec); + +template <typename K, typename V> +void PrintObject(std::ostream *out, const std::map<K, V> &map); + +template <typename T> +void PrintObject(std::ostream *out, const T &arg) { + static_assert(!std::is_convertible<T, Expression *>::value, + "This overload shouldn't be called with pointers convertible " + "to Expression *. This means your other PrintObject overloads aren't " + "being called for certain AST nodes when they should (or perhaps such " + "overloads don't exist yet)."); + *out << arg; +} + +void PrintObject(std::ostream *out, const std::string &str) { *out << utils::Escape(str); } + +void PrintObject(std::ostream *out, Aggregation::Op op) { *out << Aggregation::OpToString(op); } + +void PrintObject(std::ostream *out, Expression *expr) { + if (expr) { + ExpressionPrettyPrinter printer{out}; + expr->Accept(printer); + } else { + *out << "<null>"; + } +} + +void PrintObject(std::ostream *out, Identifier *expr) { PrintObject(out, static_cast<Expression *>(expr)); } + +void PrintObject(std::ostream *out, const storage::v3::PropertyValue &value) { + switch (value.type()) { + case storage::v3::PropertyValue::Type::Null: + *out << "null"; + break; + + case storage::v3::PropertyValue::Type::String: + PrintObject(out, value.ValueString()); + break; + + case storage::v3::PropertyValue::Type::Bool: + *out << (value.ValueBool() ? "true" : "false"); + break; + + case storage::v3::PropertyValue::Type::Int: + PrintObject(out, value.ValueInt()); + break; + + case storage::v3::PropertyValue::Type::Double: + PrintObject(out, value.ValueDouble()); + break; + + case storage::v3::PropertyValue::Type::List: + PrintObject(out, value.ValueList()); + break; + + case storage::v3::PropertyValue::Type::Map: + PrintObject(out, value.ValueMap()); + break; + case storage::v3::PropertyValue::Type::TemporalData: + PrintObject(out, value.ValueTemporalData()); + break; + } +} + +template <typename T> +void PrintObject(std::ostream *out, const std::vector<T> &vec) { + *out << "["; + utils::PrintIterable(*out, vec, ", ", [](auto &stream, const auto &item) { PrintObject(&stream, item); }); + *out << "]"; +} + +template <typename K, typename V> +void PrintObject(std::ostream *out, const std::map<K, V> &map) { + *out << "{"; + utils::PrintIterable(*out, map, ", ", [](auto &stream, const auto &item) { + PrintObject(&stream, item.first); + stream << ": "; + PrintObject(&stream, item.second); + }); + *out << "}"; +} + +template <typename T> +void PrintOperatorArgs(std::ostream *out, const T &arg) { + *out << " "; + PrintObject(out, arg); + *out << ")"; +} + +template <typename T, typename... Ts> +void PrintOperatorArgs(std::ostream *out, const T &arg, const Ts &...args) { + *out << " "; + PrintObject(out, arg); + PrintOperatorArgs(out, args...); +} + +template <typename... Ts> +void PrintOperator(std::ostream *out, const std::string &name, const Ts &...args) { + *out << "(" << name; + PrintOperatorArgs(out, args...); +} + +ExpressionPrettyPrinter::ExpressionPrettyPrinter(std::ostream *out) : out_(out) {} + +#define UNARY_OPERATOR_VISIT(OP_NODE, OP_STR) \ + void ExpressionPrettyPrinter::Visit(OP_NODE &op) { PrintOperator(out_, OP_STR, op.expression_); } + +UNARY_OPERATOR_VISIT(NotOperator, "Not"); +UNARY_OPERATOR_VISIT(UnaryPlusOperator, "+"); +UNARY_OPERATOR_VISIT(UnaryMinusOperator, "-"); +UNARY_OPERATOR_VISIT(IsNullOperator, "IsNull"); + +#undef UNARY_OPERATOR_VISIT + +#define BINARY_OPERATOR_VISIT(OP_NODE, OP_STR) \ + void ExpressionPrettyPrinter::Visit(OP_NODE &op) { PrintOperator(out_, OP_STR, op.expression1_, op.expression2_); } + +BINARY_OPERATOR_VISIT(OrOperator, "Or"); +BINARY_OPERATOR_VISIT(XorOperator, "Xor"); +BINARY_OPERATOR_VISIT(AndOperator, "And"); +BINARY_OPERATOR_VISIT(AdditionOperator, "+"); +BINARY_OPERATOR_VISIT(SubtractionOperator, "-"); +BINARY_OPERATOR_VISIT(MultiplicationOperator, "*"); +BINARY_OPERATOR_VISIT(DivisionOperator, "/"); +BINARY_OPERATOR_VISIT(ModOperator, "%"); +BINARY_OPERATOR_VISIT(NotEqualOperator, "!="); +BINARY_OPERATOR_VISIT(EqualOperator, "=="); +BINARY_OPERATOR_VISIT(LessOperator, "<"); +BINARY_OPERATOR_VISIT(GreaterOperator, ">"); +BINARY_OPERATOR_VISIT(LessEqualOperator, "<="); +BINARY_OPERATOR_VISIT(GreaterEqualOperator, ">="); +BINARY_OPERATOR_VISIT(InListOperator, "In"); +BINARY_OPERATOR_VISIT(SubscriptOperator, "Subscript"); + +#undef BINARY_OPERATOR_VISIT + +void ExpressionPrettyPrinter::Visit(ListSlicingOperator &op) { + PrintOperator(out_, "ListSlicing", op.list_, op.lower_bound_, op.upper_bound_); +} + +void ExpressionPrettyPrinter::Visit(IfOperator &op) { + PrintOperator(out_, "If", op.condition_, op.then_expression_, op.else_expression_); +} + +void ExpressionPrettyPrinter::Visit(ListLiteral &op) { PrintOperator(out_, "ListLiteral", op.elements_); } + +void ExpressionPrettyPrinter::Visit(MapLiteral &op) { + std::map<std::string, Expression *> map; + for (const auto &kv : op.elements_) { + map[kv.first.name] = kv.second; + } + PrintObject(out_, map); +} + +void ExpressionPrettyPrinter::Visit(LabelsTest &op) { PrintOperator(out_, "LabelsTest", op.expression_); } + +void ExpressionPrettyPrinter::Visit(Aggregation &op) { PrintOperator(out_, "Aggregation", op.op_); } + +void ExpressionPrettyPrinter::Visit(Function &op) { PrintOperator(out_, "Function", op.function_name_, op.arguments_); } + +void ExpressionPrettyPrinter::Visit(Reduce &op) { + PrintOperator(out_, "Reduce", op.accumulator_, op.initializer_, op.identifier_, op.list_, op.expression_); +} + +void ExpressionPrettyPrinter::Visit(Coalesce &op) { PrintOperator(out_, "Coalesce", op.expressions_); } + +void ExpressionPrettyPrinter::Visit(Extract &op) { + PrintOperator(out_, "Extract", op.identifier_, op.list_, op.expression_); +} + +void ExpressionPrettyPrinter::Visit(All &op) { + PrintOperator(out_, "All", op.identifier_, op.list_expression_, op.where_->expression_); +} + +void ExpressionPrettyPrinter::Visit(Single &op) { + PrintOperator(out_, "Single", op.identifier_, op.list_expression_, op.where_->expression_); +} + +void ExpressionPrettyPrinter::Visit(Any &op) { + PrintOperator(out_, "Any", op.identifier_, op.list_expression_, op.where_->expression_); +} + +void ExpressionPrettyPrinter::Visit(None &op) { + PrintOperator(out_, "None", op.identifier_, op.list_expression_, op.where_->expression_); +} + +void ExpressionPrettyPrinter::Visit(Identifier &op) { PrintOperator(out_, "Identifier", op.name_); } + +void ExpressionPrettyPrinter::Visit(PrimitiveLiteral &op) { PrintObject(out_, op.value_); } + +void ExpressionPrettyPrinter::Visit(PropertyLookup &op) { + PrintOperator(out_, "PropertyLookup", op.expression_, op.property_.name); +} + +void ExpressionPrettyPrinter::Visit(ParameterLookup &op) { PrintOperator(out_, "ParameterLookup", op.token_position_); } + +void ExpressionPrettyPrinter::Visit(NamedExpression &op) { + PrintOperator(out_, "NamedExpression", op.name_, op.expression_); +} + +void ExpressionPrettyPrinter::Visit(RegexMatch &op) { PrintOperator(out_, "=~", op.string_expr_, op.regex_); } + +} // namespace + +void PrintExpression(Expression *expr, std::ostream *out) { + ExpressionPrettyPrinter printer{out}; + expr->Accept(printer); +} + +void PrintExpression(NamedExpression *expr, std::ostream *out) { + ExpressionPrettyPrinter printer{out}; + expr->Accept(printer); +} + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/ast/pretty_print.hpp b/src/query/v2/frontend/ast/pretty_print.hpp new file mode 100644 index 000000000..d6047c349 --- /dev/null +++ b/src/query/v2/frontend/ast/pretty_print.hpp @@ -0,0 +1,23 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <iostream> + +#include "query/v2/frontend/ast/ast.hpp" + +namespace memgraph::query::v2 { + +void PrintExpression(Expression *expr, std::ostream *out); +void PrintExpression(NamedExpression *expr, std::ostream *out); + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/opencypher/grammar/Cypher.g4 b/src/query/v2/frontend/opencypher/grammar/Cypher.g4 new file mode 100644 index 000000000..6ce84db1a --- /dev/null +++ b/src/query/v2/frontend/opencypher/grammar/Cypher.g4 @@ -0,0 +1,391 @@ +/* + * Copyright (c) 2015-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +parser grammar Cypher; + +options { tokenVocab=CypherLexer; } + +cypher : statement ';'? EOF ; + +statement : query ; + +query : cypherQuery + | indexQuery + | explainQuery + | profileQuery + | infoQuery + | constraintQuery + ; + +constraintQuery : ( CREATE | DROP ) CONSTRAINT ON constraint ; + +constraint : '(' nodeName=variable ':' labelName ')' ASSERT EXISTS '(' constraintPropertyList ')' + | '(' nodeName=variable ':' labelName ')' ASSERT constraintPropertyList IS UNIQUE + | '(' nodeName=variable ':' labelName ')' ASSERT '(' constraintPropertyList ')' IS NODE KEY + ; + +constraintPropertyList : variable propertyLookup ( ',' variable propertyLookup )* ; + +storageInfo : STORAGE INFO ; + +indexInfo : INDEX INFO ; + +constraintInfo : CONSTRAINT INFO ; + +infoQuery : SHOW ( storageInfo | indexInfo | constraintInfo ) ; + +explainQuery : EXPLAIN cypherQuery ; + +profileQuery : PROFILE cypherQuery ; + +cypherQuery : singleQuery ( cypherUnion )* ( queryMemoryLimit )? ; + +indexQuery : createIndex | dropIndex; + +singleQuery : clause ( clause )* ; + +cypherUnion : ( UNION ALL singleQuery ) + | ( UNION singleQuery ) + ; + +clause : cypherMatch + | unwind + | merge + | create + | set + | cypherDelete + | remove + | with + | cypherReturn + | callProcedure + ; + +cypherMatch : OPTIONAL? MATCH pattern where? ; + +unwind : UNWIND expression AS variable ; + +merge : MERGE patternPart ( mergeAction )* ; + +mergeAction : ( ON MATCH set ) + | ( ON CREATE set ) + ; + +create : CREATE pattern ; + +set : SET setItem ( ',' setItem )* ; + +setItem : ( propertyExpression '=' expression ) + | ( variable '=' expression ) + | ( variable '+=' expression ) + | ( variable nodeLabels ) + ; + +cypherDelete : DETACH? DELETE expression ( ',' expression )* ; + +remove : REMOVE removeItem ( ',' removeItem )* ; + +removeItem : ( variable nodeLabels ) + | propertyExpression + ; + +with : WITH ( DISTINCT )? returnBody ( where )? ; + +cypherReturn : RETURN ( DISTINCT )? returnBody ; + +callProcedure : CALL procedureName '(' ( expression ( ',' expression )* )? ')' ( procedureMemoryLimit )? ( yieldProcedureResults )? ; + +procedureName : symbolicName ( '.' symbolicName )* ; + +yieldProcedureResults : YIELD ( '*' | ( procedureResult ( ',' procedureResult )* ) ) ; + +memoryLimit : MEMORY ( UNLIMITED | LIMIT literal ( MB | KB ) ) ; + +queryMemoryLimit : QUERY memoryLimit ; + +procedureMemoryLimit : PROCEDURE memoryLimit ; + +procedureResult : ( variable AS variable ) | variable ; + +returnBody : returnItems ( order )? ( skip )? ( limit )? ; + +returnItems : ( '*' ( ',' returnItem )* ) + | ( returnItem ( ',' returnItem )* ) + ; + +returnItem : ( expression AS variable ) + | expression + ; + +order : ORDER BY sortItem ( ',' sortItem )* ; + +skip : L_SKIP expression ; + +limit : LIMIT expression ; + +sortItem : expression ( ASCENDING | ASC | DESCENDING | DESC )? ; + +where : WHERE expression ; + +pattern : patternPart ( ',' patternPart )* ; + +patternPart : ( variable '=' anonymousPatternPart ) + | anonymousPatternPart + ; + +anonymousPatternPart : patternElement ; + +patternElement : ( nodePattern ( patternElementChain )* ) + | ( '(' patternElement ')' ) + ; + +nodePattern : '(' ( variable )? ( nodeLabels )? ( properties )? ')' ; + +patternElementChain : relationshipPattern nodePattern ; + +relationshipPattern : ( leftArrowHead dash ( relationshipDetail )? dash rightArrowHead ) + | ( leftArrowHead dash ( relationshipDetail )? dash ) + | ( dash ( relationshipDetail )? dash rightArrowHead ) + | ( dash ( relationshipDetail )? dash ) + ; + +leftArrowHead : '<' | LeftArrowHeadPart ; +rightArrowHead : '>' | RightArrowHeadPart ; +dash : '-' | DashPart ; + +relationshipDetail : '[' ( name=variable )? ( relationshipTypes )? ( variableExpansion )? properties ']' + | '[' ( name=variable )? ( relationshipTypes )? ( variableExpansion )? relationshipLambda ( total_weight=variable )? (relationshipLambda )? ']' + | '[' ( name=variable )? ( relationshipTypes )? ( variableExpansion )? (properties )* ( relationshipLambda total_weight=variable )? (relationshipLambda )? ']'; + +relationshipLambda: '(' traversed_edge=variable ',' traversed_node=variable '|' expression ')'; + +variableExpansion : '*' (BFS | WSHORTEST)? ( expression )? ( '..' ( expression )? )? ; + +properties : mapLiteral + | parameter + ; + +relationshipTypes : ':' relTypeName ( '|' ':'? relTypeName )* ; + +nodeLabels : nodeLabel ( nodeLabel )* ; + +nodeLabel : ':' labelName ; + +labelName : symbolicName ; + +relTypeName : symbolicName ; + +expression : expression12 ; + +expression12 : expression11 ( OR expression11 )* ; + +expression11 : expression10 ( XOR expression10 )* ; + +expression10 : expression9 ( AND expression9 )* ; + +expression9 : ( NOT )* expression8 ; + +expression8 : expression7 ( partialComparisonExpression )* ; + +expression7 : expression6 ( ( '+' expression6 ) | ( '-' expression6 ) )* ; + +expression6 : expression5 ( ( '*' expression5 ) | ( '/' expression5 ) | ( '%' expression5 ) )* ; + +expression5 : expression4 ( '^' expression4 )* ; + +expression4 : ( ( '+' | '-' ) )* expression3a ; + +expression3a : expression3b ( stringAndNullOperators )* ; + +stringAndNullOperators : ( ( ( ( '=~' ) | ( IN ) | ( STARTS WITH ) | ( ENDS WITH ) | ( CONTAINS ) ) expression3b) | ( IS CYPHERNULL ) | ( IS NOT CYPHERNULL ) ) ; + +expression3b : expression2a ( listIndexingOrSlicing )* ; + +listIndexingOrSlicing : ( '[' expression ']' ) + | ( '[' lower_bound=expression? '..' upper_bound=expression? ']' ) + ; + +expression2a : expression2b ( nodeLabels )? ; + +expression2b : atom ( propertyLookup )* ; + +atom : literal + | parameter + | caseExpression + | ( COUNT '(' '*' ')' ) + | listComprehension + | patternComprehension + | ( FILTER '(' filterExpression ')' ) + | ( EXTRACT '(' extractExpression ')' ) + | ( REDUCE '(' reduceExpression ')' ) + | ( COALESCE '(' expression ( ',' expression )* ')' ) + | ( ALL '(' filterExpression ')' ) + | ( ANY '(' filterExpression ')' ) + | ( NONE '(' filterExpression ')' ) + | ( SINGLE '(' filterExpression ')' ) + | relationshipsPattern + | parenthesizedExpression + | functionInvocation + | variable + ; + +literal : numberLiteral + | StringLiteral + | booleanLiteral + | CYPHERNULL + | mapLiteral + | listLiteral + ; + +booleanLiteral : TRUE + | FALSE + ; + +listLiteral : '[' ( expression ( ',' expression )* )? ']' ; + +partialComparisonExpression : ( '=' expression7 ) + | ( '<>' expression7 ) + | ( '!=' expression7 ) + | ( '<' expression7 ) + | ( '>' expression7 ) + | ( '<=' expression7 ) + | ( '>=' expression7 ) + ; + +parenthesizedExpression : '(' expression ')' ; + +relationshipsPattern : nodePattern ( patternElementChain )+ ; + +filterExpression : idInColl ( where )? ; + +reduceExpression : accumulator=variable '=' initial=expression ',' idInColl '|' expression ; + +extractExpression : idInColl '|' expression ; + +idInColl : variable IN expression ; + +functionInvocation : functionName '(' ( DISTINCT )? ( expression ( ',' expression )* )? ')' ; + +functionName : symbolicName ( '.' symbolicName )* ; + +listComprehension : '[' filterExpression ( '|' expression )? ']' ; + +patternComprehension : '[' ( variable '=' )? relationshipsPattern ( WHERE expression )? '|' expression ']' ; + +propertyLookup : '.' ( propertyKeyName ) ; + +caseExpression : ( ( CASE ( caseAlternatives )+ ) | ( CASE test=expression ( caseAlternatives )+ ) ) ( ELSE else_expression=expression )? END ; + +caseAlternatives : WHEN when_expression=expression THEN then_expression=expression ; + +variable : symbolicName ; + +numberLiteral : doubleLiteral + | integerLiteral + ; + +mapLiteral : '{' ( propertyKeyName ':' expression ( ',' propertyKeyName ':' expression )* )? '}' ; + +parameter : '$' ( symbolicName | DecimalLiteral ) ; + +propertyExpression : atom ( propertyLookup )+ ; + +propertyKeyName : symbolicName ; + +integerLiteral : DecimalLiteral + | OctalLiteral + | HexadecimalLiteral + ; + +createIndex : CREATE INDEX ON ':' labelName ( '(' propertyKeyName ')' )? ; + +dropIndex : DROP INDEX ON ':' labelName ( '(' propertyKeyName ')' )? ; + +doubleLiteral : FloatingLiteral ; + +cypherKeyword : ALL + | AND + | ANY + | AS + | ASC + | ASCENDING + | ASSERT + | BFS + | BY + | CALL + | CASE + | CONSTRAINT + | CONTAINS + | COUNT + | CREATE + | CYPHERNULL + | DELETE + | DESC + | DESCENDING + | DETACH + | DISTINCT + | ELSE + | END + | ENDS + | EXISTS + | EXPLAIN + | EXTRACT + | FALSE + | FILTER + | IN + | INDEX + | INFO + | IS + | KEY + | LIMIT + | L_SKIP + | MATCH + | MERGE + | NODE + | NONE + | NOT + | ON + | OPTIONAL + | OR + | ORDER + | PROCEDURE + | PROFILE + | QUERY + | REDUCE + | REMOVE + | RETURN + | SET + | SHOW + | SINGLE + | STARTS + | STORAGE + | THEN + | TRUE + | UNION + | UNIQUE + | UNWIND + | WHEN + | WHERE + | WITH + | WSHORTEST + | XOR + | YIELD + ; + +symbolicName : UnescapedSymbolicName + | EscapedSymbolicName + | cypherKeyword + ; diff --git a/src/query/v2/frontend/opencypher/grammar/CypherLexer.g4 b/src/query/v2/frontend/opencypher/grammar/CypherLexer.g4 new file mode 100644 index 000000000..1377fbc82 --- /dev/null +++ b/src/query/v2/frontend/opencypher/grammar/CypherLexer.g4 @@ -0,0 +1,208 @@ +/* + * When changing this grammar make sure to update constants in + * src/query/frontend/stripped_lexer_constants.hpp (kKeywords, kSpecialTokens + * and bitsets) if needed. + */ + +lexer grammar CypherLexer ; + +import UnicodeCategories ; + +/* Skip whitespace and comments. */ +Skipped : ( Whitespace | Comment ) -> skip ; + +fragment Whitespace : '\u0020' + | [\u0009-\u000D] + | [\u001C-\u001F] + | '\u1680' | '\u180E' + | [\u2000-\u200A] + | '\u2028' | '\u2029' + | '\u205F' + | '\u3000' + | '\u00A0' + | '\u202F' + ; + +fragment Comment : '/*' .*? '*/' + | '//' ~[\r\n]* + ; + +/* Special symbols. */ +LPAREN : '(' ; +RPAREN : ')' ; +LBRACK : '[' ; +RBRACK : ']' ; +LBRACE : '{' ; +RBRACE : '}' ; + +COMMA : ',' ; +DOT : '.' ; +DOTS : '..' ; +COLON : ':' ; +SEMICOLON : ';' ; +DOLLAR : '$' ; +PIPE : '|' ; + +EQ : '=' ; +LT : '<' ; +GT : '>' ; +LTE : '<=' ; +GTE : '>=' ; +NEQ1 : '<>' ; +NEQ2 : '!=' ; +SIM : '=~' ; + +PLUS : '+' ; +MINUS : '-' ; +ASTERISK : '*' ; +SLASH : '/' ; +PERCENT : '%' ; +CARET : '^' ; +PLUS_EQ : '+=' ; + +/* Some random unicode characters that can be used to draw arrows. */ +LeftArrowHeadPart : '⟨' | '〈' | '﹤' | '<' ; +RightArrowHeadPart : '⟩' | '〉' | '﹥' | '>' ; +DashPart : '' | '‐' | '‑' | '‒' | '–' | '—' | '―' + | '−' | '﹘' | '﹣' | '-' + ; + +/* Cypher reserved words. */ +ALL : A L L ; +AND : A N D ; +ANY : A N Y ; +AS : A S ; +ASC : A S C ; +ASCENDING : A S C E N D I N G ; +ASSERT : A S S E R T ; +BFS : B F S ; +BY : B Y ; +CALL : C A L L ; +CASE : C A S E ; +COALESCE : C O A L E S C E ; +CONSTRAINT : C O N S T R A I N T ; +CONTAINS : C O N T A I N S ; +COUNT : C O U N T ; +CREATE : C R E A T E ; +CYPHERNULL : N U L L ; +DELETE : D E L E T E ; +DESC : D E S C ; +DESCENDING : D E S C E N D I N G ; +DETACH : D E T A C H ; +DISTINCT : D I S T I N C T ; +DROP : D R O P ; +ELSE : E L S E ; +END : E N D ; +ENDS : E N D S ; +EXISTS : E X I S T S ; +EXPLAIN : E X P L A I N ; +EXTRACT : E X T R A C T ; +FALSE : F A L S E ; +FILTER : F I L T E R ; +IN : I N ; +INDEX : I N D E X ; +INFO : I N F O ; +IS : I S ; +KB : K B ; +KEY : K E Y ; +LIMIT : L I M I T ; +L_SKIP : S K I P ; +MATCH : M A T C H ; +MB : M B ; +MEMORY : M E M O R Y ; +MERGE : M E R G E ; +NODE : N O D E ; +NONE : N O N E ; +NOT : N O T ; +ON : O N ; +OPTIONAL : O P T I O N A L ; +OR : O R ; +ORDER : O R D E R ; +PROCEDURE : P R O C E D U R E ; +PROFILE : P R O F I L E ; +QUERY : Q U E R Y ; +REDUCE : R E D U C E ; +REMOVE : R E M O V E ; +RETURN : R E T U R N ; +SET : S E T ; +SHOW : S H O W ; +SINGLE : S I N G L E ; +STARTS : S T A R T S ; +STORAGE : S T O R A G E ; +THEN : T H E N ; +TRUE : T R U E ; +UNION : U N I O N ; +UNIQUE : U N I Q U E ; +UNLIMITED : U N L I M I T E D ; +UNWIND : U N W I N D ; +WHEN : W H E N ; +WHERE : W H E R E ; +WITH : W I T H ; +WSHORTEST : W S H O R T E S T ; +XOR : X O R ; +YIELD : Y I E L D ; + +/* Double and single quoted string literals. */ +StringLiteral : '"' ( ~[\\"] | EscapeSequence )* '"' + | '\'' ( ~[\\'] | EscapeSequence )* '\'' + ; + +fragment EscapeSequence : '\\' ( B | F | N | R | T | '\\' | '\'' | '"' ) + | '\\u' HexDigit HexDigit HexDigit HexDigit + | '\\U' HexDigit HexDigit HexDigit HexDigit + HexDigit HexDigit HexDigit HexDigit + ; + +/* Number literals. */ +DecimalLiteral : '0' | NonZeroDigit ( DecDigit )* ; +OctalLiteral : '0' ( OctDigit )+ ; +HexadecimalLiteral : '0x' ( HexDigit )+ ; +FloatingLiteral : DecDigit* '.' DecDigit+ ( E '-'? DecDigit+ )? + | DecDigit+ ( '.' DecDigit* )? ( E '-'? DecDigit+ ) + | DecDigit+ ( E '-'? DecDigit+ ) + ; + +fragment NonZeroDigit : [1-9] ; +fragment DecDigit : [0-9] ; +fragment OctDigit : [0-7] ; +fragment HexDigit : [0-9] | [a-f] | [A-F] ; + +/* Symbolic names. */ +UnescapedSymbolicName : IdentifierStart ( IdentifierPart )* ; +EscapedSymbolicName : ( '`' ~[`]* '`' )+ ; + +/** + * Based on the unicode identifier and pattern syntax + * (http://www.unicode.org/reports/tr31/) + * and extended with a few characters. + */ +IdentifierStart : ID_Start | Pc ; +IdentifierPart : ID_Continue | Sc ; + +/* Hack for case-insensitive reserved words */ +fragment A : 'A' | 'a' ; +fragment B : 'B' | 'b' ; +fragment C : 'C' | 'c' ; +fragment D : 'D' | 'd' ; +fragment E : 'E' | 'e' ; +fragment F : 'F' | 'f' ; +fragment G : 'G' | 'g' ; +fragment H : 'H' | 'h' ; +fragment I : 'I' | 'i' ; +fragment J : 'J' | 'j' ; +fragment K : 'K' | 'k' ; +fragment L : 'L' | 'l' ; +fragment M : 'M' | 'm' ; +fragment N : 'N' | 'n' ; +fragment O : 'O' | 'o' ; +fragment P : 'P' | 'p' ; +fragment Q : 'Q' | 'q' ; +fragment R : 'R' | 'r' ; +fragment S : 'S' | 's' ; +fragment T : 'T' | 't' ; +fragment U : 'U' | 'u' ; +fragment V : 'V' | 'v' ; +fragment W : 'W' | 'w' ; +fragment X : 'X' | 'x' ; +fragment Y : 'Y' | 'y' ; +fragment Z : 'Z' | 'z' ; diff --git a/src/query/v2/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/v2/frontend/opencypher/grammar/MemgraphCypher.g4 new file mode 100644 index 000000000..b412a474a --- /dev/null +++ b/src/query/v2/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -0,0 +1,376 @@ +/* + * Copyright 2021 Memgraph Ltd. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source + * License, and you may not use this file except in compliance with the Business Source License. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +/* Memgraph specific part of Cypher grammar with enterprise features. */ + +parser grammar MemgraphCypher ; + +options { tokenVocab=MemgraphCypherLexer; } + +import Cypher ; + +memgraphCypherKeyword : cypherKeyword + | AFTER + | ALTER + | ASYNC + | AUTH + | BAD + | BATCH_INTERVAL + | BATCH_LIMIT + | BATCH_SIZE + | BEFORE + | BOOTSTRAP_SERVERS + | CHECK + | CLEAR + | COMMIT + | COMMITTED + | CONFIG + | CONFIGS + | CONSUMER_GROUP + | CREDENTIALS + | CSV + | DATA + | DELIMITER + | DATABASE + | DENY + | DROP + | DUMP + | EXECUTE + | FOR + | FOREACH + | FREE + | FROM + | GLOBAL + | GRANT + | HEADER + | IDENTIFIED + | ISOLATION + | KAFKA + | LEVEL + | LOAD + | LOCK + | MAIN + | MODE + | NEXT + | NO + | PASSWORD + | PULSAR + | PORT + | PRIVILEGES + | READ + | REGISTER + | REPLICA + | REPLICAS + | REPLICATION + | REVOKE + | ROLE + | ROLES + | QUOTE + | SESSION + | SETTING + | SETTINGS + | SNAPSHOT + | START + | STATS + | STREAM + | STREAMS + | SYNC + | TIMEOUT + | TO + | TOPICS + | TRANSACTION + | TRANSFORM + | TRIGGER + | TRIGGERS + | UNCOMMITTED + | UNLOCK + | UPDATE + | USER + | USERS + | VERSION + ; + +symbolicName : UnescapedSymbolicName + | EscapedSymbolicName + | memgraphCypherKeyword + ; + +query : cypherQuery + | indexQuery + | explainQuery + | profileQuery + | infoQuery + | constraintQuery + | authQuery + | dumpQuery + | replicationQuery + | lockPathQuery + | freeMemoryQuery + | triggerQuery + | isolationLevelQuery + | createSnapshotQuery + | streamQuery + | settingQuery + | versionQuery + ; + +authQuery : createRole + | dropRole + | showRoles + | createUser + | setPassword + | dropUser + | showUsers + | setRole + | clearRole + | grantPrivilege + | denyPrivilege + | revokePrivilege + | showPrivileges + | showRoleForUser + | showUsersForRole + ; + +replicationQuery : setReplicationRole + | showReplicationRole + | registerReplica + | dropReplica + | showReplicas + ; + +triggerQuery : createTrigger + | dropTrigger + | showTriggers + ; + +clause : cypherMatch + | unwind + | merge + | create + | set + | cypherDelete + | remove + | with + | cypherReturn + | callProcedure + | loadCsv + | foreach + ; + +updateClause : set + | remove + | create + | merge + | cypherDelete + | foreach + ; + +foreach : FOREACH '(' variable IN expression '|' updateClause+ ')' ; + +streamQuery : checkStream + | createStream + | dropStream + | startStream + | startAllStreams + | stopStream + | stopAllStreams + | showStreams + ; + +settingQuery : setSetting + | showSetting + | showSettings + ; + +loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER + ( IGNORE BAD ) ? + ( DELIMITER delimiter ) ? + ( QUOTE quote ) ? + AS rowVar ; + +csvFile : literal ; + +delimiter : literal ; + +quote : literal ; + +rowVar : variable ; + +userOrRoleName : symbolicName ; + +createRole : CREATE ROLE role=userOrRoleName ; + +dropRole : DROP ROLE role=userOrRoleName ; + +showRoles : SHOW ROLES ; + +createUser : CREATE USER user=userOrRoleName + ( IDENTIFIED BY password=literal )? ; + +setPassword : SET PASSWORD FOR user=userOrRoleName TO password=literal; + +dropUser : DROP USER user=userOrRoleName ; + +showUsers : SHOW USERS ; + +setRole : SET ROLE FOR user=userOrRoleName TO role=userOrRoleName; + +clearRole : CLEAR ROLE FOR user=userOrRoleName ; + +grantPrivilege : GRANT ( ALL PRIVILEGES | privileges=privilegeList ) TO userOrRole=userOrRoleName ; + +denyPrivilege : DENY ( ALL PRIVILEGES | privileges=privilegeList ) TO userOrRole=userOrRoleName ; + +revokePrivilege : REVOKE ( ALL PRIVILEGES | privileges=privilegeList ) FROM userOrRole=userOrRoleName ; + +privilege : CREATE + | DELETE + | MATCH + | MERGE + | SET + | REMOVE + | INDEX + | STATS + | AUTH + | CONSTRAINT + | DUMP + | REPLICATION + | READ_FILE + | FREE_MEMORY + | TRIGGER + | CONFIG + | DURABILITY + | STREAM + | MODULE_READ + | MODULE_WRITE + | WEBSOCKET + ; + +privilegeList : privilege ( ',' privilege )* ; + +showPrivileges : SHOW PRIVILEGES FOR userOrRole=userOrRoleName ; + +showRoleForUser : SHOW ROLE FOR user=userOrRoleName ; + +showUsersForRole : SHOW USERS FOR role=userOrRoleName ; + +dumpQuery: DUMP DATABASE ; + +setReplicationRole : SET REPLICATION ROLE TO ( MAIN | REPLICA ) + ( WITH PORT port=literal ) ? ; + +showReplicationRole : SHOW REPLICATION ROLE ; + +replicaName : symbolicName ; + +socketAddress : literal ; + +registerReplica : REGISTER REPLICA replicaName ( SYNC | ASYNC ) + ( WITH TIMEOUT timeout=literal ) ? + TO socketAddress ; + +dropReplica : DROP REPLICA replicaName ; + +showReplicas : SHOW REPLICAS ; + +lockPathQuery : ( LOCK | UNLOCK ) DATA DIRECTORY ; + +freeMemoryQuery : FREE MEMORY ; + +triggerName : symbolicName ; + +triggerStatement : .*? ; + +emptyVertex : '(' ')' ; + +emptyEdge : dash dash rightArrowHead ; + +createTrigger : CREATE TRIGGER triggerName ( ON ( emptyVertex | emptyEdge ) ? ( CREATE | UPDATE | DELETE ) ) ? + ( AFTER | BEFORE ) COMMIT EXECUTE triggerStatement ; + +dropTrigger : DROP TRIGGER triggerName ; + +showTriggers : SHOW TRIGGERS ; + +isolationLevel : SNAPSHOT ISOLATION | READ COMMITTED | READ UNCOMMITTED ; + +isolationLevelScope : GLOBAL | SESSION | NEXT ; + +isolationLevelQuery : SET isolationLevelScope TRANSACTION ISOLATION LEVEL isolationLevel ; + +createSnapshotQuery : CREATE SNAPSHOT ; + +streamName : symbolicName ; + +symbolicNameWithMinus : symbolicName ( MINUS symbolicName )* ; + +symbolicNameWithDotsAndMinus: symbolicNameWithMinus ( DOT symbolicNameWithMinus )* ; + +symbolicTopicNames : symbolicNameWithDotsAndMinus ( COMMA symbolicNameWithDotsAndMinus )* ; + +topicNames : symbolicTopicNames | literal ; + +commonCreateStreamConfig : TRANSFORM transformationName=procedureName + | BATCH_INTERVAL batchInterval=literal + | BATCH_SIZE batchSize=literal + ; + +createStream : kafkaCreateStream | pulsarCreateStream ; + +configKeyValuePair : literal ':' literal ; + +configMap : '{' ( configKeyValuePair ( ',' configKeyValuePair )* )? '}' ; + +kafkaCreateStreamConfig : TOPICS topicNames + | CONSUMER_GROUP consumerGroup=symbolicNameWithDotsAndMinus + | BOOTSTRAP_SERVERS bootstrapServers=literal + | CONFIGS configsMap=configMap + | CREDENTIALS credentialsMap=configMap + | commonCreateStreamConfig + ; + +kafkaCreateStream : CREATE KAFKA STREAM streamName ( kafkaCreateStreamConfig ) * ; + + +pulsarCreateStreamConfig : TOPICS topicNames + | SERVICE_URL serviceUrl=literal + | commonCreateStreamConfig + ; + +pulsarCreateStream : CREATE PULSAR STREAM streamName ( pulsarCreateStreamConfig ) * ; + +dropStream : DROP STREAM streamName ; + +startStream : START STREAM streamName ( BATCH_LIMIT batchLimit=literal ) ? ( TIMEOUT timeout=literal ) ? ; + +startAllStreams : START ALL STREAMS ; + +stopStream : STOP STREAM streamName ; + +stopAllStreams : STOP ALL STREAMS ; + +showStreams : SHOW STREAMS ; + +checkStream : CHECK STREAM streamName ( BATCH_LIMIT batchLimit=literal ) ? ( TIMEOUT timeout=literal ) ? ; + +settingName : literal ; + +settingValue : literal ; + +setSetting : SET DATABASE SETTING settingName TO settingValue ; + +showSetting : SHOW DATABASE SETTING settingName ; + +showSettings : SHOW DATABASE SETTINGS ; + +versionQuery : SHOW VERSION ; diff --git a/src/query/v2/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/v2/frontend/opencypher/grammar/MemgraphCypherLexer.g4 new file mode 100644 index 000000000..55e5d53a2 --- /dev/null +++ b/src/query/v2/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -0,0 +1,116 @@ +/* + * Copyright 2021 Memgraph Ltd. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source + * License, and you may not use this file except in compliance with the Business Source License. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +/* Memgraph specific Cypher reserved words used for enterprise features. */ + +/* + * When changing this grammar make sure to update constants in + * src/query/frontend/stripped_lexer_constants.hpp (kKeywords, kSpecialTokens + * and bitsets) if needed. + */ + +lexer grammar MemgraphCypherLexer ; + +import CypherLexer ; + +UNDERSCORE : '_' ; + +AFTER : A F T E R ; +ALTER : A L T E R ; +ASYNC : A S Y N C ; +AUTH : A U T H ; +BAD : B A D ; +BATCH_INTERVAL : B A T C H UNDERSCORE I N T E R V A L ; +BATCH_LIMIT : B A T C H UNDERSCORE L I M I T ; +BATCH_SIZE : B A T C H UNDERSCORE S I Z E ; +BEFORE : B E F O R E ; +BOOTSTRAP_SERVERS : B O O T S T R A P UNDERSCORE S E R V E R S ; +CHECK : C H E C K ; +CLEAR : C L E A R ; +COMMIT : C O M M I T ; +COMMITTED : C O M M I T T E D ; +CONFIG : C O N F I G ; +CONFIGS : C O N F I G S; +CONSUMER_GROUP : C O N S U M E R UNDERSCORE G R O U P ; +CREDENTIALS : C R E D E N T I A L S ; +CSV : C S V ; +DATA : D A T A ; +DELIMITER : D E L I M I T E R ; +DATABASE : D A T A B A S E ; +DENY : D E N Y ; +DIRECTORY : D I R E C T O R Y ; +DROP : D R O P ; +DUMP : D U M P ; +DURABILITY : D U R A B I L I T Y ; +EXECUTE : E X E C U T E ; +FOR : F O R ; +FOREACH : F O R E A C H; +FREE : F R E E ; +FREE_MEMORY : F R E E UNDERSCORE M E M O R Y ; +FROM : F R O M ; +GLOBAL : G L O B A L ; +GRANT : G R A N T ; +GRANTS : G R A N T S ; +HEADER : H E A D E R ; +IDENTIFIED : I D E N T I F I E D ; +IGNORE : I G N O R E ; +ISOLATION : I S O L A T I O N ; +KAFKA : K A F K A ; +LEVEL : L E V E L ; +LOAD : L O A D ; +LOCK : L O C K ; +MAIN : M A I N ; +MODE : M O D E ; +MODULE_READ : M O D U L E UNDERSCORE R E A D ; +MODULE_WRITE : M O D U L E UNDERSCORE W R I T E ; +NEXT : N E X T ; +NO : N O ; +PASSWORD : P A S S W O R D ; +PORT : P O R T ; +PRIVILEGES : P R I V I L E G E S ; +PULSAR : P U L S A R ; +READ : R E A D ; +READ_FILE : R E A D UNDERSCORE F I L E ; +REGISTER : R E G I S T E R ; +REPLICA : R E P L I C A ; +REPLICAS : R E P L I C A S ; +REPLICATION : R E P L I C A T I O N ; +REVOKE : R E V O K E ; +ROLE : R O L E ; +ROLES : R O L E S ; +QUOTE : Q U O T E ; +SERVICE_URL : S E R V I C E UNDERSCORE U R L ; +SESSION : S E S S I O N ; +SETTING : S E T T I N G ; +SETTINGS : S E T T I N G S ; +SNAPSHOT : S N A P S H O T ; +START : S T A R T ; +STATS : S T A T S ; +STOP : S T O P ; +STREAM : S T R E A M ; +STREAMS : S T R E A M S ; +SYNC : S Y N C ; +TIMEOUT : T I M E O U T ; +TO : T O ; +TOPICS : T O P I C S; +TRANSACTION : T R A N S A C T I O N ; +TRANSFORM : T R A N S F O R M ; +TRIGGER : T R I G G E R ; +TRIGGERS : T R I G G E R S ; +UNCOMMITTED : U N C O M M I T T E D ; +UNLOCK : U N L O C K ; +UPDATE : U P D A T E ; +USER : U S E R ; +USERS : U S E R S ; +VERSION : V E R S I O N ; +WEBSOCKET : W E B S O C K E T ; diff --git a/src/query/v2/frontend/opencypher/grammar/UnicodeCategories.g4 b/src/query/v2/frontend/opencypher/grammar/UnicodeCategories.g4 new file mode 100644 index 000000000..aa19ed96e --- /dev/null +++ b/src/query/v2/frontend/opencypher/grammar/UnicodeCategories.g4 @@ -0,0 +1,15 @@ +/** + * Unicode character categories used in openCypher lexer. This is separated from + * the lexer grammar as you probably don't want to ever change this (or even see + * this at all). + */ + +lexer grammar UnicodeCategories ; + +fragment ID_Start : [A-Za-z\u00AA\u00B5\u00BA\u00C0-\u00D6\u00D8-\u00F6\u00F8-\u02C1\u02C6-\u02D1\u02E0-\u02E4\u02EC\u02EE\u0370-\u0374\u0376-\u0377\u037A-\u037D\u0386\u0388-\u038A\u038C\u038E-\u03A1\u03A3-\u03F5\u03F7-\u0481\u048A-\u0527\u0531-\u0556\u0559\u0561-\u0587\u05D0-\u05EA\u05F0-\u05F2\u0620-\u064A\u066E-\u066F\u0671-\u06D3\u06D5\u06E5-\u06E6\u06EE-\u06EF\u06FA-\u06FC\u06FF\u0710\u0712-\u072F\u074D-\u07A5\u07B1\u07CA-\u07EA\u07F4-\u07F5\u07FA\u0800-\u0815\u081A\u0824\u0828\u0840-\u0858\u08A0\u08A2-\u08AC\u0904-\u0939\u093D\u0950\u0958-\u0961\u0971-\u0977\u0979-\u097F\u0985-\u098C\u098F-\u0990\u0993-\u09A8\u09AA-\u09B0\u09B2\u09B6-\u09B9\u09BD\u09CE\u09DC-\u09DD\u09DF-\u09E1\u09F0-\u09F1\u0A05-\u0A0A\u0A0F-\u0A10\u0A13-\u0A28\u0A2A-\u0A30\u0A32-\u0A33\u0A35-\u0A36\u0A38-\u0A39\u0A59-\u0A5C\u0A5E\u0A72-\u0A74\u0A85-\u0A8D\u0A8F-\u0A91\u0A93-\u0AA8\u0AAA-\u0AB0\u0AB2-\u0AB3\u0AB5-\u0AB9\u0ABD\u0AD0\u0AE0-\u0AE1\u0B05-\u0B0C\u0B0F-\u0B10\u0B13-\u0B28\u0B2A-\u0B30\u0B32-\u0B33\u0B35-\u0B39\u0B3D\u0B5C-\u0B5D\u0B5F-\u0B61\u0B71\u0B83\u0B85-\u0B8A\u0B8E-\u0B90\u0B92-\u0B95\u0B99-\u0B9A\u0B9C\u0B9E-\u0B9F\u0BA3-\u0BA4\u0BA8-\u0BAA\u0BAE-\u0BB9\u0BD0\u0C05-\u0C0C\u0C0E-\u0C10\u0C12-\u0C28\u0C2A-\u0C33\u0C35-\u0C39\u0C3D\u0C58-\u0C59\u0C60-\u0C61\u0C85-\u0C8C\u0C8E-\u0C90\u0C92-\u0CA8\u0CAA-\u0CB3\u0CB5-\u0CB9\u0CBD\u0CDE\u0CE0-\u0CE1\u0CF1-\u0CF2\u0D05-\u0D0C\u0D0E-\u0D10\u0D12-\u0D3A\u0D3D\u0D4E\u0D60-\u0D61\u0D7A-\u0D7F\u0D85-\u0D96\u0D9A-\u0DB1\u0DB3-\u0DBB\u0DBD\u0DC0-\u0DC6\u0E01-\u0E30\u0E32-\u0E33\u0E40-\u0E46\u0E81-\u0E82\u0E84\u0E87-\u0E88\u0E8A\u0E8D\u0E94-\u0E97\u0E99-\u0E9F\u0EA1-\u0EA3\u0EA5\u0EA7\u0EAA-\u0EAB\u0EAD-\u0EB0\u0EB2-\u0EB3\u0EBD\u0EC0-\u0EC4\u0EC6\u0EDC-\u0EDF\u0F00\u0F40-\u0F47\u0F49-\u0F6C\u0F88-\u0F8C\u1000-\u102A\u103F\u1050-\u1055\u105A-\u105D\u1061\u1065-\u1066\u106E-\u1070\u1075-\u1081\u108E\u10A0-\u10C5\u10C7\u10CD\u10D0-\u10FA\u10FC-\u1248\u124A-\u124D\u1250-\u1256\u1258\u125A-\u125D\u1260-\u1288\u128A-\u128D\u1290-\u12B0\u12B2-\u12B5\u12B8-\u12BE\u12C0\u12C2-\u12C5\u12C8-\u12D6\u12D8-\u1310\u1312-\u1315\u1318-\u135A\u1380-\u138F\u13A0-\u13F4\u1401-\u166C\u166F-\u167F\u1681-\u169A\u16A0-\u16EA\u16EE-\u16F0\u1700-\u170C\u170E-\u1711\u1720-\u1731\u1740-\u1751\u1760-\u176C\u176E-\u1770\u1780-\u17B3\u17D7\u17DC\u1820-\u1877\u1880-\u18A8\u18AA\u18B0-\u18F5\u1900-\u191C\u1950-\u196D\u1970-\u1974\u1980-\u19AB\u19C1-\u19C7\u1A00-\u1A16\u1A20-\u1A54\u1AA7\u1B05-\u1B33\u1B45-\u1B4B\u1B83-\u1BA0\u1BAE-\u1BAF\u1BBA-\u1BE5\u1C00-\u1C23\u1C4D-\u1C4F\u1C5A-\u1C7D\u1CE9-\u1CEC\u1CEE-\u1CF1\u1CF5-\u1CF6\u1D00-\u1DBF\u1E00-\u1F15\u1F18-\u1F1D\u1F20-\u1F45\u1F48-\u1F4D\u1F50-\u1F57\u1F59\u1F5B\u1F5D\u1F5F-\u1F7D\u1F80-\u1FB4\u1FB6-\u1FBC\u1FBE\u1FC2-\u1FC4\u1FC6-\u1FCC\u1FD0-\u1FD3\u1FD6-\u1FDB\u1FE0-\u1FEC\u1FF2-\u1FF4\u1FF6-\u1FFC\u2071\u207F\u2090-\u209C\u2102\u2107\u210A-\u2113\u2115\u2118-\u211D\u2124\u2126\u2128\u212A-\u2139\u213C-\u213F\u2145-\u2149\u214E\u2160-\u2188\u2C00-\u2C2E\u2C30-\u2C5E\u2C60-\u2CE4\u2CEB-\u2CEE\u2CF2-\u2CF3\u2D00-\u2D25\u2D27\u2D2D\u2D30-\u2D67\u2D6F\u2D80-\u2D96\u2DA0-\u2DA6\u2DA8-\u2DAE\u2DB0-\u2DB6\u2DB8-\u2DBE\u2DC0-\u2DC6\u2DC8-\u2DCE\u2DD0-\u2DD6\u2DD8-\u2DDE\u3005-\u3007\u3021-\u3029\u3031-\u3035\u3038-\u303C\u3041-\u3096\u309B-\u309F\u30A1-\u30FA\u30FC-\u30FF\u3105-\u312D\u3131-\u318E\u31A0-\u31BA\u31F0-\u31FF\u3400-\u4DB5\u4E00-\u9FCC\uA000-\uA48C\uA4D0-\uA4FD\uA500-\uA60C\uA610-\uA61F\uA62A-\uA62B\uA640-\uA66E\uA67F-\uA697\uA6A0-\uA6EF\uA717-\uA71F\uA722-\uA788\uA78B-\uA78E\uA790-\uA793\uA7A0-\uA7AA\uA7F8-\uA801\uA803-\uA805\uA807-\uA80A\uA80C-\uA822\uA840-\uA873\uA882-\uA8B3\uA8F2-\uA8F7\uA8FB\uA90A-\uA925\uA930-\uA946\uA960-\uA97C\uA984-\uA9B2\uA9CF\uAA00-\uAA28\uAA40-\uAA42\uAA44-\uAA4B\uAA60-\uAA76\uAA7A\uAA80-\uAAAF\uAAB1\uAAB5-\uAAB6\uAAB9-\uAABD\uAAC0\uAAC2\uAADB-\uAADD\uAAE0-\uAAEA\uAAF2-\uAAF4\uAB01-\uAB06\uAB09-\uAB0E\uAB11-\uAB16\uAB20-\uAB26\uAB28-\uAB2E\uABC0-\uABE2\uAC00-\uD7A3\uD7B0-\uD7C6\uD7CB-\uD7FB\uF900-\uFA6D\uFA70-\uFAD9\uFB00-\uFB06\uFB13-\uFB17\uFB1D\uFB1F-\uFB28\uFB2A-\uFB36\uFB38-\uFB3C\uFB3E\uFB40-\uFB41\uFB43-\uFB44\uFB46-\uFBB1\uFBD3-\uFD3D\uFD50-\uFD8F\uFD92-\uFDC7\uFDF0-\uFDFB\uFE70-\uFE74\uFE76-\uFEFC\uFF21-\uFF3A\uFF41-\uFF5A\uFF66-\uFFBE\uFFC2-\uFFC7\uFFCA-\uFFCF\uFFD2-\uFFD7\uFFDA-\uFFDC] ; + +fragment ID_Continue : [0-9A-Z_a-z\u00AA\u00B5\u00B7\u00BA\u00C0-\u00D6\u00D8-\u00F6\u00F8-\u02C1\u02C6-\u02D1\u02E0-\u02E4\u02EC\u02EE\u0300-\u0374\u0376-\u0377\u037A-\u037D\u0386-\u038A\u038C\u038E-\u03A1\u03A3-\u03F5\u03F7-\u0481\u0483-\u0487\u048A-\u0527\u0531-\u0556\u0559\u0561-\u0587\u0591-\u05BD\u05BF\u05C1-\u05C2\u05C4-\u05C5\u05C7\u05D0-\u05EA\u05F0-\u05F2\u0610-\u061A\u0620-\u0669\u066E-\u06D3\u06D5-\u06DC\u06DF-\u06E8\u06EA-\u06FC\u06FF\u0710-\u074A\u074D-\u07B1\u07C0-\u07F5\u07FA\u0800-\u082D\u0840-\u085B\u08A0\u08A2-\u08AC\u08E4-\u08FE\u0900-\u0963\u0966-\u096F\u0971-\u0977\u0979-\u097F\u0981-\u0983\u0985-\u098C\u098F-\u0990\u0993-\u09A8\u09AA-\u09B0\u09B2\u09B6-\u09B9\u09BC-\u09C4\u09C7-\u09C8\u09CB-\u09CE\u09D7\u09DC-\u09DD\u09DF-\u09E3\u09E6-\u09F1\u0A01-\u0A03\u0A05-\u0A0A\u0A0F-\u0A10\u0A13-\u0A28\u0A2A-\u0A30\u0A32-\u0A33\u0A35-\u0A36\u0A38-\u0A39\u0A3C\u0A3E-\u0A42\u0A47-\u0A48\u0A4B-\u0A4D\u0A51\u0A59-\u0A5C\u0A5E\u0A66-\u0A75\u0A81-\u0A83\u0A85-\u0A8D\u0A8F-\u0A91\u0A93-\u0AA8\u0AAA-\u0AB0\u0AB2-\u0AB3\u0AB5-\u0AB9\u0ABC-\u0AC5\u0AC7-\u0AC9\u0ACB-\u0ACD\u0AD0\u0AE0-\u0AE3\u0AE6-\u0AEF\u0B01-\u0B03\u0B05-\u0B0C\u0B0F-\u0B10\u0B13-\u0B28\u0B2A-\u0B30\u0B32-\u0B33\u0B35-\u0B39\u0B3C-\u0B44\u0B47-\u0B48\u0B4B-\u0B4D\u0B56-\u0B57\u0B5C-\u0B5D\u0B5F-\u0B63\u0B66-\u0B6F\u0B71\u0B82-\u0B83\u0B85-\u0B8A\u0B8E-\u0B90\u0B92-\u0B95\u0B99-\u0B9A\u0B9C\u0B9E-\u0B9F\u0BA3-\u0BA4\u0BA8-\u0BAA\u0BAE-\u0BB9\u0BBE-\u0BC2\u0BC6-\u0BC8\u0BCA-\u0BCD\u0BD0\u0BD7\u0BE6-\u0BEF\u0C01-\u0C03\u0C05-\u0C0C\u0C0E-\u0C10\u0C12-\u0C28\u0C2A-\u0C33\u0C35-\u0C39\u0C3D-\u0C44\u0C46-\u0C48\u0C4A-\u0C4D\u0C55-\u0C56\u0C58-\u0C59\u0C60-\u0C63\u0C66-\u0C6F\u0C82-\u0C83\u0C85-\u0C8C\u0C8E-\u0C90\u0C92-\u0CA8\u0CAA-\u0CB3\u0CB5-\u0CB9\u0CBC-\u0CC4\u0CC6-\u0CC8\u0CCA-\u0CCD\u0CD5-\u0CD6\u0CDE\u0CE0-\u0CE3\u0CE6-\u0CEF\u0CF1-\u0CF2\u0D02-\u0D03\u0D05-\u0D0C\u0D0E-\u0D10\u0D12-\u0D3A\u0D3D-\u0D44\u0D46-\u0D48\u0D4A-\u0D4E\u0D57\u0D60-\u0D63\u0D66-\u0D6F\u0D7A-\u0D7F\u0D82-\u0D83\u0D85-\u0D96\u0D9A-\u0DB1\u0DB3-\u0DBB\u0DBD\u0DC0-\u0DC6\u0DCA\u0DCF-\u0DD4\u0DD6\u0DD8-\u0DDF\u0DF2-\u0DF3\u0E01-\u0E3A\u0E40-\u0E4E\u0E50-\u0E59\u0E81-\u0E82\u0E84\u0E87-\u0E88\u0E8A\u0E8D\u0E94-\u0E97\u0E99-\u0E9F\u0EA1-\u0EA3\u0EA5\u0EA7\u0EAA-\u0EAB\u0EAD-\u0EB9\u0EBB-\u0EBD\u0EC0-\u0EC4\u0EC6\u0EC8-\u0ECD\u0ED0-\u0ED9\u0EDC-\u0EDF\u0F00\u0F18-\u0F19\u0F20-\u0F29\u0F35\u0F37\u0F39\u0F3E-\u0F47\u0F49-\u0F6C\u0F71-\u0F84\u0F86-\u0F97\u0F99-\u0FBC\u0FC6\u1000-\u1049\u1050-\u109D\u10A0-\u10C5\u10C7\u10CD\u10D0-\u10FA\u10FC-\u1248\u124A-\u124D\u1250-\u1256\u1258\u125A-\u125D\u1260-\u1288\u128A-\u128D\u1290-\u12B0\u12B2-\u12B5\u12B8-\u12BE\u12C0\u12C2-\u12C5\u12C8-\u12D6\u12D8-\u1310\u1312-\u1315\u1318-\u135A\u135D-\u135F\u1369-\u1371\u1380-\u138F\u13A0-\u13F4\u1401-\u166C\u166F-\u167F\u1681-\u169A\u16A0-\u16EA\u16EE-\u16F0\u1700-\u170C\u170E-\u1714\u1720-\u1734\u1740-\u1753\u1760-\u176C\u176E-\u1770\u1772-\u1773\u1780-\u17D3\u17D7\u17DC-\u17DD\u17E0-\u17E9\u180B-\u180D\u1810-\u1819\u1820-\u1877\u1880-\u18AA\u18B0-\u18F5\u1900-\u191C\u1920-\u192B\u1930-\u193B\u1946-\u196D\u1970-\u1974\u1980-\u19AB\u19B0-\u19C9\u19D0-\u19DA\u1A00-\u1A1B\u1A20-\u1A5E\u1A60-\u1A7C\u1A7F-\u1A89\u1A90-\u1A99\u1AA7\u1B00-\u1B4B\u1B50-\u1B59\u1B6B-\u1B73\u1B80-\u1BF3\u1C00-\u1C37\u1C40-\u1C49\u1C4D-\u1C7D\u1CD0-\u1CD2\u1CD4-\u1CF6\u1D00-\u1DE6\u1DFC-\u1F15\u1F18-\u1F1D\u1F20-\u1F45\u1F48-\u1F4D\u1F50-\u1F57\u1F59\u1F5B\u1F5D\u1F5F-\u1F7D\u1F80-\u1FB4\u1FB6-\u1FBC\u1FBE\u1FC2-\u1FC4\u1FC6-\u1FCC\u1FD0-\u1FD3\u1FD6-\u1FDB\u1FE0-\u1FEC\u1FF2-\u1FF4\u1FF6-\u1FFC\u203F-\u2040\u2054\u2071\u207F\u2090-\u209C\u20D0-\u20DC\u20E1\u20E5-\u20F0\u2102\u2107\u210A-\u2113\u2115\u2118-\u211D\u2124\u2126\u2128\u212A-\u2139\u213C-\u213F\u2145-\u2149\u214E\u2160-\u2188\u2C00-\u2C2E\u2C30-\u2C5E\u2C60-\u2CE4\u2CEB-\u2CF3\u2D00-\u2D25\u2D27\u2D2D\u2D30-\u2D67\u2D6F\u2D7F-\u2D96\u2DA0-\u2DA6\u2DA8-\u2DAE\u2DB0-\u2DB6\u2DB8-\u2DBE\u2DC0-\u2DC6\u2DC8-\u2DCE\u2DD0-\u2DD6\u2DD8-\u2DDE\u2DE0-\u2DFF\u3005-\u3007\u3021-\u302F\u3031-\u3035\u3038-\u303C\u3041-\u3096\u3099-\u309F\u30A1-\u30FA\u30FC-\u30FF\u3105-\u312D\u3131-\u318E\u31A0-\u31BA\u31F0-\u31FF\u3400-\u4DB5\u4E00-\u9FCC\uA000-\uA48C\uA4D0-\uA4FD\uA500-\uA60C\uA610-\uA62B\uA640-\uA66F\uA674-\uA67D\uA67F-\uA697\uA69F-\uA6F1\uA717-\uA71F\uA722-\uA788\uA78B-\uA78E\uA790-\uA793\uA7A0-\uA7AA\uA7F8-\uA827\uA840-\uA873\uA880-\uA8C4\uA8D0-\uA8D9\uA8E0-\uA8F7\uA8FB\uA900-\uA92D\uA930-\uA953\uA960-\uA97C\uA980-\uA9C0\uA9CF-\uA9D9\uAA00-\uAA36\uAA40-\uAA4D\uAA50-\uAA59\uAA60-\uAA76\uAA7A-\uAA7B\uAA80-\uAAC2\uAADB-\uAADD\uAAE0-\uAAEF\uAAF2-\uAAF6\uAB01-\uAB06\uAB09-\uAB0E\uAB11-\uAB16\uAB20-\uAB26\uAB28-\uAB2E\uABC0-\uABEA\uABEC-\uABED\uABF0-\uABF9\uAC00-\uD7A3\uD7B0-\uD7C6\uD7CB-\uD7FB\uF900-\uFA6D\uFA70-\uFAD9\uFB00-\uFB06\uFB13-\uFB17\uFB1D-\uFB28\uFB2A-\uFB36\uFB38-\uFB3C\uFB3E\uFB40-\uFB41\uFB43-\uFB44\uFB46-\uFBB1\uFBD3-\uFD3D\uFD50-\uFD8F\uFD92-\uFDC7\uFDF0-\uFDFB\uFE00-\uFE0F\uFE20-\uFE26\uFE33-\uFE34\uFE4D-\uFE4F\uFE70-\uFE74\uFE76-\uFEFC\uFF10-\uFF19\uFF21-\uFF3A\uFF3F\uFF41-\uFF5A\uFF66-\uFFBE\uFFC2-\uFFC7\uFFCA-\uFFCF\uFFD2-\uFFD7\uFFDA-\uFFDC] ; + +fragment Pc : [\u005F\u203F\u2040\u2054\uFE33\uFE34\uFE4D\uFE4E\uFE4F\uFF3F] ; + +fragment Sc : [$\u00A2-\u00A5\u058F\u060B\u09F2-\u09F3\u09FB\u0AF1\u0BF9\u0E3F\u17DB\u20A0-\u20BA\uA838\uFDFC\uFE69\uFF04\uFFE0-\uFFE1\uFFE5-\uFFE6] ; diff --git a/src/query/v2/frontend/opencypher/parser.hpp b/src/query/v2/frontend/opencypher/parser.hpp new file mode 100644 index 000000000..003209318 --- /dev/null +++ b/src/query/v2/frontend/opencypher/parser.hpp @@ -0,0 +1,68 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <string> + +#include "antlr4-runtime.h" +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/opencypher/generated/MemgraphCypher.h" +#include "query/v2/frontend/opencypher/generated/MemgraphCypherLexer.h" + +namespace memgraph::query::v2::frontend::opencypher { + +/** + * Generates openCypher AST + * This thing must me a class since parser.cypher() returns pointer and there is + * no way for us to get ownership over the object. + */ +class Parser { + public: + /** + * @param query incoming query that has to be compiled into query plan + * the first step is to generate AST + */ + Parser(const std::string query) : query_(std::move(query)) { + parser_.removeErrorListeners(); + parser_.addErrorListener(&error_listener_); + tree_ = parser_.cypher(); + if (parser_.getNumberOfSyntaxErrors()) { + throw query::v2::SyntaxException(error_listener_.error_); + } + } + + auto tree() { return tree_; } + + private: + class FirstMessageErrorListener : public antlr4::BaseErrorListener { + void syntaxError(antlr4::Recognizer *, antlr4::Token *, size_t line, size_t position, const std::string &message, + std::exception_ptr) override { + if (error_.empty()) { + error_ = "line " + std::to_string(line) + ":" + std::to_string(position + 1) + " " + message; + } + } + + public: + std::string error_; + }; + + FirstMessageErrorListener error_listener_; + std::string query_; + antlr4::ANTLRInputStream input_{query_}; + antlropencypher::MemgraphCypherLexer lexer_{&input_}; + antlr4::CommonTokenStream tokens_{&lexer_}; + + // generate ast + antlropencypher::MemgraphCypher parser_{&tokens_}; + antlr4::tree::ParseTree *tree_ = nullptr; +}; +} // namespace memgraph::query::v2::frontend::opencypher diff --git a/src/query/v2/frontend/parsing.cpp b/src/query/v2/frontend/parsing.cpp new file mode 100644 index 000000000..1f3208d9a --- /dev/null +++ b/src/query/v2/frontend/parsing.cpp @@ -0,0 +1,184 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/frontend/parsing.hpp" + +#include <cctype> +#include <codecvt> +#include <locale> +#include <stdexcept> + +#include "query/v2/exceptions.hpp" +#include "utils/logging.hpp" +#include "utils/string.hpp" + +namespace memgraph::query::v2::frontend { + +int64_t ParseIntegerLiteral(const std::string &s) { + try { + // Not really correct since long long can have a bigger range than int64_t. + return static_cast<int64_t>(std::stoll(s, 0, 0)); + } catch (const std::out_of_range &) { + throw SemanticException("Integer literal exceeds 64 bits."); + } +} + +std::string ParseStringLiteral(const std::string &s) { + // These functions is declared as lambda since its semantics is highly + // specific for this conxtext and shouldn't be used elsewhere. + auto EncodeEscapedUnicodeCodepointUtf32 = [](const std::string &s, int &i) { + const int kLongUnicodeLength = 8; + int j = i + 1; + while (j < static_cast<int>(s.size()) - 1 && j < i + kLongUnicodeLength + 1 && isxdigit(s[j])) { + ++j; + } + if (j - i == kLongUnicodeLength + 1) { + char32_t t = stoi(s.substr(i + 1, kLongUnicodeLength), 0, 16); + i += kLongUnicodeLength; + std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter; + return converter.to_bytes(t); + } + throw SyntaxException( + "Expected 8 hex digits as unicode codepoint started with \\U. " + "Use \\u for 4 hex digits format."); + }; + auto EncodeEscapedUnicodeCodepointUtf16 = [](const std::string &s, int &i) { + const int kShortUnicodeLength = 4; + int j = i + 1; + while (j < static_cast<int>(s.size()) - 1 && j < i + kShortUnicodeLength + 1 && isxdigit(s[j])) { + ++j; + } + if (j - i >= kShortUnicodeLength + 1) { + char16_t t = stoi(s.substr(i + 1, kShortUnicodeLength), 0, 16); + if (t >= 0xD800 && t <= 0xDBFF) { + // t is high surrogate pair. Expect one more utf16 codepoint. + j = i + kShortUnicodeLength + 1; + if (j >= static_cast<int>(s.size()) - 1 || s[j] != '\\') { + throw SemanticException("Invalid UTF codepoint."); + } + ++j; + if (j >= static_cast<int>(s.size()) - 1 || (s[j] != 'u' && s[j] != 'U')) { + throw SemanticException("Invalid UTF codepoint."); + } + ++j; + int k = j; + while (k < static_cast<int>(s.size()) - 1 && k < j + kShortUnicodeLength && isxdigit(s[k])) { + ++k; + } + if (k != j + kShortUnicodeLength) { + throw SemanticException("Invalid UTF codepoint."); + } + char16_t surrogates[3] = {t, static_cast<char16_t>(stoi(s.substr(j, kShortUnicodeLength), 0, 16)), 0}; + i += kShortUnicodeLength + 2 + kShortUnicodeLength; + std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t> converter; + return converter.to_bytes(surrogates); + } else { + i += kShortUnicodeLength; + std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t> converter; + return converter.to_bytes(t); + } + } + throw SyntaxException( + "Expected 4 hex digits as unicode codepoint started with \\u. " + "Use \\U for 8 hex digits format."); + }; + + std::string unescaped; + bool escape = false; + + // First and last char is quote, we don't need to look at them. + for (int i = 1; i < static_cast<int>(s.size()) - 1; ++i) { + if (escape) { + switch (s[i]) { + case '\\': + unescaped += '\\'; + break; + case '\'': + unescaped += '\''; + break; + case '"': + unescaped += '"'; + break; + case 'B': + case 'b': + unescaped += '\b'; + break; + case 'F': + case 'f': + unescaped += '\f'; + break; + case 'N': + case 'n': + unescaped += '\n'; + break; + case 'R': + case 'r': + unescaped += '\r'; + break; + case 'T': + case 't': + unescaped += '\t'; + break; + case 'U': + try { + unescaped += EncodeEscapedUnicodeCodepointUtf32(s, i); + } catch (const std::range_error &) { + throw SemanticException("Invalid UTF codepoint."); + } + break; + case 'u': + try { + unescaped += EncodeEscapedUnicodeCodepointUtf16(s, i); + } catch (const std::range_error &) { + throw SemanticException("Invalid UTF codepoint."); + } + break; + default: + // This should never happen, except grammar changes and we don't + // notice change in this production. + DLOG_FATAL("can't happen"); + throw std::exception(); + } + escape = false; + } else if (s[i] == '\\') { + escape = true; + } else { + unescaped += s[i]; + } + } + return unescaped; +} + +double ParseDoubleLiteral(const std::string &s) { + try { + return utils::ParseDouble(s); + } catch (const utils::BasicException &) { + throw SemanticException("Couldn't parse string to double."); + } +} + +std::string ParseParameter(const std::string &s) { + DMG_ASSERT(s[0] == '$', "Invalid string passed as parameter name"); + if (s[1] != '`') return s.substr(1); + // If parameter name is escaped symbolic name then symbolic name should be + // unescaped and leading and trailing backquote should be removed. + DMG_ASSERT(s.size() > 3U && s.back() == '`', "Invalid string passed as parameter name"); + std::string out; + for (int i = 2; i < static_cast<int>(s.size()) - 1; ++i) { + if (s[i] == '`') { + ++i; + } + out.push_back(s[i]); + } + return out; +} + +} // namespace memgraph::query::v2::frontend diff --git a/src/query/v2/frontend/parsing.hpp b/src/query/v2/frontend/parsing.hpp new file mode 100644 index 000000000..2ba05b0d6 --- /dev/null +++ b/src/query/v2/frontend/parsing.hpp @@ -0,0 +1,27 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +#pragma once + +#include <cstdint> +#include <string> + +namespace memgraph::query::v2::frontend { + +// These are the functions for parsing literals and parameter names from +// opencypher query. +int64_t ParseIntegerLiteral(const std::string &s); +std::string ParseStringLiteral(const std::string &s); +double ParseDoubleLiteral(const std::string &s); +std::string ParseParameter(const std::string &s); + +} // namespace memgraph::query::v2::frontend diff --git a/src/query/v2/frontend/semantic/required_privileges.cpp b/src/query/v2/frontend/semantic/required_privileges.cpp new file mode 100644 index 000000000..0790529cf --- /dev/null +++ b/src/query/v2/frontend/semantic/required_privileges.cpp @@ -0,0 +1,152 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/ast_visitor.hpp" +#include "query/v2/procedure/module.hpp" +#include "utils/memory.hpp" + +namespace memgraph::query::v2 { + +class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVisitor { + public: + using HierarchicalTreeVisitor::PostVisit; + using HierarchicalTreeVisitor::PreVisit; + using HierarchicalTreeVisitor::Visit; + using QueryVisitor<void>::Visit; + + std::vector<AuthQuery::Privilege> privileges() { return privileges_; } + + void Visit(IndexQuery &) override { AddPrivilege(AuthQuery::Privilege::INDEX); } + + void Visit(AuthQuery &) override { AddPrivilege(AuthQuery::Privilege::AUTH); } + + void Visit(ExplainQuery &query) override { query.cypher_query_->Accept(*this); } + + void Visit(ProfileQuery &query) override { query.cypher_query_->Accept(*this); } + + void Visit(InfoQuery &info_query) override { + switch (info_query.info_type_) { + case InfoQuery::InfoType::INDEX: + // TODO: This should be INDEX | STATS, but we don't have support for + // *or* with privileges. + AddPrivilege(AuthQuery::Privilege::INDEX); + break; + case InfoQuery::InfoType::STORAGE: + AddPrivilege(AuthQuery::Privilege::STATS); + break; + case InfoQuery::InfoType::CONSTRAINT: + // TODO: This should be CONSTRAINT | STATS, but we don't have support + // for *or* with privileges. + AddPrivilege(AuthQuery::Privilege::CONSTRAINT); + break; + } + } + + void Visit(ConstraintQuery &constraint_query) override { AddPrivilege(AuthQuery::Privilege::CONSTRAINT); } + + void Visit(CypherQuery &query) override { + query.single_query_->Accept(*this); + for (auto *cypher_union : query.cypher_unions_) { + cypher_union->Accept(*this); + } + } + + void Visit(DumpQuery &dump_query) override { AddPrivilege(AuthQuery::Privilege::DUMP); } + + void Visit(LockPathQuery &lock_path_query) override { AddPrivilege(AuthQuery::Privilege::DURABILITY); } + + void Visit(FreeMemoryQuery &free_memory_query) override { AddPrivilege(AuthQuery::Privilege::FREE_MEMORY); } + + void Visit(TriggerQuery &trigger_query) override { AddPrivilege(AuthQuery::Privilege::TRIGGER); } + + void Visit(StreamQuery &stream_query) override { AddPrivilege(AuthQuery::Privilege::STREAM); } + + void Visit(ReplicationQuery &replication_query) override { AddPrivilege(AuthQuery::Privilege::REPLICATION); } + + void Visit(IsolationLevelQuery &isolation_level_query) override { AddPrivilege(AuthQuery::Privilege::CONFIG); } + + void Visit(CreateSnapshotQuery &create_snapshot_query) override { AddPrivilege(AuthQuery::Privilege::DURABILITY); } + + void Visit(SettingQuery & /*setting_query*/) override { AddPrivilege(AuthQuery::Privilege::CONFIG); } + + void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); } + + bool PreVisit(Create & /*unused*/) override { + AddPrivilege(AuthQuery::Privilege::CREATE); + return false; + } + bool PreVisit(CallProcedure &procedure) override { + const auto maybe_proc = + procedure::FindProcedure(procedure::gModuleRegistry, procedure.procedure_name_, utils::NewDeleteResource()); + if (maybe_proc && maybe_proc->second->info.required_privilege) { + AddPrivilege(*maybe_proc->second->info.required_privilege); + } + return false; + } + bool PreVisit(Delete & /*unused*/) override { + AddPrivilege(AuthQuery::Privilege::DELETE); + return false; + } + bool PreVisit(Match & /*unused*/) override { + AddPrivilege(AuthQuery::Privilege::MATCH); + return false; + } + bool PreVisit(Merge & /*unused*/) override { + AddPrivilege(AuthQuery::Privilege::MERGE); + return false; + } + bool PreVisit(SetProperty & /*unused*/) override { + AddPrivilege(AuthQuery::Privilege::SET); + return false; + } + bool PreVisit(SetProperties & /*unused*/) override { + AddPrivilege(AuthQuery::Privilege::SET); + return false; + } + bool PreVisit(SetLabels & /*unused*/) override { + AddPrivilege(AuthQuery::Privilege::SET); + return false; + } + bool PreVisit(RemoveProperty & /*unused*/) override { + AddPrivilege(AuthQuery::Privilege::REMOVE); + return false; + } + bool PreVisit(RemoveLabels & /*unused*/) override { + AddPrivilege(AuthQuery::Privilege::REMOVE); + return false; + } + bool PreVisit(LoadCsv & /*unused*/) override { + AddPrivilege(AuthQuery::Privilege::READ_FILE); + return false; + } + + bool Visit(Identifier & /*unused*/) override { return true; } + bool Visit(PrimitiveLiteral & /*unused*/) override { return true; } + bool Visit(ParameterLookup & /*unused*/) override { return true; } + + private: + void AddPrivilege(AuthQuery::Privilege privilege) { + if (!utils::Contains(privileges_, privilege)) { + privileges_.push_back(privilege); + } + } + + std::vector<AuthQuery::Privilege> privileges_; +}; + +std::vector<AuthQuery::Privilege> GetRequiredPrivileges(Query *query) { + PrivilegeExtractor extractor; + query->Accept(extractor); + return extractor.privileges(); +} + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/semantic/required_privileges.hpp b/src/query/v2/frontend/semantic/required_privileges.hpp new file mode 100644 index 000000000..943f786a3 --- /dev/null +++ b/src/query/v2/frontend/semantic/required_privileges.hpp @@ -0,0 +1,18 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "query/v2/frontend/ast/ast.hpp" + +namespace memgraph::query::v2 { +std::vector<AuthQuery::Privilege> GetRequiredPrivileges(Query *query); +} // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/semantic/symbol.lcp b/src/query/v2/frontend/semantic/symbol.lcp new file mode 100644 index 000000000..c5b0b8030 --- /dev/null +++ b/src/query/v2/frontend/semantic/symbol.lcp @@ -0,0 +1,89 @@ +;; Copyright 2022 Memgraph Ltd. +;; +;; Use of this software is governed by the Business Source License +;; included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +;; License, and you may not use this file except in compliance with the Business Source License. +;; +;; As of the Change Date specified in that file, in accordance with +;; the Business Source License, use of this software will be governed +;; by the Apache License, Version 2.0, included in the file +;; licenses/APL.txt. + +#>cpp +#pragma once + +#include <string> + +#include "utils/typeinfo.hpp" +cpp<# + +(lcp:namespace memgraph) +(lcp:namespace query) +(lcp:namespace v2) + +(lcp:define-class symbol () + ((name "std::string" :scope :public) + (position :int64_t :scope :public) + (user-declared :bool :initval "true" :scope :public) + (type "Type" :initval "Type::ANY" :scope :public) + (token-position :int64_t :initval "-1" :scope :public)) + (:public + ;; This is similar to TypedValue::Type, but this has `Any` type. + ;; TODO: Make a better Type structure which can store a generic List. + (lcp:define-enum type (any vertex edge path number edge-list) + (:serialize)) + #>cpp + // TODO: Generate enum to string conversion from LCP. Note, that this is + // displayed to the end user, so we may want to have a pretty name of each + // value. + static std::string TypeToString(Type type) { + const char *enum_string[] = {"Any", "Vertex", "Edge", + "Path", "Number", "EdgeList"}; + return enum_string[static_cast<int>(type)]; + } + + Symbol() {} + Symbol(const std::string &name, int position, bool user_declared, + Type type = Type::ANY, int token_position = -1) + : name_(name), + position_(position), + user_declared_(user_declared), + type_(type), + token_position_(token_position) {} + + bool operator==(const Symbol &other) const { + return position_ == other.position_ && name_ == other.name_ && + type_ == other.type_; + } + bool operator!=(const Symbol &other) const { return !operator==(other); } + + // TODO: Remove these since members are public + const auto &name() const { return name_; } + int position() const { return position_; } + Type type() const { return type_; } + bool user_declared() const { return user_declared_; } + int token_position() const { return token_position_; } + cpp<#) + (:serialize (:slk))) + +(lcp:pop-namespace) ;; v2 +(lcp:pop-namespace) ;; query +(lcp:pop-namespace) ;; memgraph + +#>cpp +namespace std { + +template <> +struct hash<memgraph::query::v2::Symbol> { + size_t operator()(const memgraph::query::v2::Symbol &symbol) const { + size_t prime = 265443599u; + size_t hash = std::hash<int>{}(symbol.position()); + hash ^= prime * std::hash<std::string>{}(symbol.name()); + hash ^= prime * std::hash<int>{}(static_cast<int>(symbol.type())); + return hash; + } +}; + +} // namespace std + +cpp<# diff --git a/src/query/v2/frontend/semantic/symbol_generator.cpp b/src/query/v2/frontend/semantic/symbol_generator.cpp new file mode 100644 index 000000000..64e3604b1 --- /dev/null +++ b/src/query/v2/frontend/semantic/symbol_generator.cpp @@ -0,0 +1,625 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +// Copyright 2017 Memgraph +// +// Created by Teon Banek on 24-03-2017 + +#include "query/v2/frontend/semantic/symbol_generator.hpp" + +#include <algorithm> +#include <optional> +#include <ranges> +#include <unordered_set> +#include <variant> + +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/ast_visitor.hpp" +#include "utils/algorithm.hpp" +#include "utils/logging.hpp" + +namespace memgraph::query::v2 { + +namespace { +std::unordered_map<std::string, Identifier *> GeneratePredefinedIdentifierMap( + const std::vector<Identifier *> &predefined_identifiers) { + std::unordered_map<std::string, Identifier *> identifier_map; + for (const auto &identifier : predefined_identifiers) { + identifier_map.emplace(identifier->name_, identifier); + } + + return identifier_map; +} +} // namespace + +SymbolGenerator::SymbolGenerator(SymbolTable *symbol_table, const std::vector<Identifier *> &predefined_identifiers) + : symbol_table_(symbol_table), + predefined_identifiers_{GeneratePredefinedIdentifierMap(predefined_identifiers)}, + scopes_(1, Scope()) {} + +std::optional<Symbol> SymbolGenerator::FindSymbolInScope(const std::string &name, const Scope &scope, + Symbol::Type type) { + if (auto it = scope.symbols.find(name); it != scope.symbols.end()) { + const auto &symbol = it->second; + // Unless we have `ANY` type, check that types match. + if (type != Symbol::Type::ANY && symbol.type() != Symbol::Type::ANY && type != symbol.type()) { + throw TypeMismatchError(name, Symbol::TypeToString(symbol.type()), Symbol::TypeToString(type)); + } + return symbol; + } + return std::nullopt; +} + +auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type, int token_position) { + auto symbol = symbol_table_->CreateSymbol(name, user_declared, type, token_position); + scopes_.back().symbols[name] = symbol; + return symbol; +} + +auto SymbolGenerator::GetOrCreateSymbolLocalScope(const std::string &name, bool user_declared, Symbol::Type type) { + auto &scope = scopes_.back(); + if (auto maybe_symbol = FindSymbolInScope(name, scope, type); maybe_symbol) { + return *maybe_symbol; + } + return CreateSymbol(name, user_declared, type); +} + +auto SymbolGenerator::GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type) { + // NOLINTNEXTLINE + for (auto scope = scopes_.rbegin(); scope != scopes_.rend(); ++scope) { + if (auto maybe_symbol = FindSymbolInScope(name, *scope, type); maybe_symbol) { + return *maybe_symbol; + } + } + return CreateSymbol(name, user_declared, type); +} + +void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) { + auto &scope = scopes_.back(); + for (auto &expr : body.named_expressions) { + expr->Accept(*this); + } + std::vector<Symbol> user_symbols; + if (body.all_identifiers) { + // Carry over user symbols because '*' appeared. + for (const auto &sym_pair : scope.symbols) { + if (!sym_pair.second.user_declared()) { + continue; + } + user_symbols.emplace_back(sym_pair.second); + } + if (user_symbols.empty()) { + throw SemanticException("There are no variables in scope to use for '*'."); + } + } + // WITH/RETURN clause removes declarations of all the previous variables and + // declares only those established through named expressions. New declarations + // must not be visible inside named expressions themselves. + bool removed_old_names = false; + if ((!where && body.order_by.empty()) || scope.has_aggregation) { + // WHERE and ORDER BY need to see both the old and new symbols, unless we + // have an aggregation. Therefore, we can clear the symbols immediately if + // there is neither ORDER BY nor WHERE, or we have an aggregation. + scope.symbols.clear(); + removed_old_names = true; + } + // Create symbols for named expressions. + std::unordered_set<std::string> new_names; + for (const auto &user_sym : user_symbols) { + new_names.insert(user_sym.name()); + scope.symbols[user_sym.name()] = user_sym; + } + for (auto &named_expr : body.named_expressions) { + const auto &name = named_expr->name_; + if (!new_names.insert(name).second) { + throw SemanticException("Multiple results with the same name '{}' are not allowed.", name); + } + // An improvement would be to infer the type of the expression, so that the + // new symbol would have a more specific type. + named_expr->MapTo(CreateSymbol(name, true, Symbol::Type::ANY, named_expr->token_position_)); + } + scope.in_order_by = true; + for (const auto &order_pair : body.order_by) { + order_pair.expression->Accept(*this); + } + scope.in_order_by = false; + if (body.skip) { + scope.in_skip = true; + body.skip->Accept(*this); + scope.in_skip = false; + } + if (body.limit) { + scope.in_limit = true; + body.limit->Accept(*this); + scope.in_limit = false; + } + if (where) where->Accept(*this); + if (!removed_old_names) { + // We have an ORDER BY or WHERE, but no aggregation, which means we didn't + // clear the old symbols, so do it now. We cannot just call clear, because + // we've added new symbols. + for (auto sym_it = scope.symbols.begin(); sym_it != scope.symbols.end();) { + if (new_names.find(sym_it->first) == new_names.end()) { + sym_it = scope.symbols.erase(sym_it); + } else { + sym_it++; + } + } + } + scopes_.back().has_aggregation = false; +} + +// Query + +bool SymbolGenerator::PreVisit(SingleQuery &) { + prev_return_names_ = curr_return_names_; + curr_return_names_.clear(); + return true; +} + +// Union + +bool SymbolGenerator::PreVisit(CypherUnion &) { + scopes_.back() = Scope(); + return true; +} + +bool SymbolGenerator::PostVisit(CypherUnion &cypher_union) { + if (prev_return_names_ != curr_return_names_) { + throw SemanticException("All subqueries in an UNION must have the same column names."); + } + + // create new symbols for the result of the union + for (const auto &name : curr_return_names_) { + auto symbol = CreateSymbol(name, false); + cypher_union.union_symbols_.push_back(symbol); + } + + return true; +} + +// Clauses + +bool SymbolGenerator::PreVisit(Create &) { + scopes_.back().in_create = true; + return true; +} +bool SymbolGenerator::PostVisit(Create &) { + scopes_.back().in_create = false; + return true; +} + +bool SymbolGenerator::PreVisit(CallProcedure &call_proc) { + for (auto *expr : call_proc.arguments_) { + expr->Accept(*this); + } + return false; +} + +bool SymbolGenerator::PostVisit(CallProcedure &call_proc) { + for (auto *ident : call_proc.result_identifiers_) { + if (HasSymbolLocalScope(ident->name_)) { + throw RedeclareVariableError(ident->name_); + } + ident->MapTo(CreateSymbol(ident->name_, true)); + } + return true; +} + +bool SymbolGenerator::PreVisit(LoadCsv &load_csv) { return false; } + +bool SymbolGenerator::PostVisit(LoadCsv &load_csv) { + if (HasSymbolLocalScope(load_csv.row_var_->name_)) { + throw RedeclareVariableError(load_csv.row_var_->name_); + } + load_csv.row_var_->MapTo(CreateSymbol(load_csv.row_var_->name_, true)); + return true; +} + +bool SymbolGenerator::PreVisit(Return &ret) { + auto &scope = scopes_.back(); + scope.in_return = true; + VisitReturnBody(ret.body_); + scope.in_return = false; + return false; // We handled the traversal ourselves. +} + +bool SymbolGenerator::PostVisit(Return &) { + for (const auto &name_symbol : scopes_.back().symbols) curr_return_names_.insert(name_symbol.first); + return true; +} + +bool SymbolGenerator::PreVisit(With &with) { + auto &scope = scopes_.back(); + scope.in_with = true; + VisitReturnBody(with.body_, with.where_); + scope.in_with = false; + return false; // We handled the traversal ourselves. +} + +bool SymbolGenerator::PreVisit(Where &) { + scopes_.back().in_where = true; + return true; +} +bool SymbolGenerator::PostVisit(Where &) { + scopes_.back().in_where = false; + return true; +} + +bool SymbolGenerator::PreVisit(Merge &) { + scopes_.back().in_merge = true; + return true; +} +bool SymbolGenerator::PostVisit(Merge &) { + scopes_.back().in_merge = false; + return true; +} + +bool SymbolGenerator::PostVisit(Unwind &unwind) { + const auto &name = unwind.named_expression_->name_; + if (HasSymbolLocalScope(name)) { + throw RedeclareVariableError(name); + } + unwind.named_expression_->MapTo(CreateSymbol(name, true)); + return true; +} + +bool SymbolGenerator::PreVisit(Match &) { + scopes_.back().in_match = true; + return true; +} +bool SymbolGenerator::PostVisit(Match &) { + auto &scope = scopes_.back(); + scope.in_match = false; + // Check variables in property maps after visiting Match, so that they can + // reference symbols out of bind order. + for (auto &ident : scope.identifiers_in_match) { + if (!HasSymbolLocalScope(ident->name_) && !ConsumePredefinedIdentifier(ident->name_)) + throw UnboundVariableError(ident->name_); + ident->MapTo(scope.symbols[ident->name_]); + } + scope.identifiers_in_match.clear(); + return true; +} + +bool SymbolGenerator::PreVisit(Foreach &for_each) { + const auto &name = for_each.named_expression_->name_; + scopes_.emplace_back(Scope()); + scopes_.back().in_foreach = true; + for_each.named_expression_->MapTo( + CreateSymbol(name, true, Symbol::Type::ANY, for_each.named_expression_->token_position_)); + return true; +} +bool SymbolGenerator::PostVisit([[maybe_unused]] Foreach &for_each) { + scopes_.pop_back(); + return true; +} + +// Expressions + +SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) { + auto &scope = scopes_.back(); + if (scope.in_skip || scope.in_limit) { + throw SemanticException("Variables are not allowed in {}.", scope.in_skip ? "SKIP" : "LIMIT"); + } + Symbol symbol; + if (scope.in_pattern && !(scope.in_node_atom || scope.visiting_edge)) { + // If we are in the pattern, and outside of a node or an edge, the + // identifier is the pattern name. + symbol = GetOrCreateSymbolLocalScope(ident.name_, ident.user_declared_, Symbol::Type::PATH); + } else if (scope.in_pattern && scope.in_pattern_atom_identifier) { + // Patterns used to create nodes and edges cannot redeclare already + // established bindings. Declaration only happens in single node + // patterns and in edge patterns. OpenCypher example, + // `MATCH (n) CREATE (n)` should throw an error that `n` is already + // declared. While `MATCH (n) CREATE (n) -[:R]-> (n)` is allowed, + // since `n` now references the bound node instead of declaring it. + if ((scope.in_create_node || scope.in_create_edge) && HasSymbolLocalScope(ident.name_)) { + throw RedeclareVariableError(ident.name_); + } + auto type = Symbol::Type::VERTEX; + if (scope.visiting_edge) { + // Edge referencing is not allowed (like in Neo4j): + // `MATCH (n) - [r] -> (n) - [r] -> (n) RETURN r` is not allowed. + if (HasSymbolLocalScope(ident.name_)) { + throw RedeclareVariableError(ident.name_); + } + type = scope.visiting_edge->IsVariable() ? Symbol::Type::EDGE_LIST : Symbol::Type::EDGE; + } + symbol = GetOrCreateSymbolLocalScope(ident.name_, ident.user_declared_, type); + } else if (scope.in_pattern && !scope.in_pattern_atom_identifier && scope.in_match) { + if (scope.in_edge_range && scope.visiting_edge->identifier_->name_ == ident.name_) { + // Prevent variable path bounds to reference the identifier which is bound + // by the variable path itself. + throw UnboundVariableError(ident.name_); + } + // Variables in property maps or bounds of variable length path during MATCH + // can reference symbols bound later in the same MATCH. We collect them + // here, so that they can be checked after visiting Match. + scope.identifiers_in_match.emplace_back(&ident); + } else { + // Everything else references a bound symbol. + if (!HasSymbol(ident.name_) && !ConsumePredefinedIdentifier(ident.name_)) throw UnboundVariableError(ident.name_); + symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, Symbol::Type::ANY); + } + ident.MapTo(symbol); + return true; +} + +bool SymbolGenerator::PreVisit(Aggregation &aggr) { + auto &scope = scopes_.back(); + // Check if the aggregation can be used in this context. This check should + // probably move to a separate phase, which checks if the query is well + // formed. + if ((!scope.in_return && !scope.in_with) || scope.in_order_by || scope.in_skip || scope.in_limit || scope.in_where) { + throw SemanticException("Aggregation functions are only allowed in WITH and RETURN."); + } + if (scope.in_aggregation) { + throw SemanticException( + "Using aggregation functions inside aggregation functions is not " + "allowed."); + } + if (scope.num_if_operators) { + // Neo allows aggregations here and produces very interesting behaviors. + // To simplify implementation at this moment we decided to completely + // disallow aggregations inside of the CASE. + // However, in some cases aggregation makes perfect sense, for example: + // CASE count(n) WHEN 10 THEN "YES" ELSE "NO" END. + // TODO: Rethink of allowing aggregations in some parts of the CASE + // construct. + throw SemanticException("Using aggregation functions inside of CASE is not allowed."); + } + // Create a virtual symbol for aggregation result. + // Currently, we only have aggregation operators which return numbers. + auto aggr_name = Aggregation::OpToString(aggr.op_) + std::to_string(aggr.symbol_pos_); + aggr.MapTo(CreateSymbol(aggr_name, false, Symbol::Type::NUMBER)); + scope.in_aggregation = true; + scope.has_aggregation = true; + return true; +} + +bool SymbolGenerator::PostVisit(Aggregation &) { + scopes_.back().in_aggregation = false; + return true; +} + +bool SymbolGenerator::PreVisit(IfOperator &) { + ++scopes_.back().num_if_operators; + return true; +} + +bool SymbolGenerator::PostVisit(IfOperator &) { + --scopes_.back().num_if_operators; + return true; +} + +bool SymbolGenerator::PreVisit(All &all) { + all.list_expression_->Accept(*this); + VisitWithIdentifiers(all.where_->expression_, {all.identifier_}); + return false; +} + +bool SymbolGenerator::PreVisit(Single &single) { + single.list_expression_->Accept(*this); + VisitWithIdentifiers(single.where_->expression_, {single.identifier_}); + return false; +} + +bool SymbolGenerator::PreVisit(Any &any) { + any.list_expression_->Accept(*this); + VisitWithIdentifiers(any.where_->expression_, {any.identifier_}); + return false; +} + +bool SymbolGenerator::PreVisit(None &none) { + none.list_expression_->Accept(*this); + VisitWithIdentifiers(none.where_->expression_, {none.identifier_}); + return false; +} + +bool SymbolGenerator::PreVisit(Reduce &reduce) { + reduce.initializer_->Accept(*this); + reduce.list_->Accept(*this); + VisitWithIdentifiers(reduce.expression_, {reduce.accumulator_, reduce.identifier_}); + return false; +} + +bool SymbolGenerator::PreVisit(Extract &extract) { + extract.list_->Accept(*this); + VisitWithIdentifiers(extract.expression_, {extract.identifier_}); + return false; +} + +// Pattern and its subparts. + +bool SymbolGenerator::PreVisit(Pattern &pattern) { + auto &scope = scopes_.back(); + scope.in_pattern = true; + if ((scope.in_create || scope.in_merge) && pattern.atoms_.size() == 1U) { + MG_ASSERT(utils::IsSubtype(*pattern.atoms_[0], NodeAtom::kType), "Expected a single NodeAtom in Pattern"); + scope.in_create_node = true; + } + return true; +} + +bool SymbolGenerator::PostVisit(Pattern &) { + auto &scope = scopes_.back(); + scope.in_pattern = false; + scope.in_create_node = false; + return true; +} + +bool SymbolGenerator::PreVisit(NodeAtom &node_atom) { + auto &scope = scopes_.back(); + auto check_node_semantic = [&node_atom, &scope, this](const bool props_or_labels) { + const auto &node_name = node_atom.identifier_->name_; + if ((scope.in_create || scope.in_merge) && props_or_labels && HasSymbolLocalScope(node_name)) { + throw SemanticException("Cannot create node '" + node_name + + "' with labels or properties, because it is already declared."); + } + scope.in_pattern_atom_identifier = true; + node_atom.identifier_->Accept(*this); + scope.in_pattern_atom_identifier = false; + }; + + scope.in_node_atom = true; + if (auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&node_atom.properties_)) { + bool props_or_labels = !properties->empty() || !node_atom.labels_.empty(); + + check_node_semantic(props_or_labels); + for (auto kv : *properties) { + kv.second->Accept(*this); + } + + return false; + } + auto &properties_parameter = std::get<ParameterLookup *>(node_atom.properties_); + bool props_or_labels = !properties_parameter || !node_atom.labels_.empty(); + + check_node_semantic(props_or_labels); + properties_parameter->Accept(*this); + return false; +} + +bool SymbolGenerator::PostVisit(NodeAtom &) { + scopes_.back().in_node_atom = false; + return true; +} + +bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) { + auto &scope = scopes_.back(); + scope.visiting_edge = &edge_atom; + if (scope.in_create || scope.in_merge) { + scope.in_create_edge = true; + if (edge_atom.edge_types_.size() != 1U) { + throw SemanticException( + "A single relationship type must be specified " + "when creating an edge."); + } + if (scope.in_create && // Merge allows bidirectionality + edge_atom.direction_ == EdgeAtom::Direction::BOTH) { + throw SemanticException( + "Bidirectional relationship are not supported " + "when creating an edge"); + } + if (edge_atom.IsVariable()) { + throw SemanticException( + "Variable length relationships are not supported when creating an " + "edge."); + } + } + if (auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&edge_atom.properties_)) { + for (auto kv : *properties) { + kv.second->Accept(*this); + } + } else { + std::get<ParameterLookup *>(edge_atom.properties_)->Accept(*this); + } + if (edge_atom.IsVariable()) { + scope.in_edge_range = true; + if (edge_atom.lower_bound_) { + edge_atom.lower_bound_->Accept(*this); + } + if (edge_atom.upper_bound_) { + edge_atom.upper_bound_->Accept(*this); + } + scope.in_edge_range = false; + scope.in_pattern = false; + if (edge_atom.filter_lambda_.expression) { + VisitWithIdentifiers(edge_atom.filter_lambda_.expression, + {edge_atom.filter_lambda_.inner_edge, edge_atom.filter_lambda_.inner_node}); + } else { + // Create inner symbols, but don't bind them in scope, since they are to + // be used in the missing filter expression. + auto *inner_edge = edge_atom.filter_lambda_.inner_edge; + inner_edge->MapTo(symbol_table_->CreateSymbol(inner_edge->name_, inner_edge->user_declared_, Symbol::Type::EDGE)); + auto *inner_node = edge_atom.filter_lambda_.inner_node; + inner_node->MapTo( + symbol_table_->CreateSymbol(inner_node->name_, inner_node->user_declared_, Symbol::Type::VERTEX)); + } + if (edge_atom.weight_lambda_.expression) { + VisitWithIdentifiers(edge_atom.weight_lambda_.expression, + {edge_atom.weight_lambda_.inner_edge, edge_atom.weight_lambda_.inner_node}); + } + scope.in_pattern = true; + } + scope.in_pattern_atom_identifier = true; + edge_atom.identifier_->Accept(*this); + scope.in_pattern_atom_identifier = false; + if (edge_atom.total_weight_) { + if (HasSymbolLocalScope(edge_atom.total_weight_->name_)) { + throw RedeclareVariableError(edge_atom.total_weight_->name_); + } + edge_atom.total_weight_->MapTo(GetOrCreateSymbolLocalScope( + edge_atom.total_weight_->name_, edge_atom.total_weight_->user_declared_, Symbol::Type::NUMBER)); + } + return false; +} + +bool SymbolGenerator::PostVisit(EdgeAtom &) { + auto &scope = scopes_.back(); + scope.visiting_edge = nullptr; + scope.in_create_edge = false; + return true; +} + +void SymbolGenerator::VisitWithIdentifiers(Expression *expr, const std::vector<Identifier *> &identifiers) { + auto &scope = scopes_.back(); + std::vector<std::pair<std::optional<Symbol>, Identifier *>> prev_symbols; + // Collect previous symbols if they exist. + for (const auto &identifier : identifiers) { + std::optional<Symbol> prev_symbol; + auto prev_symbol_it = scope.symbols.find(identifier->name_); + if (prev_symbol_it != scope.symbols.end()) { + prev_symbol = prev_symbol_it->second; + } + identifier->MapTo(CreateSymbol(identifier->name_, identifier->user_declared_)); + prev_symbols.emplace_back(prev_symbol, identifier); + } + // Visit the expression with the new symbols bound. + expr->Accept(*this); + // Restore back to previous symbols. + for (const auto &prev : prev_symbols) { + const auto &prev_symbol = prev.first; + const auto &identifier = prev.second; + if (prev_symbol) { + scope.symbols[identifier->name_] = *prev_symbol; + } else { + scope.symbols.erase(identifier->name_); + } + } +} + +bool SymbolGenerator::HasSymbol(const std::string &name) const { + return std::ranges::any_of(scopes_, [&name](const auto &scope) { return scope.symbols.contains(name); }); +} + +bool SymbolGenerator::HasSymbolLocalScope(const std::string &name) const { + return scopes_.back().symbols.contains(name); +} + +bool SymbolGenerator::ConsumePredefinedIdentifier(const std::string &name) { + auto it = predefined_identifiers_.find(name); + + if (it == predefined_identifiers_.end()) { + return false; + } + + // we can only use the predefined identifier in a single scope so we remove it after creating + // a symbol for it + auto &identifier = it->second; + MG_ASSERT(!identifier->user_declared_, "Predefined symbols cannot be user declared!"); + identifier->MapTo(CreateSymbol(identifier->name_, identifier->user_declared_)); + predefined_identifiers_.erase(it); + return true; +} + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/semantic/symbol_generator.hpp b/src/query/v2/frontend/semantic/symbol_generator.hpp new file mode 100644 index 000000000..991717bdf --- /dev/null +++ b/src/query/v2/frontend/semantic/symbol_generator.hpp @@ -0,0 +1,176 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +// Copyright 2017 Memgraph +// +// Created by Teon Banek on 11-03-2017 + +#pragma once + +#include <optional> +#include <vector> + +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/semantic/symbol_table.hpp" + +namespace memgraph::query::v2 { + +/// Visits the AST and generates symbols for variables. +/// +/// During the process of symbol generation, simple semantic checks are +/// performed. Such as, redeclaring a variable or conflicting expectations of +/// variable types. +class SymbolGenerator : public HierarchicalTreeVisitor { + public: + explicit SymbolGenerator(SymbolTable *symbol_table, const std::vector<Identifier *> &predefined_identifiers); + + using HierarchicalTreeVisitor::PostVisit; + using HierarchicalTreeVisitor::PreVisit; + using HierarchicalTreeVisitor::Visit; + using typename HierarchicalTreeVisitor::ReturnType; + + // Query + bool PreVisit(SingleQuery &) override; + + // Union + bool PreVisit(CypherUnion &) override; + bool PostVisit(CypherUnion &) override; + + // Clauses + bool PreVisit(Create &) override; + bool PostVisit(Create &) override; + bool PreVisit(CallProcedure &) override; + bool PostVisit(CallProcedure &) override; + bool PreVisit(LoadCsv &) override; + bool PostVisit(LoadCsv &) override; + bool PreVisit(Return &) override; + bool PostVisit(Return &) override; + bool PreVisit(With &) override; + bool PreVisit(Where &) override; + bool PostVisit(Where &) override; + bool PreVisit(Merge &) override; + bool PostVisit(Merge &) override; + bool PostVisit(Unwind &) override; + bool PreVisit(Match &) override; + bool PostVisit(Match &) override; + bool PreVisit(Foreach &) override; + bool PostVisit(Foreach &) override; + + // Expressions + ReturnType Visit(Identifier &) override; + ReturnType Visit(PrimitiveLiteral &) override { return true; } + ReturnType Visit(ParameterLookup &) override { return true; } + bool PreVisit(Aggregation &) override; + bool PostVisit(Aggregation &) override; + bool PreVisit(IfOperator &) override; + bool PostVisit(IfOperator &) override; + bool PreVisit(All &) override; + bool PreVisit(Single &) override; + bool PreVisit(Any &) override; + bool PreVisit(None &) override; + bool PreVisit(Reduce &) override; + bool PreVisit(Extract &) override; + + // Pattern and its subparts. + bool PreVisit(Pattern &) override; + bool PostVisit(Pattern &) override; + bool PreVisit(NodeAtom &) override; + bool PostVisit(NodeAtom &) override; + bool PreVisit(EdgeAtom &) override; + bool PostVisit(EdgeAtom &) override; + + private: + // Scope stores the state of where we are when visiting the AST and a map of + // names to symbols. + struct Scope { + bool in_pattern{false}; + bool in_merge{false}; + bool in_create{false}; + // in_create_node is true if we are creating or merging *only* a node. + // Therefore, it is *not* equivalent to (in_create || in_merge) && + // in_node_atom. + bool in_create_node{false}; + // True if creating an edge; + // shortcut for (in_create || in_merge) && visiting_edge. + bool in_create_edge{false}; + bool in_node_atom{false}; + EdgeAtom *visiting_edge{nullptr}; + bool in_aggregation{false}; + bool in_return{false}; + bool in_with{false}; + bool in_skip{false}; + bool in_limit{false}; + bool in_order_by{false}; + bool in_where{false}; + bool in_match{false}; + bool in_foreach{false}; + // True when visiting a pattern atom (node or edge) identifier, which can be + // reused or created in the pattern itself. + bool in_pattern_atom_identifier{false}; + // True when visiting range bounds of a variable path. + bool in_edge_range{false}; + // True if the return/with contains an aggregation in any named expression. + bool has_aggregation{false}; + // Map from variable names to symbols. + std::map<std::string, Symbol> symbols; + // Identifiers found in property maps of patterns or as variable length path + // bounds in a single Match clause. They need to be checked after visiting + // Match. Identifiers created by naming vertices, edges and paths are *not* + // stored in here. + std::vector<Identifier *> identifiers_in_match; + // Number of nested IfOperators. + int num_if_operators{0}; + }; + + static std::optional<Symbol> FindSymbolInScope(const std::string &name, const Scope &scope, Symbol::Type type); + + bool HasSymbol(const std::string &name) const; + bool HasSymbolLocalScope(const std::string &name) const; + + // @return true if it added a predefined identifier with that name + bool ConsumePredefinedIdentifier(const std::string &name); + + // Returns a freshly generated symbol. Previous mapping of the same name to a + // different symbol is replaced with the new one. + auto CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY, + int token_position = -1); + + auto GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY); + // Returns the symbol by name. If the mapping already exists, checks if the + // types match. Otherwise, returns a new symbol. + auto GetOrCreateSymbolLocalScope(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY); + + void VisitReturnBody(ReturnBody &body, Where *where = nullptr); + + void VisitWithIdentifiers(Expression *, const std::vector<Identifier *> &); + + SymbolTable *symbol_table_; + + // Identifiers which are injected from outside the query. Each identifier + // is mapped by its name. + std::unordered_map<std::string, Identifier *> predefined_identifiers_; + std::vector<Scope> scopes_; + std::unordered_set<std::string> prev_return_names_; + std::unordered_set<std::string> curr_return_names_; +}; + +inline SymbolTable MakeSymbolTable(CypherQuery *query, const std::vector<Identifier *> &predefined_identifiers = {}) { + SymbolTable symbol_table; + SymbolGenerator symbol_generator(&symbol_table, predefined_identifiers); + query->single_query_->Accept(symbol_generator); + for (auto *cypher_union : query->cypher_unions_) { + cypher_union->Accept(symbol_generator); + } + return symbol_table; +} + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/semantic/symbol_table.hpp b/src/query/v2/frontend/semantic/symbol_table.hpp new file mode 100644 index 000000000..a4ccf7e76 --- /dev/null +++ b/src/query/v2/frontend/semantic/symbol_table.hpp @@ -0,0 +1,64 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <map> +#include <string> + +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/semantic/symbol.hpp" +#include "utils/logging.hpp" + +namespace memgraph::query::v2 { + +class SymbolTable final { + public: + SymbolTable() {} + const Symbol &CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY, + int32_t token_position = -1) { + MG_ASSERT(table_.size() <= std::numeric_limits<int32_t>::max(), + "SymbolTable size doesn't fit into 32-bit integer!"); + auto got = table_.emplace(position_, Symbol(name, position_, user_declared, type, token_position)); + MG_ASSERT(got.second, "Duplicate symbol ID!"); + position_++; + return got.first->second; + } + + // TODO(buda): This is the same logic as in the cypher_main_visitor. During + // parsing phase symbol table doesn't exist. Figure out a better solution. + const Symbol &CreateAnonymousSymbol(Symbol::Type type = Symbol::Type::ANY) { + int id = 1; + while (true) { + static const std::string &kAnonPrefix = "anon"; + std::string name_candidate = kAnonPrefix + std::to_string(id++); + if (std::find_if(std::begin(table_), std::end(table_), [&name_candidate](const auto &item) -> bool { + return item.second.name_ == name_candidate; + }) == std::end(table_)) { + return CreateSymbol(name_candidate, false, type); + } + } + } + + const Symbol &at(const Identifier &ident) const { return table_.at(ident.symbol_pos_); } + const Symbol &at(const NamedExpression &nexpr) const { return table_.at(nexpr.symbol_pos_); } + const Symbol &at(const Aggregation &aggr) const { return table_.at(aggr.symbol_pos_); } + + // TODO: Remove these since members are public + int32_t max_position() const { return static_cast<int32_t>(table_.size()); } + + const auto &table() const { return table_; } + + int32_t position_{0}; + std::map<int32_t, Symbol> table_; +}; + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/stripped.cpp b/src/query/v2/frontend/stripped.cpp new file mode 100644 index 000000000..3d50d57d2 --- /dev/null +++ b/src/query/v2/frontend/stripped.cpp @@ -0,0 +1,535 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/frontend/stripped.hpp" + +#include <cctype> +#include <cstdint> +#include <iostream> +#include <span> +#include <string> +#include <vector> + +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/opencypher/generated/MemgraphCypher.h" +#include "query/v2/frontend/opencypher/generated/MemgraphCypherBaseVisitor.h" +#include "query/v2/frontend/opencypher/generated/MemgraphCypherLexer.h" +#include "query/v2/frontend/parsing.hpp" +#include "query/v2/frontend/stripped_lexer_constants.hpp" +#include "utils/fnv.hpp" +#include "utils/logging.hpp" +#include "utils/string.hpp" + +namespace memgraph::query::v2::frontend { + +using namespace lexer_constants; + +StrippedQuery::StrippedQuery(const std::string &query) : original_(query) { + enum class Token { + UNMATCHED, + KEYWORD, // Including true, false and null. + SPECIAL, // +, .., +=, (, { and so on. + STRING, + INT, // Decimal, octal and hexadecimal. + REAL, + PARAMETER, + ESCAPED_NAME, + UNESCAPED_NAME, + SPACE + }; + + std::vector<std::pair<Token, std::string>> tokens; + std::string unstripped_chunk; + for (int i = 0; i < static_cast<int>(original_.size());) { + Token token = Token::UNMATCHED; + int len = 0; + auto update = [&](int new_len, Token new_token) { + if (new_len > len) { + len = new_len; + token = new_token; + } + }; + update(MatchKeyword(i), Token::KEYWORD); + update(MatchSpecial(i), Token::SPECIAL); + update(MatchString(i), Token::STRING); + update(MatchDecimalInt(i), Token::INT); + update(MatchOctalInt(i), Token::INT); + update(MatchHexadecimalInt(i), Token::INT); + update(MatchReal(i), Token::REAL); + update(MatchParameter(i), Token::PARAMETER); + update(MatchEscapedName(i), Token::ESCAPED_NAME); + update(MatchUnescapedName(i), Token::UNESCAPED_NAME); + update(MatchWhitespaceAndComments(i), Token::SPACE); + if (token == Token::UNMATCHED) throw LexingException("Invalid query."); + tokens.emplace_back(token, original_.substr(i, len)); + i += len; + + // If we notice execute, we possibly create a trigger which has defined statements. + // The statements will be parsed separately later on so we skip it for now. + if (utils::IEquals(tokens.back().second, "execute")) { + // check if it's CREATE TRIGGER query + std::span token_span{tokens}; + + // query could start with spaces and/or comments + if (token_span.front().first == Token::SPACE) { + token_span = token_span.subspan(1); + } + + // we need to check that first and third elements are correct keywords + // CREATE<SPACE>TRIGGER<SPACE>trigger-name...EXECUTE + // trigger-name (5th element) can also be "execute" so we verify that the size is larger than 5 + if (token_span.size() > 5 && utils::IEquals(token_span[0].second, "create") && + utils::IEquals(token_span[2].second, "trigger")) { + unstripped_chunk = original_.substr(i); + break; + } + } + } + + std::vector<std::string> token_strings; + // A helper function that stores literal and its token position in a + // literals_. In stripped query text literal is replaced with a new_value. + // new_value can be any value that is lexed as a literal. + auto replace_stripped = [this, &token_strings](int position, const auto &value, const std::string &new_value) { + literals_.Add(position, storage::v3::PropertyValue(value)); + token_strings.push_back(new_value); + }; + + // Copy original tokens because we need to use original case in named + // expressions and keywords in tokens will be lowercased in the next loop. + auto original_tokens = tokens; + // For every token in original query remember token index in stripped query. + std::vector<int> position_mapping(tokens.size(), -1); + + // Convert tokens to strings, perform filtering, store literals and nonaliased + // named expressions in return. + for (int i = 0; i < static_cast<int>(tokens.size()); ++i) { + auto &token = tokens[i]; + + // We need to shift token index for every parameter since antlr's parser + // thinks of parameter as two tokens. + int token_index = token_strings.size() + parameters_.size(); + switch (token.first) { + case Token::UNMATCHED: + LOG_FATAL("Shouldn't happen"); + case Token::KEYWORD: { + // We don't strip NULL, since it can appear in special expressions + // like IS NULL and IS NOT NULL, but we strip true and false keywords. + if (utils::IEquals(token.second, "true")) { + replace_stripped(token_index, true, kStrippedBooleanToken); + } else if (utils::IEquals(token.second, "false")) { + replace_stripped(token_index, false, kStrippedBooleanToken); + } else { + token_strings.push_back(token.second); + } + } break; + case Token::SPACE: + break; + case Token::STRING: + replace_stripped(token_index, ParseStringLiteral(token.second), kStrippedStringToken); + break; + case Token::INT: + replace_stripped(token_index, ParseIntegerLiteral(token.second), kStrippedIntToken); + break; + case Token::REAL: + replace_stripped(token_index, ParseDoubleLiteral(token.second), kStrippedDoubleToken); + break; + case Token::SPECIAL: + case Token::ESCAPED_NAME: + case Token::UNESCAPED_NAME: + token_strings.push_back(token.second); + break; + case Token::PARAMETER: + parameters_[token_index] = ParseParameter(token.second); + token_strings.push_back(token.second); + break; + } + + if (token.first != Token::SPACE) { + position_mapping[i] = token_index; + } + } + + if (!unstripped_chunk.empty()) { + token_strings.push_back(std::move(unstripped_chunk)); + } + + query_ = utils::Join(token_strings, " "); + hash_ = utils::Fnv(query_); + + auto it = tokens.begin(); + while (it != tokens.end()) { + // Store nonaliased named expressions in returns in named_exprs_. + it = std::find_if(it, tokens.end(), + [](const std::pair<Token, std::string> &a) { return utils::IEquals(a.second, "return"); }); + // There is no RETURN so there is nothing to do here. + if (it == tokens.end()) return; + // Skip RETURN; + ++it; + + // Now we need to parse cypherReturn production from opencypher grammar. + // Skip leading whitespaces and DISTINCT statemant if there is one. + while (it != tokens.end() && it->first == Token::SPACE) { + ++it; + } + if (it != tokens.end() && utils::IEquals(it->second, "distinct")) { + ++it; + } + + // If the query is invalid, either antlr parser or cypher_main_visitor will + // report an error. + // TODO: we shouldn't rely on the fact that those checks will be done + // after this step. We should do them here. + while (it < tokens.end()) { + // Disregard leading whitespace + while (it != tokens.end() && it->first == Token::SPACE) { + ++it; + } + // There is only whitespace, nothing to do... + if (it == tokens.end()) break; + + bool has_as = false; + auto last_non_space = it; + auto jt = it; + // We should track number of opened braces and parantheses so that we can + // recognize if comma is a named expression separator or part of the + // list literal / function call. + int num_open_braces = 0; + int num_open_parantheses = 0; + int num_open_brackets = 0; + for (; + jt != tokens.end() && (jt->second != "," || num_open_braces || num_open_parantheses || num_open_brackets) && + !utils::IEquals(jt->second, "order") && !utils::IEquals(jt->second, "skip") && + !utils::IEquals(jt->second, "limit") && !utils::IEquals(jt->second, "union") && + !utils::IEquals(jt->second, "query") && jt->second != ";"; + ++jt) { + if (jt->second == "(") { + ++num_open_parantheses; + } else if (jt->second == ")") { + --num_open_parantheses; + } else if (jt->second == "[") { + ++num_open_braces; + } else if (jt->second == "]") { + --num_open_braces; + } else if (jt->second == "{") { + ++num_open_brackets; + } else if (jt->second == "}") { + --num_open_brackets; + } + has_as |= utils::IEquals(jt->second, "as"); + if (jt->first != Token::SPACE) { + last_non_space = jt; + } + } + if (!has_as) { + // Named expression is not aliased. Save string disregarding leading and + // trailing whitespaces. + std::string s; + auto begin_token = it - tokens.begin() + original_tokens.begin(); + auto end_token = last_non_space - tokens.begin() + original_tokens.begin() + 1; + for (auto kt = begin_token; kt != end_token; ++kt) { + s += kt->second; + } + named_exprs_[position_mapping[it - tokens.begin()]] = s; + } + if (jt != tokens.end() && jt->second == ",") { + // There are more named expressions. + it = jt + 1; + } else { + // We're done with this return statement + break; + } + } + } +} + +std::string GetFirstUtf8Symbol(const char *_s) { + // According to + // https://stackoverflow.com/questions/16260033/reinterpret-cast-between-char-and-stduint8-t-safe + // this checks if casting from const char * to uint8_t is undefined behaviour. + static_assert(std::is_same<std::uint8_t, unsigned char>::value, + "This library requires std::uint8_t to be implemented as " + "unsigned char."); + const uint8_t *s = reinterpret_cast<const uint8_t *>(_s); + if ((*s >> 7) == 0x00) return std::string(_s, _s + 1); + if ((*s >> 5) == 0x06) { + auto *s1 = s + 1; + if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character."); + return std::string(_s, _s + 2); + } + if ((*s >> 4) == 0x0e) { + auto *s1 = s + 1; + if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character."); + auto *s2 = s + 2; + if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character."); + return std::string(_s, _s + 3); + } + if ((*s >> 3) == 0x1e) { + auto *s1 = s + 1; + if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character."); + auto *s2 = s + 2; + if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character."); + auto *s3 = s + 3; + if ((*s3 >> 6) != 0x02) throw LexingException("Invalid character."); + return std::string(_s, _s + 4); + } + throw LexingException("Invalid character."); +} + +// Return codepoint of first utf8 symbol and its encoded length. +std::pair<int, int> GetFirstUtf8SymbolCodepoint(const char *_s) { + static_assert(std::is_same<std::uint8_t, unsigned char>::value, + "This library requires std::uint8_t to be implemented as " + "unsigned char."); + const uint8_t *s = reinterpret_cast<const uint8_t *>(_s); + if ((*s >> 7) == 0x00) return {*s & 0x7f, 1}; + if ((*s >> 5) == 0x06) { + auto *s1 = s + 1; + if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character."); + return {((*s & 0x1f) << 6) | (*s1 & 0x3f), 2}; + } + if ((*s >> 4) == 0x0e) { + auto *s1 = s + 1; + if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character."); + auto *s2 = s + 2; + if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character."); + return {((*s & 0x0f) << 12) | ((*s1 & 0x3f) << 6) | (*s2 & 0x3f), 3}; + } + if ((*s >> 3) == 0x1e) { + auto *s1 = s + 1; + if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character."); + auto *s2 = s + 2; + if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character."); + auto *s3 = s + 3; + if ((*s3 >> 6) != 0x02) throw LexingException("Invalid character."); + return {((*s & 0x07) << 18) | ((*s1 & 0x3f) << 12) | ((*s2 & 0x3f) << 6) | (*s3 & 0x3f), 4}; + } + throw LexingException("Invalid character."); +} + +// From here until end of file there are functions that calculate matches for +// every possible token. Functions are more or less compatible with Cypher.g4 +// grammar. Unfortunately, they contain a lof of special cases and shouldn't +// be changed without good reasons. +// +// Here be dragons, do not touch! +// ____ __ +// { --.\ | .)%%%)%% +// '-._\\ | (\___ %)%%(%%(%%% +// `\\|{/ ^ _)-%(%%%%)%%;%%% +// .'^^^^^^^ /` %%)%%%%)%%%' +// //\ ) , / '%%%%(%%' +// , _.'/ `\<-- \< +// `^^^` ^^ ^^ +int StrippedQuery::MatchKeyword(int start) const { return kKeywords.Match<tolower>(original_.c_str() + start); } + +int StrippedQuery::MatchSpecial(int start) const { return kSpecialTokens.Match(original_.c_str() + start); } + +int StrippedQuery::MatchString(int start) const { + if (original_[start] != '"' && original_[start] != '\'') return 0; + char start_char = original_[start]; + for (auto *p = original_.data() + start + 1; *p; ++p) { + if (*p == start_char) return p - (original_.data() + start) + 1; + if (*p == '\\') { + ++p; + if (*p == '\\' || *p == '\'' || *p == '"' || *p == 'B' || *p == 'b' || *p == 'F' || *p == 'f' || *p == 'N' || + *p == 'n' || *p == 'R' || *p == 'r' || *p == 'T' || *p == 't') { + // Allowed escaped characters. + continue; + } else if (*p == 'U' || *p == 'u') { + int cnt = 0; + auto *r = p + 1; + while (isxdigit(*r) && cnt < 8) { + ++cnt; + ++r; + } + if (!*r) return 0; + if (cnt < 4) return 0; + if (cnt >= 4 && cnt < 8) { + p += 4; + } + if (cnt >= 8) { + p += 8; + } + } else { + return 0; + } + } + } + return 0; +} + +int StrippedQuery::MatchDecimalInt(int start) const { + if (original_[start] == '0') return 1; + int i = start; + while (i < static_cast<int>(original_.size()) && isdigit(original_[i])) { + ++i; + } + return i - start; +} + +int StrippedQuery::MatchOctalInt(int start) const { + if (original_[start] != '0') return 0; + int i = start + 1; + while (i < static_cast<int>(original_.size()) && '0' <= original_[i] && original_[i] <= '7') { + ++i; + } + if (i == start + 1) return 0; + return i - start; +} + +int StrippedQuery::MatchHexadecimalInt(int start) const { + if (original_[start] != '0') return 0; + if (start + 1 >= static_cast<int>(original_.size())) return 0; + if (original_[start + 1] != 'x') return 0; + int i = start + 2; + while (i < static_cast<int>(original_.size()) && isxdigit(original_[i])) { + ++i; + } + if (i == start + 2) return 0; + return i - start; +} + +int StrippedQuery::MatchReal(int start) const { + enum class State { START, BEFORE_DOT, DOT, AFTER_DOT, E, E_MINUS, AFTER_E }; + State state = State::START; + auto i = start; + while (i < static_cast<int>(original_.size())) { + if (original_[i] == '.') { + if (state != State::BEFORE_DOT && state != State::START) break; + state = State::DOT; + } else if ('0' <= original_[i] && original_[i] <= '9') { + if (state == State::START) { + state = State::BEFORE_DOT; + } else if (state == State::DOT) { + state = State::AFTER_DOT; + } else if (state == State::E || state == State::E_MINUS) { + state = State::AFTER_E; + } + } else if (original_[i] == 'e' || original_[i] == 'E') { + if (state != State::BEFORE_DOT && state != State::AFTER_DOT) break; + state = State::E; + } else if (original_[i] == '-') { + if (state != State::E) break; + state = State::E_MINUS; + } else { + break; + } + ++i; + } + if (state == State::DOT) --i; + if (state == State::E) --i; + if (state == State::E_MINUS) i -= 2; + return i - start; +} + +int StrippedQuery::MatchParameter(int start) const { + int len = original_.size(); + if (start + 1 == len) return 0; + if (original_[start] != '$') return 0; + int max_len = 0; + max_len = std::max(max_len, MatchUnescapedName(start + 1)); + max_len = std::max(max_len, MatchEscapedName(start + 1)); + max_len = std::max(max_len, MatchKeyword(start + 1)); + max_len = std::max(max_len, MatchDecimalInt(start + 1)); + if (max_len == 0) return 0; + return 1 + max_len; +} + +int StrippedQuery::MatchEscapedName(int start) const { + int len = original_.size(); + int i = start; + while (i < len) { + if (original_[i] != '`') break; + int j = i + 1; + while (j < len && original_[j] != '`') { + ++j; + } + if (j == len) break; + i = j + 1; + } + return i - start; +} + +int StrippedQuery::MatchUnescapedName(int start) const { + auto i = start; + auto got = GetFirstUtf8SymbolCodepoint(original_.data() + i); + if (got.first >= lexer_constants::kBitsetSize || !kUnescapedNameAllowedStarts[got.first]) { + return 0; + } + i += got.second; + while (i < static_cast<int>(original_.size())) { + got = GetFirstUtf8SymbolCodepoint(original_.data() + i); + if (got.first >= lexer_constants::kBitsetSize || !kUnescapedNameAllowedParts[got.first]) { + break; + } + i += got.second; + } + return i - start; +} + +int StrippedQuery::MatchWhitespaceAndComments(int start) const { + enum class State { OUT, IN_LINE_COMMENT, IN_BLOCK_COMMENT }; + State state = State::OUT; + int i = start; + int len = original_.size(); + // We need to remember at which position comment started because if we fail + // to match comment finish we have a match until comment start position. + int comment_position = -1; + while (i < len) { + if (state == State::OUT) { + auto got = GetFirstUtf8SymbolCodepoint(original_.data() + i); + if (got.first < lexer_constants::kBitsetSize && kSpaceParts[got.first]) { + i += got.second; + } else if (i + 1 < len && original_[i] == '/' && original_[i + 1] == '*') { + comment_position = i; + state = State::IN_BLOCK_COMMENT; + i += 2; + } else if (i + 1 < len && original_[i] == '/' && original_[i + 1] == '/') { + comment_position = i; + if (i + 2 < len) { + // Special case for an empty line comment starting right at the end of + // the query. + state = State::IN_LINE_COMMENT; + } + i += 2; + } else { + break; + } + } else if (state == State::IN_LINE_COMMENT) { + if (original_[i] == '\n') { + state = State::OUT; + ++i; + } else if (i + 1 < len && original_[i] == '\r' && original_[i + 1] == '\n') { + state = State::OUT; + i += 2; + } else if (original_[i] == '\r') { + break; + } else if (i + 1 == len) { + state = State::OUT; + ++i; + } else { + ++i; + } + } else if (state == State::IN_BLOCK_COMMENT) { + if (i + 1 < len && original_[i] == '*' && original_[i + 1] == '/') { + i += 2; + state = State::OUT; + } else { + ++i; + } + } + } + if (state != State::OUT) return comment_position - start; + return i - start; +} + +} // namespace memgraph::query::v2::frontend diff --git a/src/query/v2/frontend/stripped.hpp b/src/query/v2/frontend/stripped.hpp new file mode 100644 index 000000000..e70a9c671 --- /dev/null +++ b/src/query/v2/frontend/stripped.hpp @@ -0,0 +1,103 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <string> +#include <unordered_map> + +#include "query/v2/parameters.hpp" +#include "utils/fnv.hpp" + +namespace memgraph::query::v2::frontend { + +// Strings used to replace original tokens. Different types are replaced with +// different token. +const std::string kStrippedIntToken = "0"; +const std::string kStrippedDoubleToken = "0.0"; +const std::string kStrippedStringToken = "\"a\""; +const std::string kStrippedBooleanToken = "true"; + +/** + * StrippedQuery contains: + * * stripped query + * * literals stripped from query + * * hash of stripped query + */ +class StrippedQuery { + public: + /** + * Strips the input query and stores stripped query, stripped arguments and + * stripped query hash. + * + * @param query Input query. + */ + explicit StrippedQuery(const std::string &query); + + /** + * Copy constructor is deleted because we don't want to make unnecessary + * copies of this object (copying of string and vector could be expensive) + */ + StrippedQuery(const StrippedQuery &other) = delete; + StrippedQuery &operator=(const StrippedQuery &other) = delete; + + /** + * Move is allowed operation because it is not expensive and we can + * move the object after it was created. + */ + StrippedQuery(StrippedQuery &&other) = default; + StrippedQuery &operator=(StrippedQuery &&other) = default; + + const std::string &query() const { return query_; } + const auto &original_query() const { return original_; } + const auto &literals() const { return literals_; } + const auto &named_expressions() const { return named_exprs_; } + const auto ¶meters() const { return parameters_; } + uint64_t hash() const { return hash_; } + + private: + // Return len of matched keyword if something is matched, otherwise 0. + int MatchKeyword(int start) const; + int MatchString(int start) const; + int MatchSpecial(int start) const; + int MatchDecimalInt(int start) const; + int MatchOctalInt(int start) const; + int MatchHexadecimalInt(int start) const; + int MatchReal(int start) const; + int MatchParameter(int start) const; + int MatchEscapedName(int start) const; + int MatchUnescapedName(int start) const; + int MatchWhitespaceAndComments(int start) const; + + // Original query. + std::string original_; + + // Stripped query. + std::string query_; + + // Token positions of stripped out literals mapped to their values. + // TODO: Parameters class really doesn't provide anything interesting. This + // could be changed to std::unordered_map, but first we need to rewrite (or + // get rid of) hardcoded queries which expect Parameters. + Parameters literals_; + + // Token positions of query parameters mapped to their names. + std::unordered_map<int, std::string> parameters_; + + // Token positions of nonaliased named expressions in return statement mapped + // to their original (unstripped) string. + std::unordered_map<int, std::string> named_exprs_; + + // Hash based on the stripped query. + uint64_t hash_; +}; + +} // namespace memgraph::query::v2::frontend diff --git a/src/query/v2/frontend/stripped_lexer_constants.hpp b/src/query/v2/frontend/stripped_lexer_constants.hpp new file mode 100644 index 000000000..4e52fbdc4 --- /dev/null +++ b/src/query/v2/frontend/stripped_lexer_constants.hpp @@ -0,0 +1,2924 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <bitset> +#include <initializer_list> +#include <string> +#include <unordered_set> +#include <vector> + +namespace memgraph::query::v2 { +namespace lexer_constants { + +namespace trie { + +// Trie data structure implemented to be used in StrippedQuery. If you want to +// change it please rerun benchmark/stripped to be sure that performance of +// StrippedQuery is not degraded by the change. Also there are no tests that +// directly test this class, but there are tests that test StrippedQuery. + +namespace detail { +inline int Noop(int x) { return x; } +}; // namespace detail + +class Trie { + public: + Trie() {} + Trie(std::initializer_list<std::string> l) { + for (const auto &s : l) { + Insert(s); + } + } + + void Insert(const std::string &s) { + int node_id = kRootIndex; + for (const auto &_c : s) { + const unsigned char &c = reinterpret_cast<const unsigned char &>(_c); + int &next_node_id = nodes_[node_id].next[c]; + if (next_node_id == 0) { + next_node_id = nodes_.size(); + // First assign then emplace_back because after emplace_back reference + // could be invalid. + node_id = next_node_id; + nodes_.emplace_back(); + } else { + node_id = next_node_id; + } + } + nodes_[node_id].finish = true; + } + + template <int (*Map)(int c) = detail::Noop> + int Match(const char *s) const { + int node_id = kRootIndex; + int longest_found_len = 0; + int i = 1; + for (const char *p = s; *p; ++p, ++i) { + const unsigned char &c = reinterpret_cast<const unsigned char &>(*p); + node_id = nodes_[node_id].next[Map(c)]; + if (node_id == 0) break; + if (nodes_[node_id].finish) { + longest_found_len = i; + } + } + return longest_found_len; + } + + private: + struct Node { + int next[1 << (sizeof(unsigned char) * 8)] = {}; + bool finish = false; + }; + + const static int kRootIndex = 0; + std::vector<Node> nodes_{1}; +}; +} // namespace trie + +// All word constants should be lowercase in this file. + +const int kBitsetSize = 65536; + +const trie::Trie kKeywords = {"union", + "all", + "optional", + "match", + "unwind", + "as", + "merge", + "on", + "create", + "set", + "detach", + "delete", + "remove", + "with", + "distinct", + "return", + "order", + "by", + "skip", + "limit", + "ascending", + "asc", + "descending", + "desc", + "where", + "or", + "xor", + "and", + "not", + "in", + "starts", + "ends", + "contains", + "is", + "null", + "case", + "when", + "then", + "else", + "end", + "count", + "filter", + "extract", + "any", + "none", + "single", + "true", + "false", + "reduce", + "coalesce", + "user", + "password", + "alter", + "drop", + "show", + "stats", + "unique", + "explain", + "profile", + "storage", + "index", + "info", + "exists", + "assert", + "constraint", + "node", + "key", + "dump", + "database", + "call", + "yield", + "memory", + "mb", + "kb", + "unlimited", + "free", + "procedure", + "query", + "free_memory", + "read_file", + "lock_path", + "after", + "before", + "execute", + "transaction", + "trigger", + "triggers", + "update", + "comitted", + "uncomitted", + "global", + "isolation", + "level", + "next", + "read", + "session", + "snapshot", + "transaction", + "batch_limit", + "batch_interval", + "batch_size", + "consumer_group", + "start", + "stream", + "streams", + "transform", + "topics", + "check", + "setting", + "settings", + "bootstrap_servers", + "kafka", + "pulsar", + "service_url", + "version", + "websocket" + "foreach"}; + +// Unicode codepoints that are allowed at the start of the unescaped name. +const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts( + std::string("00000000000000000000000000000000000111001111110011111100111111000111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111100000000000111111111111111111111111110100001111111111111111111111111" + "10000000000000000000000000000000000001111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111101111100000000000000000000000000000000111000000000" + "00000000000000011000000000000000000000000000000000000000000000000000000011" + "11111111110000000000000000000000000000000000000000111111111111111111111111" + "11111111111111111111111111111100111111111111111111111111111111111111111111" + "11111111111111111111110000000000000000001111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111100000000000000000000000000000000011111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111011011010111110111111111111101111111111010000011" + "11100000000000011111110000000000000000000000000000000000000011111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111100111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111110000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000111111111111111111111111111111111111111111" + "11111110000111111111111111111111110000000000001111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111000000000000000000000000000001" + "11111111111111111111111111111111110000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000001111111011111110000000001111110011111100111" + "11100000000000011100000001111111111100111000000000000000000000000101001111" + "10011000101111111111111111111111111111111111111111111111110000010001111111" + "11111111111111110000000000000000000011111111011100000000000000000000000111" + "11111111111111111111111111111111111111000000000000000000000000000000000000" + "00000000000010000000000000000000000000000111111111111111111111111111111111" + "11111111111111000000011111111111111111111111111111000000000000000000000000" + "01111111111111111111111100000000001111111111111111111111111111000000000000" + "00100011111100000000000000000000000000000000000000000000000000000000000000" + "11111111111111111111111111111111111111111111111111000000000000001111111111" + "11111111111111111111111111111111111111111100000000000000000000000000000111" + "11111111111111111111011110111011111111110000000000000000000000000000000000" + "00000000000000000000000000000000000000000001111111111100000000000011110111" + "10011111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111001111111110000000000000000000000000000000" + "00000000111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111000000001111111111111111111111111000000000000000011111111111" + "11111111111111111111111111111111111100000000000000000000110000000000111111" + "11111111110001111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111100111111111111" + "11111111111111111111111111111111110000000000000000000000000000000000000000" + "00000000000000000000000000011111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111000000000000000000000000000000000000000000000000000111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111110000000000000000" + "00000000000000000000000000000000000000000000000000000000001111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111100000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000011111111111111110000000000" + "00000000000000000000000000000000000000000001111111111111111111111111110000" + "00000000000001111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111100011111111111111111111111111111111111111" + "11100000111101111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111011111000011111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111100001111100111110000000" + "11111111100000000000000000000000001110000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000011111" + "11011111110111111101111111011111110111111101111111011111110000000001111111" + "11111111111111110000000000000000100000001111111111111111111111111111111111" + "11111111111111111111110010000010111111111111111111111111111111111111110000" + "00000000110001111000000111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111011111111111111111111111111111111111111111111111011111111111111111" + "11111111111111111111111111111100000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000011111111111111111111111111111" + "11111111111100000000000000000100001111100000111100111111111111111101010100" + "00001111110010111111111100100001000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000111111111111100" + "00000000000000100000000000001000000000000000000000000000010000000000000000" + "00011000000000000000000000000000000000000000000000000000000000000000000111" + "11110111000001111111111111000011111100111100011111110111000101111111011111" + "11111111111111111111111111111111111111111111111100111111111111111111111111" + "11111110101010111111110011111100111111111111111111111111111111111111110011" + "11110011111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111000000000000" + "00000000000000000000000000000000000000000000000000001111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111110000000001100011110111100000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00001111111111111111111111111111111111110000000000111000000000000000000000" + "00000000000000000000111111111111111111111111111111111111000000000000000000" + "00000000111111111111111111111111111111111111111111110000000000110000000000" + "00011111111111111111111111111111100000000000000000000000000000000000000000" + "00000000000000111111100000000000000000111111111111111111111111111111111111" + "11111111111000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000010000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000011111111111111111111111111111111111" + "11111111111111111100000000011111111111111111111111000000000000000000000000" + "00000000000000000000000000000000111111100000000000000000000011111111111111" + "11111111111111111111111111111100000000000111110011111111111111111111111111" + "11110000000000000000000000000000000000000000000000000001111111111111111111" + "11111111110000000000111111111111111111111111111111111111111111111111111111" + "11111111111111110000010111111111111111111111111111111111111111110000000011" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111100000000000000000000000000000000000000000000000000000000000000" + "00000100001000000000000000000000000000000000001111111111111111111111111111" + "11111111111111111111111100000000000000011101111111111111000000000000001111" + "11111111111111000000000000001111111111111111110000000000000011110111111111" + "11110000000000000001110001111111111111111111111111111111111111111111111111" + "11111111111111111111111111000001111111111111111111111111101111111111111111" + "10011111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111110000000000001111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111100000000000000001111" + "11111111111100000000000000000000000000000000000001111111111111111111111111" + "11111111111111111111111111111111111111111100111101111111111111111111111111" + "11111111111111111111111111111111011111111111111100111101011111110011110111" + "11111111111111111111111111111100111101111111111111111111111111111111111111" + "11110011110101111111001111011111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111110111111111" + "11111111111111111111111111111111110010000010111111111111111111111111111111" + "11111111000000000000000001000000000000111111111111100001110000000110001000" + "11110000111111000000000000000010000000000000000000011111111111111111111111" + "11111111111111111111000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000001111100000000" + "00000000000000000001111111111111111111111111111111111110111111110000000000" + "00000000000000000000000000000000000000000000000000000100000000000000000000" + "00000000000011110000000000000000000001011111001000000000110111101100101011" + "10111111101111000000100101100101100000000000000000000000000000000000000000" + "00000000000000000111111100000000000011011111111111111111111111111111111111" + "11111111111110000000000000000000000000000000000000000000000000000000000111" + "11110010111111111011111111111111111111111100011111111111111111100000111111" + "00000000000000000000000011000000000000000001000000000000000010011111111111" + "11111111111111111111111111111101110111111110000000000000000001100000000000" + "00001101000000000000000000000000000000001000111110111111111101111111111111" + "11111111110111011111111000000000000000000000000000000000001100000011000000" + "00000000000000000000100011111011111111110111111111111111111111110111011111" + "11100000000000000000000000000000000000000000000000000001000000000000000000" + "00001111111111110001110001100011010110001111011100011111101000000000000000" + "00100000000000000011101100000000000000000000000000000010001111101101111111" + "01111111111111111111111001100111111110000000000000000000000000000000000011" + "00000000000000010000000000000000001000111110110111111101111111111111111111" + "11101110111111111000000000000000011100000000000000000001011110000000000000" + "00000000000000000011011011011111110111111111111111111111100110000111111000" + "00000000000000001100000000000000111011000000000000010000000000000000100011" + "11000101111111011111111111111111111110011001111111100000111111101111111000" + "00000000000011111111110000000100000000000000000010001111111111111111111111" + "11111111111111111111111111111111000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000011111111111010000000000000000" + "00000000000000000000000000000000000000000000000000000001111111111111111111" + "11111100000000000000000000000100010000000001000011111111111111111111110000" + "01000011000000000111111111111111111111111111111111000000000000000000000000" + "10000000000011111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111100000000000000000000000000000111111111111111111" + "11111111111101000000000000000010011100000000001100000001100000000000000010" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111110110000000000000000000000000000000000011111111111" + "11111111111111111111111111111111000000000000000000000000000000000000000000" + "00011100000111111111111111111111111111000000000000000000000000000000000000" + "00000000000000000000000000000000000011111111111111111111111111111111111111" + "10000000100111111111111111111111111111111111111110000000001111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111000000" + "00111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111110111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11101111111111111111111101011101000000001111001101111100000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000101000000011111000000000000001111111111" + "11000011111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111011111111111111111111111111111110111111111111111111111" + "11000001000010000000000100000000000000000000000000000000000000000000000111" + "11111111111111111111111010000111111111111111111111111110000000000000000000" + "0000000000000000000000000000000000000000000000")); + +// Unicode codepoints that are allowed at the middle of the unescaped name. +const std::bitset<kBitsetSize> kUnescapedNameAllowedParts( + std::string("00000000000000000000000001100011000111001111110011111100111111000111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111100000000000111111111111111111111111110100001111111111111111111111111" + "10000000111111111100000000000100000001111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111101111100000010000000000000000000000000111000000000" + "00000000000000011000000000000111111100000000000000001111111111111111000111" + "11111111110000000000000000000000000000000000000000111111111111111111111111" + "11111111111111111111111111111100111111111111111111111111111111111111111111" + "11111111111111111111110000000000000000001111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111100000000000000000000000000000000011111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111011011010111110111111111111101111111111110000011" + "11100000000000011111110000000000000000000000000000000000000011111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111100111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111110000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000111111111111111111111111111111111111111111" + "11111110000111111111111111111111110000000000001111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111000000111111111100110111111111" + "11111111111111111111111111111111110000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000001111111011111110000000001111110011111100111" + "11100000000001111100111111111111111100111000000000000000000000000111111111" + "11111111111111111111111111111111111111111111111111111111110000110001111111" + "11111111111111110000001111111111001111111111111100000000011111111111111111" + "11111111111111111111111111111111111111000000000000000000000000000000000000" + "00111111111110000000000000011111111111111111111111111111111111111111111111" + "11111111111111111100011111111111111111111111111111000000000000111111111111" + "11111111111111111111111100111111111111111111111111111111111111111111111100" + "00100011111111111111111111111100000011111111110000000000011111111111111111" + "11111111111111111111111111111111111111111111111111110000000000001111111111" + "11111111111111111111111111111111111111111100000001000000000000000011111111" + "11111111111111111111111111111111111111110000000000000000000000000000000000" + "00000000000000000000000000000000000000000001111111111100000000000011110111" + "10011111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111001111111110000000000000000000000000000000" + "00000011111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111100000001111111111111111111111111011111111110000111111111111" + "11111111111111111111111111111111111100000000000000000000111111111111111111" + "11111111110001111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111100111111111111" + "11111111111111111111111111111111110000000000000000000000000000000000000000" + "00000000000000000000000000011111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111000000000000000000000000000000000000000000000000000111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111110000000000000000" + "00000000000000000000000000000000000000000000000000000000001111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111100000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000011111111111111110000000000" + "00000000000000000000000000000000000000000001111111111111111111111111110000" + "00000000000001111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111100011111111111111111111111111111111111111" + "11100000111101111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111011111110011111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111100001111100111110111111" + "11111111100000000000000000000000001110000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000011111111111111111111111111111111011111" + "11011111110111111101111111011111110111111101111111011111110000000001111111" + "11111111111111111000000000000000100000001111111111111111111111111111111111" + "11111111111111111111110010000010111111111111111111111111111111111111110000" + "00000000111111111000000111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111011111111111111111111111111111111111111111111111011111111111111111" + "11111111111111111111111111111100000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000011111111111111111111111111111" + "11111111111100000000000000000100001111100000111100111111111111111101010100" + "00001111110010111111111100100001000000000000000001111111111110001000011111" + "11111111000000000000000000000111111111111111111111111111000111111111111100" + "00000000000000100000000000001000000000000000000000000000010000000000000000" + "00011000000000000000000000000000000000000000000000000000000000000000000111" + "11110111000001111111111111000011111100111100011111110111000101111111011111" + "11111111111111111111111111111111111111111111111100111111111111111111111111" + "11111110101010111111110011111100111111111111111111111111111111111111110011" + "11110011111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111100000000" + "00000000000001111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111110000000001111111111111111111111111111111111101110000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00001111111111111111111111111111111111111111111111111000111111111100000000" + "11111111111111111111111111111111111111111111111111111111000000000000111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111100000000000011111111100000000000000000" + "11111111110000111111111111111111111111111111111111111111111111111111111111" + "11111111111111110000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000010000000000000111111111100000011111111111001" + "11111111111111111111111111110111111111111111111111111111111111111111111111" + "11111111111111111100001111111111111111111111111111000000000000000000000000" + "00000000000001111111111100000011111111111111111111111111000011111111111111" + "11111111111111111111111111111100000000000111110011111111111111111111111111" + "11111111111111000000000011111111111100001111111111110001111111111111111111" + "11111111110000000000111111111111111111111111111111111111111111111111111111" + "11111111111111110000011111111111111111111111111111111111111111110000000011" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111100000011111111110011100000000000000000000000000000000011111111" + "11001110001000111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111100000000000011011101111111111111000000000000111111" + "11111111111111000000000001111111111111111111110000000000011111110111111111" + "11110000000000000001110001111111111111111111111111111111111111111111111111" + "11111111111111111111111111000001111111111111111111111111101111111111111111" + "10011111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111110000000000001111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111100000000000000001111" + "11111111111100000000000000111111111000000000111001111111111111111111111111" + "11111111111111111111111111111111111111111100111101111111111111111111111111" + "11111111111111111111111111111111011111111111111100111101011111110011110111" + "11111111111111111111111111111100111101111111111111111111111111111111111111" + "11110011110101111111001111011111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111110111111111" + "11111111111111111111111111111111110010000010111111111111111111111111111111" + "11111111001111111111111111111111111111111111111111111111111111111111111111" + "11111111111111000000111111111111111111111111111111111111111111111111111111" + "11111111111111111111000000000000000000000000000000000000000000000000000000" + "00010000000001111111111111111111111111111111111110111111111111111111011111" + "11111111111111100001111111111111111111111111111111111110111111111100001010" + "10000000000011111111110000001100000000000000000000000100000000000000000000" + "00000000000011110011111111110011111101011111001110111111111111101100101011" + "10111111101111000000100101100101100000000000000000000000000000000000000011" + "11111111011111111111111110000111111111111111111111111111111111111111111111" + "11111111111110000000000000110000000000000000001111111101011111100001000111" + "11110010111111111011111111111111111111111100011111111111111111101100111111" + "00000000001111111111001111000000001000000001111101110111111110011111111111" + "11111111111111111111111111111101110111111110110000000000000001101111111111" + "00111101000000011000000011110111011111111100111110111111111101111111111111" + "11111111110111011111111011000000000000000000111111111100111100000011011000" + "00001111011101111111100011111011111111110111111111111111111111110111011111" + "11101110000000100000000011111111110000000000000010000001001111011100011111" + "00001111111111110001110001100011010110001111011100011111101100000000000000" + "00101111111111001111101100001100000000111001100111111111001111101101111111" + "01111111111111111111111001100111111110111000000000000000101111111111001111" + "00000000000000010011101110111111111100111110110111111101111111111111111111" + "11101110111111111011100000000000111111111111111100000001011110000000100011" + "10011000011111010011011011011111110111111111111111111111100110000111111011" + "10000010000000111111111111110011111011000010000000011110011001111111110011" + "11000101111111011111111111111111111110011001111111101110111111101111111011" + "11111111001111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111101111111111111111111111111110000000000" + "00000000000000000000000000000000000000000000011111111111010000000000000000" + "00000000000000000000000000000000000000000000000000001111111111111111111111" + "11111100000000000000000011111111111111111111111111111111111111111111110000" + "01000011111111111111111111111111111111111111111111111111111100000000000000" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111100111111111111111111111111111111111111111111111" + "11111111111111000000000000000010011111111111111111110111111111100111111110" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111110000111111111111111111111111111111111111111111" + "11111111111111111111111111111111000001111111111100001000000000000000000000" + "00011100000111111111111111111111111111000000001011011010111111111111111111" + "11111111111111111111111111101000000011111111111111111111111111111111111111" + "10000000100111111111111111111111111111111111111110000000001111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111001111" + "10111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111110111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11101111111111111111111101011111000000001111001101111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111100000000000000000101000000011111000000000000001111111111" + "11000011111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111111111111111111111111111111111111111111111111111111111" + "11111111111111111111011111111111111111111111111111110111111111111111111111" + "11000001001010000000000100001111000000000000000000000000000000000000000111" + "11111111111111111111111010000111111111111111111111111110000000111111111100" + "0000000000000000000000000000000000000000000000")); + +const std::bitset<kBitsetSize> kSpaceParts( + std::string("00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000100000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000001000000000000000000000000000" + "00000000000000000000100000110000000000000000000000000000011111111111000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000100000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000010000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000010000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000000" + "0000000000000111110000000000000011111000000000")); + +const trie::Trie kSpecialTokens = {";", + ",", + "=", + "+=", + "*", + "(", + ")", + "[", + "]", + ":", + "|", + "..", + "+", + "-", + "/", + "%", + "^", + "=~", + "<>", + "!=", + "<", + ">", + "<=", + ">=", + ".", + "{", + "}", + "$", + "\xE2\x9F\xA8", // u8"\u27e8" + "\xE3\x80\x88", // u8"\u3008" + "\xEF\xB9\xA4", // u8"\ufe64" + "\xEF\xBC\x9C", // u8"\uff1c" + "\xE2\x9F\xA9", // u8"\u27e9" + "\xE3\x80\x89", // u8"\u3009" + "\xEF\xB9\xA5", // u8"\ufe65" + "\xEF\xBC\x9E", // u8"\uff1e" + "\xC2\xAD", // u8"\u00ad" + "\xE2\x80\x90", // u8"\u2010" + "\xE2\x80\x91", // u8"\u2011" + "\xE2\x80\x92", // u8"\u2012" + "\xE2\x80\x93", // u8"\u2013" + "\xE2\x80\x94", // u8"\u2014" + "\xE2\x80\x95", // u8"\u2015" + "\xE2\x88\x92", // u8"\u2212" + "\xEF\xB9\x98", // u8"\ufe58" + "\xEF\xB9\xA3", // u8"\ufe63" + "\xEF\xBC\x8D"}; // u8"\uff0d" +} // namespace lexer_constants +} // namespace memgraph::query::v2 diff --git a/src/query/v2/interpret/awesome_memgraph_functions.cpp b/src/query/v2/interpret/awesome_memgraph_functions.cpp new file mode 100644 index 000000000..9dbe3ccba --- /dev/null +++ b/src/query/v2/interpret/awesome_memgraph_functions.cpp @@ -0,0 +1,1323 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/interpret/awesome_memgraph_functions.hpp" + +#include <algorithm> +#include <cctype> +#include <cmath> +#include <cstdlib> +#include <functional> +#include <random> +#include <string_view> +#include <type_traits> + +#include "query/v2/db_accessor.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/procedure/cypher_types.hpp" +#include "query/v2/procedure/mg_procedure_impl.hpp" +#include "query/v2/procedure/module.hpp" +#include "query/v2/typed_value.hpp" +#include "utils/string.hpp" +#include "utils/temporal.hpp" + +namespace memgraph::query::v2 { +namespace { + +//////////////////////////////////////////////////////////////////////////////// +// eDSL using template magic for describing a type of an awesome memgraph +// function and checking if the passed in arguments match the description. +// +// To use the type checking eDSL, you should put a `FType` invocation in the +// body of your awesome Memgraph function. `FType` takes type arguments as the +// description of the function type signature. Each runtime argument will be +// checked in order corresponding to given compile time type arguments. These +// type arguments can come in two forms: +// +// * final, primitive type descriptor and +// * combinator type descriptor. +// +// The primitive type descriptors are defined as empty structs, they are right +// below this documentation. +// +// Combinator type descriptors are defined as structs taking additional type +// parameters, you can find these further below in the implementation. Of +// primary interest are `Or` and `Optional` type combinators. +// +// With `Or` you can describe that an argument can be any of the types listed in +// `Or`. For example, `Or<Null, Bool, Integer>` allows an argument to be either +// `Null` or a boolean or an integer. +// +// The `Optional` combinator is used to define optional arguments to a function. +// These must come as the last positional arguments. Naturally, you can use `Or` +// inside `Optional`. So for example, `Optional<Or<Null, Bool>, Integer>` +// describes that a function takes 2 optional arguments. The 1st one must be +// either a `Null` or a boolean, while the 2nd one must be an integer. The type +// signature check will succeed in the following cases. +// +// * No optional arguments were supplied. +// * One argument was supplied and it passes `Or<Null, Bool>` check. +// * Two arguments were supplied, the 1st one passes `Or<Null, Bool>` check +// and the 2nd one passes `Integer` check. +// +// Runtime arguments to `FType` are: function name, pointer to arguments and the +// number of received arguments. +// +// Full example. +// +// FType<Or<Null, String>, NonNegativeInteger, +// Optional<NonNegativeInteger>>("substring", args, nargs); +// +// The above will check that `substring` function received the 2 required +// arguments. Optionally, the function may take a 3rd argument. The 1st argument +// must be either a `Null` or a character string. The 2nd argument is required +// to be a non-negative integer. If the 3rd argument was supplied, it will also +// be checked that it is a non-negative integer. If any of these checks fail, +// `FType` will throw a `QueryRuntimeException` with an appropriate error +// message. +//////////////////////////////////////////////////////////////////////////////// + +struct Null {}; +struct Bool {}; +struct Integer {}; +struct PositiveInteger {}; +struct NonZeroInteger {}; +struct NonNegativeInteger {}; +struct Double {}; +struct Number {}; +struct List {}; +struct String {}; +struct Map {}; +struct Edge {}; +struct Vertex {}; +struct Path {}; +struct Date {}; +struct LocalTime {}; +struct LocalDateTime {}; +struct Duration {}; + +template <class ArgType> +bool ArgIsType(const TypedValue &arg) { + if constexpr (std::is_same_v<ArgType, Null>) { + return arg.IsNull(); + } else if constexpr (std::is_same_v<ArgType, Bool>) { + return arg.IsBool(); + } else if constexpr (std::is_same_v<ArgType, Integer>) { + return arg.IsInt(); + } else if constexpr (std::is_same_v<ArgType, PositiveInteger>) { + return arg.IsInt() && arg.ValueInt() > 0; + } else if constexpr (std::is_same_v<ArgType, NonZeroInteger>) { + return arg.IsInt() && arg.ValueInt() != 0; + } else if constexpr (std::is_same_v<ArgType, NonNegativeInteger>) { + return arg.IsInt() && arg.ValueInt() >= 0; + } else if constexpr (std::is_same_v<ArgType, Double>) { + return arg.IsDouble(); + } else if constexpr (std::is_same_v<ArgType, Number>) { + return arg.IsNumeric(); + } else if constexpr (std::is_same_v<ArgType, List>) { + return arg.IsList(); + } else if constexpr (std::is_same_v<ArgType, String>) { + return arg.IsString(); + } else if constexpr (std::is_same_v<ArgType, Map>) { + return arg.IsMap(); + } else if constexpr (std::is_same_v<ArgType, Vertex>) { + return arg.IsVertex(); + } else if constexpr (std::is_same_v<ArgType, Edge>) { + return arg.IsEdge(); + } else if constexpr (std::is_same_v<ArgType, Path>) { + return arg.IsPath(); + } else if constexpr (std::is_same_v<ArgType, Date>) { + return arg.IsDate(); + } else if constexpr (std::is_same_v<ArgType, LocalTime>) { + return arg.IsLocalTime(); + } else if constexpr (std::is_same_v<ArgType, LocalDateTime>) { + return arg.IsLocalDateTime(); + } else if constexpr (std::is_same_v<ArgType, Duration>) { + return arg.IsDuration(); + } else if constexpr (std::is_same_v<ArgType, void>) { + return true; + } else { + static_assert(std::is_same_v<ArgType, Null>, "Unknown ArgType"); + } + return false; +} + +template <class ArgType> +constexpr const char *ArgTypeName() { + // The type names returned should be standardized openCypher type names. + // https://github.com/opencypher/openCypher/blob/master/docs/openCypher9.pdf + if constexpr (std::is_same_v<ArgType, Null>) { + return "null"; + } else if constexpr (std::is_same_v<ArgType, Bool>) { + return "boolean"; + } else if constexpr (std::is_same_v<ArgType, Integer>) { + return "integer"; + } else if constexpr (std::is_same_v<ArgType, PositiveInteger>) { + return "positive integer"; + } else if constexpr (std::is_same_v<ArgType, NonZeroInteger>) { + return "non-zero integer"; + } else if constexpr (std::is_same_v<ArgType, NonNegativeInteger>) { + return "non-negative integer"; + } else if constexpr (std::is_same_v<ArgType, Double>) { + return "float"; + } else if constexpr (std::is_same_v<ArgType, Number>) { + return "number"; + } else if constexpr (std::is_same_v<ArgType, List>) { + return "list"; + } else if constexpr (std::is_same_v<ArgType, String>) { + return "string"; + } else if constexpr (std::is_same_v<ArgType, Map>) { + return "map"; + } else if constexpr (std::is_same_v<ArgType, Vertex>) { + return "node"; + } else if constexpr (std::is_same_v<ArgType, Edge>) { + return "relationship"; + } else if constexpr (std::is_same_v<ArgType, Path>) { + return "path"; + } else if constexpr (std::is_same_v<ArgType, void>) { + return "void"; + } else if constexpr (std::is_same_v<ArgType, Date>) { + return "Date"; + } else if constexpr (std::is_same_v<ArgType, LocalTime>) { + return "LocalTime"; + } else if constexpr (std::is_same_v<ArgType, LocalDateTime>) { + return "LocalDateTime"; + } else if constexpr (std::is_same_v<ArgType, Duration>) { + return "Duration"; + } else { + static_assert(std::is_same_v<ArgType, Null>, "Unknown ArgType"); + } + return "<unknown-type>"; +} + +template <class... ArgType> +struct Or; + +template <class ArgType> +struct Or<ArgType> { + static bool Check(const TypedValue &arg) { return ArgIsType<ArgType>(arg); } + + static std::string TypeNames() { return ArgTypeName<ArgType>(); } +}; + +template <class ArgType, class... ArgTypes> +struct Or<ArgType, ArgTypes...> { + static bool Check(const TypedValue &arg) { + if (ArgIsType<ArgType>(arg)) return true; + return Or<ArgTypes...>::Check(arg); + } + + static std::string TypeNames() { + if constexpr (sizeof...(ArgTypes) > 1) { + return fmt::format("'{}', {}", ArgTypeName<ArgType>(), Or<ArgTypes...>::TypeNames()); + } else { + return fmt::format("'{}' or '{}'", ArgTypeName<ArgType>(), Or<ArgTypes...>::TypeNames()); + } + } +}; + +template <class T> +struct IsOrType { + static constexpr bool value = false; +}; + +template <class... ArgTypes> +struct IsOrType<Or<ArgTypes...>> { + static constexpr bool value = true; +}; + +template <class... ArgTypes> +struct Optional; + +template <class ArgType> +struct Optional<ArgType> { + static constexpr size_t size = 1; + + static void Check(const char *name, const TypedValue *args, int64_t nargs, int64_t pos) { + if (nargs == 0) return; + const TypedValue &arg = args[0]; + if constexpr (IsOrType<ArgType>::value) { + if (!ArgType::Check(arg)) { + throw QueryRuntimeException("Optional '{}' argument at position {} must be either {}.", name, pos, + ArgType::TypeNames()); + } + } else { + if (!ArgIsType<ArgType>(arg)) + throw QueryRuntimeException("Optional '{}' argument at position {} must be '{}'.", name, pos, + ArgTypeName<ArgType>()); + } + } +}; + +template <class ArgType, class... ArgTypes> +struct Optional<ArgType, ArgTypes...> { + static constexpr size_t size = 1 + sizeof...(ArgTypes); + + static void Check(const char *name, const TypedValue *args, int64_t nargs, int64_t pos) { + if (nargs == 0) return; + Optional<ArgType>::Check(name, args, nargs, pos); + Optional<ArgTypes...>::Check(name, args + 1, nargs - 1, pos + 1); + } +}; + +template <class T> +struct IsOptional { + static constexpr bool value = false; +}; + +template <class... ArgTypes> +struct IsOptional<Optional<ArgTypes...>> { + static constexpr bool value = true; +}; + +template <class ArgType, class... ArgTypes> +constexpr size_t FTypeRequiredArgs() { + if constexpr (IsOptional<ArgType>::value) { + static_assert(sizeof...(ArgTypes) == 0, "Optional arguments must be last!"); + return 0; + } else if constexpr (sizeof...(ArgTypes) == 0) { + return 1; + } else { + return 1U + FTypeRequiredArgs<ArgTypes...>(); + } +} + +template <class ArgType, class... ArgTypes> +constexpr size_t FTypeOptionalArgs() { + if constexpr (IsOptional<ArgType>::value) { + static_assert(sizeof...(ArgTypes) == 0, "Optional arguments must be last!"); + return ArgType::size; + } else if constexpr (sizeof...(ArgTypes) == 0) { + return 0; + } else { + return FTypeOptionalArgs<ArgTypes...>(); + } +} + +template <class ArgType, class... ArgTypes> +void FType(const char *name, const TypedValue *args, int64_t nargs, int64_t pos = 1) { + if constexpr (std::is_same_v<ArgType, void>) { + if (nargs != 0) { + throw QueryRuntimeException("'{}' requires no arguments.", name); + } + return; + } + static constexpr int64_t required_args = FTypeRequiredArgs<ArgType, ArgTypes...>(); + static constexpr int64_t optional_args = FTypeOptionalArgs<ArgType, ArgTypes...>(); + static constexpr int64_t total_args = required_args + optional_args; + if constexpr (optional_args > 0) { + if (nargs < required_args || nargs > total_args) { + throw QueryRuntimeException("'{}' requires between {} and {} arguments.", name, required_args, total_args); + } + } else { + if (nargs != required_args) { + throw QueryRuntimeException("'{}' requires exactly {} {}.", name, required_args, + required_args == 1 ? "argument" : "arguments"); + } + } + const TypedValue &arg = args[0]; + if constexpr (IsOrType<ArgType>::value) { + if (!ArgType::Check(arg)) { + throw QueryRuntimeException("'{}' argument at position {} must be either {}.", name, pos, ArgType::TypeNames()); + } + } else if constexpr (IsOptional<ArgType>::value) { + static_assert(sizeof...(ArgTypes) == 0, "Optional arguments must be last!"); + ArgType::Check(name, args, nargs, pos); + } else { + if (!ArgIsType<ArgType>(arg)) { + throw QueryRuntimeException("'{}' argument at position {} must be '{}'", name, pos, ArgTypeName<ArgType>()); + } + } + if constexpr (sizeof...(ArgTypes) > 0) { + FType<ArgTypes...>(name, args + 1, nargs - 1, pos + 1); + } +} + +//////////////////////////////////////////////////////////////////////////////// +// END function type description eDSL +//////////////////////////////////////////////////////////////////////////////// + +// Predicate functions. +// Neo4j has all, any, exists, none, single +// Those functions are a little bit different since they take a filterExpression +// as an argument. +// There is all, any, none and single productions in opencypher grammar, but it +// will be trivial to also add exists. +// TODO: Implement this. + +// Scalar functions. +// We don't have a way to implement id function since we don't store any. If it +// is really neccessary we could probably map vlist* to id. +// TODO: Implement length (it works on a path, but we didn't define path +// structure yet). +// TODO: Implement size(pattern), for example size((a)-[:X]-()) should return +// number of results of this pattern. I don't think we will ever do this. +// TODO: Implement rest of the list functions. +// TODO: Implement degrees, haversin, radians +// TODO: Implement spatial functions + +TypedValue EndNode(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Edge>>("endNode", args, nargs); + if (args[0].IsNull()) return TypedValue(ctx.memory); + return TypedValue(args[0].ValueEdge().To(), ctx.memory); +} + +TypedValue Head(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, List>>("head", args, nargs); + if (args[0].IsNull()) return TypedValue(ctx.memory); + const auto &list = args[0].ValueList(); + if (list.empty()) return TypedValue(ctx.memory); + return TypedValue(list[0], ctx.memory); +} + +TypedValue Last(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, List>>("last", args, nargs); + if (args[0].IsNull()) return TypedValue(ctx.memory); + const auto &list = args[0].ValueList(); + if (list.empty()) return TypedValue(ctx.memory); + return TypedValue(list.back(), ctx.memory); +} + +TypedValue Properties(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Vertex, Edge>>("properties", args, nargs); + auto *dba = ctx.db_accessor; + auto get_properties = [&](const auto &record_accessor) { + TypedValue::TMap properties(ctx.memory); + auto maybe_props = record_accessor.Properties(ctx.view); + if (maybe_props.HasError()) { + switch (maybe_props.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to get properties from a deleted object."); + case storage::v3::Error::NONEXISTENT_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get properties from an object that doesn't exist."); + case storage::v3::Error::SERIALIZATION_ERROR: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Unexpected error when getting properties."); + } + } + for (const auto &property : *maybe_props) { + properties.emplace(dba->PropertyToName(property.first), property.second); + } + return TypedValue(std::move(properties)); + }; + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValue(ctx.memory); + } else if (value.IsVertex()) { + return get_properties(value.ValueVertex()); + } else { + return get_properties(value.ValueEdge()); + } +} + +TypedValue Size(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, List, String, Map, Path>>("size", args, nargs); + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValue(ctx.memory); + } else if (value.IsList()) { + return TypedValue(static_cast<int64_t>(value.ValueList().size()), ctx.memory); + } else if (value.IsString()) { + return TypedValue(static_cast<int64_t>(value.ValueString().size()), ctx.memory); + } else if (value.IsMap()) { + // neo4j doesn't implement size for map, but I don't see a good reason not + // to do it. + return TypedValue(static_cast<int64_t>(value.ValueMap().size()), ctx.memory); + } else { + return TypedValue(static_cast<int64_t>(value.ValuePath().edges().size()), ctx.memory); + } +} + +TypedValue StartNode(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Edge>>("startNode", args, nargs); + if (args[0].IsNull()) return TypedValue(ctx.memory); + return TypedValue(args[0].ValueEdge().From(), ctx.memory); +} + +namespace { + +size_t UnwrapDegreeResult(storage::v3::Result<size_t> maybe_degree) { + if (maybe_degree.HasError()) { + switch (maybe_degree.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to get degree of a deleted node."); + case storage::v3::Error::NONEXISTENT_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get degree of a node that doesn't exist."); + case storage::v3::Error::SERIALIZATION_ERROR: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Unexpected error when getting node degree."); + } + } + return *maybe_degree; +} + +} // namespace + +TypedValue Degree(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Vertex>>("degree", args, nargs); + if (args[0].IsNull()) return TypedValue(ctx.memory); + const auto &vertex = args[0].ValueVertex(); + size_t out_degree = UnwrapDegreeResult(vertex.OutDegree(ctx.view)); + size_t in_degree = UnwrapDegreeResult(vertex.InDegree(ctx.view)); + return TypedValue(static_cast<int64_t>(out_degree + in_degree), ctx.memory); +} + +TypedValue InDegree(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Vertex>>("inDegree", args, nargs); + if (args[0].IsNull()) return TypedValue(ctx.memory); + const auto &vertex = args[0].ValueVertex(); + size_t in_degree = UnwrapDegreeResult(vertex.InDegree(ctx.view)); + return TypedValue(static_cast<int64_t>(in_degree), ctx.memory); +} + +TypedValue OutDegree(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Vertex>>("outDegree", args, nargs); + if (args[0].IsNull()) return TypedValue(ctx.memory); + const auto &vertex = args[0].ValueVertex(); + size_t out_degree = UnwrapDegreeResult(vertex.OutDegree(ctx.view)); + return TypedValue(static_cast<int64_t>(out_degree), ctx.memory); +} + +TypedValue ToBoolean(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Bool, Integer, String>>("toBoolean", args, nargs); + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValue(ctx.memory); + } else if (value.IsBool()) { + return TypedValue(value.ValueBool(), ctx.memory); + } else if (value.IsInt()) { + return TypedValue(value.ValueInt() != 0L, ctx.memory); + } else { + auto s = utils::ToUpperCase(utils::Trim(value.ValueString())); + if (s == "TRUE") return TypedValue(true, ctx.memory); + if (s == "FALSE") return TypedValue(false, ctx.memory); + // I think this is just stupid and that exception should be thrown, but + // neo4j does it this way... + return TypedValue(ctx.memory); + } +} + +TypedValue ToFloat(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Number, String>>("toFloat", args, nargs); + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValue(ctx.memory); + } else if (value.IsInt()) { + return TypedValue(static_cast<double>(value.ValueInt()), ctx.memory); + } else if (value.IsDouble()) { + return TypedValue(value, ctx.memory); + } else { + try { + return TypedValue(utils::ParseDouble(utils::Trim(value.ValueString())), ctx.memory); + } catch (const utils::BasicException &) { + return TypedValue(ctx.memory); + } + } +} + +TypedValue ToInteger(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Bool, Number, String>>("toInteger", args, nargs); + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValue(ctx.memory); + } else if (value.IsBool()) { + return TypedValue(value.ValueBool() ? 1L : 0L, ctx.memory); + } else if (value.IsInt()) { + return TypedValue(value, ctx.memory); + } else if (value.IsDouble()) { + return TypedValue(static_cast<int64_t>(value.ValueDouble()), ctx.memory); + } else { + try { + // Yup, this is correct. String is valid if it has floating point + // number, then it is parsed and converted to int. + return TypedValue(static_cast<int64_t>(utils::ParseDouble(utils::Trim(value.ValueString()))), ctx.memory); + } catch (const utils::BasicException &) { + return TypedValue(ctx.memory); + } + } +} + +TypedValue Type(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Edge>>("type", args, nargs); + auto *dba = ctx.db_accessor; + if (args[0].IsNull()) return TypedValue(ctx.memory); + return TypedValue(dba->EdgeTypeToName(args[0].ValueEdge().EdgeType()), ctx.memory); +} + +TypedValue ValueType(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Bool, Integer, Double, String, List, Map, Vertex, Edge, Path>>("type", args, nargs); + // The type names returned should be standardized openCypher type names. + // https://github.com/opencypher/openCypher/blob/master/docs/openCypher9.pdf + switch (args[0].type()) { + case TypedValue::Type::Null: + return TypedValue("NULL", ctx.memory); + case TypedValue::Type::Bool: + return TypedValue("BOOLEAN", ctx.memory); + case TypedValue::Type::Int: + return TypedValue("INTEGER", ctx.memory); + case TypedValue::Type::Double: + return TypedValue("FLOAT", ctx.memory); + case TypedValue::Type::String: + return TypedValue("STRING", ctx.memory); + case TypedValue::Type::List: + return TypedValue("LIST", ctx.memory); + case TypedValue::Type::Map: + return TypedValue("MAP", ctx.memory); + case TypedValue::Type::Vertex: + return TypedValue("NODE", ctx.memory); + case TypedValue::Type::Edge: + return TypedValue("RELATIONSHIP", ctx.memory); + case TypedValue::Type::Path: + return TypedValue("PATH", ctx.memory); + case TypedValue::Type::Date: + return TypedValue("DATE", ctx.memory); + case TypedValue::Type::LocalTime: + return TypedValue("LOCAL_TIME", ctx.memory); + case TypedValue::Type::LocalDateTime: + return TypedValue("LOCAL_DATE_TIME", ctx.memory); + case TypedValue::Type::Duration: + return TypedValue("DURATION", ctx.memory); + } +} + +// TODO: How is Keys different from Properties function? +TypedValue Keys(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Vertex, Edge>>("keys", args, nargs); + auto *dba = ctx.db_accessor; + auto get_keys = [&](const auto &record_accessor) { + TypedValue::TVector keys(ctx.memory); + auto maybe_props = record_accessor.Properties(ctx.view); + if (maybe_props.HasError()) { + switch (maybe_props.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to get keys from a deleted object."); + case storage::v3::Error::NONEXISTENT_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get keys from an object that doesn't exist."); + case storage::v3::Error::SERIALIZATION_ERROR: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Unexpected error when getting keys."); + } + } + for (const auto &property : *maybe_props) { + keys.emplace_back(dba->PropertyToName(property.first)); + } + return TypedValue(std::move(keys)); + }; + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValue(ctx.memory); + } else if (value.IsVertex()) { + return get_keys(value.ValueVertex()); + } else { + return get_keys(value.ValueEdge()); + } +} + +TypedValue Labels(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Vertex>>("labels", args, nargs); + auto *dba = ctx.db_accessor; + if (args[0].IsNull()) return TypedValue(ctx.memory); + TypedValue::TVector labels(ctx.memory); + auto maybe_labels = args[0].ValueVertex().Labels(ctx.view); + if (maybe_labels.HasError()) { + switch (maybe_labels.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to get labels from a deleted node."); + case storage::v3::Error::NONEXISTENT_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get labels from a node that doesn't exist."); + case storage::v3::Error::SERIALIZATION_ERROR: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Unexpected error when getting labels."); + } + } + for (const auto &label : *maybe_labels) { + labels.emplace_back(dba->LabelToName(label)); + } + return TypedValue(std::move(labels)); +} + +TypedValue Nodes(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Path>>("nodes", args, nargs); + if (args[0].IsNull()) return TypedValue(ctx.memory); + const auto &vertices = args[0].ValuePath().vertices(); + TypedValue::TVector values(ctx.memory); + values.reserve(vertices.size()); + for (const auto &v : vertices) values.emplace_back(v); + return TypedValue(std::move(values)); +} + +TypedValue Relationships(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Path>>("relationships", args, nargs); + if (args[0].IsNull()) return TypedValue(ctx.memory); + const auto &edges = args[0].ValuePath().edges(); + TypedValue::TVector values(ctx.memory); + values.reserve(edges.size()); + for (const auto &e : edges) values.emplace_back(e); + return TypedValue(std::move(values)); +} + +TypedValue Range(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Integer>, Or<Null, Integer>, Optional<Or<Null, NonZeroInteger>>>("range", args, nargs); + for (int64_t i = 0; i < nargs; ++i) + if (args[i].IsNull()) return TypedValue(ctx.memory); + auto lbound = args[0].ValueInt(); + auto rbound = args[1].ValueInt(); + int64_t step = nargs == 3 ? args[2].ValueInt() : 1; + TypedValue::TVector list(ctx.memory); + if (lbound <= rbound && step > 0) { + for (auto i = lbound; i <= rbound; i += step) { + list.emplace_back(i); + } + } else if (lbound >= rbound && step < 0) { + for (auto i = lbound; i >= rbound; i += step) { + list.emplace_back(i); + } + } + return TypedValue(std::move(list)); +} + +TypedValue Tail(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, List>>("tail", args, nargs); + if (args[0].IsNull()) return TypedValue(ctx.memory); + TypedValue::TVector list(args[0].ValueList(), ctx.memory); + if (list.empty()) return TypedValue(std::move(list)); + list.erase(list.begin()); + return TypedValue(std::move(list)); +} + +TypedValue UniformSample(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, List>, Or<Null, NonNegativeInteger>>("uniformSample", args, nargs); + static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()}; + if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory); + const auto &population = args[0].ValueList(); + auto population_size = population.size(); + if (population_size == 0) return TypedValue(ctx.memory); + auto desired_length = args[1].ValueInt(); + std::uniform_int_distribution<uint64_t> rand_dist{0, population_size - 1}; + TypedValue::TVector sampled(ctx.memory); + sampled.reserve(desired_length); + for (int64_t i = 0; i < desired_length; ++i) { + sampled.emplace_back(population[rand_dist(pseudo_rand_gen_)]); + } + return TypedValue(std::move(sampled)); +} + +TypedValue Abs(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Number>>("abs", args, nargs); + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValue(ctx.memory); + } else if (value.IsInt()) { + return TypedValue(std::abs(value.ValueInt()), ctx.memory); + } else { + return TypedValue(std::abs(value.ValueDouble()), ctx.memory); + } +} + +#define WRAP_CMATH_FLOAT_FUNCTION(name, lowercased_name) \ + TypedValue name(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { \ + FType<Or<Null, Number>>(#lowercased_name, args, nargs); \ + const auto &value = args[0]; \ + if (value.IsNull()) { \ + return TypedValue(ctx.memory); \ + } else if (value.IsInt()) { \ + return TypedValue(lowercased_name(value.ValueInt()), ctx.memory); \ + } else { \ + return TypedValue(lowercased_name(value.ValueDouble()), ctx.memory); \ + } \ + } + +WRAP_CMATH_FLOAT_FUNCTION(Ceil, ceil) +WRAP_CMATH_FLOAT_FUNCTION(Floor, floor) +// We are not completely compatible with neoj4 in this function because, +// neo4j rounds -0.5, -1.5, -2.5... to 0, -1, -2... +WRAP_CMATH_FLOAT_FUNCTION(Round, round) +WRAP_CMATH_FLOAT_FUNCTION(Exp, exp) +WRAP_CMATH_FLOAT_FUNCTION(Log, log) +WRAP_CMATH_FLOAT_FUNCTION(Log10, log10) +WRAP_CMATH_FLOAT_FUNCTION(Sqrt, sqrt) +WRAP_CMATH_FLOAT_FUNCTION(Acos, acos) +WRAP_CMATH_FLOAT_FUNCTION(Asin, asin) +WRAP_CMATH_FLOAT_FUNCTION(Atan, atan) +WRAP_CMATH_FLOAT_FUNCTION(Cos, cos) +WRAP_CMATH_FLOAT_FUNCTION(Sin, sin) +WRAP_CMATH_FLOAT_FUNCTION(Tan, tan) + +#undef WRAP_CMATH_FLOAT_FUNCTION + +TypedValue Atan2(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Number>, Or<Null, Number>>("atan2", args, nargs); + if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory); + auto to_double = [](const TypedValue &t) -> double { + if (t.IsInt()) { + return t.ValueInt(); + } else { + return t.ValueDouble(); + } + }; + double y = to_double(args[0]); + double x = to_double(args[1]); + return TypedValue(atan2(y, x), ctx.memory); +} + +TypedValue Sign(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Number>>("sign", args, nargs); + auto sign = [&](auto x) { return TypedValue((0 < x) - (x < 0), ctx.memory); }; + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValue(ctx.memory); + } else if (value.IsInt()) { + return sign(value.ValueInt()); + } else { + return sign(value.ValueDouble()); + } +} + +TypedValue E(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<void>("e", args, nargs); + return TypedValue(M_E, ctx.memory); +} + +TypedValue Pi(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<void>("pi", args, nargs); + return TypedValue(M_PI, ctx.memory); +} + +TypedValue Rand(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<void>("rand", args, nargs); + static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()}; + static thread_local std::uniform_real_distribution<> rand_dist_{0, 1}; + return TypedValue(rand_dist_(pseudo_rand_gen_), ctx.memory); +} + +template <class TPredicate> +TypedValue StringMatchOperator(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, String>, Or<Null, String>>(TPredicate::name, args, nargs); + if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory); + const auto &s1 = args[0].ValueString(); + const auto &s2 = args[1].ValueString(); + return TypedValue(TPredicate{}(s1, s2), ctx.memory); +} + +// Check if s1 starts with s2. +struct StartsWithPredicate { + static constexpr const char *name = "startsWith"; + bool operator()(const TypedValue::TString &s1, const TypedValue::TString &s2) const { + if (s1.size() < s2.size()) return false; + return std::equal(s2.begin(), s2.end(), s1.begin()); + } +}; +auto StartsWith = StringMatchOperator<StartsWithPredicate>; + +// Check if s1 ends with s2. +struct EndsWithPredicate { + static constexpr const char *name = "endsWith"; + bool operator()(const TypedValue::TString &s1, const TypedValue::TString &s2) const { + if (s1.size() < s2.size()) return false; + return std::equal(s2.rbegin(), s2.rend(), s1.rbegin()); + } +}; +auto EndsWith = StringMatchOperator<EndsWithPredicate>; + +// Check if s1 contains s2. +struct ContainsPredicate { + static constexpr const char *name = "contains"; + bool operator()(const TypedValue::TString &s1, const TypedValue::TString &s2) const { + if (s1.size() < s2.size()) return false; + return s1.find(s2) != std::string::npos; + } +}; +auto Contains = StringMatchOperator<ContainsPredicate>; + +TypedValue Assert(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Bool, Optional<String>>("assert", args, nargs); + if (!args[0].ValueBool()) { + std::string message("Assertion failed"); + if (nargs == 2) { + message += ": "; + message += args[1].ValueString(); + } + message += "."; + throw QueryRuntimeException(message); + } + return TypedValue(args[0], ctx.memory); +} + +TypedValue Counter(const TypedValue *args, int64_t nargs, const FunctionContext &context) { + FType<String, Integer, Optional<NonZeroInteger>>("counter", args, nargs); + int64_t step = 1; + if (nargs == 3) { + step = args[2].ValueInt(); + } + + auto [it, inserted] = context.counters->emplace(args[0].ValueString(), args[1].ValueInt()); + auto value = it->second; + it->second += step; + + return TypedValue(value, context.memory); +} + +TypedValue Id(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, Vertex, Edge>>("id", args, nargs); + const auto &arg = args[0]; + if (arg.IsNull()) { + return TypedValue(ctx.memory); + } else if (arg.IsVertex()) { + return TypedValue(arg.ValueVertex().CypherId(), ctx.memory); + } else { + return TypedValue(arg.ValueEdge().CypherId(), ctx.memory); + } +} + +TypedValue ToString(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, String, Number, Bool>>("toString", args, nargs); + const auto &arg = args[0]; + if (arg.IsNull()) { + return TypedValue(ctx.memory); + } else if (arg.IsString()) { + return TypedValue(arg, ctx.memory); + } else if (arg.IsInt()) { + // TODO: This is making a pointless copy of std::string, we may want to + // use a different conversion to string + return TypedValue(std::to_string(arg.ValueInt()), ctx.memory); + } else if (arg.IsDouble()) { + return TypedValue(std::to_string(arg.ValueDouble()), ctx.memory); + } else { + return TypedValue(arg.ValueBool() ? "true" : "false", ctx.memory); + } +} + +TypedValue Timestamp(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Optional<Or<Date, LocalTime, LocalDateTime, Duration>>>("timestamp", args, nargs); + const auto &arg = *args; + if (arg.IsDate()) { + return TypedValue(arg.ValueDate().MicrosecondsSinceEpoch(), ctx.memory); + } + if (arg.IsLocalTime()) { + return TypedValue(arg.ValueLocalTime().MicrosecondsSinceEpoch(), ctx.memory); + } + if (arg.IsLocalDateTime()) { + return TypedValue(arg.ValueLocalDateTime().MicrosecondsSinceEpoch(), ctx.memory); + } + if (arg.IsDuration()) { + return TypedValue(arg.ValueDuration().microseconds, ctx.memory); + } + return TypedValue(ctx.timestamp, ctx.memory); +} + +TypedValue Left(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, String>, Or<Null, NonNegativeInteger>>("left", args, nargs); + if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory); + return TypedValue(utils::Substr(args[0].ValueString(), 0, args[1].ValueInt()), ctx.memory); +} + +TypedValue Right(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, String>, Or<Null, NonNegativeInteger>>("right", args, nargs); + if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory); + const auto &str = args[0].ValueString(); + auto len = args[1].ValueInt(); + return len <= str.size() ? TypedValue(utils::Substr(str, str.size() - len, len), ctx.memory) + : TypedValue(str, ctx.memory); +} + +TypedValue CallStringFunction(const TypedValue *args, int64_t nargs, utils::MemoryResource *memory, const char *name, + std::function<TypedValue::TString(const TypedValue::TString &)> fun) { + FType<Or<Null, String>>(name, args, nargs); + if (args[0].IsNull()) return TypedValue(memory); + return TypedValue(fun(args[0].ValueString()), memory); +} + +TypedValue LTrim(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + return CallStringFunction(args, nargs, ctx.memory, "lTrim", + [&](const auto &str) { return TypedValue::TString(utils::LTrim(str), ctx.memory); }); +} + +TypedValue RTrim(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + return CallStringFunction(args, nargs, ctx.memory, "rTrim", + [&](const auto &str) { return TypedValue::TString(utils::RTrim(str), ctx.memory); }); +} + +TypedValue Trim(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + return CallStringFunction(args, nargs, ctx.memory, "trim", + [&](const auto &str) { return TypedValue::TString(utils::Trim(str), ctx.memory); }); +} + +TypedValue Reverse(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + return CallStringFunction(args, nargs, ctx.memory, "reverse", + [&](const auto &str) { return utils::Reversed(str, ctx.memory); }); +} + +TypedValue ToLower(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + return CallStringFunction(args, nargs, ctx.memory, "toLower", [&](const auto &str) { + TypedValue::TString res(ctx.memory); + utils::ToLowerCase(&res, str); + return res; + }); +} + +TypedValue ToUpper(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + return CallStringFunction(args, nargs, ctx.memory, "toUpper", [&](const auto &str) { + TypedValue::TString res(ctx.memory); + utils::ToUpperCase(&res, str); + return res; + }); +} + +TypedValue Replace(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, String>, Or<Null, String>, Or<Null, String>>("replace", args, nargs); + if (args[0].IsNull() || args[1].IsNull() || args[2].IsNull()) { + return TypedValue(ctx.memory); + } + TypedValue::TString replaced(ctx.memory); + utils::Replace(&replaced, args[0].ValueString(), args[1].ValueString(), args[2].ValueString()); + return TypedValue(std::move(replaced)); +} + +TypedValue Split(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, String>, Or<Null, String>>("split", args, nargs); + if (args[0].IsNull() || args[1].IsNull()) { + return TypedValue(ctx.memory); + } + TypedValue::TVector result(ctx.memory); + utils::Split(&result, args[0].ValueString(), args[1].ValueString()); + return TypedValue(std::move(result)); +} + +TypedValue Substring(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<Null, String>, NonNegativeInteger, Optional<NonNegativeInteger>>("substring", args, nargs); + if (args[0].IsNull()) return TypedValue(ctx.memory); + const auto &str = args[0].ValueString(); + auto start = args[1].ValueInt(); + if (nargs == 2) return TypedValue(utils::Substr(str, start), ctx.memory); + auto len = args[2].ValueInt(); + return TypedValue(utils::Substr(str, start, len), ctx.memory); +} + +TypedValue ToByteString(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<String>("toByteString", args, nargs); + const auto &str = args[0].ValueString(); + if (str.empty()) return TypedValue("", ctx.memory); + if (!utils::StartsWith(str, "0x") && !utils::StartsWith(str, "0X")) { + throw QueryRuntimeException("'toByteString' argument must start with '0x'"); + } + const auto &hex_str = utils::Substr(str, 2); + auto read_hex = [](const char ch) -> unsigned char { + if (ch >= '0' && ch <= '9') return ch - '0'; + if (ch >= 'a' && ch <= 'f') return ch - 'a' + 10; + if (ch >= 'A' && ch <= 'F') return ch - 'A' + 10; + throw QueryRuntimeException("'toByteString' argument has an invalid character '{}'", ch); + }; + utils::pmr::string bytes(ctx.memory); + bytes.reserve((1 + hex_str.size()) / 2); + size_t i = 0; + // Treat odd length hex string as having a leading zero. + if (hex_str.size() % 2) bytes.append(1, read_hex(hex_str[i++])); + for (; i < hex_str.size(); i += 2) { + unsigned char byte = read_hex(hex_str[i]) * 16U + read_hex(hex_str[i + 1]); + // MemcpyCast in case we are converting to a signed value, so as to avoid + // undefined behaviour. + bytes.append(1, utils::MemcpyCast<decltype(bytes)::value_type>(byte)); + } + return TypedValue(std::move(bytes)); +} + +TypedValue FromByteString(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<String, Optional<PositiveInteger>>("fromByteString", args, nargs); + const auto &bytes = args[0].ValueString(); + if (bytes.empty()) return TypedValue("", ctx.memory); + size_t min_length = bytes.size(); + if (nargs == 2) min_length = std::max(min_length, static_cast<size_t>(args[1].ValueInt())); + utils::pmr::string str(ctx.memory); + str.reserve(min_length * 2 + 2); + str.append("0x"); + for (size_t pad = 0; pad < min_length - bytes.size(); ++pad) str.append(2, '0'); + // Convert the bytes to a character string in hex representation. + // Unfortunately, we don't know whether the default `char` is signed or + // unsigned, so we have to work around any potential undefined behaviour when + // conversions between the 2 occur. That's why this function is more + // complicated than it should be. + auto to_hex = [](const unsigned char val) -> char { + unsigned char ch = val < 10U ? static_cast<unsigned char>('0') + val : static_cast<unsigned char>('a') + val - 10U; + return utils::MemcpyCast<char>(ch); + }; + for (unsigned char byte : bytes) { + str.append(1, to_hex(byte / 16U)); + str.append(1, to_hex(byte % 16U)); + } + return TypedValue(std::move(str)); +} + +template <typename T> +concept IsNumberOrInteger = utils::SameAsAnyOf<T, Number, Integer>; + +template <IsNumberOrInteger ArgType> +void MapNumericParameters(auto ¶meter_mappings, const auto &input_parameters) { + for (const auto &[key, value] : input_parameters) { + if (auto it = parameter_mappings.find(key); it != parameter_mappings.end()) { + if (value.IsInt()) { + *it->second = value.ValueInt(); + } else if (std::is_same_v<ArgType, Number> && value.IsDouble()) { + *it->second = value.ValueDouble(); + } else { + std::string_view error = std::is_same_v<ArgType, Integer> ? "an integer." : "a numeric value."; + throw QueryRuntimeException("Invalid value for key '{}'. Expected {}", key, error); + } + } else { + throw QueryRuntimeException("Unknown key '{}'.", key); + } + } +} + +TypedValue Date(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Optional<Or<String, Map>>>("date", args, nargs); + if (nargs == 0) { + return TypedValue(utils::LocalDateTime(ctx.timestamp).date, ctx.memory); + } + + if (args[0].IsString()) { + const auto &[date_parameters, is_extended] = utils::ParseDateParameters(args[0].ValueString()); + return TypedValue(utils::Date(date_parameters), ctx.memory); + } + + utils::DateParameters date_parameters; + + using namespace std::literals; + std::unordered_map parameter_mappings = {std::pair{"year"sv, &date_parameters.year}, + std::pair{"month"sv, &date_parameters.month}, + std::pair{"day"sv, &date_parameters.day}}; + + MapNumericParameters<Integer>(parameter_mappings, args[0].ValueMap()); + return TypedValue(utils::Date(date_parameters), ctx.memory); +} + +TypedValue LocalTime(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Optional<Or<String, Map>>>("localtime", args, nargs); + + if (nargs == 0) { + return TypedValue(utils::LocalDateTime(ctx.timestamp).local_time, ctx.memory); + } + + if (args[0].IsString()) { + const auto &[local_time_parameters, is_extended] = utils::ParseLocalTimeParameters(args[0].ValueString()); + return TypedValue(utils::LocalTime(local_time_parameters), ctx.memory); + } + + utils::LocalTimeParameters local_time_parameters; + + using namespace std::literals; + std::unordered_map parameter_mappings{ + std::pair{"hour"sv, &local_time_parameters.hour}, + std::pair{"minute"sv, &local_time_parameters.minute}, + std::pair{"second"sv, &local_time_parameters.second}, + std::pair{"millisecond"sv, &local_time_parameters.millisecond}, + std::pair{"microsecond"sv, &local_time_parameters.microsecond}, + }; + + MapNumericParameters<Integer>(parameter_mappings, args[0].ValueMap()); + return TypedValue(utils::LocalTime(local_time_parameters), ctx.memory); +} + +TypedValue LocalDateTime(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Optional<Or<String, Map>>>("localdatetime", args, nargs); + + if (nargs == 0) { + return TypedValue(utils::LocalDateTime(ctx.timestamp), ctx.memory); + } + + if (args[0].IsString()) { + const auto &[date_parameters, local_time_parameters] = ParseLocalDateTimeParameters(args[0].ValueString()); + return TypedValue(utils::LocalDateTime(date_parameters, local_time_parameters), ctx.memory); + } + + utils::DateParameters date_parameters; + utils::LocalTimeParameters local_time_parameters; + using namespace std::literals; + std::unordered_map parameter_mappings{ + std::pair{"year"sv, &date_parameters.year}, + std::pair{"month"sv, &date_parameters.month}, + std::pair{"day"sv, &date_parameters.day}, + std::pair{"hour"sv, &local_time_parameters.hour}, + std::pair{"minute"sv, &local_time_parameters.minute}, + std::pair{"second"sv, &local_time_parameters.second}, + std::pair{"millisecond"sv, &local_time_parameters.millisecond}, + std::pair{"microsecond"sv, &local_time_parameters.microsecond}, + }; + + MapNumericParameters<Integer>(parameter_mappings, args[0].ValueMap()); + return TypedValue(utils::LocalDateTime(date_parameters, local_time_parameters), ctx.memory); +} + +TypedValue Duration(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType<Or<String, Map>>("duration", args, nargs); + + if (args[0].IsString()) { + return TypedValue(utils::Duration(ParseDurationParameters(args[0].ValueString())), ctx.memory); + } + + utils::DurationParameters duration_parameters; + using namespace std::literals; + std::unordered_map parameter_mappings{std::pair{"day"sv, &duration_parameters.day}, + std::pair{"hour"sv, &duration_parameters.hour}, + std::pair{"minute"sv, &duration_parameters.minute}, + std::pair{"second"sv, &duration_parameters.second}, + std::pair{"millisecond"sv, &duration_parameters.millisecond}, + std::pair{"microsecond"sv, &duration_parameters.microsecond}}; + MapNumericParameters<Number>(parameter_mappings, args[0].ValueMap()); + return TypedValue(utils::Duration(duration_parameters), ctx.memory); +} + +std::function<TypedValue(const TypedValue *, const int64_t, const FunctionContext &)> UserFunction( + const mgp_func &func, const std::string &fully_qualified_name) { + return [func, fully_qualified_name](const TypedValue *args, int64_t nargs, const FunctionContext &ctx) -> TypedValue { + /// Find function is called to aquire the lock on Module pointer while user-defined function is executed + const auto &maybe_found = + procedure::FindFunction(procedure::gModuleRegistry, fully_qualified_name, utils::NewDeleteResource()); + if (!maybe_found) { + throw QueryRuntimeException( + "Function '{}' has been unloaded. Please check query modules to confirm that function is loaded in Memgraph.", + fully_qualified_name); + } + /// Explicit extraction of module pointer, to clearly state that the lock is aquired. + // NOLINTNEXTLINE(clang-diagnostic-unused-variable) + const auto &module_ptr = (*maybe_found).first; + + const auto &func_cb = func.cb; + mgp_memory memory{ctx.memory}; + mgp_func_context functx{ctx.db_accessor, ctx.view}; + auto graph = mgp_graph::NonWritableGraph(*ctx.db_accessor, ctx.view); + + std::vector<TypedValue> args_list; + args_list.reserve(nargs); + for (std::size_t i = 0; i < nargs; ++i) { + args_list.emplace_back(args[i]); + } + + auto function_argument_list = mgp_list(ctx.memory); + procedure::ConstructArguments(args_list, func, fully_qualified_name, function_argument_list, graph); + + mgp_func_result maybe_res; + func_cb(&function_argument_list, &functx, &maybe_res, &memory); + if (maybe_res.error_msg) { + throw QueryRuntimeException(*maybe_res.error_msg); + } + + if (!maybe_res.value) { + throw QueryRuntimeException( + "Function '{}' didn't set the result nor the error message. Please either set the result by using " + "mgp_func_result_set_value or the error by using mgp_func_result_set_error_msg.", + fully_qualified_name); + } + + return {*(maybe_res.value), ctx.memory}; + }; +} + +} // namespace + +std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &ctx)> NameToFunction( + const std::string &function_name) { + // Scalar functions + if (function_name == "DEGREE") return Degree; + if (function_name == "INDEGREE") return InDegree; + if (function_name == "OUTDEGREE") return OutDegree; + if (function_name == "ENDNODE") return EndNode; + if (function_name == "HEAD") return Head; + if (function_name == kId) return Id; + if (function_name == "LAST") return Last; + if (function_name == "PROPERTIES") return Properties; + if (function_name == "SIZE") return Size; + if (function_name == "STARTNODE") return StartNode; + if (function_name == "TIMESTAMP") return Timestamp; + if (function_name == "TOBOOLEAN") return ToBoolean; + if (function_name == "TOFLOAT") return ToFloat; + if (function_name == "TOINTEGER") return ToInteger; + if (function_name == "TYPE") return Type; + if (function_name == "VALUETYPE") return ValueType; + + // List functions + if (function_name == "KEYS") return Keys; + if (function_name == "LABELS") return Labels; + if (function_name == "NODES") return Nodes; + if (function_name == "RANGE") return Range; + if (function_name == "RELATIONSHIPS") return Relationships; + if (function_name == "TAIL") return Tail; + if (function_name == "UNIFORMSAMPLE") return UniformSample; + + // Mathematical functions - numeric + if (function_name == "ABS") return Abs; + if (function_name == "CEIL") return Ceil; + if (function_name == "FLOOR") return Floor; + if (function_name == "RAND") return Rand; + if (function_name == "ROUND") return Round; + if (function_name == "SIGN") return Sign; + + // Mathematical functions - logarithmic + if (function_name == "E") return E; + if (function_name == "EXP") return Exp; + if (function_name == "LOG") return Log; + if (function_name == "LOG10") return Log10; + if (function_name == "SQRT") return Sqrt; + + // Mathematical functions - trigonometric + if (function_name == "ACOS") return Acos; + if (function_name == "ASIN") return Asin; + if (function_name == "ATAN") return Atan; + if (function_name == "ATAN2") return Atan2; + if (function_name == "COS") return Cos; + if (function_name == "PI") return Pi; + if (function_name == "SIN") return Sin; + if (function_name == "TAN") return Tan; + + // String functions + if (function_name == kContains) return Contains; + if (function_name == kEndsWith) return EndsWith; + if (function_name == "LEFT") return Left; + if (function_name == "LTRIM") return LTrim; + if (function_name == "REPLACE") return Replace; + if (function_name == "REVERSE") return Reverse; + if (function_name == "RIGHT") return Right; + if (function_name == "RTRIM") return RTrim; + if (function_name == "SPLIT") return Split; + if (function_name == kStartsWith) return StartsWith; + if (function_name == "SUBSTRING") return Substring; + if (function_name == "TOLOWER") return ToLower; + if (function_name == "TOSTRING") return ToString; + if (function_name == "TOUPPER") return ToUpper; + if (function_name == "TRIM") return Trim; + + // Memgraph specific functions + if (function_name == "ASSERT") return Assert; + if (function_name == "COUNTER") return Counter; + if (function_name == "TOBYTESTRING") return ToByteString; + if (function_name == "FROMBYTESTRING") return FromByteString; + + // Functions for temporal types + if (function_name == "DATE") return Date; + if (function_name == "LOCALTIME") return LocalTime; + if (function_name == "LOCALDATETIME") return LocalDateTime; + if (function_name == "DURATION") return Duration; + + const auto &maybe_found = + procedure::FindFunction(procedure::gModuleRegistry, function_name, utils::NewDeleteResource()); + + if (maybe_found) { + const auto *func = (*maybe_found).second; + return UserFunction(*func, function_name); + } + + return nullptr; +} + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/interpret/awesome_memgraph_functions.hpp b/src/query/v2/interpret/awesome_memgraph_functions.hpp new file mode 100644 index 000000000..27c0bc50c --- /dev/null +++ b/src/query/v2/interpret/awesome_memgraph_functions.hpp @@ -0,0 +1,50 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <functional> +#include <string> +#include <unordered_map> + +#include "storage/v3/view.hpp" +#include "utils/memory.hpp" + +namespace memgraph::query::v2 { + +class DbAccessor; +class TypedValue; + +namespace { +const char kStartsWith[] = "STARTSWITH"; +const char kEndsWith[] = "ENDSWITH"; +const char kContains[] = "CONTAINS"; +const char kId[] = "ID"; +} // namespace + +struct FunctionContext { + DbAccessor *db_accessor; + utils::MemoryResource *memory; + int64_t timestamp; + std::unordered_map<std::string, int64_t> *counters; + storage::v3::View view; +}; + +/// Return the function implementation with the given name. +/// +/// Note, returned function signature uses C-style access to an array to allow +/// having an array stored anywhere the caller likes, as long as it is +/// contiguous in memory. Since most functions don't take many arguments, it's +/// convenient to have them stored in the calling stack frame. +std::function<TypedValue(const TypedValue *arguments, int64_t num_arguments, const FunctionContext &context)> +NameToFunction(const std::string &function_name); + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/interpret/eval.cpp b/src/query/v2/interpret/eval.cpp new file mode 100644 index 000000000..eba77adf9 --- /dev/null +++ b/src/query/v2/interpret/eval.cpp @@ -0,0 +1,35 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/interpret/eval.hpp" + +namespace memgraph::query::v2 { + +int64_t EvaluateInt(ExpressionEvaluator *evaluator, Expression *expr, const std::string &what) { + TypedValue value = expr->Accept(*evaluator); + try { + return value.ValueInt(); + } catch (TypedValueException &e) { + throw QueryRuntimeException(what + " must be an int"); + } +} + +std::optional<size_t> EvaluateMemoryLimit(ExpressionEvaluator *eval, Expression *memory_limit, size_t memory_scale) { + if (!memory_limit) return std::nullopt; + auto limit_value = memory_limit->Accept(*eval); + if (!limit_value.IsInt() || limit_value.ValueInt() <= 0) + throw QueryRuntimeException("Memory limit must be a non-negative integer."); + size_t limit = limit_value.ValueInt(); + if (std::numeric_limits<size_t>::max() / memory_scale < limit) throw QueryRuntimeException("Memory limit overflow."); + return limit * memory_scale; +} + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/interpret/eval.hpp b/src/query/v2/interpret/eval.hpp new file mode 100644 index 000000000..10cee31a9 --- /dev/null +++ b/src/query/v2/interpret/eval.hpp @@ -0,0 +1,764 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +#pragma once + +#include <algorithm> +#include <limits> +#include <map> +#include <optional> +#include <regex> +#include <vector> + +#include "query/v2/common.hpp" +#include "query/v2/context.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/semantic/symbol_table.hpp" +#include "query/v2/interpret/frame.hpp" +#include "query/v2/typed_value.hpp" +#include "utils/exceptions.hpp" + +namespace memgraph::query::v2 { + +class ExpressionEvaluator : public ExpressionVisitor<TypedValue> { + public: + ExpressionEvaluator(Frame *frame, const SymbolTable &symbol_table, const EvaluationContext &ctx, DbAccessor *dba, + storage::v3::View view) + : frame_(frame), symbol_table_(&symbol_table), ctx_(&ctx), dba_(dba), view_(view) {} + + using ExpressionVisitor<TypedValue>::Visit; + + utils::MemoryResource *GetMemoryResource() const { return ctx_->memory; } + + TypedValue Visit(NamedExpression &named_expression) override { + const auto &symbol = symbol_table_->at(named_expression); + auto value = named_expression.expression_->Accept(*this); + frame_->at(symbol) = value; + return value; + } + + TypedValue Visit(Identifier &ident) override { + return TypedValue(frame_->at(symbol_table_->at(ident)), ctx_->memory); + } + +#define BINARY_OPERATOR_VISITOR(OP_NODE, CPP_OP, CYPHER_OP) \ + TypedValue Visit(OP_NODE &op) override { \ + auto val1 = op.expression1_->Accept(*this); \ + auto val2 = op.expression2_->Accept(*this); \ + try { \ + return val1 CPP_OP val2; \ + } catch (const TypedValueException &) { \ + throw QueryRuntimeException("Invalid types: {} and {} for '{}'.", val1.type(), val2.type(), #CYPHER_OP); \ + } \ + } + +#define UNARY_OPERATOR_VISITOR(OP_NODE, CPP_OP, CYPHER_OP) \ + TypedValue Visit(OP_NODE &op) override { \ + auto val = op.expression_->Accept(*this); \ + try { \ + return CPP_OP val; \ + } catch (const TypedValueException &) { \ + throw QueryRuntimeException("Invalid type {} for '{}'.", val.type(), #CYPHER_OP); \ + } \ + } + + BINARY_OPERATOR_VISITOR(OrOperator, ||, OR); + BINARY_OPERATOR_VISITOR(XorOperator, ^, XOR); + BINARY_OPERATOR_VISITOR(AdditionOperator, +, +); + BINARY_OPERATOR_VISITOR(SubtractionOperator, -, -); + BINARY_OPERATOR_VISITOR(MultiplicationOperator, *, *); + BINARY_OPERATOR_VISITOR(DivisionOperator, /, /); + BINARY_OPERATOR_VISITOR(ModOperator, %, %); + BINARY_OPERATOR_VISITOR(NotEqualOperator, !=, <>); + BINARY_OPERATOR_VISITOR(EqualOperator, ==, =); + BINARY_OPERATOR_VISITOR(LessOperator, <, <); + BINARY_OPERATOR_VISITOR(GreaterOperator, >, >); + BINARY_OPERATOR_VISITOR(LessEqualOperator, <=, <=); + BINARY_OPERATOR_VISITOR(GreaterEqualOperator, >=, >=); + + UNARY_OPERATOR_VISITOR(NotOperator, !, NOT); + UNARY_OPERATOR_VISITOR(UnaryPlusOperator, +, +); + UNARY_OPERATOR_VISITOR(UnaryMinusOperator, -, -); + +#undef BINARY_OPERATOR_VISITOR +#undef UNARY_OPERATOR_VISITOR + + TypedValue Visit(AndOperator &op) override { + auto value1 = op.expression1_->Accept(*this); + if (value1.IsBool() && !value1.ValueBool()) { + // If first expression is false, don't evaluate the second one. + return value1; + } + auto value2 = op.expression2_->Accept(*this); + try { + return value1 && value2; + } catch (const TypedValueException &) { + throw QueryRuntimeException("Invalid types: {} and {} for AND.", value1.type(), value2.type()); + } + } + + TypedValue Visit(IfOperator &if_operator) override { + auto condition = if_operator.condition_->Accept(*this); + if (condition.IsNull()) { + return if_operator.then_expression_->Accept(*this); + } + if (condition.type() != TypedValue::Type::Bool) { + // At the moment IfOperator is used only in CASE construct. + throw QueryRuntimeException("CASE expected boolean expression, got {}.", condition.type()); + } + if (condition.ValueBool()) { + return if_operator.then_expression_->Accept(*this); + } + return if_operator.else_expression_->Accept(*this); + } + + TypedValue Visit(InListOperator &in_list) override { + auto literal = in_list.expression1_->Accept(*this); + auto _list = in_list.expression2_->Accept(*this); + if (_list.IsNull()) { + return TypedValue(ctx_->memory); + } + // Exceptions have higher priority than returning nulls when list expression + // is not null. + if (_list.type() != TypedValue::Type::List) { + throw QueryRuntimeException("IN expected a list, got {}.", _list.type()); + } + const auto &list = _list.ValueList(); + + // If literal is NULL there is no need to try to compare it with every + // element in the list since result of every comparison will be NULL. There + // is one special case that we must test explicitly: if list is empty then + // result is false since no comparison will be performed. + if (list.empty()) return TypedValue(false, ctx_->memory); + if (literal.IsNull()) return TypedValue(ctx_->memory); + + auto has_null = false; + for (const auto &element : list) { + auto result = literal == element; + if (result.IsNull()) { + has_null = true; + } else if (result.ValueBool()) { + return TypedValue(true, ctx_->memory); + } + } + if (has_null) { + return TypedValue(ctx_->memory); + } + return TypedValue(false, ctx_->memory); + } + + TypedValue Visit(SubscriptOperator &list_indexing) override { + auto lhs = list_indexing.expression1_->Accept(*this); + auto index = list_indexing.expression2_->Accept(*this); + if (!lhs.IsList() && !lhs.IsMap() && !lhs.IsVertex() && !lhs.IsEdge() && !lhs.IsNull()) + throw QueryRuntimeException( + "Expected a list, a map, a node or an edge to index with '[]', got " + "{}.", + lhs.type()); + if (lhs.IsNull() || index.IsNull()) return TypedValue(ctx_->memory); + if (lhs.IsList()) { + if (!index.IsInt()) throw QueryRuntimeException("Expected an integer as a list index, got {}.", index.type()); + auto index_int = index.ValueInt(); + // NOTE: Take non-const reference to list, so that we can move out the + // indexed element as the result. + auto &list = lhs.ValueList(); + if (index_int < 0) { + index_int += static_cast<int64_t>(list.size()); + } + if (index_int >= static_cast<int64_t>(list.size()) || index_int < 0) return TypedValue(ctx_->memory); + // NOTE: Explicit move is needed, so that we return the move constructed + // value and preserve the correct MemoryResource. + return std::move(list[index_int]); + } + + if (lhs.IsMap()) { + if (!index.IsString()) throw QueryRuntimeException("Expected a string as a map index, got {}.", index.type()); + // NOTE: Take non-const reference to map, so that we can move out the + // looked-up element as the result. + auto &map = lhs.ValueMap(); + auto found = map.find(index.ValueString()); + if (found == map.end()) return TypedValue(ctx_->memory); + // NOTE: Explicit move is needed, so that we return the move constructed + // value and preserve the correct MemoryResource. + return std::move(found->second); + } + + if (lhs.IsVertex()) { + if (!index.IsString()) throw QueryRuntimeException("Expected a string as a property name, got {}.", index.type()); + return TypedValue(GetProperty(lhs.ValueVertex(), index.ValueString()), ctx_->memory); + } + + if (lhs.IsEdge()) { + if (!index.IsString()) throw QueryRuntimeException("Expected a string as a property name, got {}.", index.type()); + return TypedValue(GetProperty(lhs.ValueEdge(), index.ValueString()), ctx_->memory); + } + + // lhs is Null + return TypedValue(ctx_->memory); + } + + TypedValue Visit(ListSlicingOperator &op) override { + // If some type is null we can't return null, because throwing exception + // on illegal type has higher priority. + auto is_null = false; + auto get_bound = [&](Expression *bound_expr, int64_t default_value) { + if (bound_expr) { + auto bound = bound_expr->Accept(*this); + if (bound.type() == TypedValue::Type::Null) { + is_null = true; + } else if (bound.type() != TypedValue::Type::Int) { + throw QueryRuntimeException("Expected an integer for a bound in list slicing, got {}.", bound.type()); + } + return bound; + } + return TypedValue(default_value, ctx_->memory); + }; + auto _upper_bound = get_bound(op.upper_bound_, std::numeric_limits<int64_t>::max()); + auto _lower_bound = get_bound(op.lower_bound_, 0); + + auto _list = op.list_->Accept(*this); + if (_list.type() == TypedValue::Type::Null) { + is_null = true; + } else if (_list.type() != TypedValue::Type::List) { + throw QueryRuntimeException("Expected a list to slice, got {}.", _list.type()); + } + + if (is_null) { + return TypedValue(ctx_->memory); + } + const auto &list = _list.ValueList(); + auto normalise_bound = [&](int64_t bound) { + if (bound < 0) { + bound = static_cast<int64_t>(list.size()) + bound; + } + return std::max(static_cast<int64_t>(0), std::min(bound, static_cast<int64_t>(list.size()))); + }; + auto lower_bound = normalise_bound(_lower_bound.ValueInt()); + auto upper_bound = normalise_bound(_upper_bound.ValueInt()); + if (upper_bound <= lower_bound) { + return TypedValue(TypedValue::TVector(ctx_->memory), ctx_->memory); + } + return TypedValue(TypedValue::TVector(list.begin() + lower_bound, list.begin() + upper_bound, ctx_->memory)); + } + + TypedValue Visit(IsNullOperator &is_null) override { + auto value = is_null.expression_->Accept(*this); + return TypedValue(value.IsNull(), ctx_->memory); + } + + TypedValue Visit(PropertyLookup &property_lookup) override { + auto expression_result = property_lookup.expression_->Accept(*this); + auto maybe_date = [this](const auto &date, const auto &prop_name) -> std::optional<TypedValue> { + if (prop_name == "year") { + return TypedValue(date.year, ctx_->memory); + } + if (prop_name == "month") { + return TypedValue(date.month, ctx_->memory); + } + if (prop_name == "day") { + return TypedValue(date.day, ctx_->memory); + } + return std::nullopt; + }; + auto maybe_local_time = [this](const auto <, const auto &prop_name) -> std::optional<TypedValue> { + if (prop_name == "hour") { + return TypedValue(lt.hour, ctx_->memory); + } + if (prop_name == "minute") { + return TypedValue(lt.minute, ctx_->memory); + } + if (prop_name == "second") { + return TypedValue(lt.second, ctx_->memory); + } + if (prop_name == "millisecond") { + return TypedValue(lt.millisecond, ctx_->memory); + } + if (prop_name == "microsecond") { + return TypedValue(lt.microsecond, ctx_->memory); + } + return std::nullopt; + }; + auto maybe_duration = [this](const auto &dur, const auto &prop_name) -> std::optional<TypedValue> { + if (prop_name == "day") { + return TypedValue(dur.Days(), ctx_->memory); + } + if (prop_name == "hour") { + return TypedValue(dur.SubDaysAsHours(), ctx_->memory); + } + if (prop_name == "minute") { + return TypedValue(dur.SubDaysAsMinutes(), ctx_->memory); + } + if (prop_name == "second") { + return TypedValue(dur.SubDaysAsSeconds(), ctx_->memory); + } + if (prop_name == "millisecond") { + return TypedValue(dur.SubDaysAsMilliseconds(), ctx_->memory); + } + if (prop_name == "microsecond") { + return TypedValue(dur.SubDaysAsMicroseconds(), ctx_->memory); + } + if (prop_name == "nanosecond") { + return TypedValue(dur.SubDaysAsNanoseconds(), ctx_->memory); + } + return std::nullopt; + }; + switch (expression_result.type()) { + case TypedValue::Type::Null: + return TypedValue(ctx_->memory); + case TypedValue::Type::Vertex: + return TypedValue(GetProperty(expression_result.ValueVertex(), property_lookup.property_), ctx_->memory); + case TypedValue::Type::Edge: + return TypedValue(GetProperty(expression_result.ValueEdge(), property_lookup.property_), ctx_->memory); + case TypedValue::Type::Map: { + // NOTE: Take non-const reference to map, so that we can move out the + // looked-up element as the result. + auto &map = expression_result.ValueMap(); + auto found = map.find(property_lookup.property_.name.c_str()); + if (found == map.end()) return TypedValue(ctx_->memory); + // NOTE: Explicit move is needed, so that we return the move constructed + // value and preserve the correct MemoryResource. + return std::move(found->second); + } + case TypedValue::Type::Duration: { + const auto &prop_name = property_lookup.property_.name; + const auto &dur = expression_result.ValueDuration(); + if (auto dur_field = maybe_duration(dur, prop_name); dur_field) { + return std::move(*dur_field); + } + throw QueryRuntimeException("Invalid property name {} for Duration", prop_name); + } + case TypedValue::Type::Date: { + const auto &prop_name = property_lookup.property_.name; + const auto &date = expression_result.ValueDate(); + if (auto date_field = maybe_date(date, prop_name); date_field) { + return std::move(*date_field); + } + throw QueryRuntimeException("Invalid property name {} for Date", prop_name); + } + case TypedValue::Type::LocalTime: { + const auto &prop_name = property_lookup.property_.name; + const auto < = expression_result.ValueLocalTime(); + if (auto lt_field = maybe_local_time(lt, prop_name); lt_field) { + return std::move(*lt_field); + } + throw QueryRuntimeException("Invalid property name {} for LocalTime", prop_name); + } + case TypedValue::Type::LocalDateTime: { + const auto &prop_name = property_lookup.property_.name; + const auto &ldt = expression_result.ValueLocalDateTime(); + if (auto date_field = maybe_date(ldt.date, prop_name); date_field) { + return std::move(*date_field); + } + if (auto lt_field = maybe_local_time(ldt.local_time, prop_name); lt_field) { + return std::move(*lt_field); + } + throw QueryRuntimeException("Invalid property name {} for LocalDateTime", prop_name); + } + default: + throw QueryRuntimeException("Only nodes, edges, maps and temporal types have properties to be looked-up."); + } + } + + TypedValue Visit(LabelsTest &labels_test) override { + auto expression_result = labels_test.expression_->Accept(*this); + switch (expression_result.type()) { + case TypedValue::Type::Null: + return TypedValue(ctx_->memory); + case TypedValue::Type::Vertex: { + const auto &vertex = expression_result.ValueVertex(); + for (const auto &label : labels_test.labels_) { + auto has_label = vertex.HasLabel(view_, GetLabel(label)); + if (has_label.HasError() && has_label.GetError() == storage::v3::Error::NONEXISTENT_OBJECT) { + // This is a very nasty and temporary hack in order to make MERGE + // work. The old storage had the following logic when returning an + // `OLD` view: `return old ? old : new`. That means that if the + // `OLD` view didn't exist, it returned the NEW view. With this hack + // we simulate that behavior. + // TODO (mferencevic, teon.banek): Remove once MERGE is + // reimplemented. + has_label = vertex.HasLabel(storage::v3::View::NEW, GetLabel(label)); + } + if (has_label.HasError()) { + switch (has_label.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to access labels on a deleted node."); + case storage::v3::Error::NONEXISTENT_OBJECT: + throw query::v2::QueryRuntimeException("Trying to access labels from a node that doesn't exist."); + case storage::v3::Error::SERIALIZATION_ERROR: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Unexpected error when accessing labels."); + } + } + if (!*has_label) { + return TypedValue(false, ctx_->memory); + } + } + return TypedValue(true, ctx_->memory); + } + default: + throw QueryRuntimeException("Only nodes have labels."); + } + } + + TypedValue Visit(PrimitiveLiteral &literal) override { + // TODO: no need to evaluate constants, we can write it to frame in one + // of the previous phases. + return TypedValue(literal.value_, ctx_->memory); + } + + TypedValue Visit(ListLiteral &literal) override { + TypedValue::TVector result(ctx_->memory); + result.reserve(literal.elements_.size()); + for (const auto &expression : literal.elements_) result.emplace_back(expression->Accept(*this)); + return TypedValue(result, ctx_->memory); + } + + TypedValue Visit(MapLiteral &literal) override { + TypedValue::TMap result(ctx_->memory); + for (const auto &pair : literal.elements_) result.emplace(pair.first.name, pair.second->Accept(*this)); + return TypedValue(result, ctx_->memory); + } + + TypedValue Visit(Aggregation &aggregation) override { + return TypedValue(frame_->at(symbol_table_->at(aggregation)), ctx_->memory); + } + + TypedValue Visit(Coalesce &coalesce) override { + auto &exprs = coalesce.expressions_; + + if (exprs.size() == 0) { + throw QueryRuntimeException("'coalesce' requires at least one argument."); + } + + for (int64_t i = 0; i < exprs.size(); ++i) { + TypedValue val(exprs[i]->Accept(*this), ctx_->memory); + if (!val.IsNull()) { + return val; + } + } + + return TypedValue(ctx_->memory); + } + + TypedValue Visit(Function &function) override { + FunctionContext function_ctx{dba_, ctx_->memory, ctx_->timestamp, &ctx_->counters, view_}; + // Stack allocate evaluated arguments when there's a small number of them. + if (function.arguments_.size() <= 8) { + TypedValue arguments[8] = {TypedValue(ctx_->memory), TypedValue(ctx_->memory), TypedValue(ctx_->memory), + TypedValue(ctx_->memory), TypedValue(ctx_->memory), TypedValue(ctx_->memory), + TypedValue(ctx_->memory), TypedValue(ctx_->memory)}; + for (size_t i = 0; i < function.arguments_.size(); ++i) { + arguments[i] = function.arguments_[i]->Accept(*this); + } + auto res = function.function_(arguments, function.arguments_.size(), function_ctx); + MG_ASSERT(res.GetMemoryResource() == ctx_->memory); + return res; + } else { + TypedValue::TVector arguments(ctx_->memory); + arguments.reserve(function.arguments_.size()); + for (const auto &argument : function.arguments_) { + arguments.emplace_back(argument->Accept(*this)); + } + auto res = function.function_(arguments.data(), arguments.size(), function_ctx); + MG_ASSERT(res.GetMemoryResource() == ctx_->memory); + return res; + } + } + + TypedValue Visit(Reduce &reduce) override { + auto list_value = reduce.list_->Accept(*this); + if (list_value.IsNull()) { + return TypedValue(ctx_->memory); + } + if (list_value.type() != TypedValue::Type::List) { + throw QueryRuntimeException("REDUCE expected a list, got {}.", list_value.type()); + } + const auto &list = list_value.ValueList(); + const auto &element_symbol = symbol_table_->at(*reduce.identifier_); + const auto &accumulator_symbol = symbol_table_->at(*reduce.accumulator_); + auto accumulator = reduce.initializer_->Accept(*this); + for (const auto &element : list) { + frame_->at(accumulator_symbol) = accumulator; + frame_->at(element_symbol) = element; + accumulator = reduce.expression_->Accept(*this); + } + return accumulator; + } + + TypedValue Visit(Extract &extract) override { + auto list_value = extract.list_->Accept(*this); + if (list_value.IsNull()) { + return TypedValue(ctx_->memory); + } + if (list_value.type() != TypedValue::Type::List) { + throw QueryRuntimeException("EXTRACT expected a list, got {}.", list_value.type()); + } + const auto &list = list_value.ValueList(); + const auto &element_symbol = symbol_table_->at(*extract.identifier_); + TypedValue::TVector result(ctx_->memory); + result.reserve(list.size()); + for (const auto &element : list) { + if (element.IsNull()) { + result.emplace_back(); + } else { + frame_->at(element_symbol) = element; + result.emplace_back(extract.expression_->Accept(*this)); + } + } + return TypedValue(result, ctx_->memory); + } + + TypedValue Visit(All &all) override { + auto list_value = all.list_expression_->Accept(*this); + if (list_value.IsNull()) { + return TypedValue(ctx_->memory); + } + if (list_value.type() != TypedValue::Type::List) { + throw QueryRuntimeException("ALL expected a list, got {}.", list_value.type()); + } + const auto &list = list_value.ValueList(); + const auto &symbol = symbol_table_->at(*all.identifier_); + bool has_null_elements = false; + bool has_value = false; + for (const auto &element : list) { + frame_->at(symbol) = element; + auto result = all.where_->expression_->Accept(*this); + if (!result.IsNull() && result.type() != TypedValue::Type::Bool) { + throw QueryRuntimeException("Predicate of ALL must evaluate to boolean, got {}.", result.type()); + } + if (!result.IsNull()) { + has_value = true; + if (!result.ValueBool()) { + return TypedValue(false, ctx_->memory); + } + } else { + has_null_elements = true; + } + } + if (!has_value) { + return TypedValue(ctx_->memory); + } + if (has_null_elements) { + return TypedValue(false, ctx_->memory); + } else { + return TypedValue(true, ctx_->memory); + } + } + + TypedValue Visit(Single &single) override { + auto list_value = single.list_expression_->Accept(*this); + if (list_value.IsNull()) { + return TypedValue(ctx_->memory); + } + if (list_value.type() != TypedValue::Type::List) { + throw QueryRuntimeException("SINGLE expected a list, got {}.", list_value.type()); + } + const auto &list = list_value.ValueList(); + const auto &symbol = symbol_table_->at(*single.identifier_); + bool has_value = false; + bool predicate_satisfied = false; + for (const auto &element : list) { + frame_->at(symbol) = element; + auto result = single.where_->expression_->Accept(*this); + if (!result.IsNull() && result.type() != TypedValue::Type::Bool) { + throw QueryRuntimeException("Predicate of SINGLE must evaluate to boolean, got {}.", result.type()); + } + if (result.type() == TypedValue::Type::Bool) { + has_value = true; + } + if (result.IsNull() || !result.ValueBool()) { + continue; + } + // Return false if more than one element satisfies the predicate. + if (predicate_satisfied) { + return TypedValue(false, ctx_->memory); + } else { + predicate_satisfied = true; + } + } + if (!has_value) { + return TypedValue(ctx_->memory); + } else { + return TypedValue(predicate_satisfied, ctx_->memory); + } + } + + TypedValue Visit(Any &any) override { + auto list_value = any.list_expression_->Accept(*this); + if (list_value.IsNull()) { + return TypedValue(ctx_->memory); + } + if (list_value.type() != TypedValue::Type::List) { + throw QueryRuntimeException("ANY expected a list, got {}.", list_value.type()); + } + const auto &list = list_value.ValueList(); + const auto &symbol = symbol_table_->at(*any.identifier_); + bool has_value = false; + for (const auto &element : list) { + frame_->at(symbol) = element; + auto result = any.where_->expression_->Accept(*this); + if (!result.IsNull() && result.type() != TypedValue::Type::Bool) { + throw QueryRuntimeException("Predicate of ANY must evaluate to boolean, got {}.", result.type()); + } + if (!result.IsNull()) { + has_value = true; + if (result.ValueBool()) { + return TypedValue(true, ctx_->memory); + } + } + } + // Return Null if all elements are Null + if (!has_value) { + return TypedValue(ctx_->memory); + } else { + return TypedValue(false, ctx_->memory); + } + } + + TypedValue Visit(None &none) override { + auto list_value = none.list_expression_->Accept(*this); + if (list_value.IsNull()) { + return TypedValue(ctx_->memory); + } + if (list_value.type() != TypedValue::Type::List) { + throw QueryRuntimeException("NONE expected a list, got {}.", list_value.type()); + } + const auto &list = list_value.ValueList(); + const auto &symbol = symbol_table_->at(*none.identifier_); + bool has_value = false; + for (const auto &element : list) { + frame_->at(symbol) = element; + auto result = none.where_->expression_->Accept(*this); + if (!result.IsNull() && result.type() != TypedValue::Type::Bool) { + throw QueryRuntimeException("Predicate of NONE must evaluate to boolean, got {}.", result.type()); + } + if (!result.IsNull()) { + has_value = true; + if (result.ValueBool()) { + return TypedValue(false, ctx_->memory); + } + } + } + // Return Null if all elements are Null + if (!has_value) { + return TypedValue(ctx_->memory); + } else { + return TypedValue(true, ctx_->memory); + } + } + + TypedValue Visit(ParameterLookup ¶m_lookup) override { + return TypedValue(ctx_->parameters.AtTokenPosition(param_lookup.token_position_), ctx_->memory); + } + + TypedValue Visit(RegexMatch ®ex_match) override { + auto target_string_value = regex_match.string_expr_->Accept(*this); + auto regex_value = regex_match.regex_->Accept(*this); + if (target_string_value.IsNull() || regex_value.IsNull()) { + return TypedValue(ctx_->memory); + } + if (regex_value.type() != TypedValue::Type::String) { + throw QueryRuntimeException("Regular expression must evaluate to a string, got {}.", regex_value.type()); + } + if (target_string_value.type() != TypedValue::Type::String) { + // Instead of error, we return Null which makes it compatible in case we + // use indexed lookup which filters out any non-string properties. + // Assuming a property lookup is the target_string_value. + return TypedValue(ctx_->memory); + } + const auto &target_string = target_string_value.ValueString(); + try { + std::regex regex(regex_value.ValueString()); + return TypedValue(std::regex_match(target_string, regex), ctx_->memory); + } catch (const std::regex_error &e) { + throw QueryRuntimeException("Regex error in '{}': {}", regex_value.ValueString(), e.what()); + } + } + + private: + template <class TRecordAccessor> + storage::v3::PropertyValue GetProperty(const TRecordAccessor &record_accessor, PropertyIx prop) { + auto maybe_prop = record_accessor.GetProperty(view_, ctx_->properties[prop.ix]); + if (maybe_prop.HasError() && maybe_prop.GetError() == storage::v3::Error::NONEXISTENT_OBJECT) { + // This is a very nasty and temporary hack in order to make MERGE work. + // The old storage had the following logic when returning an `OLD` view: + // `return old ? old : new`. That means that if the `OLD` view didn't + // exist, it returned the NEW view. With this hack we simulate that + // behavior. + // TODO (mferencevic, teon.banek): Remove once MERGE is reimplemented. + maybe_prop = record_accessor.GetProperty(storage::v3::View::NEW, ctx_->properties[prop.ix]); + } + if (maybe_prop.HasError()) { + switch (maybe_prop.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to get a property from a deleted object."); + case storage::v3::Error::NONEXISTENT_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get a property from an object that doesn't exist."); + case storage::v3::Error::SERIALIZATION_ERROR: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Unexpected error when getting a property."); + } + } + return *maybe_prop; + } + + template <class TRecordAccessor> + storage::v3::PropertyValue GetProperty(const TRecordAccessor &record_accessor, const std::string_view name) { + auto maybe_prop = record_accessor.GetProperty(view_, dba_->NameToProperty(name)); + if (maybe_prop.HasError() && maybe_prop.GetError() == storage::v3::Error::NONEXISTENT_OBJECT) { + // This is a very nasty and temporary hack in order to make MERGE work. + // The old storage had the following logic when returning an `OLD` view: + // `return old ? old : new`. That means that if the `OLD` view didn't + // exist, it returned the NEW view. With this hack we simulate that + // behavior. + // TODO (mferencevic, teon.banek): Remove once MERGE is reimplemented. + maybe_prop = record_accessor.GetProperty(view_, dba_->NameToProperty(name)); + } + if (maybe_prop.HasError()) { + switch (maybe_prop.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to get a property from a deleted object."); + case storage::v3::Error::NONEXISTENT_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get a property from an object that doesn't exist."); + case storage::v3::Error::SERIALIZATION_ERROR: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Unexpected error when getting a property."); + } + } + return *maybe_prop; + } + + storage::v3::LabelId GetLabel(LabelIx label) { return ctx_->labels[label.ix]; } + + Frame *frame_; + const SymbolTable *symbol_table_; + const EvaluationContext *ctx_; + DbAccessor *dba_; + // which switching approach should be used when evaluating + storage::v3::View view_; +}; + +/// A helper function for evaluating an expression that's an int. +/// +/// @param what - Name of what's getting evaluated. Used for user feedback (via +/// exception) when the evaluated value is not an int. +/// @throw QueryRuntimeException if expression doesn't evaluate to an int. +int64_t EvaluateInt(ExpressionEvaluator *evaluator, Expression *expr, const std::string &what); + +std::optional<size_t> EvaluateMemoryLimit(ExpressionEvaluator *eval, Expression *memory_limit, size_t memory_scale); + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/interpret/frame.hpp b/src/query/v2/interpret/frame.hpp new file mode 100644 index 000000000..6b02a8a6c --- /dev/null +++ b/src/query/v2/interpret/frame.hpp @@ -0,0 +1,45 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <vector> + +#include "query/v2/frontend/semantic/symbol_table.hpp" +#include "query/v2/typed_value.hpp" +#include "utils/logging.hpp" +#include "utils/memory.hpp" +#include "utils/pmr/vector.hpp" + +namespace memgraph::query::v2 { + +class Frame { + public: + /// Create a Frame of given size backed by a utils::NewDeleteResource() + explicit Frame(int64_t size) : elems_(size, utils::NewDeleteResource()) { MG_ASSERT(size >= 0); } + + Frame(int64_t size, utils::MemoryResource *memory) : elems_(size, memory) { MG_ASSERT(size >= 0); } + + TypedValue &operator[](const Symbol &symbol) { return elems_[symbol.position()]; } + const TypedValue &operator[](const Symbol &symbol) const { return elems_[symbol.position()]; } + + TypedValue &at(const Symbol &symbol) { return elems_.at(symbol.position()); } + const TypedValue &at(const Symbol &symbol) const { return elems_.at(symbol.position()); } + + auto &elems() { return elems_; } + + utils::MemoryResource *GetMemoryResource() const { return elems_.get_allocator().GetMemoryResource(); } + + private: + utils::pmr::vector<TypedValue> elems_; +}; + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/interpreter.cpp b/src/query/v2/interpreter.cpp new file mode 100644 index 000000000..a60d77dec --- /dev/null +++ b/src/query/v2/interpreter.cpp @@ -0,0 +1,2412 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/interpreter.hpp" + +#include <fmt/core.h> +#include <algorithm> +#include <atomic> +#include <chrono> +#include <cstddef> +#include <cstdint> +#include <functional> +#include <limits> +#include <optional> + +#include "memory/memory_control.hpp" +#include "query/v2/constants.hpp" +#include "query/v2/context.hpp" +#include "query/v2/cypher_query_interpreter.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/dump.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/ast_visitor.hpp" +#include "query/v2/frontend/ast/cypher_main_visitor.hpp" +#include "query/v2/frontend/opencypher/parser.hpp" +#include "query/v2/frontend/semantic/required_privileges.hpp" +#include "query/v2/frontend/semantic/symbol_generator.hpp" +#include "query/v2/interpret/eval.hpp" +#include "query/v2/metadata.hpp" +#include "query/v2/plan/planner.hpp" +#include "query/v2/plan/profile.hpp" +#include "query/v2/plan/vertex_count_cache.hpp" +#include "query/v2/stream/common.hpp" +#include "query/v2/trigger.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/storage.hpp" +#include "utils/algorithm.hpp" +#include "utils/csv_parsing.hpp" +#include "utils/event_counter.hpp" +#include "utils/exceptions.hpp" +#include "utils/flag_validation.hpp" +#include "utils/license.hpp" +#include "utils/likely.hpp" +#include "utils/logging.hpp" +#include "utils/memory.hpp" +#include "utils/memory_tracker.hpp" +#include "utils/readable_size.hpp" +#include "utils/settings.hpp" +#include "utils/string.hpp" +#include "utils/tsc.hpp" +#include "utils/variant_helpers.hpp" + +namespace EventCounter { +extern Event ReadQuery; +extern Event WriteQuery; +extern Event ReadWriteQuery; + +extern const Event LabelIndexCreated; +extern const Event LabelPropertyIndexCreated; + +extern const Event StreamsCreated; +extern const Event TriggersCreated; +} // namespace EventCounter + +namespace memgraph::query::v2 { + +namespace { +void UpdateTypeCount(const plan::ReadWriteTypeChecker::RWType type) { + switch (type) { + case plan::ReadWriteTypeChecker::RWType::R: + EventCounter::IncrementCounter(EventCounter::ReadQuery); + break; + case plan::ReadWriteTypeChecker::RWType::W: + EventCounter::IncrementCounter(EventCounter::WriteQuery); + break; + case plan::ReadWriteTypeChecker::RWType::RW: + EventCounter::IncrementCounter(EventCounter::ReadWriteQuery); + break; + default: + break; + } +} + +struct Callback { + std::vector<std::string> header; + using CallbackFunction = std::function<std::vector<std::vector<TypedValue>>()>; + CallbackFunction fn; + bool should_abort_query{false}; +}; + +TypedValue EvaluateOptionalExpression(Expression *expression, ExpressionEvaluator *eval) { + return expression ? expression->Accept(*eval) : TypedValue(); +} + +template <typename TResult> +std::optional<TResult> GetOptionalValue(query::v2::Expression *expression, ExpressionEvaluator &evaluator) { + if (expression != nullptr) { + auto int_value = expression->Accept(evaluator); + MG_ASSERT(int_value.IsNull() || int_value.IsInt()); + if (int_value.IsInt()) { + return TResult{int_value.ValueInt()}; + } + } + return {}; +}; + +std::optional<std::string> GetOptionalStringValue(query::v2::Expression *expression, ExpressionEvaluator &evaluator) { + if (expression != nullptr) { + auto value = expression->Accept(evaluator); + MG_ASSERT(value.IsNull() || value.IsString()); + if (value.IsString()) { + return {std::string(value.ValueString().begin(), value.ValueString().end())}; + } + } + return {}; +}; + +class ReplQueryHandler final : public query::v2::ReplicationQueryHandler { + public: + explicit ReplQueryHandler(storage::v3::Storage *db) : db_(db) {} + + /// @throw QueryRuntimeException if an error ocurred. + void SetReplicationRole(ReplicationQuery::ReplicationRole replication_role, std::optional<int64_t> port) override { + if (replication_role == ReplicationQuery::ReplicationRole::MAIN) { + if (!db_->SetMainReplicationRole()) { + throw QueryRuntimeException("Couldn't set role to main!"); + } + } + if (replication_role == ReplicationQuery::ReplicationRole::REPLICA) { + if (!port || *port < 0 || *port > std::numeric_limits<uint16_t>::max()) { + throw QueryRuntimeException("Port number invalid!"); + } + if (!db_->SetReplicaRole( + io::network::Endpoint(query::v2::kDefaultReplicationServerIp, static_cast<uint16_t>(*port)))) { + throw QueryRuntimeException("Couldn't set role to replica!"); + } + } + } + + /// @throw QueryRuntimeException if an error ocurred. + ReplicationQuery::ReplicationRole ShowReplicationRole() const override { + switch (db_->GetReplicationRole()) { + case storage::v3::ReplicationRole::MAIN: + return ReplicationQuery::ReplicationRole::MAIN; + case storage::v3::ReplicationRole::REPLICA: + return ReplicationQuery::ReplicationRole::REPLICA; + } + throw QueryRuntimeException("Couldn't show replication role - invalid role set!"); + } + + /// @throw QueryRuntimeException if an error ocurred. + void RegisterReplica(const std::string &name, const std::string &socket_address, + const ReplicationQuery::SyncMode sync_mode, const std::optional<double> timeout, + const std::chrono::seconds replica_check_frequency) override { + if (db_->GetReplicationRole() == storage::v3::ReplicationRole::REPLICA) { + // replica can't register another replica + throw QueryRuntimeException("Replica can't register another replica!"); + } + + storage::v3::replication::ReplicationMode repl_mode; + switch (sync_mode) { + case ReplicationQuery::SyncMode::ASYNC: { + repl_mode = storage::v3::replication::ReplicationMode::ASYNC; + break; + } + case ReplicationQuery::SyncMode::SYNC: { + repl_mode = storage::v3::replication::ReplicationMode::SYNC; + break; + } + } + + auto maybe_ip_and_port = + io::network::Endpoint::ParseSocketOrIpAddress(socket_address, query::v2::kDefaultReplicationPort); + if (maybe_ip_and_port) { + auto [ip, port] = *maybe_ip_and_port; + auto ret = db_->RegisterReplica( + name, {std::move(ip), port}, repl_mode, + {.timeout = timeout, .replica_check_frequency = replica_check_frequency, .ssl = std::nullopt}); + if (ret.HasError()) { + throw QueryRuntimeException(fmt::format("Couldn't register replica '{}'!", name)); + } + } else { + throw QueryRuntimeException("Invalid socket address!"); + } + } + + /// @throw QueryRuntimeException if an error ocurred. + void DropReplica(const std::string &replica_name) override { + if (db_->GetReplicationRole() == storage::v3::ReplicationRole::REPLICA) { + // replica can't unregister a replica + throw QueryRuntimeException("Replica can't unregister a replica!"); + } + if (!db_->UnregisterReplica(replica_name)) { + throw QueryRuntimeException(fmt::format("Couldn't unregister the replica '{}'", replica_name)); + } + } + + using Replica = ReplicationQueryHandler::Replica; + std::vector<Replica> ShowReplicas() const override { + if (db_->GetReplicationRole() == storage::v3::ReplicationRole::REPLICA) { + // replica can't show registered replicas (it shouldn't have any) + throw QueryRuntimeException("Replica can't show registered replicas (it shouldn't have any)!"); + } + + auto repl_infos = db_->ReplicasInfo(); + std::vector<Replica> replicas; + replicas.reserve(repl_infos.size()); + + const auto from_info = [](const auto &repl_info) -> Replica { + Replica replica; + replica.name = repl_info.name; + replica.socket_address = repl_info.endpoint.SocketAddress(); + switch (repl_info.mode) { + case storage::v3::replication::ReplicationMode::SYNC: + replica.sync_mode = ReplicationQuery::SyncMode::SYNC; + break; + case storage::v3::replication::ReplicationMode::ASYNC: + replica.sync_mode = ReplicationQuery::SyncMode::ASYNC; + break; + } + if (repl_info.timeout) { + replica.timeout = *repl_info.timeout; + } + + switch (repl_info.state) { + case storage::v3::replication::ReplicaState::READY: + replica.state = ReplicationQuery::ReplicaState::READY; + break; + case storage::v3::replication::ReplicaState::REPLICATING: + replica.state = ReplicationQuery::ReplicaState::REPLICATING; + break; + case storage::v3::replication::ReplicaState::RECOVERY: + replica.state = ReplicationQuery::ReplicaState::RECOVERY; + break; + case storage::v3::replication::ReplicaState::INVALID: + replica.state = ReplicationQuery::ReplicaState::INVALID; + break; + } + + return replica; + }; + + std::transform(repl_infos.begin(), repl_infos.end(), std::back_inserter(replicas), from_info); + return replicas; + } + + private: + storage::v3::Storage *db_; +}; +/// returns false if the replication role can't be set +/// @throw QueryRuntimeException if an error ocurred. + +Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Parameters ¶meters, + DbAccessor *db_accessor) { + // Empty frame for evaluation of password expression. This is OK since + // password should be either null or string literal and it's evaluation + // should not depend on frame. + Frame frame(0); + SymbolTable symbol_table; + EvaluationContext evaluation_context; + // TODO: MemoryResource for EvaluationContext, it should probably be passed as + // the argument to Callback. + evaluation_context.timestamp = QueryTimestamp(); + evaluation_context.parameters = parameters; + ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::v3::View::OLD); + + std::string username = auth_query->user_; + std::string rolename = auth_query->role_; + std::string user_or_role = auth_query->user_or_role_; + std::vector<AuthQuery::Privilege> privileges = auth_query->privileges_; + auto password = EvaluateOptionalExpression(auth_query->password_, &evaluator); + + Callback callback; + + const auto license_check_result = utils::license::global_license_checker.IsValidLicense(utils::global_settings); + + static const std::unordered_set enterprise_only_methods{ + AuthQuery::Action::CREATE_ROLE, AuthQuery::Action::DROP_ROLE, AuthQuery::Action::SET_ROLE, + AuthQuery::Action::CLEAR_ROLE, AuthQuery::Action::GRANT_PRIVILEGE, AuthQuery::Action::DENY_PRIVILEGE, + AuthQuery::Action::REVOKE_PRIVILEGE, AuthQuery::Action::SHOW_PRIVILEGES, AuthQuery::Action::SHOW_USERS_FOR_ROLE, + AuthQuery::Action::SHOW_ROLE_FOR_USER}; + + if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) { + throw utils::BasicException( + utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication features")); + } + + switch (auth_query->action_) { + case AuthQuery::Action::CREATE_USER: + callback.fn = [auth, username, password, valid_enterprise_license = !license_check_result.HasError()] { + MG_ASSERT(password.IsString() || password.IsNull()); + if (!auth->CreateUser(username, password.IsString() ? std::make_optional(std::string(password.ValueString())) + : std::nullopt)) { + throw QueryRuntimeException("User '{}' already exists.", username); + } + + // If the license is not valid we create users with admin access + if (!valid_enterprise_license) { + spdlog::warn("Granting all the privileges to {}.", username); + auth->GrantPrivilege(username, kPrivilegesAll); + } + + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + case AuthQuery::Action::DROP_USER: + callback.fn = [auth, username] { + if (!auth->DropUser(username)) { + throw QueryRuntimeException("User '{}' doesn't exist.", username); + } + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + case AuthQuery::Action::SET_PASSWORD: + callback.fn = [auth, username, password] { + MG_ASSERT(password.IsString() || password.IsNull()); + auth->SetPassword(username, + password.IsString() ? std::make_optional(std::string(password.ValueString())) : std::nullopt); + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + case AuthQuery::Action::CREATE_ROLE: + callback.fn = [auth, rolename] { + if (!auth->CreateRole(rolename)) { + throw QueryRuntimeException("Role '{}' already exists.", rolename); + } + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + case AuthQuery::Action::DROP_ROLE: + callback.fn = [auth, rolename] { + if (!auth->DropRole(rolename)) { + throw QueryRuntimeException("Role '{}' doesn't exist.", rolename); + } + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + case AuthQuery::Action::SHOW_USERS: + callback.header = {"user"}; + callback.fn = [auth] { + std::vector<std::vector<TypedValue>> rows; + auto usernames = auth->GetUsernames(); + rows.reserve(usernames.size()); + for (auto &&username : usernames) { + rows.emplace_back(std::vector<TypedValue>{username}); + } + return rows; + }; + return callback; + case AuthQuery::Action::SHOW_ROLES: + callback.header = {"role"}; + callback.fn = [auth] { + std::vector<std::vector<TypedValue>> rows; + auto rolenames = auth->GetRolenames(); + rows.reserve(rolenames.size()); + for (auto &&rolename : rolenames) { + rows.emplace_back(std::vector<TypedValue>{rolename}); + } + return rows; + }; + return callback; + case AuthQuery::Action::SET_ROLE: + callback.fn = [auth, username, rolename] { + auth->SetRole(username, rolename); + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + case AuthQuery::Action::CLEAR_ROLE: + callback.fn = [auth, username] { + auth->ClearRole(username); + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + case AuthQuery::Action::GRANT_PRIVILEGE: + callback.fn = [auth, user_or_role, privileges] { + auth->GrantPrivilege(user_or_role, privileges); + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + case AuthQuery::Action::DENY_PRIVILEGE: + callback.fn = [auth, user_or_role, privileges] { + auth->DenyPrivilege(user_or_role, privileges); + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + case AuthQuery::Action::REVOKE_PRIVILEGE: { + callback.fn = [auth, user_or_role, privileges] { + auth->RevokePrivilege(user_or_role, privileges); + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + } + case AuthQuery::Action::SHOW_PRIVILEGES: + callback.header = {"privilege", "effective", "description"}; + callback.fn = [auth, user_or_role] { return auth->GetPrivileges(user_or_role); }; + return callback; + case AuthQuery::Action::SHOW_ROLE_FOR_USER: + callback.header = {"role"}; + callback.fn = [auth, username] { + auto maybe_rolename = auth->GetRolenameForUser(username); + return std::vector<std::vector<TypedValue>>{ + std::vector<TypedValue>{TypedValue(maybe_rolename ? *maybe_rolename : "null")}}; + }; + return callback; + case AuthQuery::Action::SHOW_USERS_FOR_ROLE: + callback.header = {"users"}; + callback.fn = [auth, rolename] { + std::vector<std::vector<TypedValue>> rows; + auto usernames = auth->GetUsernamesForRole(rolename); + rows.reserve(usernames.size()); + for (auto &&username : usernames) { + rows.emplace_back(std::vector<TypedValue>{username}); + } + return rows; + }; + return callback; + default: + break; + } +} + +Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters ¶meters, + InterpreterContext *interpreter_context, DbAccessor *db_accessor, + std::vector<Notification> *notifications) { + Frame frame(0); + SymbolTable symbol_table; + EvaluationContext evaluation_context; + // TODO: MemoryResource for EvaluationContext, it should probably be passed as + // the argument to Callback. + evaluation_context.timestamp = QueryTimestamp(); + evaluation_context.parameters = parameters; + ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::v3::View::OLD); + + Callback callback; + switch (repl_query->action_) { + case ReplicationQuery::Action::SET_REPLICATION_ROLE: { + auto port = EvaluateOptionalExpression(repl_query->port_, &evaluator); + std::optional<int64_t> maybe_port; + if (port.IsInt()) { + maybe_port = port.ValueInt(); + } + if (maybe_port == 7687 && repl_query->role_ == ReplicationQuery::ReplicationRole::REPLICA) { + notifications->emplace_back(SeverityLevel::WARNING, NotificationCode::REPLICA_PORT_WARNING, + "Be careful the replication port must be different from the memgraph port!"); + } + callback.fn = [handler = ReplQueryHandler{interpreter_context->db}, role = repl_query->role_, + maybe_port]() mutable { + handler.SetReplicationRole(role, maybe_port); + return std::vector<std::vector<TypedValue>>(); + }; + notifications->emplace_back( + SeverityLevel::INFO, NotificationCode::SET_REPLICA, + fmt::format("Replica role set to {}.", + repl_query->role_ == ReplicationQuery::ReplicationRole::MAIN ? "MAIN" : "REPLICA")); + return callback; + } + case ReplicationQuery::Action::SHOW_REPLICATION_ROLE: { + callback.header = {"replication role"}; + callback.fn = [handler = ReplQueryHandler{interpreter_context->db}] { + auto mode = handler.ShowReplicationRole(); + switch (mode) { + case ReplicationQuery::ReplicationRole::MAIN: { + return std::vector<std::vector<TypedValue>>{{TypedValue("main")}}; + } + case ReplicationQuery::ReplicationRole::REPLICA: { + return std::vector<std::vector<TypedValue>>{{TypedValue("replica")}}; + } + } + }; + return callback; + } + case ReplicationQuery::Action::REGISTER_REPLICA: { + const auto &name = repl_query->replica_name_; + const auto &sync_mode = repl_query->sync_mode_; + auto socket_address = repl_query->socket_address_->Accept(evaluator); + auto timeout = EvaluateOptionalExpression(repl_query->timeout_, &evaluator); + const auto replica_check_frequency = interpreter_context->config.replication_replica_check_frequency; + std::optional<double> maybe_timeout; + if (timeout.IsDouble()) { + maybe_timeout = timeout.ValueDouble(); + } else if (timeout.IsInt()) { + maybe_timeout = static_cast<double>(timeout.ValueInt()); + } + callback.fn = [handler = ReplQueryHandler{interpreter_context->db}, name, socket_address, sync_mode, + maybe_timeout, replica_check_frequency]() mutable { + handler.RegisterReplica(name, std::string(socket_address.ValueString()), sync_mode, maybe_timeout, + replica_check_frequency); + return std::vector<std::vector<TypedValue>>(); + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::REGISTER_REPLICA, + fmt::format("Replica {} is registered.", repl_query->replica_name_)); + return callback; + } + + case ReplicationQuery::Action::DROP_REPLICA: { + const auto &name = repl_query->replica_name_; + callback.fn = [handler = ReplQueryHandler{interpreter_context->db}, name]() mutable { + handler.DropReplica(name); + return std::vector<std::vector<TypedValue>>(); + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::DROP_REPLICA, + fmt::format("Replica {} is dropped.", repl_query->replica_name_)); + return callback; + } + + case ReplicationQuery::Action::SHOW_REPLICAS: { + callback.header = {"name", "socket_address", "sync_mode", "timeout", "state"}; + callback.fn = [handler = ReplQueryHandler{interpreter_context->db}, replica_nfields = callback.header.size()] { + const auto &replicas = handler.ShowReplicas(); + auto typed_replicas = std::vector<std::vector<TypedValue>>{}; + typed_replicas.reserve(replicas.size()); + for (const auto &replica : replicas) { + std::vector<TypedValue> typed_replica; + typed_replica.reserve(replica_nfields); + + typed_replica.emplace_back(TypedValue(replica.name)); + typed_replica.emplace_back(TypedValue(replica.socket_address)); + + switch (replica.sync_mode) { + case ReplicationQuery::SyncMode::SYNC: + typed_replica.emplace_back(TypedValue("sync")); + break; + case ReplicationQuery::SyncMode::ASYNC: + typed_replica.emplace_back(TypedValue("async")); + break; + } + + if (replica.timeout) { + typed_replica.emplace_back(TypedValue(*replica.timeout)); + } else { + typed_replica.emplace_back(TypedValue()); + } + + switch (replica.state) { + case ReplicationQuery::ReplicaState::READY: + typed_replica.emplace_back(TypedValue("ready")); + break; + case ReplicationQuery::ReplicaState::REPLICATING: + typed_replica.emplace_back(TypedValue("replicating")); + break; + case ReplicationQuery::ReplicaState::RECOVERY: + typed_replica.emplace_back(TypedValue("recovery")); + break; + case ReplicationQuery::ReplicaState::INVALID: + typed_replica.emplace_back(TypedValue("invalid")); + break; + } + + typed_replicas.emplace_back(std::move(typed_replica)); + } + return typed_replicas; + }; + return callback; + } + } +} + +std::optional<std::string> StringPointerToOptional(const std::string *str) { + return str == nullptr ? std::nullopt : std::make_optional(*str); +} + +stream::CommonStreamInfo GetCommonStreamInfo(StreamQuery *stream_query, ExpressionEvaluator &evaluator) { + return { + .batch_interval = GetOptionalValue<std::chrono::milliseconds>(stream_query->batch_interval_, evaluator) + .value_or(stream::kDefaultBatchInterval), + .batch_size = GetOptionalValue<int64_t>(stream_query->batch_size_, evaluator).value_or(stream::kDefaultBatchSize), + .transformation_name = stream_query->transform_name_}; +} + +std::vector<std::string> EvaluateTopicNames(ExpressionEvaluator &evaluator, + std::variant<Expression *, std::vector<std::string>> topic_variant) { + return std::visit(utils::Overloaded{[&](Expression *expression) { + auto topic_names = expression->Accept(evaluator); + MG_ASSERT(topic_names.IsString()); + return utils::Split(topic_names.ValueString(), ","); + }, + [&](std::vector<std::string> topic_names) { return topic_names; }}, + std::move(topic_variant)); +} + +Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, ExpressionEvaluator &evaluator, + InterpreterContext *interpreter_context, + const std::string *username) { + static constexpr std::string_view kDefaultConsumerGroup = "mg_consumer"; + std::string consumer_group{stream_query->consumer_group_.empty() ? kDefaultConsumerGroup + : stream_query->consumer_group_}; + + auto bootstrap = GetOptionalStringValue(stream_query->bootstrap_servers_, evaluator); + if (bootstrap && bootstrap->empty()) { + throw SemanticException("Bootstrap servers must not be an empty string!"); + } + auto common_stream_info = GetCommonStreamInfo(stream_query, evaluator); + + const auto get_config_map = [&evaluator](std::unordered_map<Expression *, Expression *> map, + std::string_view map_name) -> std::unordered_map<std::string, std::string> { + std::unordered_map<std::string, std::string> config_map; + for (const auto [key_expr, value_expr] : map) { + const auto key = key_expr->Accept(evaluator); + const auto value = value_expr->Accept(evaluator); + if (!key.IsString() || !value.IsString()) { + throw SemanticException("{} must contain only string keys and values!", map_name); + } + config_map.emplace(key.ValueString(), value.ValueString()); + } + return config_map; + }; + + return [interpreter_context, stream_name = stream_query->stream_name_, + topic_names = EvaluateTopicNames(evaluator, stream_query->topic_names_), + consumer_group = std::move(consumer_group), common_stream_info = std::move(common_stream_info), + bootstrap_servers = std::move(bootstrap), owner = StringPointerToOptional(username), + configs = get_config_map(stream_query->configs_, "Configs"), + credentials = get_config_map(stream_query->credentials_, "Credentials")]() mutable { + std::string bootstrap = bootstrap_servers + ? std::move(*bootstrap_servers) + : std::string{interpreter_context->config.default_kafka_bootstrap_servers}; + interpreter_context->streams.Create<query::v2::stream::KafkaStream>(stream_name, + {.common_info = std::move(common_stream_info), + .topics = std::move(topic_names), + .consumer_group = std::move(consumer_group), + .bootstrap_servers = std::move(bootstrap), + .configs = std::move(configs), + .credentials = std::move(credentials)}, + std::move(owner)); + + return std::vector<std::vector<TypedValue>>{}; + }; +} + +Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, ExpressionEvaluator &evaluator, + InterpreterContext *interpreter_context, + const std::string *username) { + auto service_url = GetOptionalStringValue(stream_query->service_url_, evaluator); + if (service_url && service_url->empty()) { + throw SemanticException("Service URL must not be an empty string!"); + } + auto common_stream_info = GetCommonStreamInfo(stream_query, evaluator); + return [interpreter_context, stream_name = stream_query->stream_name_, + topic_names = EvaluateTopicNames(evaluator, stream_query->topic_names_), + common_stream_info = std::move(common_stream_info), service_url = std::move(service_url), + owner = StringPointerToOptional(username)]() mutable { + std::string url = + service_url ? std::move(*service_url) : std::string{interpreter_context->config.default_pulsar_service_url}; + interpreter_context->streams.Create<query::v2::stream::PulsarStream>( + stream_name, + {.common_info = std::move(common_stream_info), .topics = std::move(topic_names), .service_url = std::move(url)}, + std::move(owner)); + + return std::vector<std::vector<TypedValue>>{}; + }; +} + +Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶meters, + InterpreterContext *interpreter_context, DbAccessor *db_accessor, + const std::string *username, std::vector<Notification> *notifications) { + Frame frame(0); + SymbolTable symbol_table; + EvaluationContext evaluation_context; + // TODO: MemoryResource for EvaluationContext, it should probably be passed as + // the argument to Callback. + evaluation_context.timestamp = QueryTimestamp(); + evaluation_context.parameters = parameters; + ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::v3::View::OLD); + + Callback callback; + switch (stream_query->action_) { + case StreamQuery::Action::CREATE_STREAM: { + EventCounter::IncrementCounter(EventCounter::StreamsCreated); + switch (stream_query->type_) { + case StreamQuery::Type::KAFKA: + callback.fn = GetKafkaCreateCallback(stream_query, evaluator, interpreter_context, username); + break; + case StreamQuery::Type::PULSAR: + callback.fn = GetPulsarCreateCallback(stream_query, evaluator, interpreter_context, username); + break; + } + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::CREATE_STREAM, + fmt::format("Created stream {}.", stream_query->stream_name_)); + return callback; + } + case StreamQuery::Action::START_STREAM: { + const auto batch_limit = GetOptionalValue<int64_t>(stream_query->batch_limit_, evaluator); + const auto timeout = GetOptionalValue<std::chrono::milliseconds>(stream_query->timeout_, evaluator); + + if (batch_limit.has_value()) { + if (batch_limit.value() < 0) { + throw utils::BasicException("Parameter BATCH_LIMIT cannot hold negative value"); + } + + callback.fn = [interpreter_context, stream_name = stream_query->stream_name_, batch_limit, timeout]() { + interpreter_context->streams.StartWithLimit(stream_name, static_cast<uint64_t>(batch_limit.value()), timeout); + return std::vector<std::vector<TypedValue>>{}; + }; + } else { + callback.fn = [interpreter_context, stream_name = stream_query->stream_name_]() { + interpreter_context->streams.Start(stream_name); + return std::vector<std::vector<TypedValue>>{}; + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::START_STREAM, + fmt::format("Started stream {}.", stream_query->stream_name_)); + } + return callback; + } + case StreamQuery::Action::START_ALL_STREAMS: { + callback.fn = [interpreter_context]() { + interpreter_context->streams.StartAll(); + return std::vector<std::vector<TypedValue>>{}; + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::START_ALL_STREAMS, "Started all streams."); + return callback; + } + case StreamQuery::Action::STOP_STREAM: { + callback.fn = [interpreter_context, stream_name = stream_query->stream_name_]() { + interpreter_context->streams.Stop(stream_name); + return std::vector<std::vector<TypedValue>>{}; + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::STOP_STREAM, + fmt::format("Stopped stream {}.", stream_query->stream_name_)); + return callback; + } + case StreamQuery::Action::STOP_ALL_STREAMS: { + callback.fn = [interpreter_context]() { + interpreter_context->streams.StopAll(); + return std::vector<std::vector<TypedValue>>{}; + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::STOP_ALL_STREAMS, "Stopped all streams."); + return callback; + } + case StreamQuery::Action::DROP_STREAM: { + callback.fn = [interpreter_context, stream_name = stream_query->stream_name_]() { + interpreter_context->streams.Drop(stream_name); + return std::vector<std::vector<TypedValue>>{}; + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::DROP_STREAM, + fmt::format("Dropped stream {}.", stream_query->stream_name_)); + return callback; + } + case StreamQuery::Action::SHOW_STREAMS: { + callback.header = {"name", "type", "batch_interval", "batch_size", "transformation_name", "owner", "is running"}; + callback.fn = [interpreter_context]() { + auto streams_status = interpreter_context->streams.GetStreamInfo(); + std::vector<std::vector<TypedValue>> results; + results.reserve(streams_status.size()); + auto stream_info_as_typed_stream_info_emplace_in = [](auto &typed_status, const auto &stream_info) { + typed_status.emplace_back(stream_info.batch_interval.count()); + typed_status.emplace_back(stream_info.batch_size); + typed_status.emplace_back(stream_info.transformation_name); + }; + + for (const auto &status : streams_status) { + std::vector<TypedValue> typed_status; + typed_status.reserve(7); + typed_status.emplace_back(status.name); + typed_status.emplace_back(StreamSourceTypeToString(status.type)); + stream_info_as_typed_stream_info_emplace_in(typed_status, status.info); + if (status.owner.has_value()) { + typed_status.emplace_back(*status.owner); + } else { + typed_status.emplace_back(); + } + typed_status.emplace_back(status.is_running); + results.push_back(std::move(typed_status)); + } + + return results; + }; + return callback; + } + case StreamQuery::Action::CHECK_STREAM: { + callback.header = {"queries", "raw messages"}; + + const auto batch_limit = GetOptionalValue<int64_t>(stream_query->batch_limit_, evaluator); + if (batch_limit.has_value() && batch_limit.value() < 0) { + throw utils::BasicException("Parameter BATCH_LIMIT cannot hold negative value"); + } + + callback.fn = [interpreter_context, stream_name = stream_query->stream_name_, + timeout = GetOptionalValue<std::chrono::milliseconds>(stream_query->timeout_, evaluator), + batch_limit]() mutable { + return interpreter_context->streams.Check(stream_name, timeout, batch_limit); + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::CHECK_STREAM, + fmt::format("Checked stream {}.", stream_query->stream_name_)); + return callback; + } + } +} + +Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters ¶meters, DbAccessor *db_accessor) { + Frame frame(0); + SymbolTable symbol_table; + EvaluationContext evaluation_context; + // TODO: MemoryResource for EvaluationContext, it should probably be passed as + // the argument to Callback. + evaluation_context.timestamp = + std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()) + .count(); + evaluation_context.parameters = parameters; + ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::v3::View::OLD); + + Callback callback; + switch (setting_query->action_) { + case SettingQuery::Action::SET_SETTING: { + const auto setting_name = EvaluateOptionalExpression(setting_query->setting_name_, &evaluator); + if (!setting_name.IsString()) { + throw utils::BasicException("Setting name should be a string literal"); + } + + const auto setting_value = EvaluateOptionalExpression(setting_query->setting_value_, &evaluator); + if (!setting_value.IsString()) { + throw utils::BasicException("Setting value should be a string literal"); + } + + callback.fn = [setting_name = std::string{setting_name.ValueString()}, + setting_value = std::string{setting_value.ValueString()}]() mutable { + if (!utils::global_settings.SetValue(setting_name, setting_value)) { + throw utils::BasicException("Unknown setting name '{}'", setting_name); + } + return std::vector<std::vector<TypedValue>>{}; + }; + return callback; + } + case SettingQuery::Action::SHOW_SETTING: { + const auto setting_name = EvaluateOptionalExpression(setting_query->setting_name_, &evaluator); + if (!setting_name.IsString()) { + throw utils::BasicException("Setting name should be a string literal"); + } + + callback.header = {"setting_value"}; + callback.fn = [setting_name = std::string{setting_name.ValueString()}] { + auto maybe_value = utils::global_settings.GetValue(setting_name); + if (!maybe_value) { + throw utils::BasicException("Unknown setting name '{}'", setting_name); + } + std::vector<std::vector<TypedValue>> results; + results.reserve(1); + + std::vector<TypedValue> setting_value; + setting_value.reserve(1); + + setting_value.emplace_back(*maybe_value); + results.push_back(std::move(setting_value)); + return results; + }; + return callback; + } + case SettingQuery::Action::SHOW_ALL_SETTINGS: { + callback.header = {"setting_name", "setting_value"}; + callback.fn = [] { + auto all_settings = utils::global_settings.AllSettings(); + std::vector<std::vector<TypedValue>> results; + results.reserve(all_settings.size()); + + for (const auto &[k, v] : all_settings) { + std::vector<TypedValue> setting_info; + setting_info.reserve(2); + + setting_info.emplace_back(k); + setting_info.emplace_back(v); + results.push_back(std::move(setting_info)); + } + + return results; + }; + return callback; + } + } +} + +// Struct for lazy pulling from a vector +struct PullPlanVector { + explicit PullPlanVector(std::vector<std::vector<TypedValue>> values) : values_(std::move(values)) {} + + // @return true if there are more unstreamed elements in vector, + // false otherwise. + bool Pull(AnyStream *stream, std::optional<int> n) { + int local_counter{0}; + while (global_counter < values_.size() && (!n || local_counter < n)) { + stream->Result(values_[global_counter]); + ++global_counter; + ++local_counter; + } + + return global_counter == values_.size(); + } + + private: + int global_counter{0}; + std::vector<std::vector<TypedValue>> values_; +}; + +struct PullPlan { + explicit PullPlan(std::shared_ptr<CachedPlan> plan, const Parameters ¶meters, bool is_profile_query, + DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, + TriggerContextCollector *trigger_context_collector = nullptr, + std::optional<size_t> memory_limit = {}); + std::optional<plan::ProfilingStatsWithTotalTime> Pull(AnyStream *stream, std::optional<int> n, + const std::vector<Symbol> &output_symbols, + std::map<std::string, TypedValue> *summary); + + private: + std::shared_ptr<CachedPlan> plan_ = nullptr; + plan::UniqueCursorPtr cursor_ = nullptr; + Frame frame_; + ExecutionContext ctx_; + std::optional<size_t> memory_limit_; + + // As it's possible to query execution using multiple pulls + // we need the keep track of the total execution time across + // those pulls by accumulating the execution time. + std::chrono::duration<double> execution_time_{0}; + + // To pull the results from a query we call the `Pull` method on + // the cursor which saves the results in a Frame. + // Becuase we can't find out if there are some saved results in a frame, + // and the cursor cannot deduce if the next pull will have a result, + // we have to keep track of any unsent results from previous `PullPlan::Pull` + // manually by using this flag. + bool has_unsent_results_ = false; +}; + +PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters ¶meters, const bool is_profile_query, + DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, + TriggerContextCollector *trigger_context_collector, const std::optional<size_t> memory_limit) + : plan_(plan), + cursor_(plan->plan().MakeCursor(execution_memory)), + frame_(plan->symbol_table().max_position(), execution_memory), + memory_limit_(memory_limit) { + ctx_.db_accessor = dba; + ctx_.symbol_table = plan->symbol_table(); + ctx_.evaluation_context.timestamp = QueryTimestamp(); + ctx_.evaluation_context.parameters = parameters; + ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba); + ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba); + if (interpreter_context->config.execution_timeout_sec > 0) { + ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec}; + } + ctx_.is_shutting_down = &interpreter_context->is_shutting_down; + ctx_.is_profile_query = is_profile_query; + ctx_.trigger_context_collector = trigger_context_collector; +} + +std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *stream, std::optional<int> n, + const std::vector<Symbol> &output_symbols, + std::map<std::string, TypedValue> *summary) { + // Set up temporary memory for a single Pull. Initial memory comes from the + // stack. 256 KiB should fit on the stack and should be more than enough for a + // single `Pull`. + static constexpr size_t stack_size = 256UL * 1024UL; + char stack_data[stack_size]; + utils::ResourceWithOutOfMemoryException resource_with_exception; + utils::MonotonicBufferResource monotonic_memory(&stack_data[0], stack_size, &resource_with_exception); + // We can throw on every query because a simple queries for deleting will use only + // the stack allocated buffer. + // Also, we want to throw only when the query engine requests more memory and not the storage + // so we add the exception to the allocator. + // TODO (mferencevic): Tune the parameters accordingly. + utils::PoolResource pool_memory(128, 1024, &monotonic_memory); + std::optional<utils::LimitedMemoryResource> maybe_limited_resource; + + if (memory_limit_) { + maybe_limited_resource.emplace(&pool_memory, *memory_limit_); + ctx_.evaluation_context.memory = &*maybe_limited_resource; + } else { + ctx_.evaluation_context.memory = &pool_memory; + } + + // Returns true if a result was pulled. + const auto pull_result = [&]() -> bool { return cursor_->Pull(frame_, ctx_); }; + + const auto stream_values = [&]() { + // TODO: The streamed values should also probably use the above memory. + std::vector<TypedValue> values; + values.reserve(output_symbols.size()); + + for (const auto &symbol : output_symbols) { + values.emplace_back(frame_[symbol]); + } + + stream->Result(values); + }; + + // Get the execution time of all possible result pulls and streams. + utils::Timer timer; + + int i = 0; + if (has_unsent_results_ && !output_symbols.empty()) { + // stream unsent results from previous pull + stream_values(); + ++i; + } + + for (; !n || i < n; ++i) { + if (!pull_result()) { + break; + } + + if (!output_symbols.empty()) { + stream_values(); + } + } + + // If we finished because we streamed the requested n results, + // we try to pull the next result to see if there is more. + // If there is additional result, we leave the pulled result in the frame + // and set the flag to true. + has_unsent_results_ = i == n && pull_result(); + + execution_time_ += timer.Elapsed(); + + if (has_unsent_results_) { + return std::nullopt; + } + summary->insert_or_assign("plan_execution_time", execution_time_.count()); + // We are finished with pulling all the data, therefore we can send any + // metadata about the results i.e. notifications and statistics + const bool is_any_counter_set = + std::any_of(ctx_.execution_stats.counters.begin(), ctx_.execution_stats.counters.end(), + [](const auto &counter) { return counter > 0; }); + if (is_any_counter_set) { + std::map<std::string, TypedValue> stats; + for (size_t i = 0; i < ctx_.execution_stats.counters.size(); ++i) { + stats.emplace(ExecutionStatsKeyToString(ExecutionStats::Key(i)), ctx_.execution_stats.counters[i]); + } + summary->insert_or_assign("stats", std::move(stats)); + } + cursor_->Shutdown(); + ctx_.profile_execution_time = execution_time_; + return GetStatsWithTotalTime(ctx_); +} + +using RWType = plan::ReadWriteTypeChecker::RWType; +} // namespace + +InterpreterContext::InterpreterContext(storage::v3::Storage *db, const InterpreterConfig config, + const std::filesystem::path &data_directory) + : db(db), trigger_store(data_directory / "triggers"), config(config), streams{this, data_directory / "streams"} {} + +Interpreter::Interpreter(InterpreterContext *interpreter_context) : interpreter_context_(interpreter_context) { + MG_ASSERT(interpreter_context_, "Interpreter context must not be NULL"); +} + +PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper) { + std::function<void()> handler; + + if (query_upper == "BEGIN") { + handler = [this] { + if (in_explicit_transaction_) { + throw ExplicitTransactionUsageException("Nested transactions are not supported."); + } + in_explicit_transaction_ = true; + expect_rollback_ = false; + + db_accessor_ = std::make_unique<storage::v3::Storage::Accessor>( + interpreter_context_->db->Access(GetIsolationLevelOverride())); + execution_db_accessor_.emplace(db_accessor_.get()); + + if (interpreter_context_->trigger_store.HasTriggers()) { + trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes()); + } + }; + } else if (query_upper == "COMMIT") { + handler = [this] { + if (!in_explicit_transaction_) { + throw ExplicitTransactionUsageException("No current transaction to commit."); + } + if (expect_rollback_) { + throw ExplicitTransactionUsageException( + "Transaction can't be committed because there was a previous " + "error. Please invoke a rollback instead."); + } + + try { + Commit(); + } catch (const utils::BasicException &) { + AbortCommand(nullptr); + throw; + } + + expect_rollback_ = false; + in_explicit_transaction_ = false; + }; + } else if (query_upper == "ROLLBACK") { + handler = [this] { + if (!in_explicit_transaction_) { + throw ExplicitTransactionUsageException("No current transaction to rollback."); + } + Abort(); + expect_rollback_ = false; + in_explicit_transaction_ = false; + }; + } else { + LOG_FATAL("Should not get here -- unknown transaction query!"); + } + + return {{}, + {}, + [handler = std::move(handler)](AnyStream *, std::optional<int>) { + handler(); + return QueryHandlerResult::NOTHING; + }, + RWType::NONE}; +} + +PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary, + InterpreterContext *interpreter_context, DbAccessor *dba, + utils::MemoryResource *execution_memory, std::vector<Notification> *notifications, + TriggerContextCollector *trigger_context_collector = nullptr) { + auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query); + + Frame frame(0); + SymbolTable symbol_table; + EvaluationContext evaluation_context; + evaluation_context.timestamp = QueryTimestamp(); + evaluation_context.parameters = parsed_query.parameters; + ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, dba, storage::v3::View::OLD); + const auto memory_limit = EvaluateMemoryLimit(&evaluator, cypher_query->memory_limit_, cypher_query->memory_scale_); + if (memory_limit) { + spdlog::info("Running query with memory limit of {}", utils::GetReadableSize(*memory_limit)); + } + + if (const auto &clauses = cypher_query->single_query_->clauses_; std::any_of( + clauses.begin(), clauses.end(), [](const auto *clause) { return clause->GetTypeInfo() == LoadCsv::kType; })) { + notifications->emplace_back( + SeverityLevel::INFO, NotificationCode::LOAD_CSV_TIP, + "It's important to note that the parser parses the values as strings. It's up to the user to " + "convert the parsed row values to the appropriate type. This can be done using the built-in " + "conversion functions such as ToInteger, ToFloat, ToBoolean etc."); + } + + auto plan = CypherQueryToPlan(parsed_query.stripped_query.hash(), std::move(parsed_query.ast_storage), cypher_query, + parsed_query.parameters, + parsed_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba); + + summary->insert_or_assign("cost_estimate", plan->cost()); + auto rw_type_checker = plan::ReadWriteTypeChecker(); + rw_type_checker.InferRWType(const_cast<plan::LogicalOperator &>(plan->plan())); + + auto output_symbols = plan->plan().OutputSymbols(plan->symbol_table()); + + std::vector<std::string> header; + header.reserve(output_symbols.size()); + + for (const auto &symbol : output_symbols) { + // When the symbol is aliased or expanded from '*' (inside RETURN or + // WITH), then there is no token position, so use symbol name. + // Otherwise, find the name from stripped query. + header.push_back( + utils::FindOr(parsed_query.stripped_query.named_expressions(), symbol.token_position(), symbol.name()).first); + } + auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, + execution_memory, trigger_context_collector, memory_limit); + return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges), + [pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary]( + AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { + if (pull_plan->Pull(stream, n, output_symbols, summary)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + rw_type_checker.type}; +} + +PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary, + InterpreterContext *interpreter_context, DbAccessor *dba, + utils::MemoryResource *execution_memory) { + const std::string kExplainQueryStart = "explain "; + MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kExplainQueryStart), + "Expected stripped query to start with '{}'", kExplainQueryStart); + + // Parse and cache the inner query separately (as if it was a standalone + // query), producing a fresh AST. Note that currently we cannot just reuse + // part of the already produced AST because the parameters within ASTs are + // looked up using their positions within the string that was parsed. These + // wouldn't match up if if we were to reuse the AST (produced by parsing the + // full query string) when given just the inner query to execute. + ParsedQuery parsed_inner_query = + ParseQuery(parsed_query.query_string.substr(kExplainQueryStart.size()), parsed_query.user_parameters, + &interpreter_context->ast_cache, &interpreter_context->antlr_lock, interpreter_context->config.query); + + auto *cypher_query = utils::Downcast<CypherQuery>(parsed_inner_query.query); + MG_ASSERT(cypher_query, "Cypher grammar should not allow other queries in EXPLAIN"); + + auto cypher_query_plan = CypherQueryToPlan( + parsed_inner_query.stripped_query.hash(), std::move(parsed_inner_query.ast_storage), cypher_query, + parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba); + + std::stringstream printed_plan; + plan::PrettyPrint(*dba, &cypher_query_plan->plan(), &printed_plan); + + std::vector<std::vector<TypedValue>> printed_plan_rows; + for (const auto &row : utils::Split(utils::RTrim(printed_plan.str()), "\n")) { + printed_plan_rows.push_back(std::vector<TypedValue>{TypedValue(row)}); + } + + summary->insert_or_assign("explain", plan::PlanToJson(*dba, &cypher_query_plan->plan()).dump()); + + return PreparedQuery{{"QUERY PLAN"}, + std::move(parsed_query.required_privileges), + [pull_plan = std::make_shared<PullPlanVector>(std::move(printed_plan_rows))]( + AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE}; +} + +PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_transaction, + std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, + DbAccessor *dba, utils::MemoryResource *execution_memory) { + const std::string kProfileQueryStart = "profile "; + + MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kProfileQueryStart), + "Expected stripped query to start with '{}'", kProfileQueryStart); + + // PROFILE isn't allowed inside multi-command (explicit) transactions. This is + // because PROFILE executes each PROFILE'd query and collects additional + // perfomance metadata that it displays to the user instead of the results + // yielded by the query. Because PROFILE has side-effects, each transaction + // that is used to execute a PROFILE query *MUST* be aborted. That isn't + // possible when using multicommand (explicit) transactions (because the user + // controls the lifetime of the transaction) and that is why PROFILE is + // explicitly disabled here in multicommand (explicit) transactions. + // NOTE: Unlike PROFILE, EXPLAIN doesn't have any unwanted side-effects (in + // transaction terms) because it doesn't execute the query, it just prints its + // query plan. That is why EXPLAIN can be used in multicommand (explicit) + // transactions. + if (in_explicit_transaction) { + throw ProfileInMulticommandTxException(); + } + + if (!interpreter_context->tsc_frequency) { + throw QueryException("TSC support is missing for PROFILE"); + } + + // Parse and cache the inner query separately (as if it was a standalone + // query), producing a fresh AST. Note that currently we cannot just reuse + // part of the already produced AST because the parameters within ASTs are + // looked up using their positions within the string that was parsed. These + // wouldn't match up if if we were to reuse the AST (produced by parsing the + // full query string) when given just the inner query to execute. + ParsedQuery parsed_inner_query = + ParseQuery(parsed_query.query_string.substr(kProfileQueryStart.size()), parsed_query.user_parameters, + &interpreter_context->ast_cache, &interpreter_context->antlr_lock, interpreter_context->config.query); + + auto *cypher_query = utils::Downcast<CypherQuery>(parsed_inner_query.query); + MG_ASSERT(cypher_query, "Cypher grammar should not allow other queries in PROFILE"); + Frame frame(0); + SymbolTable symbol_table; + EvaluationContext evaluation_context; + evaluation_context.timestamp = QueryTimestamp(); + evaluation_context.parameters = parsed_inner_query.parameters; + ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, dba, storage::v3::View::OLD); + const auto memory_limit = EvaluateMemoryLimit(&evaluator, cypher_query->memory_limit_, cypher_query->memory_scale_); + + auto cypher_query_plan = CypherQueryToPlan( + parsed_inner_query.stripped_query.hash(), std::move(parsed_inner_query.ast_storage), cypher_query, + parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba); + auto rw_type_checker = plan::ReadWriteTypeChecker(); + rw_type_checker.InferRWType(const_cast<plan::LogicalOperator &>(cypher_query_plan->plan())); + + return PreparedQuery{{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}, + std::move(parsed_query.required_privileges), + [plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters), + summary, dba, interpreter_context, execution_memory, memory_limit, + // We want to execute the query we are profiling lazily, so we delay + // the construction of the corresponding context. + stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{}, + pull_plan = std::shared_ptr<PullPlanVector>(nullptr)]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { + // No output symbols are given so that nothing is streamed. + if (!stats_and_total_time) { + stats_and_total_time = PullPlan(plan, parameters, true, dba, interpreter_context, + execution_memory, nullptr, memory_limit) + .Pull(stream, {}, {}, summary); + pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time)); + } + + MG_ASSERT(stats_and_total_time, "Failed to execute the query!"); + + if (pull_plan->Pull(stream, n)) { + summary->insert_or_assign("profile", ProfilingStatsToJson(*stats_and_total_time).dump()); + return QueryHandlerResult::ABORT; + } + + return std::nullopt; + }, + rw_type_checker.type}; +} + +PreparedQuery PrepareDumpQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary, DbAccessor *dba, + utils::MemoryResource *execution_memory) { + return PreparedQuery{{"QUERY"}, + std::move(parsed_query.required_privileges), + [pull_plan = std::make_shared<PullPlanDump>(dba)]( + AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::R}; +} + +PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_transaction, + std::vector<Notification> *notifications, InterpreterContext *interpreter_context) { + if (in_explicit_transaction) { + throw IndexInMulticommandTxException(); + } + + auto *index_query = utils::Downcast<IndexQuery>(parsed_query.query); + std::function<void(Notification &)> handler; + + // Creating an index influences computed plan costs. + auto invalidate_plan_cache = [plan_cache = &interpreter_context->plan_cache] { + auto access = plan_cache->access(); + for (auto &kv : access) { + access.remove(kv.first); + } + }; + + auto label = interpreter_context->db->NameToLabel(index_query->label_.name); + + std::vector<storage::v3::PropertyId> properties; + std::vector<std::string> properties_string; + properties.reserve(index_query->properties_.size()); + properties_string.reserve(index_query->properties_.size()); + for (const auto &prop : index_query->properties_) { + properties.push_back(interpreter_context->db->NameToProperty(prop.name)); + properties_string.push_back(prop.name); + } + auto properties_stringified = utils::Join(properties_string, ", "); + + if (properties.size() > 1) { + throw utils::NotYetImplemented("index on multiple properties"); + } + + Notification index_notification(SeverityLevel::INFO); + switch (index_query->action_) { + case IndexQuery::Action::CREATE: { + index_notification.code = NotificationCode::CREATE_INDEX; + index_notification.title = + fmt::format("Created index on label {} on properties {}.", index_query->label_.name, properties_stringified); + + handler = [interpreter_context, label, properties_stringified = std::move(properties_stringified), + label_name = index_query->label_.name, properties = std::move(properties), + invalidate_plan_cache = std::move(invalidate_plan_cache)](Notification &index_notification) { + if (properties.empty()) { + if (!interpreter_context->db->CreateIndex(label)) { + index_notification.code = NotificationCode::EXISTANT_INDEX; + index_notification.title = + fmt::format("Index on label {} on properties {} already exists.", label_name, properties_stringified); + } + EventCounter::IncrementCounter(EventCounter::LabelIndexCreated); + } else { + MG_ASSERT(properties.size() == 1U); + if (!interpreter_context->db->CreateIndex(label, properties[0])) { + index_notification.code = NotificationCode::EXISTANT_INDEX; + index_notification.title = + fmt::format("Index on label {} on properties {} already exists.", label_name, properties_stringified); + } + EventCounter::IncrementCounter(EventCounter::LabelPropertyIndexCreated); + } + invalidate_plan_cache(); + }; + break; + } + case IndexQuery::Action::DROP: { + index_notification.code = NotificationCode::DROP_INDEX; + index_notification.title = fmt::format("Dropped index on label {} on properties {}.", index_query->label_.name, + utils::Join(properties_string, ", ")); + handler = [interpreter_context, label, properties_stringified = std::move(properties_stringified), + label_name = index_query->label_.name, properties = std::move(properties), + invalidate_plan_cache = std::move(invalidate_plan_cache)](Notification &index_notification) { + if (properties.empty()) { + if (!interpreter_context->db->DropIndex(label)) { + index_notification.code = NotificationCode::NONEXISTANT_INDEX; + index_notification.title = + fmt::format("Index on label {} on properties {} doesn't exist.", label_name, properties_stringified); + } + } else { + MG_ASSERT(properties.size() == 1U); + if (!interpreter_context->db->DropIndex(label, properties[0])) { + index_notification.code = NotificationCode::NONEXISTANT_INDEX; + index_notification.title = + fmt::format("Index on label {} on properties {} doesn't exist.", label_name, properties_stringified); + } + } + invalidate_plan_cache(); + }; + break; + } + } + + return PreparedQuery{ + {}, + std::move(parsed_query.required_privileges), + [handler = std::move(handler), notifications, index_notification = std::move(index_notification)]( + AnyStream * /*stream*/, std::optional<int> /*unused*/) mutable { + handler(index_notification); + notifications->push_back(index_notification); + return QueryHandlerResult::NOTHING; + }, + RWType::W}; +} + +PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction, + std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, + DbAccessor *dba, utils::MemoryResource *execution_memory) { + if (in_explicit_transaction) { + throw UserModificationInMulticommandTxException(); + } + + auto *auth_query = utils::Downcast<AuthQuery>(parsed_query.query); + + auto callback = HandleAuthQuery(auth_query, interpreter_context->auth, parsed_query.parameters, dba); + + SymbolTable symbol_table; + std::vector<Symbol> output_symbols; + for (const auto &column : callback.header) { + output_symbols.emplace_back(symbol_table.CreateSymbol(column, "false")); + } + + auto plan = std::make_shared<CachedPlan>(std::make_unique<SingleNodeLogicalPlan>( + std::make_unique<plan::OutputTable>(output_symbols, + [fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }), + 0.0, AstStorage{}, symbol_table)); + + auto pull_plan = + std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory); + return PreparedQuery{ + callback.header, std::move(parsed_query.required_privileges), + [pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols), + summary](AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { + if (pull_plan->Pull(stream, n, output_symbols, summary)) { + return callback.should_abort_query ? QueryHandlerResult::ABORT : QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE}; +} + +PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, + std::vector<Notification> *notifications, InterpreterContext *interpreter_context, + DbAccessor *dba) { + if (in_explicit_transaction) { + throw ReplicationModificationInMulticommandTxException(); + } + + auto *replication_query = utils::Downcast<ReplicationQuery>(parsed_query.query); + auto callback = + HandleReplicationQuery(replication_query, parsed_query.parameters, interpreter_context, dba, notifications); + + return PreparedQuery{callback.header, std::move(parsed_query.required_privileges), + [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { + if (UNLIKELY(!pull_plan)) { + pull_plan = std::make_shared<PullPlanVector>(callback_fn()); + } + + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE}; + // False positive report for the std::make_shared above + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) +} + +PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, + InterpreterContext *interpreter_context, DbAccessor *dba) { + if (in_explicit_transaction) { + throw LockPathModificationInMulticommandTxException(); + } + + auto *lock_path_query = utils::Downcast<LockPathQuery>(parsed_query.query); + + return PreparedQuery{{}, + std::move(parsed_query.required_privileges), + [interpreter_context, action = lock_path_query->action_]( + AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { + switch (action) { + case LockPathQuery::Action::LOCK_PATH: + if (!interpreter_context->db->LockPath()) { + throw QueryRuntimeException("Failed to lock the data directory"); + } + break; + case LockPathQuery::Action::UNLOCK_PATH: + if (!interpreter_context->db->UnlockPath()) { + throw QueryRuntimeException("Failed to unlock the data directory"); + } + break; + } + return QueryHandlerResult::COMMIT; + }, + RWType::NONE}; +} + +PreparedQuery PrepareFreeMemoryQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, + InterpreterContext *interpreter_context) { + if (in_explicit_transaction) { + throw FreeMemoryModificationInMulticommandTxException(); + } + + return PreparedQuery{ + {}, + std::move(parsed_query.required_privileges), + [interpreter_context](AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { + interpreter_context->db->FreeMemory(); + memory::PurgeUnusedMemory(); + return QueryHandlerResult::COMMIT; + }, + RWType::NONE}; +} + +TriggerEventType ToTriggerEventType(const TriggerQuery::EventType event_type) { + switch (event_type) { + case TriggerQuery::EventType::ANY: + return TriggerEventType::ANY; + + case TriggerQuery::EventType::CREATE: + return TriggerEventType::CREATE; + + case TriggerQuery::EventType::VERTEX_CREATE: + return TriggerEventType::VERTEX_CREATE; + + case TriggerQuery::EventType::EDGE_CREATE: + return TriggerEventType::EDGE_CREATE; + + case TriggerQuery::EventType::DELETE: + return TriggerEventType::DELETE; + + case TriggerQuery::EventType::VERTEX_DELETE: + return TriggerEventType::VERTEX_DELETE; + + case TriggerQuery::EventType::EDGE_DELETE: + return TriggerEventType::EDGE_DELETE; + + case TriggerQuery::EventType::UPDATE: + return TriggerEventType::UPDATE; + + case TriggerQuery::EventType::VERTEX_UPDATE: + return TriggerEventType::VERTEX_UPDATE; + + case TriggerQuery::EventType::EDGE_UPDATE: + return TriggerEventType::EDGE_UPDATE; + } +} + +Callback CreateTrigger(TriggerQuery *trigger_query, + const std::map<std::string, storage::v3::PropertyValue> &user_parameters, + InterpreterContext *interpreter_context, DbAccessor *dba, std::optional<std::string> owner) { + return { + {}, + [trigger_name = std::move(trigger_query->trigger_name_), trigger_statement = std::move(trigger_query->statement_), + event_type = trigger_query->event_type_, before_commit = trigger_query->before_commit_, interpreter_context, dba, + user_parameters, owner = std::move(owner)]() mutable -> std::vector<std::vector<TypedValue>> { + interpreter_context->trigger_store.AddTrigger( + std::move(trigger_name), trigger_statement, user_parameters, ToTriggerEventType(event_type), + before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT, &interpreter_context->ast_cache, + dba, &interpreter_context->antlr_lock, interpreter_context->config.query, std::move(owner), + interpreter_context->auth_checker); + return {}; + }}; +} + +Callback DropTrigger(TriggerQuery *trigger_query, InterpreterContext *interpreter_context) { + return {{}, + [trigger_name = std::move(trigger_query->trigger_name_), + interpreter_context]() -> std::vector<std::vector<TypedValue>> { + interpreter_context->trigger_store.DropTrigger(trigger_name); + return {}; + }}; +} + +Callback ShowTriggers(InterpreterContext *interpreter_context) { + return {{"trigger name", "statement", "event type", "phase", "owner"}, [interpreter_context] { + std::vector<std::vector<TypedValue>> results; + auto trigger_infos = interpreter_context->trigger_store.GetTriggerInfo(); + results.reserve(trigger_infos.size()); + for (auto &trigger_info : trigger_infos) { + std::vector<TypedValue> typed_trigger_info; + typed_trigger_info.reserve(4); + typed_trigger_info.emplace_back(std::move(trigger_info.name)); + typed_trigger_info.emplace_back(std::move(trigger_info.statement)); + typed_trigger_info.emplace_back(TriggerEventTypeToString(trigger_info.event_type)); + typed_trigger_info.emplace_back(trigger_info.phase == TriggerPhase::BEFORE_COMMIT ? "BEFORE COMMIT" + : "AFTER COMMIT"); + typed_trigger_info.emplace_back(trigger_info.owner.has_value() ? TypedValue{*trigger_info.owner} + : TypedValue{}); + + results.push_back(std::move(typed_trigger_info)); + } + + return results; + }}; +} + +PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, + std::vector<Notification> *notifications, InterpreterContext *interpreter_context, + DbAccessor *dba, + const std::map<std::string, storage::v3::PropertyValue> &user_parameters, + const std::string *username) { + if (in_explicit_transaction) { + throw TriggerModificationInMulticommandTxException(); + } + + auto *trigger_query = utils::Downcast<TriggerQuery>(parsed_query.query); + MG_ASSERT(trigger_query); + + std::optional<Notification> trigger_notification; + auto callback = std::invoke([trigger_query, interpreter_context, dba, &user_parameters, + owner = StringPointerToOptional(username), &trigger_notification]() mutable { + switch (trigger_query->action_) { + case TriggerQuery::Action::CREATE_TRIGGER: + trigger_notification.emplace(SeverityLevel::INFO, NotificationCode::CREATE_TRIGGER, + fmt::format("Created trigger {}.", trigger_query->trigger_name_)); + EventCounter::IncrementCounter(EventCounter::TriggersCreated); + return CreateTrigger(trigger_query, user_parameters, interpreter_context, dba, std::move(owner)); + case TriggerQuery::Action::DROP_TRIGGER: + trigger_notification.emplace(SeverityLevel::INFO, NotificationCode::DROP_TRIGGER, + fmt::format("Dropped trigger {}.", trigger_query->trigger_name_)); + return DropTrigger(trigger_query, interpreter_context); + case TriggerQuery::Action::SHOW_TRIGGERS: + return ShowTriggers(interpreter_context); + } + }); + + return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), + [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}, + trigger_notification = std::move(trigger_notification), notifications]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { + if (UNLIKELY(!pull_plan)) { + pull_plan = std::make_shared<PullPlanVector>(callback_fn()); + } + + if (pull_plan->Pull(stream, n)) { + if (trigger_notification) { + notifications->push_back(std::move(*trigger_notification)); + } + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE}; + // False positive report for the std::make_shared above + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) +} + +PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, + std::vector<Notification> *notifications, InterpreterContext *interpreter_context, + DbAccessor *dba, + const std::map<std::string, storage::v3::PropertyValue> & /*user_parameters*/, + const std::string *username) { + if (in_explicit_transaction) { + throw StreamQueryInMulticommandTxException(); + } + + auto *stream_query = utils::Downcast<StreamQuery>(parsed_query.query); + MG_ASSERT(stream_query); + auto callback = + HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, dba, username, notifications); + + return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), + [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { + if (UNLIKELY(!pull_plan)) { + pull_plan = std::make_shared<PullPlanVector>(callback_fn()); + } + + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE}; + // False positive report for the std::make_shared above + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) +} + +constexpr auto ToStorageIsolationLevel(const IsolationLevelQuery::IsolationLevel isolation_level) noexcept { + switch (isolation_level) { + case IsolationLevelQuery::IsolationLevel::SNAPSHOT_ISOLATION: + return storage::v3::IsolationLevel::SNAPSHOT_ISOLATION; + case IsolationLevelQuery::IsolationLevel::READ_COMMITTED: + return storage::v3::IsolationLevel::READ_COMMITTED; + case IsolationLevelQuery::IsolationLevel::READ_UNCOMMITTED: + return storage::v3::IsolationLevel::READ_UNCOMMITTED; + } +} + +PreparedQuery PrepareIsolationLevelQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, + InterpreterContext *interpreter_context, Interpreter *interpreter) { + if (in_explicit_transaction) { + throw IsolationLevelModificationInMulticommandTxException(); + } + + auto *isolation_level_query = utils::Downcast<IsolationLevelQuery>(parsed_query.query); + MG_ASSERT(isolation_level_query); + + const auto isolation_level = ToStorageIsolationLevel(isolation_level_query->isolation_level_); + + auto callback = [isolation_level_query, isolation_level, interpreter_context, + interpreter]() -> std::function<void()> { + switch (isolation_level_query->isolation_level_scope_) { + case IsolationLevelQuery::IsolationLevelScope::GLOBAL: + return [interpreter_context, isolation_level] { interpreter_context->db->SetIsolationLevel(isolation_level); }; + case IsolationLevelQuery::IsolationLevelScope::SESSION: + return [interpreter, isolation_level] { interpreter->SetSessionIsolationLevel(isolation_level); }; + case IsolationLevelQuery::IsolationLevelScope::NEXT: + return [interpreter, isolation_level] { interpreter->SetNextTransactionIsolationLevel(isolation_level); }; + } + }(); + + return PreparedQuery{ + {}, + std::move(parsed_query.required_privileges), + [callback = std::move(callback)](AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { + callback(); + return QueryHandlerResult::COMMIT; + }, + RWType::NONE}; +} + +PreparedQuery PrepareCreateSnapshotQuery(ParsedQuery parsed_query, bool in_explicit_transaction, + InterpreterContext *interpreter_context) { + if (in_explicit_transaction) { + throw CreateSnapshotInMulticommandTxException(); + } + + return PreparedQuery{ + {}, + std::move(parsed_query.required_privileges), + [interpreter_context](AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { + if (auto maybe_error = interpreter_context->db->CreateSnapshot(); maybe_error.HasError()) { + switch (maybe_error.GetError()) { + case storage::v3::Storage::CreateSnapshotError::DisabledForReplica: + throw utils::BasicException( + "Failed to create a snapshot. Replica instances are not allowed to create them."); + } + } + return QueryHandlerResult::COMMIT; + }, + RWType::NONE}; +} + +PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, DbAccessor *dba) { + if (in_explicit_transaction) { + throw SettingConfigInMulticommandTxException{}; + } + + auto *setting_query = utils::Downcast<SettingQuery>(parsed_query.query); + MG_ASSERT(setting_query); + auto callback = HandleSettingQuery(setting_query, parsed_query.parameters, dba); + + return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), + [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { + if (UNLIKELY(!pull_plan)) { + pull_plan = std::make_shared<PullPlanVector>(callback_fn()); + } + + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE}; + // False positive report for the std::make_shared above + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) +} + +PreparedQuery PrepareVersionQuery(ParsedQuery parsed_query, const bool in_explicit_transaction) { + if (in_explicit_transaction) { + throw VersionInfoInMulticommandTxException(); + } + + return PreparedQuery{{"version"}, + std::move(parsed_query.required_privileges), + [](AnyStream *stream, std::optional<int> /*n*/) { + std::vector<TypedValue> version_value; + version_value.reserve(1); + + version_value.emplace_back(gflags::VersionString()); + stream->Result(version_value); + return QueryHandlerResult::COMMIT; + }, + RWType::NONE}; +} + +PreparedQuery PrepareInfoQuery(ParsedQuery parsed_query, bool in_explicit_transaction, + std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, + storage::v3::Storage *db, utils::MemoryResource *execution_memory) { + if (in_explicit_transaction) { + throw InfoInMulticommandTxException(); + } + + auto *info_query = utils::Downcast<InfoQuery>(parsed_query.query); + std::vector<std::string> header; + std::function<std::pair<std::vector<std::vector<TypedValue>>, QueryHandlerResult>()> handler; + + switch (info_query->info_type_) { + case InfoQuery::InfoType::STORAGE: + header = {"storage info", "value"}; + handler = [db] { + auto info = db->GetInfo(); + std::vector<std::vector<TypedValue>> results{ + {TypedValue("vertex_count"), TypedValue(static_cast<int64_t>(info.vertex_count))}, + {TypedValue("edge_count"), TypedValue(static_cast<int64_t>(info.edge_count))}, + {TypedValue("average_degree"), TypedValue(info.average_degree)}, + {TypedValue("memory_usage"), TypedValue(static_cast<int64_t>(info.memory_usage))}, + {TypedValue("disk_usage"), TypedValue(static_cast<int64_t>(info.disk_usage))}, + {TypedValue("memory_allocated"), TypedValue(static_cast<int64_t>(utils::total_memory_tracker.Amount()))}, + {TypedValue("allocation_limit"), + TypedValue(static_cast<int64_t>(utils::total_memory_tracker.HardLimit()))}}; + return std::pair{results, QueryHandlerResult::COMMIT}; + }; + break; + case InfoQuery::InfoType::INDEX: + header = {"index type", "label", "property"}; + handler = [interpreter_context] { + auto *db = interpreter_context->db; + auto info = db->ListAllIndices(); + std::vector<std::vector<TypedValue>> results; + results.reserve(info.label.size() + info.label_property.size()); + for (const auto &item : info.label) { + results.push_back({TypedValue("label"), TypedValue(db->LabelToName(item)), TypedValue()}); + } + for (const auto &item : info.label_property) { + results.push_back({TypedValue("label+property"), TypedValue(db->LabelToName(item.first)), + TypedValue(db->PropertyToName(item.second))}); + } + return std::pair{results, QueryHandlerResult::NOTHING}; + }; + break; + case InfoQuery::InfoType::CONSTRAINT: + header = {"constraint type", "label", "properties"}; + handler = [interpreter_context] { + auto *db = interpreter_context->db; + auto info = db->ListAllConstraints(); + std::vector<std::vector<TypedValue>> results; + results.reserve(info.existence.size() + info.unique.size()); + for (const auto &item : info.existence) { + results.push_back({TypedValue("exists"), TypedValue(db->LabelToName(item.first)), + TypedValue(db->PropertyToName(item.second))}); + } + for (const auto &item : info.unique) { + std::vector<TypedValue> properties; + properties.reserve(item.second.size()); + for (const auto &property : item.second) { + properties.emplace_back(db->PropertyToName(property)); + } + results.push_back( + {TypedValue("unique"), TypedValue(db->LabelToName(item.first)), TypedValue(std::move(properties))}); + } + return std::pair{results, QueryHandlerResult::NOTHING}; + }; + break; + } + + return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges), + [handler = std::move(handler), action = QueryHandlerResult::NOTHING, + pull_plan = std::shared_ptr<PullPlanVector>(nullptr)]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { + if (!pull_plan) { + auto [results, action_on_complete] = handler(); + action = action_on_complete; + pull_plan = std::make_shared<PullPlanVector>(std::move(results)); + } + + if (pull_plan->Pull(stream, n)) { + return action; + } + return std::nullopt; + }, + RWType::NONE}; +} + +PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_transaction, + std::vector<Notification> *notifications, + InterpreterContext *interpreter_context) { + if (in_explicit_transaction) { + throw ConstraintInMulticommandTxException(); + } + + auto *constraint_query = utils::Downcast<ConstraintQuery>(parsed_query.query); + std::function<void(Notification &)> handler; + + auto label = interpreter_context->db->NameToLabel(constraint_query->constraint_.label.name); + std::vector<storage::v3::PropertyId> properties; + std::vector<std::string> properties_string; + properties.reserve(constraint_query->constraint_.properties.size()); + properties_string.reserve(constraint_query->constraint_.properties.size()); + for (const auto &prop : constraint_query->constraint_.properties) { + properties.push_back(interpreter_context->db->NameToProperty(prop.name)); + properties_string.push_back(prop.name); + } + auto properties_stringified = utils::Join(properties_string, ", "); + + Notification constraint_notification(SeverityLevel::INFO); + switch (constraint_query->action_type_) { + case ConstraintQuery::ActionType::CREATE: { + constraint_notification.code = NotificationCode::CREATE_CONSTRAINT; + + switch (constraint_query->constraint_.type) { + case Constraint::Type::NODE_KEY: + throw utils::NotYetImplemented("Node key constraints"); + case Constraint::Type::EXISTS: + if (properties.empty() || properties.size() > 1) { + throw SyntaxException("Exactly one property must be used for existence constraints."); + } + constraint_notification.title = fmt::format("Created EXISTS constraint on label {} on properties {}.", + constraint_query->constraint_.label.name, properties_stringified); + handler = [interpreter_context, label, label_name = constraint_query->constraint_.label.name, + properties_stringified = std::move(properties_stringified), + properties = std::move(properties)](Notification &constraint_notification) { + auto res = interpreter_context->db->CreateExistenceConstraint(label, properties[0]); + if (res.HasError()) { + auto violation = res.GetError(); + auto label_name = interpreter_context->db->LabelToName(violation.label); + MG_ASSERT(violation.properties.size() == 1U); + auto property_name = interpreter_context->db->PropertyToName(*violation.properties.begin()); + throw QueryRuntimeException( + "Unable to create existence constraint :{}({}), because an " + "existing node violates it.", + label_name, property_name); + } + if (res.HasValue() && !res.GetValue()) { + constraint_notification.code = NotificationCode::EXISTANT_CONSTRAINT; + constraint_notification.title = fmt::format( + "Constraint EXISTS on label {} on properties {} already exists.", label_name, properties_stringified); + } + }; + break; + case Constraint::Type::UNIQUE: + std::set<storage::v3::PropertyId> property_set; + for (const auto &property : properties) { + property_set.insert(property); + } + if (property_set.size() != properties.size()) { + throw SyntaxException("The given set of properties contains duplicates."); + } + constraint_notification.title = + fmt::format("Created UNIQUE constraint on label {} on properties {}.", + constraint_query->constraint_.label.name, utils::Join(properties_string, ", ")); + handler = [interpreter_context, label, label_name = constraint_query->constraint_.label.name, + properties_stringified = std::move(properties_stringified), + property_set = std::move(property_set)](Notification &constraint_notification) { + auto res = interpreter_context->db->CreateUniqueConstraint(label, property_set); + if (res.HasError()) { + auto violation = res.GetError(); + auto label_name = interpreter_context->db->LabelToName(violation.label); + std::stringstream property_names_stream; + utils::PrintIterable(property_names_stream, violation.properties, ", ", + [&interpreter_context](auto &stream, const auto &prop) { + stream << interpreter_context->db->PropertyToName(prop); + }); + throw QueryRuntimeException( + "Unable to create unique constraint :{}({}), because an " + "existing node violates it.", + label_name, property_names_stream.str()); + } + switch (res.GetValue()) { + case storage::v3::UniqueConstraints::CreationStatus::EMPTY_PROPERTIES: + throw SyntaxException( + "At least one property must be used for unique " + "constraints."); + case storage::v3::UniqueConstraints::CreationStatus::PROPERTIES_SIZE_LIMIT_EXCEEDED: + throw SyntaxException( + "Too many properties specified. Limit of {} properties " + "for unique constraints is exceeded.", + storage::v3::kUniqueConstraintsMaxProperties); + case storage::v3::UniqueConstraints::CreationStatus::ALREADY_EXISTS: + constraint_notification.code = NotificationCode::EXISTANT_CONSTRAINT; + constraint_notification.title = + fmt::format("Constraint UNIQUE on label {} on properties {} already exists.", label_name, + properties_stringified); + break; + case storage::v3::UniqueConstraints::CreationStatus::SUCCESS: + break; + } + }; + break; + } + } break; + case ConstraintQuery::ActionType::DROP: { + constraint_notification.code = NotificationCode::DROP_CONSTRAINT; + + switch (constraint_query->constraint_.type) { + case Constraint::Type::NODE_KEY: + throw utils::NotYetImplemented("Node key constraints"); + case Constraint::Type::EXISTS: + if (properties.empty() || properties.size() > 1) { + throw SyntaxException("Exactly one property must be used for existence constraints."); + } + constraint_notification.title = + fmt::format("Dropped EXISTS constraint on label {} on properties {}.", + constraint_query->constraint_.label.name, utils::Join(properties_string, ", ")); + handler = [interpreter_context, label, label_name = constraint_query->constraint_.label.name, + properties_stringified = std::move(properties_stringified), + properties = std::move(properties)](Notification &constraint_notification) { + if (!interpreter_context->db->DropExistenceConstraint(label, properties[0])) { + constraint_notification.code = NotificationCode::NONEXISTANT_CONSTRAINT; + constraint_notification.title = fmt::format( + "Constraint EXISTS on label {} on properties {} doesn't exist.", label_name, properties_stringified); + } + return std::vector<std::vector<TypedValue>>(); + }; + break; + case Constraint::Type::UNIQUE: + std::set<storage::v3::PropertyId> property_set; + for (const auto &property : properties) { + property_set.insert(property); + } + if (property_set.size() != properties.size()) { + throw SyntaxException("The given set of properties contains duplicates."); + } + constraint_notification.title = + fmt::format("Dropped UNIQUE constraint on label {} on properties {}.", + constraint_query->constraint_.label.name, utils::Join(properties_string, ", ")); + handler = [interpreter_context, label, label_name = constraint_query->constraint_.label.name, + properties_stringified = std::move(properties_stringified), + property_set = std::move(property_set)](Notification &constraint_notification) { + auto res = interpreter_context->db->DropUniqueConstraint(label, property_set); + switch (res) { + case storage::v3::UniqueConstraints::DeletionStatus::EMPTY_PROPERTIES: + throw SyntaxException( + "At least one property must be used for unique " + "constraints."); + break; + case storage::v3::UniqueConstraints::DeletionStatus::PROPERTIES_SIZE_LIMIT_EXCEEDED: + throw SyntaxException( + "Too many properties specified. Limit of {} properties for " + "unique constraints is exceeded.", + storage::v3::kUniqueConstraintsMaxProperties); + break; + case storage::v3::UniqueConstraints::DeletionStatus::NOT_FOUND: + constraint_notification.code = NotificationCode::NONEXISTANT_CONSTRAINT; + constraint_notification.title = + fmt::format("Constraint UNIQUE on label {} on properties {} doesn't exist.", label_name, + properties_stringified); + break; + case storage::v3::UniqueConstraints::DeletionStatus::SUCCESS: + break; + } + return std::vector<std::vector<TypedValue>>(); + }; + } + } break; + } + + return PreparedQuery{{}, + std::move(parsed_query.required_privileges), + [handler = std::move(handler), constraint_notification = std::move(constraint_notification), + notifications](AnyStream * /*stream*/, std::optional<int> /*n*/) mutable { + handler(constraint_notification); + notifications->push_back(constraint_notification); + return QueryHandlerResult::COMMIT; + }, + RWType::NONE}; +} + +void Interpreter::BeginTransaction() { + const auto prepared_query = PrepareTransactionQuery("BEGIN"); + prepared_query.query_handler(nullptr, {}); +} + +void Interpreter::CommitTransaction() { + const auto prepared_query = PrepareTransactionQuery("COMMIT"); + prepared_query.query_handler(nullptr, {}); + query_executions_.clear(); +} + +void Interpreter::RollbackTransaction() { + const auto prepared_query = PrepareTransactionQuery("ROLLBACK"); + prepared_query.query_handler(nullptr, {}); + query_executions_.clear(); +} + +Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, + const std::map<std::string, storage::v3::PropertyValue> ¶ms, + const std::string *username) { + if (!in_explicit_transaction_) { + query_executions_.clear(); + } + + query_executions_.emplace_back(std::make_unique<QueryExecution>()); + auto &query_execution = query_executions_.back(); + std::optional<int> qid = + in_explicit_transaction_ ? static_cast<int>(query_executions_.size() - 1) : std::optional<int>{}; + + // Handle transaction control queries. + + const auto upper_case_query = utils::ToUpperCase(query_string); + const auto trimmed_query = utils::Trim(upper_case_query); + + if (trimmed_query == "BEGIN" || trimmed_query == "COMMIT" || trimmed_query == "ROLLBACK") { + query_execution->prepared_query.emplace(PrepareTransactionQuery(trimmed_query)); + return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid}; + } + + // All queries other than transaction control queries advance the command in + // an explicit transaction block. + if (in_explicit_transaction_) { + AdvanceCommand(); + } + // If we're not in an explicit transaction block and we have an open + // transaction, abort it since we're about to prepare a new query. + else if (db_accessor_) { + AbortCommand(&query_execution); + } + + try { + // Set a default cost estimate of 0. Individual queries can overwrite this + // field with an improved estimate. + query_execution->summary["cost_estimate"] = 0.0; + + utils::Timer parsing_timer; + ParsedQuery parsed_query = ParseQuery(query_string, params, &interpreter_context_->ast_cache, + &interpreter_context_->antlr_lock, interpreter_context_->config.query); + query_execution->summary["parsing_time"] = parsing_timer.Elapsed().count(); + + // Some queries require an active transaction in order to be prepared. + if (!in_explicit_transaction_ && + (utils::Downcast<CypherQuery>(parsed_query.query) || utils::Downcast<ExplainQuery>(parsed_query.query) || + utils::Downcast<ProfileQuery>(parsed_query.query) || utils::Downcast<DumpQuery>(parsed_query.query) || + utils::Downcast<TriggerQuery>(parsed_query.query))) { + db_accessor_ = std::make_unique<storage::v3::Storage::Accessor>( + interpreter_context_->db->Access(GetIsolationLevelOverride())); + execution_db_accessor_.emplace(db_accessor_.get()); + + if (utils::Downcast<CypherQuery>(parsed_query.query) && interpreter_context_->trigger_store.HasTriggers()) { + trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes()); + } + } + + utils::Timer planning_timer; + PreparedQuery prepared_query; + + if (utils::Downcast<CypherQuery>(parsed_query.query)) { + prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, + &*execution_db_accessor_, &query_execution->execution_memory, + &query_execution->notifications, + trigger_context_collector_ ? &*trigger_context_collector_ : nullptr); + } else if (utils::Downcast<ExplainQuery>(parsed_query.query)) { + prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, + &*execution_db_accessor_, &query_execution->execution_memory_with_exception); + } else if (utils::Downcast<ProfileQuery>(parsed_query.query)) { + prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, + interpreter_context_, &*execution_db_accessor_, + &query_execution->execution_memory_with_exception); + } else if (utils::Downcast<DumpQuery>(parsed_query.query)) { + prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_, + &query_execution->execution_memory); + } else if (utils::Downcast<IndexQuery>(parsed_query.query)) { + prepared_query = PrepareIndexQuery(std::move(parsed_query), in_explicit_transaction_, + &query_execution->notifications, interpreter_context_); + } else if (utils::Downcast<AuthQuery>(parsed_query.query)) { + prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, + interpreter_context_, &*execution_db_accessor_, + &query_execution->execution_memory_with_exception); + } else if (utils::Downcast<InfoQuery>(parsed_query.query)) { + prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, + interpreter_context_, interpreter_context_->db, + &query_execution->execution_memory_with_exception); + } else if (utils::Downcast<ConstraintQuery>(parsed_query.query)) { + prepared_query = PrepareConstraintQuery(std::move(parsed_query), in_explicit_transaction_, + &query_execution->notifications, interpreter_context_); + } else if (utils::Downcast<ReplicationQuery>(parsed_query.query)) { + prepared_query = + PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, + interpreter_context_, &*execution_db_accessor_); + } else if (utils::Downcast<LockPathQuery>(parsed_query.query)) { + prepared_query = PrepareLockPathQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, + &*execution_db_accessor_); + } else if (utils::Downcast<FreeMemoryQuery>(parsed_query.query)) { + prepared_query = PrepareFreeMemoryQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); + } else if (utils::Downcast<TriggerQuery>(parsed_query.query)) { + prepared_query = + PrepareTriggerQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, + interpreter_context_, &*execution_db_accessor_, params, username); + } else if (utils::Downcast<StreamQuery>(parsed_query.query)) { + prepared_query = + PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, + interpreter_context_, &*execution_db_accessor_, params, username); + } else if (utils::Downcast<IsolationLevelQuery>(parsed_query.query)) { + prepared_query = + PrepareIsolationLevelQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, this); + } else if (utils::Downcast<CreateSnapshotQuery>(parsed_query.query)) { + prepared_query = + PrepareCreateSnapshotQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); + } else if (utils::Downcast<SettingQuery>(parsed_query.query)) { + prepared_query = PrepareSettingQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_); + } else if (utils::Downcast<VersionQuery>(parsed_query.query)) { + prepared_query = PrepareVersionQuery(std::move(parsed_query), in_explicit_transaction_); + } else { + LOG_FATAL("Should not get here -- unknown query type!"); + } + + query_execution->summary["planning_time"] = planning_timer.Elapsed().count(); + query_execution->prepared_query.emplace(std::move(prepared_query)); + + const auto rw_type = query_execution->prepared_query->rw_type; + query_execution->summary["type"] = plan::ReadWriteTypeChecker::TypeToString(rw_type); + + UpdateTypeCount(rw_type); + + if (const auto query_type = query_execution->prepared_query->rw_type; + interpreter_context_->db->GetReplicationRole() == storage::v3::ReplicationRole::REPLICA && + (query_type == RWType::W || query_type == RWType::RW)) { + query_execution = nullptr; + throw QueryException("Write query forbidden on the replica!"); + } + + return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid}; + } catch (const utils::BasicException &) { + EventCounter::IncrementCounter(EventCounter::FailedQuery); + AbortCommand(&query_execution); + throw; + } +} + +void Interpreter::Abort() { + expect_rollback_ = false; + in_explicit_transaction_ = false; + if (!db_accessor_) return; + db_accessor_->Abort(); + execution_db_accessor_.reset(); + db_accessor_.reset(); + trigger_context_collector_.reset(); +} + +namespace { +void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, InterpreterContext *interpreter_context, + TriggerContext trigger_context) { + // Run the triggers + for (const auto &trigger : triggers.access()) { + utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize}; + + // create a new transaction for each trigger + auto storage_acc = interpreter_context->db->Access(); + DbAccessor db_accessor{&storage_acc}; + + trigger_context.AdaptForAccessor(&db_accessor); + try { + trigger.Execute(&db_accessor, &execution_memory, interpreter_context->config.execution_timeout_sec, + &interpreter_context->is_shutting_down, trigger_context, interpreter_context->auth_checker); + } catch (const utils::BasicException &exception) { + spdlog::warn("Trigger '{}' failed with exception:\n{}", trigger.Name(), exception.what()); + db_accessor.Abort(); + continue; + } + + auto maybe_constraint_violation = db_accessor.Commit(); + if (maybe_constraint_violation.HasError()) { + const auto &constraint_violation = maybe_constraint_violation.GetError(); + switch (constraint_violation.type) { + case storage::v3::ConstraintViolation::Type::EXISTENCE: { + const auto &label_name = db_accessor.LabelToName(constraint_violation.label); + MG_ASSERT(constraint_violation.properties.size() == 1U); + const auto &property_name = db_accessor.PropertyToName(*constraint_violation.properties.begin()); + spdlog::warn("Trigger '{}' failed to commit due to existence constraint violation on :{}({})", trigger.Name(), + label_name, property_name); + break; + } + case storage::v3::ConstraintViolation::Type::UNIQUE: { + const auto &label_name = db_accessor.LabelToName(constraint_violation.label); + std::stringstream property_names_stream; + utils::PrintIterable(property_names_stream, constraint_violation.properties, ", ", + [&](auto &stream, const auto &prop) { stream << db_accessor.PropertyToName(prop); }); + spdlog::warn("Trigger '{}' failed to commit due to unique constraint violation on :{}({})", trigger.Name(), + label_name, property_names_stream.str()); + break; + } + } + } + } +} +} // namespace + +void Interpreter::Commit() { + // It's possible that some queries did not finish because the user did + // not pull all of the results from the query. + // For now, we will not check if there are some unfinished queries. + // We should document clearly that all results should be pulled to complete + // a query. + if (!db_accessor_) return; + + std::optional<TriggerContext> trigger_context = std::nullopt; + if (trigger_context_collector_) { + trigger_context.emplace(std::move(*trigger_context_collector_).TransformToTriggerContext()); + trigger_context_collector_.reset(); + } + + if (trigger_context) { + // Run the triggers + for (const auto &trigger : interpreter_context_->trigger_store.BeforeCommitTriggers().access()) { + utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize}; + AdvanceCommand(); + try { + trigger.Execute(&*execution_db_accessor_, &execution_memory, interpreter_context_->config.execution_timeout_sec, + &interpreter_context_->is_shutting_down, *trigger_context, interpreter_context_->auth_checker); + } catch (const utils::BasicException &e) { + throw utils::BasicException( + fmt::format("Trigger '{}' caused the transaction to fail.\nException: {}", trigger.Name(), e.what())); + } + } + SPDLOG_DEBUG("Finished executing before commit triggers"); + } + + const auto reset_necessary_members = [this]() { + execution_db_accessor_.reset(); + db_accessor_.reset(); + trigger_context_collector_.reset(); + }; + + auto maybe_constraint_violation = db_accessor_->Commit(); + if (maybe_constraint_violation.HasError()) { + const auto &constraint_violation = maybe_constraint_violation.GetError(); + switch (constraint_violation.type) { + case storage::v3::ConstraintViolation::Type::EXISTENCE: { + auto label_name = execution_db_accessor_->LabelToName(constraint_violation.label); + MG_ASSERT(constraint_violation.properties.size() == 1U); + auto property_name = execution_db_accessor_->PropertyToName(*constraint_violation.properties.begin()); + reset_necessary_members(); + throw QueryException("Unable to commit due to existence constraint violation on :{}({})", label_name, + property_name); + break; + } + case storage::v3::ConstraintViolation::Type::UNIQUE: { + auto label_name = execution_db_accessor_->LabelToName(constraint_violation.label); + std::stringstream property_names_stream; + utils::PrintIterable( + property_names_stream, constraint_violation.properties, ", ", + [this](auto &stream, const auto &prop) { stream << execution_db_accessor_->PropertyToName(prop); }); + reset_necessary_members(); + throw QueryException("Unable to commit due to unique constraint violation on :{}({})", label_name, + property_names_stream.str()); + break; + } + } + } + + // The ordered execution of after commit triggers is heavily depending on the exclusiveness of db_accessor_->Commit(): + // only one of the transactions can be commiting at the same time, so when the commit is finished, that transaction + // probably will schedule its after commit triggers, because the other transactions that want to commit are still + // waiting for commiting or one of them just started commiting its changes. + // This means the ordered execution of after commit triggers are not guaranteed. + if (trigger_context && interpreter_context_->trigger_store.AfterCommitTriggers().size() > 0) { + interpreter_context_->after_commit_trigger_pool.AddTask( + [trigger_context = std::move(*trigger_context), interpreter_context = this->interpreter_context_, + user_transaction = std::shared_ptr(std::move(db_accessor_))]() mutable { + RunTriggersIndividually(interpreter_context->trigger_store.AfterCommitTriggers(), interpreter_context, + std::move(trigger_context)); + user_transaction->FinalizeTransaction(); + SPDLOG_DEBUG("Finished executing after commit triggers"); // NOLINT(bugprone-lambda-function-name) + }); + } + + reset_necessary_members(); + + SPDLOG_DEBUG("Finished committing the transaction"); +} + +void Interpreter::AdvanceCommand() { + if (!db_accessor_) return; + db_accessor_->AdvanceCommand(); +} + +void Interpreter::AbortCommand(std::unique_ptr<QueryExecution> *query_execution) { + if (query_execution) { + query_execution->reset(nullptr); + } + if (in_explicit_transaction_) { + expect_rollback_ = true; + } else { + Abort(); + } +} + +std::optional<storage::v3::IsolationLevel> Interpreter::GetIsolationLevelOverride() { + if (next_transaction_isolation_level) { + const auto isolation_level = *next_transaction_isolation_level; + next_transaction_isolation_level.reset(); + return isolation_level; + } + + return interpreter_isolation_level; +} + +void Interpreter::SetNextTransactionIsolationLevel(const storage::v3::IsolationLevel isolation_level) { + next_transaction_isolation_level.emplace(isolation_level); +} + +void Interpreter::SetSessionIsolationLevel(const storage::v3::IsolationLevel isolation_level) { + interpreter_isolation_level.emplace(isolation_level); +} + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/interpreter.hpp b/src/query/v2/interpreter.hpp new file mode 100644 index 000000000..b1f89f22b --- /dev/null +++ b/src/query/v2/interpreter.hpp @@ -0,0 +1,436 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <gflags/gflags.h> + +#include "query/v2/auth_checker.hpp" +#include "query/v2/config.hpp" +#include "query/v2/context.hpp" +#include "query/v2/cypher_query_interpreter.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/cypher_main_visitor.hpp" +#include "query/v2/frontend/stripped.hpp" +#include "query/v2/interpret/frame.hpp" +#include "query/v2/metadata.hpp" +#include "query/v2/plan/operator.hpp" +#include "query/v2/plan/read_write_type_checker.hpp" +#include "query/v2/stream.hpp" +#include "query/v2/stream/streams.hpp" +#include "query/v2/trigger.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/isolation_level.hpp" +#include "utils/event_counter.hpp" +#include "utils/logging.hpp" +#include "utils/memory.hpp" +#include "utils/settings.hpp" +#include "utils/skip_list.hpp" +#include "utils/spin_lock.hpp" +#include "utils/thread_pool.hpp" +#include "utils/timer.hpp" +#include "utils/tsc.hpp" + +namespace EventCounter { +extern const Event FailedQuery; +} // namespace EventCounter + +namespace memgraph::query::v2 { + +inline constexpr size_t kExecutionMemoryBlockSize = 1UL * 1024UL * 1024UL; + +class AuthQueryHandler { + public: + AuthQueryHandler() = default; + virtual ~AuthQueryHandler() = default; + + AuthQueryHandler(const AuthQueryHandler &) = delete; + AuthQueryHandler(AuthQueryHandler &&) = delete; + AuthQueryHandler &operator=(const AuthQueryHandler &) = delete; + AuthQueryHandler &operator=(AuthQueryHandler &&) = delete; + + /// Return false if the user already exists. + /// @throw QueryRuntimeException if an error ocurred. + virtual bool CreateUser(const std::string &username, const std::optional<std::string> &password) = 0; + + /// Return false if the user does not exist. + /// @throw QueryRuntimeException if an error ocurred. + virtual bool DropUser(const std::string &username) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void SetPassword(const std::string &username, const std::optional<std::string> &password) = 0; + + /// Return false if the role already exists. + /// @throw QueryRuntimeException if an error ocurred. + virtual bool CreateRole(const std::string &rolename) = 0; + + /// Return false if the role does not exist. + /// @throw QueryRuntimeException if an error ocurred. + virtual bool DropRole(const std::string &rolename) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual std::vector<TypedValue> GetUsernames() = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual std::vector<TypedValue> GetRolenames() = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual std::optional<std::string> GetRolenameForUser(const std::string &username) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual std::vector<TypedValue> GetUsernamesForRole(const std::string &rolename) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void SetRole(const std::string &username, const std::string &rolename) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void ClearRole(const std::string &username) = 0; + + virtual std::vector<std::vector<TypedValue>> GetPrivileges(const std::string &user_or_role) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void GrantPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void DenyPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void RevokePrivilege(const std::string &user_or_role, + const std::vector<AuthQuery::Privilege> &privileges) = 0; +}; + +enum class QueryHandlerResult { COMMIT, ABORT, NOTHING }; + +class ReplicationQueryHandler { + public: + ReplicationQueryHandler() = default; + virtual ~ReplicationQueryHandler() = default; + + ReplicationQueryHandler(const ReplicationQueryHandler &) = default; + ReplicationQueryHandler &operator=(const ReplicationQueryHandler &) = default; + + ReplicationQueryHandler(ReplicationQueryHandler &&) = default; + ReplicationQueryHandler &operator=(ReplicationQueryHandler &&) = default; + + struct Replica { + std::string name; + std::string socket_address; + ReplicationQuery::SyncMode sync_mode; + std::optional<double> timeout; + ReplicationQuery::ReplicaState state; + }; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void SetReplicationRole(ReplicationQuery::ReplicationRole replication_role, std::optional<int64_t> port) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual ReplicationQuery::ReplicationRole ShowReplicationRole() const = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void RegisterReplica(const std::string &name, const std::string &socket_address, + const ReplicationQuery::SyncMode sync_mode, const std::optional<double> timeout, + const std::chrono::seconds replica_check_frequency) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void DropReplica(const std::string &replica_name) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual std::vector<Replica> ShowReplicas() const = 0; +}; + +/** + * A container for data related to the preparation of a query. + */ +struct PreparedQuery { + std::vector<std::string> header; + std::vector<AuthQuery::Privilege> privileges; + std::function<std::optional<QueryHandlerResult>(AnyStream *stream, std::optional<int> n)> query_handler; + plan::ReadWriteTypeChecker::RWType rw_type; +}; + +/** + * Holds data shared between multiple `Interpreter` instances (which might be + * running concurrently). + * + * Users should initialize the context but should not modify it after it has + * been passed to an `Interpreter` instance. + */ +struct InterpreterContext { + explicit InterpreterContext(storage::v3::Storage *db, InterpreterConfig config, + const std::filesystem::path &data_directory); + + storage::v3::Storage *db; + + // ANTLR has singleton instance that is shared between threads. It is + // protected by locks inside of ANTLR. Unfortunately, they are not protected + // in a very good way. Once we have ANTLR version without race conditions we + // can remove this lock. This will probably never happen since ANTLR + // developers introduce more bugs in each version. Fortunately, we have + // cache so this lock probably won't impact performance much... + utils::SpinLock antlr_lock; + std::optional<double> tsc_frequency{utils::GetTSCFrequency()}; + std::atomic<bool> is_shutting_down{false}; + + AuthQueryHandler *auth{nullptr}; + AuthChecker *auth_checker{nullptr}; + + utils::SkipList<QueryCacheEntry> ast_cache; + utils::SkipList<PlanCacheEntry> plan_cache; + + TriggerStore trigger_store; + utils::ThreadPool after_commit_trigger_pool{1}; + + const InterpreterConfig config; + + query::v2::stream::Streams streams; +}; + +/// Function that is used to tell all active interpreters that they should stop +/// their ongoing execution. +inline void Shutdown(InterpreterContext *context) { context->is_shutting_down.store(true, std::memory_order_release); } + +class Interpreter final { + public: + explicit Interpreter(InterpreterContext *interpreter_context); + Interpreter(const Interpreter &) = delete; + Interpreter &operator=(const Interpreter &) = delete; + Interpreter(Interpreter &&) = delete; + Interpreter &operator=(Interpreter &&) = delete; + ~Interpreter() { Abort(); } + + struct PrepareResult { + std::vector<std::string> headers; + std::vector<query::v2::AuthQuery::Privilege> privileges; + std::optional<int> qid; + }; + + /** + * Prepare a query for execution. + * + * Preparing a query means to preprocess the query and save it for + * future calls of `Pull`. + * + * @throw query::v2::QueryException + */ + PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::v3::PropertyValue> ¶ms, + const std::string *username); + + /** + * Execute the last prepared query and stream *all* of the results into the + * given stream. + * + * It is not possible to prepare a query once and execute it multiple times, + * i.e. `Prepare` has to be called before *every* call to `PullAll`. + * + * TStream should be a type implementing the `Stream` concept, i.e. it should + * contain the member function `void Result(const std::vector<TypedValue> &)`. + * The provided vector argument is valid only for the duration of the call to + * `Result`. The stream should make an explicit copy if it wants to use it + * further. + * + * @throw utils::BasicException + * @throw query::v2::QueryException + */ + template <typename TStream> + std::map<std::string, TypedValue> PullAll(TStream *result_stream) { + return Pull(result_stream); + } + + /** + * Execute a prepared query and stream result into the given stream. + * + * TStream should be a type implementing the `Stream` concept, i.e. it should + * contain the member function `void Result(const std::vector<TypedValue> &)`. + * The provided vector argument is valid only for the duration of the call to + * `Result`. The stream should make an explicit copy if it wants to use it + * further. + * + * @param n If set, amount of rows to be pulled from result, + * otherwise all the rows are pulled. + * @param qid If set, id of the query from which the result should be pulled, + * otherwise the last query should be used. + * + * @throw utils::BasicException + * @throw query::v2::QueryException + */ + template <typename TStream> + std::map<std::string, TypedValue> Pull(TStream *result_stream, std::optional<int> n = {}, + std::optional<int> qid = {}); + + void BeginTransaction(); + + void CommitTransaction(); + + void RollbackTransaction(); + + void SetNextTransactionIsolationLevel(storage::v3::IsolationLevel isolation_level); + void SetSessionIsolationLevel(storage::v3::IsolationLevel isolation_level); + + /** + * Abort the current multicommand transaction. + */ + void Abort(); + + private: + struct QueryExecution { + std::optional<PreparedQuery> prepared_query; + utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize}; + utils::ResourceWithOutOfMemoryException execution_memory_with_exception{&execution_memory}; + + std::map<std::string, TypedValue> summary; + std::vector<Notification> notifications; + + explicit QueryExecution() = default; + QueryExecution(const QueryExecution &) = delete; + QueryExecution(QueryExecution &&) = default; + QueryExecution &operator=(const QueryExecution &) = delete; + QueryExecution &operator=(QueryExecution &&) = default; + + ~QueryExecution() { + // We should always release the execution memory AFTER we + // destroy the prepared query which is using that instance + // of execution memory. + prepared_query.reset(); + execution_memory.Release(); + } + }; + + // Interpreter supports multiple prepared queries at the same time. + // The client can reference a specific query for pull using an arbitrary qid + // which is in our case the index of the query in the vector. + // To simplify the handling of the qid we avoid modifying the vector if it + // affects the position of the currently running queries in any way. + // For example, we cannot delete the prepared query from the vector because + // every prepared query after the deleted one will be moved by one place + // making their qid not equal to the their index inside the vector. + // To avoid this, we use unique_ptr with which we manualy control construction + // and deletion of a single query execution, i.e. when a query finishes, + // we reset the corresponding unique_ptr. + std::vector<std::unique_ptr<QueryExecution>> query_executions_; + + InterpreterContext *interpreter_context_; + + // This cannot be std::optional because we need to move this accessor later on into a lambda capture + // which is assigned to std::function. std::function requires every object to be copyable, so we + // move this unique_ptr into a shrared_ptr. + std::unique_ptr<storage::v3::Storage::Accessor> db_accessor_; + std::optional<DbAccessor> execution_db_accessor_; + std::optional<TriggerContextCollector> trigger_context_collector_; + bool in_explicit_transaction_{false}; + bool expect_rollback_{false}; + + std::optional<storage::v3::IsolationLevel> interpreter_isolation_level; + std::optional<storage::v3::IsolationLevel> next_transaction_isolation_level; + + PreparedQuery PrepareTransactionQuery(std::string_view query_upper); + void Commit(); + void AdvanceCommand(); + void AbortCommand(std::unique_ptr<QueryExecution> *query_execution); + std::optional<storage::v3::IsolationLevel> GetIsolationLevelOverride(); + + size_t ActiveQueryExecutions() { + return std::count_if(query_executions_.begin(), query_executions_.end(), + [](const auto &execution) { return execution && execution->prepared_query; }); + } +}; + +template <typename TStream> +std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std::optional<int> n, + std::optional<int> qid) { + MG_ASSERT(in_explicit_transaction_ || !qid, "qid can be only used in explicit transaction!"); + const int qid_value = qid ? *qid : static_cast<int>(query_executions_.size() - 1); + + if (qid_value < 0 || qid_value >= query_executions_.size()) { + throw InvalidArgumentsException("qid", "Query with specified ID does not exist!"); + } + + if (n && n < 0) { + throw InvalidArgumentsException("n", "Cannot fetch negative number of results!"); + } + + auto &query_execution = query_executions_[qid_value]; + + MG_ASSERT(query_execution && query_execution->prepared_query, "Query already finished executing!"); + + // Each prepared query has its own summary so we need to somehow preserve + // it after it finishes executing because it gets destroyed alongside + // the prepared query and its execution memory. + std::optional<std::map<std::string, TypedValue>> maybe_summary; + try { + // Wrap the (statically polymorphic) stream type into a common type which + // the handler knows. + AnyStream stream{result_stream, &query_execution->execution_memory}; + const auto maybe_res = query_execution->prepared_query->query_handler(&stream, n); + // Stream is using execution memory of the query_execution which + // can be deleted after its execution so the stream should be cleared + // first. + stream.~AnyStream(); + + // If the query finished executing, we have received a value which tells + // us what to do after. + if (maybe_res) { + // Save its summary + maybe_summary.emplace(std::move(query_execution->summary)); + if (!query_execution->notifications.empty()) { + std::vector<TypedValue> notifications; + notifications.reserve(query_execution->notifications.size()); + for (const auto ¬ification : query_execution->notifications) { + notifications.emplace_back(notification.ConvertToMap()); + } + maybe_summary->insert_or_assign("notifications", std::move(notifications)); + } + if (!in_explicit_transaction_) { + switch (*maybe_res) { + case QueryHandlerResult::COMMIT: + Commit(); + break; + case QueryHandlerResult::ABORT: + Abort(); + break; + case QueryHandlerResult::NOTHING: + // The only cases in which we have nothing to do are those where + // we're either in an explicit transaction or the query is such that + // a transaction wasn't started on a call to `Prepare()`. + MG_ASSERT(in_explicit_transaction_ || !db_accessor_); + break; + } + // As the transaction is done we can clear all the executions + // NOTE: we cannot clear query_execution inside the Abort and Commit + // methods as we will delete summary contained in them which we need + // after our query finished executing. + query_executions_.clear(); + } else { + // We can only clear this execution as some of the queries + // in the transaction can be in unfinished state + query_execution.reset(nullptr); + } + } + } catch (const ExplicitTransactionUsageException &) { + query_execution.reset(nullptr); + throw; + } catch (const utils::BasicException &) { + EventCounter::IncrementCounter(EventCounter::FailedQuery); + AbortCommand(&query_execution); + throw; + } + + if (maybe_summary) { + // return the execution summary + maybe_summary->insert_or_assign("has_more", false); + return std::move(*maybe_summary); + } + + // don't return the execution summary as it's not finished + return {{"has_more", TypedValue(true)}}; +} +} // namespace memgraph::query::v2 diff --git a/src/query/v2/metadata.cpp b/src/query/v2/metadata.cpp new file mode 100644 index 000000000..fe7461e79 --- /dev/null +++ b/src/query/v2/metadata.cpp @@ -0,0 +1,117 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/metadata.hpp" + +#include <algorithm> +#include <compare> +#include <string> +#include <string_view> + +namespace memgraph::query::v2 { + +namespace { +using namespace std::literals; + +constexpr std::string_view GetSeverityLevelString(const SeverityLevel level) { + switch (level) { + case SeverityLevel::INFO: + return "INFO"sv; + case SeverityLevel::WARNING: + return "WARNING"sv; + } +} + +constexpr std::string_view GetCodeString(const NotificationCode code) { + switch (code) { + case NotificationCode::CREATE_CONSTRAINT: + return "CreateConstraint"sv; + case NotificationCode::CREATE_INDEX: + return "CreateIndex"sv; + case NotificationCode::CREATE_STREAM: + return "CreateStream"sv; + case NotificationCode::CHECK_STREAM: + return "CheckStream"sv; + case NotificationCode::CREATE_TRIGGER: + return "CreateTrigger"sv; + case NotificationCode::DROP_CONSTRAINT: + return "DropConstraint"sv; + case NotificationCode::DROP_REPLICA: + return "DropReplica"sv; + case NotificationCode::DROP_INDEX: + return "DropIndex"sv; + case NotificationCode::DROP_STREAM: + return "DropStream"sv; + case NotificationCode::DROP_TRIGGER: + return "DropTrigger"sv; + case NotificationCode::EXISTANT_CONSTRAINT: + return "ConstraintAlreadyExists"sv; + case NotificationCode::EXISTANT_INDEX: + return "IndexAlreadyExists"sv; + case NotificationCode::LOAD_CSV_TIP: + return "LoadCSVTip"sv; + case NotificationCode::NONEXISTANT_INDEX: + return "IndexDoesNotExist"sv; + case NotificationCode::NONEXISTANT_CONSTRAINT: + return "ConstraintDoesNotExist"sv; + case NotificationCode::REGISTER_REPLICA: + return "RegisterReplica"sv; + case NotificationCode::REPLICA_PORT_WARNING: + return "ReplicaPortWarning"sv; + case NotificationCode::SET_REPLICA: + return "SetReplica"sv; + case NotificationCode::START_STREAM: + return "StartStream"sv; + case NotificationCode::START_ALL_STREAMS: + return "StartAllStreams"sv; + case NotificationCode::STOP_STREAM: + return "StopStream"sv; + case NotificationCode::STOP_ALL_STREAMS: + return "StopAllStreams"sv; + } +} +} // namespace + +Notification::Notification(SeverityLevel level) : level{level} {}; + +Notification::Notification(SeverityLevel level, NotificationCode code, std::string title, std::string description) + : level{level}, code{code}, title(std::move(title)), description(std::move(description)){}; + +Notification::Notification(SeverityLevel level, NotificationCode code, std::string title) + : level{level}, code{code}, title(std::move(title)){}; + +std::map<std::string, TypedValue> Notification::ConvertToMap() const { + return std::map<std::string, TypedValue>{{"severity", TypedValue(GetSeverityLevelString(level))}, + {"code", TypedValue(GetCodeString(code))}, + {"title", TypedValue(title)}, + {"description", TypedValue(description)}}; +} + +std::string ExecutionStatsKeyToString(const ExecutionStats::Key key) { + switch (key) { + case ExecutionStats::Key::CREATED_NODES: + return std::string("nodes-created"); + case ExecutionStats::Key::DELETED_NODES: + return std::string("nodes-deleted"); + case ExecutionStats::Key::CREATED_EDGES: + return std::string("relationships-created"); + case ExecutionStats::Key::DELETED_EDGES: + return std::string("relationships-deleted"); + case ExecutionStats::Key::CREATED_LABELS: + return std::string("labels-added"); + case ExecutionStats::Key::DELETED_LABELS: + return std::string("labels-removed"); + case ExecutionStats::Key::UPDATED_PROPERTIES: + return std::string("properties-set"); + } +} + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/metadata.hpp b/src/query/v2/metadata.hpp new file mode 100644 index 000000000..ffc621d64 --- /dev/null +++ b/src/query/v2/metadata.hpp @@ -0,0 +1,90 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <cstdint> +#include <map> +#include <string> +#include <string_view> +#include <type_traits> + +#include "query/v2/typed_value.hpp" + +namespace memgraph::query::v2 { + +enum class SeverityLevel : uint8_t { INFO, WARNING }; + +enum class NotificationCode : uint8_t { + CREATE_CONSTRAINT, + CREATE_INDEX, + CHECK_STREAM, + CREATE_STREAM, + CREATE_TRIGGER, + DROP_CONSTRAINT, + DROP_INDEX, + DROP_REPLICA, + DROP_STREAM, + DROP_TRIGGER, + EXISTANT_INDEX, + EXISTANT_CONSTRAINT, + LOAD_CSV_TIP, + NONEXISTANT_INDEX, + NONEXISTANT_CONSTRAINT, + REPLICA_PORT_WARNING, + REGISTER_REPLICA, + SET_REPLICA, + START_STREAM, + START_ALL_STREAMS, + STOP_STREAM, + STOP_ALL_STREAMS, +}; + +struct Notification { + SeverityLevel level; + NotificationCode code; + std::string title; + std::string description; + + explicit Notification(SeverityLevel level); + + Notification(SeverityLevel level, NotificationCode code, std::string title, std::string description); + + Notification(SeverityLevel level, NotificationCode code, std::string title); + + std::map<std::string, TypedValue> ConvertToMap() const; +}; + +struct ExecutionStats { + public: + // All the stats have specific key to be compatible with neo4j + enum class Key : uint8_t { + CREATED_NODES, + DELETED_NODES, + CREATED_EDGES, + DELETED_EDGES, + CREATED_LABELS, + DELETED_LABELS, + UPDATED_PROPERTIES, + }; + + int64_t &operator[](Key key) { return counters[static_cast<size_t>(key)]; } + + private: + static constexpr auto kExecutionStatsCountersSize = std::underlying_type_t<Key>(Key::UPDATED_PROPERTIES) + 1; + + public: + std::array<int64_t, kExecutionStatsCountersSize> counters{0}; +}; + +std::string ExecutionStatsKeyToString(ExecutionStats::Key key); + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/parameters.hpp b/src/query/v2/parameters.hpp new file mode 100644 index 000000000..1e2f0744f --- /dev/null +++ b/src/query/v2/parameters.hpp @@ -0,0 +1,71 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <algorithm> +#include <utility> +#include <vector> + +#include "storage/v3/property_value.hpp" +#include "utils/logging.hpp" + +/** + * Encapsulates user provided parameters (and stripped literals) + * and provides ways of obtaining them by position. + */ +namespace memgraph::query::v2 { + +struct Parameters { + public: + /** + * Adds a value to the stripped arguments under a token position. + * + * @param position Token position in query of value. + * @param value + */ + void Add(int position, const storage::v3::PropertyValue &value) { storage_.emplace_back(position, value); } + + /** + * Returns the value found for the given token position. + * + * @param position Token position in query of value. + * @return Value for the given token position. + */ + const storage::v3::PropertyValue &AtTokenPosition(int position) const { + auto found = std::find_if(storage_.begin(), storage_.end(), [&](const auto &a) { return a.first == position; }); + MG_ASSERT(found != storage_.end(), "Token position must be present in container"); + return found->second; + } + + /** + * Returns the position-th stripped value. Asserts that this + * container has at least (position + 1) elements. + * + * @param position Which stripped param is sought. + * @return Token position and value for sought param. + */ + const std::pair<int, storage::v3::PropertyValue> &At(int position) const { + MG_ASSERT(position < static_cast<int>(storage_.size()), "Invalid position"); + return storage_[position]; + } + + /** Returns the number of arguments in this container */ + auto size() const { return storage_.size(); } + + auto begin() const { return storage_.begin(); } + auto end() const { return storage_.end(); } + + private: + std::vector<std::pair<int, storage::v3::PropertyValue>> storage_; +}; + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/path.hpp b/src/query/v2/path.hpp new file mode 100644 index 000000000..d16e4bba8 --- /dev/null +++ b/src/query/v2/path.hpp @@ -0,0 +1,146 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <functional> +#include <utility> + +#include "query/v2/db_accessor.hpp" +#include "utils/logging.hpp" +#include "utils/memory.hpp" +#include "utils/pmr/vector.hpp" + +namespace memgraph::query::v2 { + +/** + * A data structure that holds a graph path. A path consists of at least one + * vertex, followed by zero or more edge + vertex extensions (thus having one + * vertex more then edges). + */ +class Path { + public: + /** Allocator type so that STL containers are aware that we need one */ + using allocator_type = utils::Allocator<char>; + + /** + * Create the path starting with the given vertex. + * Allocations are done using the given MemoryResource. + */ + explicit Path(const VertexAccessor &vertex, utils::MemoryResource *memory = utils::NewDeleteResource()) + : vertices_(memory), edges_(memory) { + Expand(vertex); + } + + /** + * Create the path starting with the given vertex and containing all other + * elements. + * Allocations are done using the default utils::NewDeleteResource(). + */ + template <typename... TOthers> + explicit Path(const VertexAccessor &vertex, const TOthers &...others) + : vertices_(utils::NewDeleteResource()), edges_(utils::NewDeleteResource()) { + Expand(vertex); + Expand(others...); + } + + /** + * Create the path starting with the given vertex and containing all other + * elements. + * Allocations are done using the given MemoryResource. + */ + template <typename... TOthers> + Path(std::allocator_arg_t, utils::MemoryResource *memory, const VertexAccessor &vertex, const TOthers &...others) + : vertices_(memory), edges_(memory) { + Expand(vertex); + Expand(others...); + } + + /** + * Construct a copy of other. + * utils::MemoryResource is obtained by calling + * std::allocator_traits<>:: + * select_on_container_copy_construction(other.GetMemoryResource()). + * Since we use utils::Allocator, which does not propagate, this means that we + * will default to utils::NewDeleteResource(). + */ + Path(const Path &other) + : Path(other, + std::allocator_traits<allocator_type>::select_on_container_copy_construction(other.GetMemoryResource()) + .GetMemoryResource()) {} + + /** Construct a copy using the given utils::MemoryResource */ + Path(const Path &other, utils::MemoryResource *memory) + : vertices_(other.vertices_, memory), edges_(other.edges_, memory) {} + + /** + * Construct with the value of other. + * utils::MemoryResource is obtained from other. After the move, other will be + * empty. + */ + Path(Path &&other) noexcept : Path(std::move(other), other.GetMemoryResource()) {} + + /** + * Construct with the value of other, but use the given utils::MemoryResource. + * After the move, other may not be empty if `*memory != + * *other.GetMemoryResource()`, because an element-wise move will be + * performed. + */ + Path(Path &&other, utils::MemoryResource *memory) + : vertices_(std::move(other.vertices_), memory), edges_(std::move(other.edges_), memory) {} + + /** Copy assign other, utils::MemoryResource of `this` is used */ + Path &operator=(const Path &) = default; + + /** Move assign other, utils::MemoryResource of `this` is used. */ + Path &operator=(Path &&) = default; + + ~Path() = default; + + /** Expands the path with the given vertex. */ + void Expand(const VertexAccessor &vertex) { + DMG_ASSERT(vertices_.size() == edges_.size(), "Illegal path construction order"); + vertices_.emplace_back(vertex); + } + + /** Expands the path with the given edge. */ + void Expand(const EdgeAccessor &edge) { + DMG_ASSERT(vertices_.size() - 1 == edges_.size(), "Illegal path construction order"); + edges_.emplace_back(edge); + } + + /** Expands the path with the given elements. */ + template <typename TFirst, typename... TOthers> + void Expand(const TFirst &first, const TOthers &...others) { + Expand(first); + Expand(others...); + } + + /** Returns the number of expansions (edges) in this path. */ + auto size() const { return edges_.size(); } + + auto &vertices() { return vertices_; } + auto &edges() { return edges_; } + const auto &vertices() const { return vertices_; } + const auto &edges() const { return edges_; } + + utils::MemoryResource *GetMemoryResource() const { return vertices_.get_allocator().GetMemoryResource(); } + + bool operator==(const Path &other) const { return vertices_ == other.vertices_ && edges_ == other.edges_; } + + private: + // Contains all the vertices in the path. + utils::pmr::vector<VertexAccessor> vertices_; + // Contains all the edges in the path (one less then there are vertices). + utils::pmr::vector<EdgeAccessor> edges_; +}; + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/plan/cost_estimator.hpp b/src/query/v2/plan/cost_estimator.hpp new file mode 100644 index 000000000..07a5fde0b --- /dev/null +++ b/src/query/v2/plan/cost_estimator.hpp @@ -0,0 +1,267 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/parameters.hpp" +#include "query/v2/plan/operator.hpp" +#include "query/v2/typed_value.hpp" + +namespace memgraph::query::v2::plan { + +/** + * Query plan execution time cost estimator, for comparing and choosing optimal + * execution plans. + * + * In Cypher the write part of the query always executes in the same + * cardinality. It is not allowed to execute a write operation before all the + * expansion for that query part (WITH splits a query into parts) have executed. + * For that reason cost estimation comes down to cardinality estimation for the + * read parts of the query, and their expansion. We want to compare different + * plans and try to figure out which has the optimal organization of scans, + * expansions and filters. + * + * Note that expansions and filtering can also happen during Merge, which is a + * write operation. We let that get evaluated like all other cardinality + * influencing ops. Also, Merge cardinality modification should be contained (it + * can never reduce it's input cardinality), but since Merge always happens + * after the read part, and can't be reoredered, we can ignore that. + * + * Limiting and accumulating (Aggregate, OrderBy, Accumulate) operations are + * cardinality modifiers that always execute at the end of the query part. Their + * cardinality influence is irrelevant because they execute the same + * for all plans for a single query part, and query part reordering is not + * allowed. + * + * This kind of cost estimation can only be used for comparing logical plans. + * It's aim is to estimate cost(A) to be less then cost(B) in every case where + * actual query execution for plan A is less then that of plan B. It can NOT be + * used to estimate how MUCH execution between A and B will differ. + */ +template <class TDbAccessor> +class CostEstimator : public HierarchicalLogicalOperatorVisitor { + public: + struct CostParam { + static constexpr double kScanAll{1.0}; + static constexpr double kScanAllByLabel{1.1}; + static constexpr double MakeScanAllByLabelPropertyValue{1.1}; + static constexpr double MakeScanAllByLabelPropertyRange{1.1}; + static constexpr double MakeScanAllByLabelProperty{1.1}; + static constexpr double kExpand{2.0}; + static constexpr double kExpandVariable{3.0}; + static constexpr double kFilter{1.5}; + static constexpr double kEdgeUniquenessFilter{1.5}; + static constexpr double kUnwind{1.3}; + static constexpr double kForeach{1.0}; + }; + + struct CardParam { + static constexpr double kExpand{3.0}; + static constexpr double kExpandVariable{9.0}; + static constexpr double kFilter{0.25}; + static constexpr double kEdgeUniquenessFilter{0.95}; + }; + + struct MiscParam { + static constexpr double kUnwindNoLiteral{10.0}; + static constexpr double kForeachNoLiteral{10.0}; + }; + + using HierarchicalLogicalOperatorVisitor::PostVisit; + using HierarchicalLogicalOperatorVisitor::PreVisit; + + CostEstimator(TDbAccessor *db_accessor, const Parameters ¶meters) + : db_accessor_(db_accessor), parameters(parameters) {} + + bool PostVisit(ScanAll &) override { + cardinality_ *= db_accessor_->VerticesCount(); + // ScanAll performs some work for every element that is produced + IncrementCost(CostParam::kScanAll); + return true; + } + + bool PostVisit(ScanAllByLabel &scan_all_by_label) override { + cardinality_ *= db_accessor_->VerticesCount(scan_all_by_label.label_); + // ScanAll performs some work for every element that is produced + IncrementCost(CostParam::kScanAllByLabel); + return true; + } + + bool PostVisit(ScanAllByLabelPropertyValue &logical_op) override { + // This cardinality estimation depends on the property value (expression). + // If it's a constant, we can evaluate cardinality exactly, otherwise + // we estimate + auto property_value = ConstPropertyValue(logical_op.expression_); + double factor = 1.0; + if (property_value) + // get the exact influence based on ScanAll(label, property, value) + factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_, property_value.value()); + else + // estimate the influence as ScanAll(label, property) * filtering + factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_) * CardParam::kFilter; + + cardinality_ *= factor; + + // ScanAll performs some work for every element that is produced + IncrementCost(CostParam::MakeScanAllByLabelPropertyValue); + return true; + } + + bool PostVisit(ScanAllByLabelPropertyRange &logical_op) override { + // this cardinality estimation depends on Bound expressions. + // if they are literals we can evaluate cardinality properly + auto lower = BoundToPropertyValue(logical_op.lower_bound_); + auto upper = BoundToPropertyValue(logical_op.upper_bound_); + + int64_t factor = 1; + if (upper || lower) + // if we have either Bound<PropertyValue>, use the value index + factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_, lower, upper); + else + // no values, but we still have the label + factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_); + + // if we failed to take either bound from the op into account, then apply + // the filtering constant to the factor + if ((logical_op.upper_bound_ && !upper) || (logical_op.lower_bound_ && !lower)) factor *= CardParam::kFilter; + + cardinality_ *= factor; + + // ScanAll performs some work for every element that is produced + IncrementCost(CostParam::MakeScanAllByLabelPropertyRange); + return true; + } + + bool PostVisit(ScanAllByLabelProperty &logical_op) override { + const auto factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_); + cardinality_ *= factor; + IncrementCost(CostParam::MakeScanAllByLabelProperty); + return true; + } + + // TODO: Cost estimate ScanAllById? + +// For the given op first increments the cardinality and then cost. +#define POST_VISIT_CARD_FIRST(NAME) \ + bool PostVisit(NAME &) override { \ + cardinality_ *= CardParam::k##NAME; \ + IncrementCost(CostParam::k##NAME); \ + return true; \ + } + + POST_VISIT_CARD_FIRST(Expand); + POST_VISIT_CARD_FIRST(ExpandVariable); + +#undef POST_VISIT_CARD_FIRST + +// For the given op first increments the cost and then cardinality. +#define POST_VISIT_COST_FIRST(LOGICAL_OP, PARAM_NAME) \ + bool PostVisit(LOGICAL_OP &) override { \ + IncrementCost(CostParam::PARAM_NAME); \ + cardinality_ *= CardParam::PARAM_NAME; \ + return true; \ + } + + POST_VISIT_COST_FIRST(Filter, kFilter) + POST_VISIT_COST_FIRST(EdgeUniquenessFilter, kEdgeUniquenessFilter); + +#undef POST_VISIT_COST_FIRST + + bool PostVisit(Unwind &unwind) override { + // Unwind cost depends more on the number of lists that get unwound + // much less on the number of outputs + // for that reason first increment cost, then modify cardinality + IncrementCost(CostParam::kUnwind); + + // try to determine how many values will be yielded by Unwind + // if the Unwind expression is a list literal, we can deduce cardinality + // exactly, otherwise we approximate + double unwind_value; + if (auto *literal = utils::Downcast<query::v2::ListLiteral>(unwind.input_expression_)) + unwind_value = literal->elements_.size(); + else + unwind_value = MiscParam::kUnwindNoLiteral; + + cardinality_ *= unwind_value; + return true; + } + + bool PostVisit(Foreach &foreach) override { + // Foreach cost depends both on the number elements in the list that get unwound + // as well as the total clauses that get called for each unwounded element. + // First estimate cardinality and then increment the cost. + + double foreach_elements{0}; + if (auto *literal = utils::Downcast<query::v2::ListLiteral>(foreach.expression_)) { + foreach_elements = literal->elements_.size(); + } else { + foreach_elements = MiscParam::kForeachNoLiteral; + } + + cardinality_ *= foreach_elements; + IncrementCost(CostParam::kForeach); + return true; + } + + bool Visit(Once &) override { return true; } + + auto cost() const { return cost_; } + auto cardinality() const { return cardinality_; } + + private: + // cost estimation that gets accumulated as the visitor + // tours the logical plan + double cost_{0}; + + // cardinality estimation (how many times an operator gets executed) + // cardinality is a double to make it easier to work with + double cardinality_{1}; + + // accessor used for cardinality estimates in ScanAll and ScanAllByLabel + TDbAccessor *db_accessor_; + const Parameters ¶meters; + + void IncrementCost(double param) { cost_ += param * cardinality_; } + + // converts an optional ScanAll range bound into a property value + // if the bound is present and is a constant expression convertible to + // a property value. otherwise returns nullopt + std::optional<utils::Bound<storage::v3::PropertyValue>> BoundToPropertyValue( + std::optional<ScanAllByLabelPropertyRange::Bound> bound) { + if (bound) { + auto property_value = ConstPropertyValue(bound->value()); + if (property_value) return utils::Bound<storage::v3::PropertyValue>(*property_value, bound->type()); + } + return std::nullopt; + } + + // If the expression is a constant property value, it is returned. Otherwise, + // return nullopt. + std::optional<storage::v3::PropertyValue> ConstPropertyValue(const Expression *expression) { + if (auto *literal = utils::Downcast<const PrimitiveLiteral>(expression)) { + return literal->value_; + } else if (auto *param_lookup = utils::Downcast<const ParameterLookup>(expression)) { + return parameters.AtTokenPosition(param_lookup->token_position_); + } + return std::nullopt; + } +}; + +/** Returns the estimated cost of the given plan. */ +template <class TDbAccessor> +double EstimatePlanCost(TDbAccessor *db, const Parameters ¶meters, LogicalOperator &plan) { + CostEstimator<TDbAccessor> estimator(db, parameters); + plan.Accept(estimator); + return estimator.cost(); +} + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp new file mode 100644 index 000000000..4dd0bf693 --- /dev/null +++ b/src/query/v2/plan/operator.cpp @@ -0,0 +1,4081 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/plan/operator.hpp" + +#include <algorithm> +#include <cstdint> +#include <limits> +#include <queue> +#include <random> +#include <string> +#include <tuple> +#include <type_traits> +#include <unordered_map> +#include <unordered_set> +#include <utility> + +#include <cppitertools/chain.hpp> +#include <cppitertools/imap.hpp> + +#include "query/v2/context.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/semantic/symbol_table.hpp" +#include "query/v2/interpret/eval.hpp" +#include "query/v2/path.hpp" +#include "query/v2/plan/scoped_profile.hpp" +#include "query/v2/procedure/cypher_types.hpp" +#include "query/v2/procedure/mg_procedure_impl.hpp" +#include "query/v2/procedure/module.hpp" +#include "storage/v3/property_value.hpp" +#include "utils/algorithm.hpp" +#include "utils/csv_parsing.hpp" +#include "utils/event_counter.hpp" +#include "utils/exceptions.hpp" +#include "utils/fnv.hpp" +#include "utils/likely.hpp" +#include "utils/logging.hpp" +#include "utils/memory.hpp" +#include "utils/pmr/unordered_map.hpp" +#include "utils/pmr/unordered_set.hpp" +#include "utils/pmr/vector.hpp" +#include "utils/readable_size.hpp" +#include "utils/string.hpp" +#include "utils/temporal.hpp" + +// macro for the default implementation of LogicalOperator::Accept +// that accepts the visitor and visits it's input_ operator +#define ACCEPT_WITH_INPUT(class_name) \ + bool class_name::Accept(HierarchicalLogicalOperatorVisitor &visitor) { \ + if (visitor.PreVisit(*this)) { \ + input_->Accept(visitor); \ + } \ + return visitor.PostVisit(*this); \ + } + +#define WITHOUT_SINGLE_INPUT(class_name) \ + bool class_name::HasSingleInput() const { return false; } \ + std::shared_ptr<LogicalOperator> class_name::input() const { \ + LOG_FATAL("Operator " #class_name " has no single input!"); \ + } \ + void class_name::set_input(std::shared_ptr<LogicalOperator>) { \ + LOG_FATAL("Operator " #class_name " has no single input!"); \ + } + +namespace EventCounter { +extern const Event OnceOperator; +extern const Event CreateNodeOperator; +extern const Event CreateExpandOperator; +extern const Event ScanAllOperator; +extern const Event ScanAllByLabelOperator; +extern const Event ScanAllByLabelPropertyRangeOperator; +extern const Event ScanAllByLabelPropertyValueOperator; +extern const Event ScanAllByLabelPropertyOperator; +extern const Event ScanAllByIdOperator; +extern const Event ExpandOperator; +extern const Event ExpandVariableOperator; +extern const Event ConstructNamedPathOperator; +extern const Event FilterOperator; +extern const Event ProduceOperator; +extern const Event DeleteOperator; +extern const Event SetPropertyOperator; +extern const Event SetPropertiesOperator; +extern const Event SetLabelsOperator; +extern const Event RemovePropertyOperator; +extern const Event RemoveLabelsOperator; +extern const Event EdgeUniquenessFilterOperator; +extern const Event AccumulateOperator; +extern const Event AggregateOperator; +extern const Event SkipOperator; +extern const Event LimitOperator; +extern const Event OrderByOperator; +extern const Event MergeOperator; +extern const Event OptionalOperator; +extern const Event UnwindOperator; +extern const Event DistinctOperator; +extern const Event UnionOperator; +extern const Event CartesianOperator; +extern const Event CallProcedureOperator; +extern const Event ForeachOperator; +} // namespace EventCounter + +namespace memgraph::query::v2::plan { + +namespace { + +// Custom equality function for a vector of typed values. +// Used in unordered_maps in Aggregate and Distinct operators. +struct TypedValueVectorEqual { + template <class TAllocator> + bool operator()(const std::vector<TypedValue, TAllocator> &left, + const std::vector<TypedValue, TAllocator> &right) const { + MG_ASSERT(left.size() == right.size(), + "TypedValueVector comparison should only be done over vectors " + "of the same size"); + return std::equal(left.begin(), left.end(), right.begin(), TypedValue::BoolEqual{}); + } +}; + +// Returns boolean result of evaluating filter expression. Null is treated as +// false. Other non boolean values raise a QueryRuntimeException. +bool EvaluateFilter(ExpressionEvaluator &evaluator, Expression *filter) { + TypedValue result = filter->Accept(evaluator); + // Null is treated like false. + if (result.IsNull()) return false; + if (result.type() != TypedValue::Type::Bool) + throw QueryRuntimeException("Filter expression must evaluate to bool or null, got {}.", result.type()); + return result.ValueBool(); +} + +template <typename T> +uint64_t ComputeProfilingKey(const T *obj) { + static_assert(sizeof(T *) == sizeof(uint64_t)); + return reinterpret_cast<uint64_t>(obj); +} + +} // namespace + +#define SCOPED_PROFILE_OP(name) ScopedProfile profile{ComputeProfilingKey(this), name, &context}; + +bool Once::OnceCursor::Pull(Frame &, ExecutionContext &context) { + SCOPED_PROFILE_OP("Once"); + + if (!did_pull_) { + did_pull_ = true; + return true; + } + return false; +} + +UniqueCursorPtr Once::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::OnceOperator); + + return MakeUniqueCursorPtr<OnceCursor>(mem); +} + +WITHOUT_SINGLE_INPUT(Once); + +void Once::OnceCursor::Shutdown() {} + +void Once::OnceCursor::Reset() { did_pull_ = false; } + +CreateNode::CreateNode(const std::shared_ptr<LogicalOperator> &input, const NodeCreationInfo &node_info) + : input_(input ? input : std::make_shared<Once>()), node_info_(node_info) {} + +// Creates a vertex on this GraphDb. Returns a reference to vertex placed on the +// frame. +VertexAccessor &CreateLocalVertex(const NodeCreationInfo &node_info, Frame *frame, ExecutionContext &context) { + auto &dba = *context.db_accessor; + auto new_node = dba.InsertVertex(); + context.execution_stats[ExecutionStats::Key::CREATED_NODES] += 1; + for (auto label : node_info.labels) { + auto maybe_error = new_node.AddLabel(label); + if (maybe_error.HasError()) { + switch (maybe_error.GetError()) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set a label on a deleted node."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a label."); + } + } + context.execution_stats[ExecutionStats::Key::CREATED_LABELS] += 1; + } + // Evaluator should use the latest accessors, as modified in this query, when + // setting properties on new nodes. + ExpressionEvaluator evaluator(frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::NEW); + // TODO: PropsSetChecked allocates a PropertyValue, make it use context.memory + // when we update PropertyValue with custom allocator. + if (const auto *node_info_properties = std::get_if<PropertiesMapList>(&node_info.properties)) { + for (const auto &[key, value_expression] : *node_info_properties) { + PropsSetChecked(&new_node, key, value_expression->Accept(evaluator)); + } + } else { + auto property_map = evaluator.Visit(*std::get<ParameterLookup *>(node_info.properties)); + for (const auto &[key, value] : property_map.ValueMap()) { + auto property_id = dba.NameToProperty(key); + PropsSetChecked(&new_node, property_id, value); + } + } + + (*frame)[node_info.symbol] = new_node; + return (*frame)[node_info.symbol].ValueVertex(); +} + +ACCEPT_WITH_INPUT(CreateNode) + +UniqueCursorPtr CreateNode::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::CreateNodeOperator); + + return MakeUniqueCursorPtr<CreateNodeCursor>(mem, *this, mem); +} + +std::vector<Symbol> CreateNode::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + symbols.emplace_back(node_info_.symbol); + return symbols; +} + +CreateNode::CreateNodeCursor::CreateNodeCursor(const CreateNode &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} + +bool CreateNode::CreateNodeCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("CreateNode"); + + if (input_cursor_->Pull(frame, context)) { + auto created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); + if (context.trigger_context_collector) { + context.trigger_context_collector->RegisterCreatedObject(created_vertex); + } + return true; + } + + return false; +} + +void CreateNode::CreateNodeCursor::Shutdown() { input_cursor_->Shutdown(); } + +void CreateNode::CreateNodeCursor::Reset() { input_cursor_->Reset(); } + +CreateExpand::CreateExpand(const NodeCreationInfo &node_info, const EdgeCreationInfo &edge_info, + const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, bool existing_node) + : node_info_(node_info), + edge_info_(edge_info), + input_(input ? input : std::make_shared<Once>()), + input_symbol_(input_symbol), + existing_node_(existing_node) {} + +ACCEPT_WITH_INPUT(CreateExpand) + +UniqueCursorPtr CreateExpand::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::CreateNodeOperator); + + return MakeUniqueCursorPtr<CreateExpandCursor>(mem, *this, mem); +} + +std::vector<Symbol> CreateExpand::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + symbols.emplace_back(node_info_.symbol); + symbols.emplace_back(edge_info_.symbol); + return symbols; +} + +CreateExpand::CreateExpandCursor::CreateExpandCursor(const CreateExpand &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} + +namespace { + +EdgeAccessor CreateEdge(const EdgeCreationInfo &edge_info, DbAccessor *dba, VertexAccessor *from, VertexAccessor *to, + Frame *frame, ExpressionEvaluator *evaluator) { + auto maybe_edge = dba->InsertEdge(from, to, edge_info.edge_type); + if (maybe_edge.HasValue()) { + auto &edge = *maybe_edge; + if (const auto *properties = std::get_if<PropertiesMapList>(&edge_info.properties)) { + for (const auto &[key, value_expression] : *properties) { + PropsSetChecked(&edge, key, value_expression->Accept(*evaluator)); + } + } else { + auto property_map = evaluator->Visit(*std::get<ParameterLookup *>(edge_info.properties)); + for (const auto &[key, value] : property_map.ValueMap()) { + auto property_id = dba->NameToProperty(key); + PropsSetChecked(&edge, property_id, value); + } + } + + (*frame)[edge_info.symbol] = edge; + } else { + switch (maybe_edge.GetError()) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to create an edge on a deleted node."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when creating an edge."); + } + } + + return *maybe_edge; +} + +} // namespace + +bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("CreateExpand"); + + if (!input_cursor_->Pull(frame, context)) return false; + + // get the origin vertex + TypedValue &vertex_value = frame[self_.input_symbol_]; + ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); + auto &v1 = vertex_value.ValueVertex(); + + // Similarly to CreateNode, newly created edges and nodes should use the + // storage::v3::View::NEW. + // E.g. we pickup new properties: `CREATE (n {p: 42}) -[:r {ep: n.p}]-> ()` + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::NEW); + + // get the destination vertex (possibly an existing node) + auto &v2 = OtherVertex(frame, context); + + // create an edge between the two nodes + auto *dba = context.db_accessor; + + auto created_edge = [&] { + switch (self_.edge_info_.direction) { + case EdgeAtom::Direction::IN: + return CreateEdge(self_.edge_info_, dba, &v2, &v1, &frame, &evaluator); + case EdgeAtom::Direction::OUT: + // in the case of an undirected CreateExpand we choose an arbitrary + // direction. this is used in the MERGE clause + // it is not allowed in the CREATE clause, and the semantic + // checker needs to ensure it doesn't reach this point + case EdgeAtom::Direction::BOTH: + return CreateEdge(self_.edge_info_, dba, &v1, &v2, &frame, &evaluator); + } + }(); + + context.execution_stats[ExecutionStats::Key::CREATED_EDGES] += 1; + if (context.trigger_context_collector) { + context.trigger_context_collector->RegisterCreatedObject(created_edge); + } + + return true; +} + +void CreateExpand::CreateExpandCursor::Shutdown() { input_cursor_->Shutdown(); } + +void CreateExpand::CreateExpandCursor::Reset() { input_cursor_->Reset(); } + +VertexAccessor &CreateExpand::CreateExpandCursor::OtherVertex(Frame &frame, ExecutionContext &context) { + if (self_.existing_node_) { + TypedValue &dest_node_value = frame[self_.node_info_.symbol]; + ExpectType(self_.node_info_.symbol, dest_node_value, TypedValue::Type::Vertex); + return dest_node_value.ValueVertex(); + } else { + auto &created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); + if (context.trigger_context_collector) { + context.trigger_context_collector->RegisterCreatedObject(created_vertex); + } + return created_vertex; + } +} + +template <class TVerticesFun> +class ScanAllCursor : public Cursor { + public: + explicit ScanAllCursor(Symbol output_symbol, UniqueCursorPtr input_cursor, TVerticesFun get_vertices, + const char *op_name) + : output_symbol_(output_symbol), + input_cursor_(std::move(input_cursor)), + get_vertices_(std::move(get_vertices)), + op_name_(op_name) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP(op_name_); + + if (MustAbort(context)) throw HintedAbortError(); + + while (!vertices_ || vertices_it_.value() == vertices_.value().end()) { + if (!input_cursor_->Pull(frame, context)) return false; + // We need a getter function, because in case of exhausting a lazy + // iterable, we cannot simply reset it by calling begin(). + auto next_vertices = get_vertices_(frame, context); + if (!next_vertices) continue; + // Since vertices iterator isn't nothrow_move_assignable, we have to use + // the roundabout assignment + emplace, instead of simple: + // vertices _ = get_vertices_(frame, context); + vertices_.emplace(std::move(next_vertices.value())); + vertices_it_.emplace(vertices_.value().begin()); + } + + frame[output_symbol_] = *vertices_it_.value(); + ++vertices_it_.value(); + return true; + } + + void Shutdown() override { input_cursor_->Shutdown(); } + + void Reset() override { + input_cursor_->Reset(); + vertices_ = std::nullopt; + vertices_it_ = std::nullopt; + } + + private: + const Symbol output_symbol_; + const UniqueCursorPtr input_cursor_; + TVerticesFun get_vertices_; + std::optional<typename std::result_of<TVerticesFun(Frame &, ExecutionContext &)>::type::value_type> vertices_; + std::optional<decltype(vertices_.value().begin())> vertices_it_; + const char *op_name_; +}; + +ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::v3::View view) + : input_(input ? input : std::make_shared<Once>()), output_symbol_(output_symbol), view_(view) {} + +ACCEPT_WITH_INPUT(ScanAll) + +UniqueCursorPtr ScanAll::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ScanAllOperator); + + auto vertices = [this](Frame &, ExecutionContext &context) { + auto *db = context.db_accessor; + return std::make_optional(db->Vertices(view_)); + }; + return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem), + std::move(vertices), "ScanAll"); +} + +std::vector<Symbol> ScanAll::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + symbols.emplace_back(output_symbol_); + return symbols; +} + +ScanAllByLabel::ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, + storage::v3::LabelId label, storage::v3::View view) + : ScanAll(input, output_symbol, view), label_(label) {} + +ACCEPT_WITH_INPUT(ScanAllByLabel) + +UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ScanAllByLabelOperator); + + auto vertices = [this](Frame &, ExecutionContext &context) { + auto *db = context.db_accessor; + return std::make_optional(db->Vertices(view_, label_)); + }; + return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem), + std::move(vertices), "ScanAllByLabel"); +} + +// TODO(buda): Implement ScanAllByLabelProperty operator to iterate over +// vertices that have the label and some value for the given property. + +ScanAllByLabelPropertyRange::ScanAllByLabelPropertyRange(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, storage::v3::LabelId label, + storage::v3::PropertyId property, + const std::string &property_name, + std::optional<Bound> lower_bound, + std::optional<Bound> upper_bound, storage::v3::View view) + : ScanAll(input, output_symbol, view), + label_(label), + property_(property), + property_name_(property_name), + lower_bound_(lower_bound), + upper_bound_(upper_bound) { + MG_ASSERT(lower_bound_ || upper_bound_, "Only one bound can be left out"); +} + +ACCEPT_WITH_INPUT(ScanAllByLabelPropertyRange) + +UniqueCursorPtr ScanAllByLabelPropertyRange::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ScanAllByLabelPropertyRangeOperator); + + auto vertices = [this](Frame &frame, ExecutionContext &context) + -> std::optional<decltype(context.db_accessor->Vertices(view_, label_, property_, std::nullopt, std::nullopt))> { + auto *db = context.db_accessor; + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_); + auto convert = [&evaluator](const auto &bound) -> std::optional<utils::Bound<storage::v3::PropertyValue>> { + if (!bound) return std::nullopt; + const auto &value = bound->value()->Accept(evaluator); + try { + const auto &property_value = storage::v3::PropertyValue(value); + switch (property_value.type()) { + case storage::v3::PropertyValue::Type::Bool: + case storage::v3::PropertyValue::Type::List: + case storage::v3::PropertyValue::Type::Map: + // Prevent indexed lookup with something that would fail if we did + // the original filter with `operator<`. Note, for some reason, + // Cypher does not support comparing boolean values. + throw QueryRuntimeException("Invalid type {} for '<'.", value.type()); + case storage::v3::PropertyValue::Type::Null: + case storage::v3::PropertyValue::Type::Int: + case storage::v3::PropertyValue::Type::Double: + case storage::v3::PropertyValue::Type::String: + case storage::v3::PropertyValue::Type::TemporalData: + // These are all fine, there's also Point, Date and Time data types + // which were added to Cypher, but we don't have support for those + // yet. + return std::make_optional(utils::Bound<storage::v3::PropertyValue>(property_value, bound->type())); + } + } catch (const TypedValueException &) { + throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type()); + } + }; + auto maybe_lower = convert(lower_bound_); + auto maybe_upper = convert(upper_bound_); + // If any bound is null, then the comparison would result in nulls. This + // is treated as not satisfying the filter, so return no vertices. + if (maybe_lower && maybe_lower->value().IsNull()) return std::nullopt; + if (maybe_upper && maybe_upper->value().IsNull()) return std::nullopt; + return std::make_optional(db->Vertices(view_, label_, property_, maybe_lower, maybe_upper)); + }; + return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem), + std::move(vertices), "ScanAllByLabelPropertyRange"); +} + +ScanAllByLabelPropertyValue::ScanAllByLabelPropertyValue(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, storage::v3::LabelId label, + storage::v3::PropertyId property, + const std::string &property_name, Expression *expression, + storage::v3::View view) + : ScanAll(input, output_symbol, view), + label_(label), + property_(property), + property_name_(property_name), + expression_(expression) { + DMG_ASSERT(expression, "Expression is not optional."); +} + +ACCEPT_WITH_INPUT(ScanAllByLabelPropertyValue) + +UniqueCursorPtr ScanAllByLabelPropertyValue::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ScanAllByLabelPropertyValueOperator); + + auto vertices = + [this](Frame &frame, ExecutionContext &context) -> std::optional<decltype(context.db_accessor->Vertices( + view_, label_, property_, storage::v3::PropertyValue()))> { + auto *db = context.db_accessor; + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_); + auto value = expression_->Accept(evaluator); + if (value.IsNull()) return std::nullopt; + if (!value.IsPropertyValue()) { + throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type()); + } + return std::make_optional(db->Vertices(view_, label_, property_, storage::v3::PropertyValue(value))); + }; + return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem), + std::move(vertices), "ScanAllByLabelPropertyValue"); +} + +ScanAllByLabelProperty::ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, + storage::v3::LabelId label, storage::v3::PropertyId property, + const std::string &property_name, storage::v3::View view) + : ScanAll(input, output_symbol, view), label_(label), property_(property), property_name_(property_name) {} + +ACCEPT_WITH_INPUT(ScanAllByLabelProperty) + +UniqueCursorPtr ScanAllByLabelProperty::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ScanAllByLabelPropertyOperator); + + auto vertices = [this](Frame &frame, ExecutionContext &context) { + auto *db = context.db_accessor; + return std::make_optional(db->Vertices(view_, label_, property_)); + }; + return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem), + std::move(vertices), "ScanAllByLabelProperty"); +} + +ScanAllById::ScanAllById(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, Expression *expression, + storage::v3::View view) + : ScanAll(input, output_symbol, view), expression_(expression) { + MG_ASSERT(expression); +} + +ACCEPT_WITH_INPUT(ScanAllById) + +UniqueCursorPtr ScanAllById::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ScanAllByIdOperator); + + auto vertices = [this](Frame &frame, ExecutionContext &context) -> std::optional<std::vector<VertexAccessor>> { + auto *db = context.db_accessor; + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_); + auto value = expression_->Accept(evaluator); + if (!value.IsNumeric()) return std::nullopt; + int64_t id = value.IsInt() ? value.ValueInt() : value.ValueDouble(); + if (value.IsDouble() && id != value.ValueDouble()) return std::nullopt; + auto maybe_vertex = db->FindVertex(storage::v3::Gid::FromInt(id), view_); + if (!maybe_vertex) return std::nullopt; + return std::vector<VertexAccessor>{*maybe_vertex}; + }; + return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem), + std::move(vertices), "ScanAllById"); +} + +namespace { +bool CheckExistingNode(const VertexAccessor &new_node, const Symbol &existing_node_sym, Frame &frame) { + const TypedValue &existing_node = frame[existing_node_sym]; + if (existing_node.IsNull()) return false; + ExpectType(existing_node_sym, existing_node, TypedValue::Type::Vertex); + return existing_node.ValueVertex() == new_node; +} + +template <class TEdges> +auto UnwrapEdgesResult(storage::v3::Result<TEdges> &&result) { + if (result.HasError()) { + switch (result.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to get relationships of a deleted node."); + case storage::v3::Error::NONEXISTENT_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get relationships from a node that doesn't exist."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::SERIALIZATION_ERROR: + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Unexpected error when accessing relationships."); + } + } + return std::move(*result); +} + +} // namespace + +Expand::Expand(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Symbol node_symbol, + Symbol edge_symbol, EdgeAtom::Direction direction, + const std::vector<storage::v3::EdgeTypeId> &edge_types, bool existing_node, storage::v3::View view) + : input_(input ? input : std::make_shared<Once>()), + input_symbol_(input_symbol), + common_{node_symbol, edge_symbol, direction, edge_types, existing_node}, + view_(view) {} + +ACCEPT_WITH_INPUT(Expand) + +UniqueCursorPtr Expand::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ExpandOperator); + + return MakeUniqueCursorPtr<ExpandCursor>(mem, *this, mem); +} + +std::vector<Symbol> Expand::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + symbols.emplace_back(common_.node_symbol); + symbols.emplace_back(common_.edge_symbol); + return symbols; +} + +Expand::ExpandCursor::ExpandCursor(const Expand &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} + +bool Expand::ExpandCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("Expand"); + + // A helper function for expanding a node from an edge. + auto pull_node = [this, &frame](const EdgeAccessor &new_edge, EdgeAtom::Direction direction) { + if (self_.common_.existing_node) return; + switch (direction) { + case EdgeAtom::Direction::IN: + frame[self_.common_.node_symbol] = new_edge.From(); + break; + case EdgeAtom::Direction::OUT: + frame[self_.common_.node_symbol] = new_edge.To(); + break; + case EdgeAtom::Direction::BOTH: + LOG_FATAL("Must indicate exact expansion direction here"); + } + }; + + while (true) { + if (MustAbort(context)) throw HintedAbortError(); + // attempt to get a value from the incoming edges + if (in_edges_ && *in_edges_it_ != in_edges_->end()) { + auto edge = *(*in_edges_it_)++; + frame[self_.common_.edge_symbol] = edge; + pull_node(edge, EdgeAtom::Direction::IN); + return true; + } + + // attempt to get a value from the outgoing edges + if (out_edges_ && *out_edges_it_ != out_edges_->end()) { + auto edge = *(*out_edges_it_)++; + // when expanding in EdgeAtom::Direction::BOTH directions + // we should do only one expansion for cycles, and it was + // already done in the block above + if (self_.common_.direction == EdgeAtom::Direction::BOTH && edge.IsCycle()) continue; + frame[self_.common_.edge_symbol] = edge; + pull_node(edge, EdgeAtom::Direction::OUT); + return true; + } + + // If we are here, either the edges have not been initialized, + // or they have been exhausted. Attempt to initialize the edges. + if (!InitEdges(frame, context)) return false; + + // we have re-initialized the edges, continue with the loop + } +} + +void Expand::ExpandCursor::Shutdown() { input_cursor_->Shutdown(); } + +void Expand::ExpandCursor::Reset() { + input_cursor_->Reset(); + in_edges_ = std::nullopt; + in_edges_it_ = std::nullopt; + out_edges_ = std::nullopt; + out_edges_it_ = std::nullopt; +} + +bool Expand::ExpandCursor::InitEdges(Frame &frame, ExecutionContext &context) { + // Input Vertex could be null if it is created by a failed optional match. In + // those cases we skip that input pull and continue with the next. + while (true) { + if (!input_cursor_->Pull(frame, context)) return false; + TypedValue &vertex_value = frame[self_.input_symbol_]; + + // Null check due to possible failed optional match. + if (vertex_value.IsNull()) continue; + + ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); + auto &vertex = vertex_value.ValueVertex(); + + auto direction = self_.common_.direction; + if (direction == EdgeAtom::Direction::IN || direction == EdgeAtom::Direction::BOTH) { + if (self_.common_.existing_node) { + TypedValue &existing_node = frame[self_.common_.node_symbol]; + // old_node_value may be Null when using optional matching + if (!existing_node.IsNull()) { + ExpectType(self_.common_.node_symbol, existing_node, TypedValue::Type::Vertex); + in_edges_.emplace( + UnwrapEdgesResult(vertex.InEdges(self_.view_, self_.common_.edge_types, existing_node.ValueVertex()))); + } + } else { + in_edges_.emplace(UnwrapEdgesResult(vertex.InEdges(self_.view_, self_.common_.edge_types))); + } + if (in_edges_) { + in_edges_it_.emplace(in_edges_->begin()); + } + } + + if (direction == EdgeAtom::Direction::OUT || direction == EdgeAtom::Direction::BOTH) { + if (self_.common_.existing_node) { + TypedValue &existing_node = frame[self_.common_.node_symbol]; + // old_node_value may be Null when using optional matching + if (!existing_node.IsNull()) { + ExpectType(self_.common_.node_symbol, existing_node, TypedValue::Type::Vertex); + out_edges_.emplace( + UnwrapEdgesResult(vertex.OutEdges(self_.view_, self_.common_.edge_types, existing_node.ValueVertex()))); + } + } else { + out_edges_.emplace(UnwrapEdgesResult(vertex.OutEdges(self_.view_, self_.common_.edge_types))); + } + if (out_edges_) { + out_edges_it_.emplace(out_edges_->begin()); + } + } + + return true; + } +} + +ExpandVariable::ExpandVariable(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Symbol node_symbol, + Symbol edge_symbol, EdgeAtom::Type type, EdgeAtom::Direction direction, + const std::vector<storage::v3::EdgeTypeId> &edge_types, bool is_reverse, + Expression *lower_bound, Expression *upper_bound, bool existing_node, + ExpansionLambda filter_lambda, std::optional<ExpansionLambda> weight_lambda, + std::optional<Symbol> total_weight) + : input_(input ? input : std::make_shared<Once>()), + input_symbol_(input_symbol), + common_{node_symbol, edge_symbol, direction, edge_types, existing_node}, + type_(type), + is_reverse_(is_reverse), + lower_bound_(lower_bound), + upper_bound_(upper_bound), + filter_lambda_(filter_lambda), + weight_lambda_(weight_lambda), + total_weight_(total_weight) { + DMG_ASSERT(type_ == EdgeAtom::Type::DEPTH_FIRST || type_ == EdgeAtom::Type::BREADTH_FIRST || + type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH, + "ExpandVariable can only be used with breadth first, depth first or " + "weighted shortest path type"); + DMG_ASSERT(!(type_ == EdgeAtom::Type::BREADTH_FIRST && is_reverse), "Breadth first expansion can't be reversed"); +} + +ACCEPT_WITH_INPUT(ExpandVariable) + +std::vector<Symbol> ExpandVariable::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + symbols.emplace_back(common_.node_symbol); + symbols.emplace_back(common_.edge_symbol); + return symbols; +} + +namespace { + +/** + * Helper function that returns an iterable over + * <EdgeAtom::Direction, EdgeAccessor> pairs + * for the given params. + * + * @param vertex - The vertex to expand from. + * @param direction - Expansion direction. All directions (IN, OUT, BOTH) + * are supported. + * @param memory - Used to allocate the result. + * @return See above. + */ +auto ExpandFromVertex(const VertexAccessor &vertex, EdgeAtom::Direction direction, + const std::vector<storage::v3::EdgeTypeId> &edge_types, utils::MemoryResource *memory) { + // wraps an EdgeAccessor into a pair <accessor, direction> + auto wrapper = [](EdgeAtom::Direction direction, auto &&edges) { + return iter::imap([direction](const auto &edge) { return std::make_pair(edge, direction); }, + std::forward<decltype(edges)>(edges)); + }; + + storage::v3::View view = storage::v3::View::OLD; + utils::pmr::vector<decltype(wrapper(direction, *vertex.InEdges(view, edge_types)))> chain_elements(memory); + + if (direction != EdgeAtom::Direction::OUT) { + auto edges = UnwrapEdgesResult(vertex.InEdges(view, edge_types)); + if (edges.begin() != edges.end()) { + chain_elements.emplace_back(wrapper(EdgeAtom::Direction::IN, std::move(edges))); + } + } + if (direction != EdgeAtom::Direction::IN) { + auto edges = UnwrapEdgesResult(vertex.OutEdges(view, edge_types)); + if (edges.begin() != edges.end()) { + chain_elements.emplace_back(wrapper(EdgeAtom::Direction::OUT, std::move(edges))); + } + } + + // TODO: Investigate whether itertools perform heap allocation? + return iter::chain.from_iterable(std::move(chain_elements)); +} + +} // namespace + +class ExpandVariableCursor : public Cursor { + public: + ExpandVariableCursor(const ExpandVariable &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)), edges_(mem), edges_it_(mem) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("ExpandVariable"); + + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::OLD); + while (true) { + if (Expand(frame, context)) return true; + + if (PullInput(frame, context)) { + // if lower bound is zero we also yield empty paths + if (lower_bound_ == 0) { + auto &start_vertex = frame[self_.input_symbol_].ValueVertex(); + if (!self_.common_.existing_node) { + frame[self_.common_.node_symbol] = start_vertex; + return true; + } else if (CheckExistingNode(start_vertex, self_.common_.node_symbol, frame)) { + return true; + } + } + // if lower bound is not zero, we just continue, the next + // loop iteration will attempt to expand and we're good + } else + return false; + // else continue with the loop, try to expand again + // because we succesfully pulled from the input + } + } + + void Shutdown() override { input_cursor_->Shutdown(); } + + void Reset() override { + input_cursor_->Reset(); + edges_.clear(); + edges_it_.clear(); + } + + private: + const ExpandVariable &self_; + const UniqueCursorPtr input_cursor_; + // bounds. in the cursor they are not optional but set to + // default values if missing in the ExpandVariable operator + // initialize to arbitrary values, they should only be used + // after a successful pull from the input + int64_t upper_bound_{-1}; + int64_t lower_bound_{-1}; + + // a stack of edge iterables corresponding to the level/depth of + // the expansion currently being Pulled + using ExpandEdges = decltype(ExpandFromVertex(std::declval<VertexAccessor>(), EdgeAtom::Direction::IN, + self_.common_.edge_types, utils::NewDeleteResource())); + + utils::pmr::vector<ExpandEdges> edges_; + // an iterator indicating the position in the corresponding edges_ element + utils::pmr::vector<decltype(edges_.begin()->begin())> edges_it_; + + /** + * Helper function that Pulls from the input vertex and + * makes iteration over it's edges possible. + * + * @return If the Pull succeeded. If not, this VariableExpandCursor + * is exhausted. + */ + bool PullInput(Frame &frame, ExecutionContext &context) { + // Input Vertex could be null if it is created by a failed optional match. + // In those cases we skip that input pull and continue with the next. + while (true) { + if (MustAbort(context)) throw HintedAbortError(); + if (!input_cursor_->Pull(frame, context)) return false; + TypedValue &vertex_value = frame[self_.input_symbol_]; + + // Null check due to possible failed optional match. + if (vertex_value.IsNull()) continue; + + ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); + auto &vertex = vertex_value.ValueVertex(); + + // Evaluate the upper and lower bounds. + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::OLD); + auto calc_bound = [&evaluator](auto &bound) { + auto value = EvaluateInt(&evaluator, bound, "Variable expansion bound"); + if (value < 0) throw QueryRuntimeException("Variable expansion bound must be a non-negative integer."); + return value; + }; + + lower_bound_ = self_.lower_bound_ ? calc_bound(self_.lower_bound_) : 1; + upper_bound_ = self_.upper_bound_ ? calc_bound(self_.upper_bound_) : std::numeric_limits<int64_t>::max(); + + if (upper_bound_ > 0) { + auto *memory = edges_.get_allocator().GetMemoryResource(); + edges_.emplace_back(ExpandFromVertex(vertex, self_.common_.direction, self_.common_.edge_types, memory)); + edges_it_.emplace_back(edges_.back().begin()); + } + + // reset the frame value to an empty edge list + auto *pull_memory = context.evaluation_context.memory; + frame[self_.common_.edge_symbol] = TypedValue::TVector(pull_memory); + + return true; + } + } + + // Helper function for appending an edge to the list on the frame. + void AppendEdge(const EdgeAccessor &new_edge, utils::pmr::vector<TypedValue> *edges_on_frame) { + // We are placing an edge on the frame. It is possible that there already + // exists an edge on the frame for this level. If so first remove it. + DMG_ASSERT(edges_.size() > 0, "Edges are empty"); + if (self_.is_reverse_) { + // TODO: This is innefficient, we should look into replacing + // vector with something else for TypedValue::List. + size_t diff = edges_on_frame->size() - std::min(edges_on_frame->size(), edges_.size() - 1U); + if (diff > 0U) edges_on_frame->erase(edges_on_frame->begin(), edges_on_frame->begin() + diff); + edges_on_frame->emplace(edges_on_frame->begin(), new_edge); + } else { + edges_on_frame->resize(std::min(edges_on_frame->size(), edges_.size() - 1U)); + edges_on_frame->emplace_back(new_edge); + } + } + + /** + * Performs a single expansion for the current state of this + * VariableExpansionCursor. + * + * @return True if the expansion was a success and this Cursor's + * consumer can consume it. False if the expansion failed. In that + * case no more expansions are available from the current input + * vertex and another Pull from the input cursor should be performed. + */ + bool Expand(Frame &frame, ExecutionContext &context) { + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::OLD); + // Some expansions might not be valid due to edge uniqueness and + // existing_node criterions, so expand in a loop until either the input + // vertex is exhausted or a valid variable-length expansion is available. + while (true) { + if (MustAbort(context)) throw HintedAbortError(); + // pop from the stack while there is stuff to pop and the current + // level is exhausted + while (!edges_.empty() && edges_it_.back() == edges_.back().end()) { + edges_.pop_back(); + edges_it_.pop_back(); + } + + // check if we exhausted everything, if so return false + if (edges_.empty()) return false; + + // we use this a lot + auto &edges_on_frame = frame[self_.common_.edge_symbol].ValueList(); + + // it is possible that edges_on_frame does not contain as many + // elements as edges_ due to edge-uniqueness (when a whole layer + // gets exhausted but no edges are valid). for that reason only + // pop from edges_on_frame if they contain enough elements + if (self_.is_reverse_) { + auto diff = edges_on_frame.size() - std::min(edges_on_frame.size(), edges_.size()); + if (diff > 0) { + edges_on_frame.erase(edges_on_frame.begin(), edges_on_frame.begin() + diff); + } + } else { + edges_on_frame.resize(std::min(edges_on_frame.size(), edges_.size())); + } + + // if we are here, we have a valid stack, + // get the edge, increase the relevant iterator + auto current_edge = *edges_it_.back()++; + + // Check edge-uniqueness. + bool found_existing = + std::any_of(edges_on_frame.begin(), edges_on_frame.end(), + [¤t_edge](const TypedValue &edge) { return current_edge.first == edge.ValueEdge(); }); + if (found_existing) continue; + + AppendEdge(current_edge.first, &edges_on_frame); + VertexAccessor current_vertex = + current_edge.second == EdgeAtom::Direction::IN ? current_edge.first.From() : current_edge.first.To(); + + if (!self_.common_.existing_node) { + frame[self_.common_.node_symbol] = current_vertex; + } + + // Skip expanding out of filtered expansion. + frame[self_.filter_lambda_.inner_edge_symbol] = current_edge.first; + frame[self_.filter_lambda_.inner_node_symbol] = current_vertex; + if (self_.filter_lambda_.expression && !EvaluateFilter(evaluator, self_.filter_lambda_.expression)) continue; + + // we are doing depth-first search, so place the current + // edge's expansions onto the stack, if we should continue to expand + if (upper_bound_ > static_cast<int64_t>(edges_.size())) { + auto *memory = edges_.get_allocator().GetMemoryResource(); + edges_.emplace_back( + ExpandFromVertex(current_vertex, self_.common_.direction, self_.common_.edge_types, memory)); + edges_it_.emplace_back(edges_.back().begin()); + } + + if (self_.common_.existing_node && !CheckExistingNode(current_vertex, self_.common_.node_symbol, frame)) continue; + + // We only yield true if we satisfy the lower bound. + if (static_cast<int64_t>(edges_on_frame.size()) >= lower_bound_) + return true; + else + continue; + } + } +}; + +class STShortestPathCursor : public query::v2::plan::Cursor { + public: + STShortestPathCursor(const ExpandVariable &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self_.input()->MakeCursor(mem)) { + MG_ASSERT(self_.common_.existing_node, + "s-t shortest path algorithm should only " + "be used when `existing_node` flag is " + "set!"); + } + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("STShortestPath"); + + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::OLD); + while (input_cursor_->Pull(frame, context)) { + const auto &source_tv = frame[self_.input_symbol_]; + const auto &sink_tv = frame[self_.common_.node_symbol]; + + // It is possible that source or sink vertex is Null due to optional + // matching. + if (source_tv.IsNull() || sink_tv.IsNull()) continue; + + const auto &source = source_tv.ValueVertex(); + const auto &sink = sink_tv.ValueVertex(); + + int64_t lower_bound = + self_.lower_bound_ ? EvaluateInt(&evaluator, self_.lower_bound_, "Min depth in breadth-first expansion") : 1; + int64_t upper_bound = self_.upper_bound_ + ? EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in breadth-first expansion") + : std::numeric_limits<int64_t>::max(); + + if (upper_bound < 1 || lower_bound > upper_bound) continue; + + if (FindPath(*context.db_accessor, source, sink, lower_bound, upper_bound, &frame, &evaluator, context)) { + return true; + } + } + return false; + } + + void Shutdown() override { input_cursor_->Shutdown(); } + + void Reset() override { input_cursor_->Reset(); } + + private: + const ExpandVariable &self_; + UniqueCursorPtr input_cursor_; + + using VertexEdgeMapT = utils::pmr::unordered_map<VertexAccessor, std::optional<EdgeAccessor>>; + + void ReconstructPath(const VertexAccessor &midpoint, const VertexEdgeMapT &in_edge, const VertexEdgeMapT &out_edge, + Frame *frame, utils::MemoryResource *pull_memory) { + utils::pmr::vector<TypedValue> result(pull_memory); + auto last_vertex = midpoint; + while (true) { + const auto &last_edge = in_edge.at(last_vertex); + if (!last_edge) break; + last_vertex = last_edge->From() == last_vertex ? last_edge->To() : last_edge->From(); + result.emplace_back(*last_edge); + } + std::reverse(result.begin(), result.end()); + last_vertex = midpoint; + while (true) { + const auto &last_edge = out_edge.at(last_vertex); + if (!last_edge) break; + last_vertex = last_edge->From() == last_vertex ? last_edge->To() : last_edge->From(); + result.emplace_back(*last_edge); + } + frame->at(self_.common_.edge_symbol) = std::move(result); + } + + bool ShouldExpand(const VertexAccessor &vertex, const EdgeAccessor &edge, Frame *frame, + ExpressionEvaluator *evaluator) { + if (!self_.filter_lambda_.expression) return true; + + frame->at(self_.filter_lambda_.inner_node_symbol) = vertex; + frame->at(self_.filter_lambda_.inner_edge_symbol) = edge; + + TypedValue result = self_.filter_lambda_.expression->Accept(*evaluator); + if (result.IsNull()) return false; + if (result.IsBool()) return result.ValueBool(); + + throw QueryRuntimeException("Expansion condition must evaluate to boolean or null"); + } + + bool FindPath(const DbAccessor &dba, const VertexAccessor &source, const VertexAccessor &sink, int64_t lower_bound, + int64_t upper_bound, Frame *frame, ExpressionEvaluator *evaluator, const ExecutionContext &context) { + using utils::Contains; + + if (source == sink) return false; + + // We expand from both directions, both from the source and the sink. + // Expansions meet at the middle of the path if it exists. This should + // perform better for real-world like graphs where the expansion front + // grows exponentially, effectively reducing the exponent by half. + + auto *pull_memory = evaluator->GetMemoryResource(); + // Holds vertices at the current level of expansion from the source + // (sink). + utils::pmr::vector<VertexAccessor> source_frontier(pull_memory); + utils::pmr::vector<VertexAccessor> sink_frontier(pull_memory); + + // Holds vertices we can expand to from `source_frontier` + // (`sink_frontier`). + utils::pmr::vector<VertexAccessor> source_next(pull_memory); + utils::pmr::vector<VertexAccessor> sink_next(pull_memory); + + // Maps each vertex we visited expanding from the source (sink) to the + // edge used. Necessary for path reconstruction. + VertexEdgeMapT in_edge(pull_memory); + VertexEdgeMapT out_edge(pull_memory); + + size_t current_length = 0; + + source_frontier.emplace_back(source); + in_edge[source] = std::nullopt; + sink_frontier.emplace_back(sink); + out_edge[sink] = std::nullopt; + + while (true) { + if (MustAbort(context)) throw HintedAbortError(); + // Top-down step (expansion from the source). + ++current_length; + if (current_length > upper_bound) return false; + + for (const auto &vertex : source_frontier) { + if (self_.common_.direction != EdgeAtom::Direction::IN) { + auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::v3::View::OLD, self_.common_.edge_types)); + for (const auto &edge : out_edges) { + if (ShouldExpand(edge.To(), edge, frame, evaluator) && !Contains(in_edge, edge.To())) { + in_edge.emplace(edge.To(), edge); + if (Contains(out_edge, edge.To())) { + if (current_length >= lower_bound) { + ReconstructPath(edge.To(), in_edge, out_edge, frame, pull_memory); + return true; + } else { + return false; + } + } + source_next.push_back(edge.To()); + } + } + } + if (self_.common_.direction != EdgeAtom::Direction::OUT) { + auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::v3::View::OLD, self_.common_.edge_types)); + for (const auto &edge : in_edges) { + if (ShouldExpand(edge.From(), edge, frame, evaluator) && !Contains(in_edge, edge.From())) { + in_edge.emplace(edge.From(), edge); + if (Contains(out_edge, edge.From())) { + if (current_length >= lower_bound) { + ReconstructPath(edge.From(), in_edge, out_edge, frame, pull_memory); + return true; + } else { + return false; + } + } + source_next.push_back(edge.From()); + } + } + } + } + + if (source_next.empty()) return false; + source_frontier.clear(); + std::swap(source_frontier, source_next); + + // Bottom-up step (expansion from the sink). + ++current_length; + if (current_length > upper_bound) return false; + + // When expanding from the sink we have to be careful which edge + // endpoint we pass to `should_expand`, because everything is + // reversed. + for (const auto &vertex : sink_frontier) { + if (self_.common_.direction != EdgeAtom::Direction::OUT) { + auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::v3::View::OLD, self_.common_.edge_types)); + for (const auto &edge : out_edges) { + if (ShouldExpand(vertex, edge, frame, evaluator) && !Contains(out_edge, edge.To())) { + out_edge.emplace(edge.To(), edge); + if (Contains(in_edge, edge.To())) { + if (current_length >= lower_bound) { + ReconstructPath(edge.To(), in_edge, out_edge, frame, pull_memory); + return true; + } else { + return false; + } + } + sink_next.push_back(edge.To()); + } + } + } + if (self_.common_.direction != EdgeAtom::Direction::IN) { + auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::v3::View::OLD, self_.common_.edge_types)); + for (const auto &edge : in_edges) { + if (ShouldExpand(vertex, edge, frame, evaluator) && !Contains(out_edge, edge.From())) { + out_edge.emplace(edge.From(), edge); + if (Contains(in_edge, edge.From())) { + if (current_length >= lower_bound) { + ReconstructPath(edge.From(), in_edge, out_edge, frame, pull_memory); + return true; + } else { + return false; + } + } + sink_next.push_back(edge.From()); + } + } + } + } + + if (sink_next.empty()) return false; + sink_frontier.clear(); + std::swap(sink_frontier, sink_next); + } + } +}; + +class SingleSourceShortestPathCursor : public query::v2::plan::Cursor { + public: + SingleSourceShortestPathCursor(const ExpandVariable &self, utils::MemoryResource *mem) + : self_(self), + input_cursor_(self_.input()->MakeCursor(mem)), + processed_(mem), + to_visit_current_(mem), + to_visit_next_(mem) { + MG_ASSERT(!self_.common_.existing_node, + "Single source shortest path algorithm " + "should not be used when `existing_node` " + "flag is set, s-t shortest path algorithm " + "should be used instead!"); + } + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("SingleSourceShortestPath"); + + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::OLD); + + // for the given (edge, vertex) pair checks if they satisfy the + // "where" condition. if so, places them in the to_visit_ structure. + auto expand_pair = [this, &evaluator, &frame](EdgeAccessor edge, VertexAccessor vertex) { + // if we already processed the given vertex it doesn't get expanded + if (processed_.find(vertex) != processed_.end()) return; + + frame[self_.filter_lambda_.inner_edge_symbol] = edge; + frame[self_.filter_lambda_.inner_node_symbol] = vertex; + + if (self_.filter_lambda_.expression) { + TypedValue result = self_.filter_lambda_.expression->Accept(evaluator); + switch (result.type()) { + case TypedValue::Type::Null: + return; + case TypedValue::Type::Bool: + if (!result.ValueBool()) return; + break; + default: + throw QueryRuntimeException("Expansion condition must evaluate to boolean or null."); + } + } + to_visit_next_.emplace_back(edge, vertex); + processed_.emplace(vertex, edge); + }; + + // populates the to_visit_next_ structure with expansions + // from the given vertex. skips expansions that don't satisfy + // the "where" condition. + auto expand_from_vertex = [this, &expand_pair](const auto &vertex) { + if (self_.common_.direction != EdgeAtom::Direction::IN) { + auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::v3::View::OLD, self_.common_.edge_types)); + for (const auto &edge : out_edges) expand_pair(edge, edge.To()); + } + if (self_.common_.direction != EdgeAtom::Direction::OUT) { + auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::v3::View::OLD, self_.common_.edge_types)); + for (const auto &edge : in_edges) expand_pair(edge, edge.From()); + } + }; + + // do it all in a loop because we skip some elements + while (true) { + if (MustAbort(context)) throw HintedAbortError(); + // if we have nothing to visit on the current depth, switch to next + if (to_visit_current_.empty()) to_visit_current_.swap(to_visit_next_); + + // if current is still empty, it means both are empty, so pull from + // input + if (to_visit_current_.empty()) { + if (!input_cursor_->Pull(frame, context)) return false; + + to_visit_current_.clear(); + to_visit_next_.clear(); + processed_.clear(); + + const auto &vertex_value = frame[self_.input_symbol_]; + // it is possible that the vertex is Null due to optional matching + if (vertex_value.IsNull()) continue; + lower_bound_ = self_.lower_bound_ + ? EvaluateInt(&evaluator, self_.lower_bound_, "Min depth in breadth-first expansion") + : 1; + upper_bound_ = self_.upper_bound_ + ? EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in breadth-first expansion") + : std::numeric_limits<int64_t>::max(); + + if (upper_bound_ < 1 || lower_bound_ > upper_bound_) continue; + + const auto &vertex = vertex_value.ValueVertex(); + processed_.emplace(vertex, std::nullopt); + expand_from_vertex(vertex); + + // go back to loop start and see if we expanded anything + continue; + } + + // take the next expansion from the queue + auto expansion = to_visit_current_.back(); + to_visit_current_.pop_back(); + + // create the frame value for the edges + auto *pull_memory = context.evaluation_context.memory; + utils::pmr::vector<TypedValue> edge_list(pull_memory); + edge_list.emplace_back(expansion.first); + auto last_vertex = expansion.second; + while (true) { + const EdgeAccessor &last_edge = edge_list.back().ValueEdge(); + last_vertex = last_edge.From() == last_vertex ? last_edge.To() : last_edge.From(); + // origin_vertex must be in processed + const auto &previous_edge = processed_.find(last_vertex)->second; + if (!previous_edge) break; + + edge_list.emplace_back(previous_edge.value()); + } + + // expand only if what we've just expanded is less then max depth + if (static_cast<int64_t>(edge_list.size()) < upper_bound_) expand_from_vertex(expansion.second); + + if (static_cast<int64_t>(edge_list.size()) < lower_bound_) continue; + + frame[self_.common_.node_symbol] = expansion.second; + + // place edges on the frame in the correct order + std::reverse(edge_list.begin(), edge_list.end()); + frame[self_.common_.edge_symbol] = std::move(edge_list); + + return true; + } + } + + void Shutdown() override { input_cursor_->Shutdown(); } + + void Reset() override { + input_cursor_->Reset(); + processed_.clear(); + to_visit_next_.clear(); + to_visit_current_.clear(); + } + + private: + const ExpandVariable &self_; + const UniqueCursorPtr input_cursor_; + + // Depth bounds. Calculated on each pull from the input, the initial value + // is irrelevant. + int64_t lower_bound_{-1}; + int64_t upper_bound_{-1}; + + // maps vertices to the edge they got expanded from. it is an optional + // edge because the root does not get expanded from anything. + // contains visited vertices as well as those scheduled to be visited. + utils::pmr::unordered_map<VertexAccessor, std::optional<EdgeAccessor>> processed_; + // edge/vertex pairs we have yet to visit, for current and next depth + utils::pmr::vector<std::pair<EdgeAccessor, VertexAccessor>> to_visit_current_; + utils::pmr::vector<std::pair<EdgeAccessor, VertexAccessor>> to_visit_next_; +}; + +class ExpandWeightedShortestPathCursor : public query::v2::plan::Cursor { + public: + ExpandWeightedShortestPathCursor(const ExpandVariable &self, utils::MemoryResource *mem) + : self_(self), + input_cursor_(self_.input_->MakeCursor(mem)), + total_cost_(mem), + previous_(mem), + yielded_vertices_(mem), + pq_(mem) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("ExpandWeightedShortestPath"); + + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::OLD); + auto create_state = [this](const VertexAccessor &vertex, int64_t depth) { + return std::make_pair(vertex, upper_bound_set_ ? depth : 0); + }; + + // For the given (edge, vertex, weight, depth) tuple checks if they + // satisfy the "where" condition. if so, places them in the priority + // queue. + auto expand_pair = [this, &evaluator, &frame, &create_state](const EdgeAccessor &edge, const VertexAccessor &vertex, + const TypedValue &total_weight, int64_t depth) { + auto *memory = evaluator.GetMemoryResource(); + if (self_.filter_lambda_.expression) { + frame[self_.filter_lambda_.inner_edge_symbol] = edge; + frame[self_.filter_lambda_.inner_node_symbol] = vertex; + + if (!EvaluateFilter(evaluator, self_.filter_lambda_.expression)) return; + } + + frame[self_.weight_lambda_->inner_edge_symbol] = edge; + frame[self_.weight_lambda_->inner_node_symbol] = vertex; + + TypedValue current_weight = self_.weight_lambda_->expression->Accept(evaluator); + + if (!current_weight.IsNumeric() && !current_weight.IsDuration()) { + throw QueryRuntimeException("Calculated weight must be numeric or a Duration, got {}.", current_weight.type()); + } + + const auto is_valid_numeric = [&] { + return current_weight.IsNumeric() && (current_weight >= TypedValue(0, memory)).ValueBool(); + }; + + const auto is_valid_duration = [&] { + return current_weight.IsDuration() && (current_weight >= TypedValue(utils::Duration(0), memory)).ValueBool(); + }; + + if (!is_valid_numeric() && !is_valid_duration()) { + throw QueryRuntimeException("Calculated weight must be non-negative!"); + } + + auto next_state = create_state(vertex, depth); + + TypedValue next_weight = std::invoke([&] { + if (total_weight.IsNull()) { + return current_weight; + } + + ValidateWeightTypes(current_weight, total_weight); + + return TypedValue(current_weight, memory) + total_weight; + }); + + auto found_it = total_cost_.find(next_state); + if (found_it != total_cost_.end() && (found_it->second.IsNull() || (found_it->second <= next_weight).ValueBool())) + return; + + pq_.push({next_weight, depth + 1, vertex, edge}); + }; + + // Populates the priority queue structure with expansions + // from the given vertex. skips expansions that don't satisfy + // the "where" condition. + auto expand_from_vertex = [this, &expand_pair](const VertexAccessor &vertex, const TypedValue &weight, + int64_t depth) { + if (self_.common_.direction != EdgeAtom::Direction::IN) { + auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::v3::View::OLD, self_.common_.edge_types)); + for (const auto &edge : out_edges) { + expand_pair(edge, edge.To(), weight, depth); + } + } + if (self_.common_.direction != EdgeAtom::Direction::OUT) { + auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::v3::View::OLD, self_.common_.edge_types)); + for (const auto &edge : in_edges) { + expand_pair(edge, edge.From(), weight, depth); + } + } + }; + + while (true) { + if (MustAbort(context)) throw HintedAbortError(); + if (pq_.empty()) { + if (!input_cursor_->Pull(frame, context)) return false; + const auto &vertex_value = frame[self_.input_symbol_]; + if (vertex_value.IsNull()) continue; + auto vertex = vertex_value.ValueVertex(); + if (self_.common_.existing_node) { + const auto &node = frame[self_.common_.node_symbol]; + // Due to optional matching the existing node could be null. + // Skip expansion for such nodes. + if (node.IsNull()) continue; + } + if (self_.upper_bound_) { + upper_bound_ = EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in weighted shortest path expansion"); + upper_bound_set_ = true; + } else { + upper_bound_ = std::numeric_limits<int64_t>::max(); + upper_bound_set_ = false; + } + if (upper_bound_ < 1) + throw QueryRuntimeException( + "Maximum depth in weighted shortest path expansion must be at " + "least 1."); + + // Clear existing data structures. + previous_.clear(); + total_cost_.clear(); + yielded_vertices_.clear(); + + pq_.push({TypedValue(), 0, vertex, std::nullopt}); + // We are adding the starting vertex to the set of yielded vertices + // because we don't want to yield paths that end with the starting + // vertex. + yielded_vertices_.insert(vertex); + } + + while (!pq_.empty()) { + if (MustAbort(context)) throw HintedAbortError(); + auto [current_weight, current_depth, current_vertex, current_edge] = pq_.top(); + pq_.pop(); + + auto current_state = create_state(current_vertex, current_depth); + + // Check if the vertex has already been processed. + if (total_cost_.find(current_state) != total_cost_.end()) { + continue; + } + previous_.emplace(current_state, current_edge); + total_cost_.emplace(current_state, current_weight); + + // Expand only if what we've just expanded is less than max depth. + if (current_depth < upper_bound_) expand_from_vertex(current_vertex, current_weight, current_depth); + + // If we yielded a path for a vertex already, make the expansion but + // don't return the path again. + if (yielded_vertices_.find(current_vertex) != yielded_vertices_.end()) continue; + + // Reconstruct the path. + auto last_vertex = current_vertex; + auto last_depth = current_depth; + auto *pull_memory = context.evaluation_context.memory; + utils::pmr::vector<TypedValue> edge_list(pull_memory); + while (true) { + // Origin_vertex must be in previous. + const auto &previous_edge = previous_.find(create_state(last_vertex, last_depth))->second; + if (!previous_edge) break; + last_vertex = previous_edge->From() == last_vertex ? previous_edge->To() : previous_edge->From(); + last_depth--; + edge_list.emplace_back(previous_edge.value()); + } + + // Place destination node on the frame, handle existence flag. + if (self_.common_.existing_node) { + const auto &node = frame[self_.common_.node_symbol]; + if ((node != TypedValue(current_vertex, pull_memory)).ValueBool()) + continue; + else + // Prevent expanding other paths, because we found the + // shortest to existing node. + ClearQueue(); + } else { + frame[self_.common_.node_symbol] = current_vertex; + } + + if (!self_.is_reverse_) { + // Place edges on the frame in the correct order. + std::reverse(edge_list.begin(), edge_list.end()); + } + frame[self_.common_.edge_symbol] = std::move(edge_list); + frame[self_.total_weight_.value()] = current_weight; + yielded_vertices_.insert(current_vertex); + return true; + } + } + } + + void Shutdown() override { input_cursor_->Shutdown(); } + + void Reset() override { + input_cursor_->Reset(); + previous_.clear(); + total_cost_.clear(); + yielded_vertices_.clear(); + ClearQueue(); + } + + private: + const ExpandVariable &self_; + const UniqueCursorPtr input_cursor_; + + // Upper bound on the path length. + int64_t upper_bound_{-1}; + bool upper_bound_set_{false}; + + struct WspStateHash { + size_t operator()(const std::pair<VertexAccessor, int64_t> &key) const { + return utils::HashCombine<VertexAccessor, int64_t>{}(key.first, key.second); + } + }; + + // Maps vertices to weights they got in expansion. + utils::pmr::unordered_map<std::pair<VertexAccessor, int64_t>, TypedValue, WspStateHash> total_cost_; + + // Maps vertices to edges used to reach them. + utils::pmr::unordered_map<std::pair<VertexAccessor, int64_t>, std::optional<EdgeAccessor>, WspStateHash> previous_; + + // Keeps track of vertices for which we yielded a path already. + utils::pmr::unordered_set<VertexAccessor> yielded_vertices_; + + static void ValidateWeightTypes(const TypedValue &lhs, const TypedValue &rhs) { + if (!((lhs.IsNumeric() && lhs.IsNumeric()) || (rhs.IsDuration() && rhs.IsDuration()))) { + throw QueryRuntimeException(utils::MessageWithLink( + "All weights should be of the same type, either numeric or a Duration. Please update the weight " + "expression or the filter expression.", + "https://memgr.ph/wsp")); + } + } + + // Priority queue comparator. Keep lowest weight on top of the queue. + class PriorityQueueComparator { + public: + bool operator()(const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &lhs, + const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &rhs) { + const auto &lhs_weight = std::get<0>(lhs); + const auto &rhs_weight = std::get<0>(rhs); + // Null defines minimum value for all types + if (lhs_weight.IsNull()) { + return false; + } + + if (rhs_weight.IsNull()) { + return true; + } + + ValidateWeightTypes(lhs_weight, rhs_weight); + return (lhs_weight > rhs_weight).ValueBool(); + } + }; + + std::priority_queue<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>>, + utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>>>, + PriorityQueueComparator> + pq_; + + void ClearQueue() { + while (!pq_.empty()) pq_.pop(); + } +}; + +UniqueCursorPtr ExpandVariable::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ExpandVariableOperator); + + switch (type_) { + case EdgeAtom::Type::BREADTH_FIRST: + if (common_.existing_node) { + return MakeUniqueCursorPtr<STShortestPathCursor>(mem, *this, mem); + } else { + return MakeUniqueCursorPtr<SingleSourceShortestPathCursor>(mem, *this, mem); + } + case EdgeAtom::Type::DEPTH_FIRST: + return MakeUniqueCursorPtr<ExpandVariableCursor>(mem, *this, mem); + case EdgeAtom::Type::WEIGHTED_SHORTEST_PATH: + return MakeUniqueCursorPtr<ExpandWeightedShortestPathCursor>(mem, *this, mem); + case EdgeAtom::Type::SINGLE: + LOG_FATAL("ExpandVariable should not be planned for a single expansion!"); + } +} + +class ConstructNamedPathCursor : public Cursor { + public: + ConstructNamedPathCursor(const ConstructNamedPath &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self_.input()->MakeCursor(mem)) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("ConstructNamedPath"); + + if (!input_cursor_->Pull(frame, context)) return false; + + auto symbol_it = self_.path_elements_.begin(); + DMG_ASSERT(symbol_it != self_.path_elements_.end(), "Named path must contain at least one node"); + + const auto &start_vertex = frame[*symbol_it++]; + auto *pull_memory = context.evaluation_context.memory; + // In an OPTIONAL MATCH everything could be Null. + if (start_vertex.IsNull()) { + frame[self_.path_symbol_] = TypedValue(pull_memory); + return true; + } + + DMG_ASSERT(start_vertex.IsVertex(), "First named path element must be a vertex"); + query::v2::Path path(start_vertex.ValueVertex(), pull_memory); + + // If the last path element symbol was for an edge list, then + // the next symbol is a vertex and it should not append to the path + // because + // expansion already did it. + bool last_was_edge_list = false; + + for (; symbol_it != self_.path_elements_.end(); symbol_it++) { + const auto &expansion = frame[*symbol_it]; + // We can have Null (OPTIONAL MATCH), a vertex, an edge, or an edge + // list (variable expand or BFS). + switch (expansion.type()) { + case TypedValue::Type::Null: + frame[self_.path_symbol_] = TypedValue(pull_memory); + return true; + case TypedValue::Type::Vertex: + if (!last_was_edge_list) path.Expand(expansion.ValueVertex()); + last_was_edge_list = false; + break; + case TypedValue::Type::Edge: + path.Expand(expansion.ValueEdge()); + break; + case TypedValue::Type::List: { + last_was_edge_list = true; + // We need to expand all edges in the list and intermediary + // vertices. + const auto &edges = expansion.ValueList(); + for (const auto &edge_value : edges) { + const auto &edge = edge_value.ValueEdge(); + const auto &from = edge.From(); + if (path.vertices().back() == from) + path.Expand(edge, edge.To()); + else + path.Expand(edge, from); + } + break; + } + default: + LOG_FATAL("Unsupported type in named path construction"); + + break; + } + } + + frame[self_.path_symbol_] = path; + return true; + } + + void Shutdown() override { input_cursor_->Shutdown(); } + + void Reset() override { input_cursor_->Reset(); } + + private: + const ConstructNamedPath self_; + const UniqueCursorPtr input_cursor_; +}; + +ACCEPT_WITH_INPUT(ConstructNamedPath) + +UniqueCursorPtr ConstructNamedPath::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ConstructNamedPathOperator); + + return MakeUniqueCursorPtr<ConstructNamedPathCursor>(mem, *this, mem); +} + +std::vector<Symbol> ConstructNamedPath::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + symbols.emplace_back(path_symbol_); + return symbols; +} + +Filter::Filter(const std::shared_ptr<LogicalOperator> &input, Expression *expression) + : input_(input ? input : std::make_shared<Once>()), expression_(expression) {} + +ACCEPT_WITH_INPUT(Filter) + +UniqueCursorPtr Filter::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::FilterOperator); + + return MakeUniqueCursorPtr<FilterCursor>(mem, *this, mem); +} + +std::vector<Symbol> Filter::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } + +Filter::FilterCursor::FilterCursor(const Filter &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {} + +bool Filter::FilterCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("Filter"); + + // Like all filters, newly set values should not affect filtering of old + // nodes and edges. + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::OLD); + while (input_cursor_->Pull(frame, context)) { + if (EvaluateFilter(evaluator, self_.expression_)) return true; + } + return false; +} + +void Filter::FilterCursor::Shutdown() { input_cursor_->Shutdown(); } + +void Filter::FilterCursor::Reset() { input_cursor_->Reset(); } + +Produce::Produce(const std::shared_ptr<LogicalOperator> &input, const std::vector<NamedExpression *> &named_expressions) + : input_(input ? input : std::make_shared<Once>()), named_expressions_(named_expressions) {} + +ACCEPT_WITH_INPUT(Produce) + +UniqueCursorPtr Produce::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ProduceOperator); + + return MakeUniqueCursorPtr<ProduceCursor>(mem, *this, mem); +} + +std::vector<Symbol> Produce::OutputSymbols(const SymbolTable &symbol_table) const { + std::vector<Symbol> symbols; + for (const auto &named_expr : named_expressions_) { + symbols.emplace_back(symbol_table.at(*named_expr)); + } + return symbols; +} + +std::vector<Symbol> Produce::ModifiedSymbols(const SymbolTable &table) const { return OutputSymbols(table); } + +Produce::ProduceCursor::ProduceCursor(const Produce &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {} + +bool Produce::ProduceCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("Produce"); + + if (input_cursor_->Pull(frame, context)) { + // Produce should always yield the latest results. + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::NEW); + for (auto named_expr : self_.named_expressions_) named_expr->Accept(evaluator); + + return true; + } + return false; +} + +void Produce::ProduceCursor::Shutdown() { input_cursor_->Shutdown(); } + +void Produce::ProduceCursor::Reset() { input_cursor_->Reset(); } + +Delete::Delete(const std::shared_ptr<LogicalOperator> &input_, const std::vector<Expression *> &expressions, + bool detach_) + : input_(input_), expressions_(expressions), detach_(detach_) {} + +ACCEPT_WITH_INPUT(Delete) + +UniqueCursorPtr Delete::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::DeleteOperator); + + return MakeUniqueCursorPtr<DeleteCursor>(mem, *this, mem); +} + +std::vector<Symbol> Delete::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } + +Delete::DeleteCursor::DeleteCursor(const Delete &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {} + +bool Delete::DeleteCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("Delete"); + + if (!input_cursor_->Pull(frame, context)) return false; + + // Delete should get the latest information, this way it is also possible + // to delete newly added nodes and edges. + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::NEW); + auto *pull_memory = context.evaluation_context.memory; + // collect expressions results so edges can get deleted before vertices + // this is necessary because an edge that gets deleted could block vertex + // deletion + utils::pmr::vector<TypedValue> expression_results(pull_memory); + expression_results.reserve(self_.expressions_.size()); + for (Expression *expression : self_.expressions_) { + expression_results.emplace_back(expression->Accept(evaluator)); + } + + auto &dba = *context.db_accessor; + // delete edges first + for (TypedValue &expression_result : expression_results) { + if (MustAbort(context)) throw HintedAbortError(); + if (expression_result.type() == TypedValue::Type::Edge) { + auto maybe_value = dba.RemoveEdge(&expression_result.ValueEdge()); + if (maybe_value.HasError()) { + switch (maybe_value.GetError()) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when deleting an edge."); + } + } + context.execution_stats[ExecutionStats::Key::DELETED_EDGES] += 1; + if (context.trigger_context_collector && maybe_value.GetValue()) { + context.trigger_context_collector->RegisterDeletedObject(*maybe_value.GetValue()); + } + } + } + + // delete vertices + for (TypedValue &expression_result : expression_results) { + if (MustAbort(context)) throw HintedAbortError(); + switch (expression_result.type()) { + case TypedValue::Type::Vertex: { + auto &va = expression_result.ValueVertex(); + if (self_.detach_) { + auto res = dba.DetachRemoveVertex(&va); + if (res.HasError()) { + switch (res.GetError()) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when deleting a node."); + } + } + + context.execution_stats[ExecutionStats::Key::DELETED_NODES] += 1; + if (*res) { + context.execution_stats[ExecutionStats::Key::DELETED_EDGES] += static_cast<int64_t>((*res)->second.size()); + } + std::invoke([&] { + if (!context.trigger_context_collector || !*res) { + return; + } + + context.trigger_context_collector->RegisterDeletedObject((*res)->first); + if (!context.trigger_context_collector->ShouldRegisterDeletedObject<query::v2::EdgeAccessor>()) { + return; + } + for (const auto &edge : (*res)->second) { + context.trigger_context_collector->RegisterDeletedObject(edge); + } + }); + } else { + auto res = dba.RemoveVertex(&va); + if (res.HasError()) { + switch (res.GetError()) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::VERTEX_HAS_EDGES: + throw RemoveAttachedVertexException(); + case storage::v3::Error::DELETED_OBJECT: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when deleting a node."); + } + } + context.execution_stats[ExecutionStats::Key::DELETED_NODES] += 1; + if (context.trigger_context_collector && res.GetValue()) { + context.trigger_context_collector->RegisterDeletedObject(*res.GetValue()); + } + } + break; + } + + // skip Edges (already deleted) and Nulls (can occur in optional + // match) + case TypedValue::Type::Edge: + case TypedValue::Type::Null: + break; + // check we're not trying to delete anything except vertices and edges + default: + throw QueryRuntimeException("Only edges and vertices can be deleted."); + } + } + + return true; +} + +void Delete::DeleteCursor::Shutdown() { input_cursor_->Shutdown(); } + +void Delete::DeleteCursor::Reset() { input_cursor_->Reset(); } + +SetProperty::SetProperty(const std::shared_ptr<LogicalOperator> &input, storage::v3::PropertyId property, + PropertyLookup *lhs, Expression *rhs) + : input_(input), property_(property), lhs_(lhs), rhs_(rhs) {} + +ACCEPT_WITH_INPUT(SetProperty) + +UniqueCursorPtr SetProperty::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::SetPropertyOperator); + + return MakeUniqueCursorPtr<SetPropertyCursor>(mem, *this, mem); +} + +std::vector<Symbol> SetProperty::ModifiedSymbols(const SymbolTable &table) const { + return input_->ModifiedSymbols(table); +} + +SetProperty::SetPropertyCursor::SetPropertyCursor(const SetProperty &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} + +bool SetProperty::SetPropertyCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("SetProperty"); + + if (!input_cursor_->Pull(frame, context)) return false; + + // Set, just like Create needs to see the latest changes. + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::NEW); + TypedValue lhs = self_.lhs_->expression_->Accept(evaluator); + TypedValue rhs = self_.rhs_->Accept(evaluator); + + switch (lhs.type()) { + case TypedValue::Type::Vertex: { + auto old_value = PropsSetChecked(&lhs.ValueVertex(), self_.property_, rhs); + context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1; + if (context.trigger_context_collector) { + // rhs cannot be moved because it was created with the allocator that is only valid during current pull + context.trigger_context_collector->RegisterSetObjectProperty(lhs.ValueVertex(), self_.property_, + TypedValue{std::move(old_value)}, TypedValue{rhs}); + } + break; + } + case TypedValue::Type::Edge: { + auto old_value = PropsSetChecked(&lhs.ValueEdge(), self_.property_, rhs); + context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1; + if (context.trigger_context_collector) { + // rhs cannot be moved because it was created with the allocator that is only valid during current pull + context.trigger_context_collector->RegisterSetObjectProperty(lhs.ValueEdge(), self_.property_, + TypedValue{std::move(old_value)}, TypedValue{rhs}); + } + break; + } + case TypedValue::Type::Null: + // Skip setting properties on Null (can occur in optional match). + break; + case TypedValue::Type::Map: + // Semantically modifying a map makes sense, but it's not supported due + // to all the copying we do (when PropertyValue -> TypedValue and in + // ExpressionEvaluator). So even though we set a map property here, that + // is never visible to the user and it's not stored. + // TODO: fix above described bug + default: + throw QueryRuntimeException("Properties can only be set on edges and vertices."); + } + return true; +} + +void SetProperty::SetPropertyCursor::Shutdown() { input_cursor_->Shutdown(); } + +void SetProperty::SetPropertyCursor::Reset() { input_cursor_->Reset(); } + +SetProperties::SetProperties(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Expression *rhs, Op op) + : input_(input), input_symbol_(input_symbol), rhs_(rhs), op_(op) {} + +ACCEPT_WITH_INPUT(SetProperties) + +UniqueCursorPtr SetProperties::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::SetPropertiesOperator); + + return MakeUniqueCursorPtr<SetPropertiesCursor>(mem, *this, mem); +} + +std::vector<Symbol> SetProperties::ModifiedSymbols(const SymbolTable &table) const { + return input_->ModifiedSymbols(table); +} + +SetProperties::SetPropertiesCursor::SetPropertiesCursor(const SetProperties &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} + +namespace { + +template <typename T> +concept AccessorWithProperties = requires(T value, storage::v3::PropertyId property_id, + storage::v3::PropertyValue property_value) { + { + value.ClearProperties() + } -> std::same_as<storage::v3::Result<std::map<storage::v3::PropertyId, storage::v3::PropertyValue>>>; + {value.SetProperty(property_id, property_value)}; +}; + +/// Helper function that sets the given values on either a Vertex or an Edge. +/// +/// @tparam TRecordAccessor Either RecordAccessor<Vertex> or +/// RecordAccessor<Edge> +template <AccessorWithProperties TRecordAccessor> +void SetPropertiesOnRecord(TRecordAccessor *record, const TypedValue &rhs, SetProperties::Op op, + ExecutionContext *context) { + std::optional<std::map<storage::v3::PropertyId, storage::v3::PropertyValue>> old_values; + const bool should_register_change = + context->trigger_context_collector && + context->trigger_context_collector->ShouldRegisterObjectPropertyChange<TRecordAccessor>(); + if (op == SetProperties::Op::REPLACE) { + auto maybe_value = record->ClearProperties(); + if (maybe_value.HasError()) { + switch (maybe_value.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set properties on a deleted graph element."); + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Can't set property because properties on edges are disabled."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting properties."); + } + } + + if (should_register_change) { + old_values.emplace(std::move(*maybe_value)); + } + } + + auto get_props = [](const auto &record) { + auto maybe_props = record.Properties(storage::v3::View::NEW); + if (maybe_props.HasError()) { + switch (maybe_props.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to get properties from a deleted object."); + case storage::v3::Error::NONEXISTENT_OBJECT: + throw query::v2::QueryRuntimeException("Trying to get properties from an object that doesn't exist."); + case storage::v3::Error::SERIALIZATION_ERROR: + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Unexpected error when getting properties."); + } + } + return *maybe_props; + }; + + auto register_set_property = [&](auto &&returned_old_value, auto key, auto &&new_value) { + auto old_value = [&]() -> storage::v3::PropertyValue { + if (!old_values) { + return std::forward<decltype(returned_old_value)>(returned_old_value); + } + + if (auto it = old_values->find(key); it != old_values->end()) { + return std::move(it->second); + } + + return {}; + }(); + + context->trigger_context_collector->RegisterSetObjectProperty( + *record, key, TypedValue(std::move(old_value)), TypedValue(std::forward<decltype(new_value)>(new_value))); + }; + + auto set_props = [&, record](auto properties) { + for (auto &kv : properties) { + auto maybe_error = record->SetProperty(kv.first, kv.second); + if (maybe_error.HasError()) { + switch (maybe_error.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set properties on a deleted graph element."); + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Can't set property because properties on edges are disabled."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting properties."); + } + } + + if (should_register_change) { + register_set_property(std::move(*maybe_error), kv.first, std::move(kv.second)); + } + } + }; + + switch (rhs.type()) { + case TypedValue::Type::Edge: + set_props(get_props(rhs.ValueEdge())); + break; + case TypedValue::Type::Vertex: + set_props(get_props(rhs.ValueVertex())); + break; + case TypedValue::Type::Map: { + for (const auto &kv : rhs.ValueMap()) { + auto key = context->db_accessor->NameToProperty(kv.first); + auto old_value = PropsSetChecked(record, key, kv.second); + if (should_register_change) { + register_set_property(std::move(old_value), key, kv.second); + } + } + break; + } + default: + throw QueryRuntimeException( + "Right-hand side in SET expression must be a node, an edge or a " + "map."); + } + + if (should_register_change && old_values) { + // register removed properties + for (auto &[property_id, property_value] : *old_values) { + context->trigger_context_collector->RegisterRemovedObjectProperty(*record, property_id, + TypedValue(std::move(property_value))); + } + } +} + +} // namespace + +bool SetProperties::SetPropertiesCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("SetProperties"); + + if (!input_cursor_->Pull(frame, context)) return false; + + TypedValue &lhs = frame[self_.input_symbol_]; + + // Set, just like Create needs to see the latest changes. + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::NEW); + TypedValue rhs = self_.rhs_->Accept(evaluator); + + switch (lhs.type()) { + case TypedValue::Type::Vertex: + SetPropertiesOnRecord(&lhs.ValueVertex(), rhs, self_.op_, &context); + break; + case TypedValue::Type::Edge: + SetPropertiesOnRecord(&lhs.ValueEdge(), rhs, self_.op_, &context); + break; + case TypedValue::Type::Null: + // Skip setting properties on Null (can occur in optional match). + break; + default: + throw QueryRuntimeException("Properties can only be set on edges and vertices."); + } + return true; +} + +void SetProperties::SetPropertiesCursor::Shutdown() { input_cursor_->Shutdown(); } + +void SetProperties::SetPropertiesCursor::Reset() { input_cursor_->Reset(); } + +SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, + const std::vector<storage::v3::LabelId> &labels) + : input_(input), input_symbol_(input_symbol), labels_(labels) {} + +ACCEPT_WITH_INPUT(SetLabels) + +UniqueCursorPtr SetLabels::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::SetLabelsOperator); + + return MakeUniqueCursorPtr<SetLabelsCursor>(mem, *this, mem); +} + +std::vector<Symbol> SetLabels::ModifiedSymbols(const SymbolTable &table) const { + return input_->ModifiedSymbols(table); +} + +SetLabels::SetLabelsCursor::SetLabelsCursor(const SetLabels &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} + +bool SetLabels::SetLabelsCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("SetLabels"); + + if (!input_cursor_->Pull(frame, context)) return false; + + TypedValue &vertex_value = frame[self_.input_symbol_]; + // Skip setting labels on Null (can occur in optional match). + if (vertex_value.IsNull()) return true; + ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); + auto &vertex = vertex_value.ValueVertex(); + for (auto label : self_.labels_) { + auto maybe_value = vertex.AddLabel(label); + if (maybe_value.HasError()) { + switch (maybe_value.GetError()) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set a label on a deleted node."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a label."); + } + } + + if (context.trigger_context_collector && *maybe_value) { + context.trigger_context_collector->RegisterSetVertexLabel(vertex, label); + } + } + + return true; +} + +void SetLabels::SetLabelsCursor::Shutdown() { input_cursor_->Shutdown(); } + +void SetLabels::SetLabelsCursor::Reset() { input_cursor_->Reset(); } + +RemoveProperty::RemoveProperty(const std::shared_ptr<LogicalOperator> &input, storage::v3::PropertyId property, + PropertyLookup *lhs) + : input_(input), property_(property), lhs_(lhs) {} + +ACCEPT_WITH_INPUT(RemoveProperty) + +UniqueCursorPtr RemoveProperty::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::RemovePropertyOperator); + + return MakeUniqueCursorPtr<RemovePropertyCursor>(mem, *this, mem); +} + +std::vector<Symbol> RemoveProperty::ModifiedSymbols(const SymbolTable &table) const { + return input_->ModifiedSymbols(table); +} + +RemoveProperty::RemovePropertyCursor::RemovePropertyCursor(const RemoveProperty &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} + +bool RemoveProperty::RemovePropertyCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("RemoveProperty"); + + if (!input_cursor_->Pull(frame, context)) return false; + + // Remove, just like Delete needs to see the latest changes. + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::NEW); + TypedValue lhs = self_.lhs_->expression_->Accept(evaluator); + + auto remove_prop = [property = self_.property_, &context](auto *record) { + auto maybe_old_value = record->RemoveProperty(property); + if (maybe_old_value.HasError()) { + switch (maybe_old_value.GetError()) { + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to remove a property on a deleted graph element."); + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException( + "Can't remove property because properties on edges are " + "disabled."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when removing property."); + } + } + + if (context.trigger_context_collector) { + context.trigger_context_collector->RegisterRemovedObjectProperty(*record, property, + TypedValue(std::move(*maybe_old_value))); + } + }; + + switch (lhs.type()) { + case TypedValue::Type::Vertex: + remove_prop(&lhs.ValueVertex()); + break; + case TypedValue::Type::Edge: + remove_prop(&lhs.ValueEdge()); + break; + case TypedValue::Type::Null: + // Skip removing properties on Null (can occur in optional match). + break; + default: + throw QueryRuntimeException("Properties can only be removed from vertices and edges."); + } + return true; +} + +void RemoveProperty::RemovePropertyCursor::Shutdown() { input_cursor_->Shutdown(); } + +void RemoveProperty::RemovePropertyCursor::Reset() { input_cursor_->Reset(); } + +RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, + const std::vector<storage::v3::LabelId> &labels) + : input_(input), input_symbol_(input_symbol), labels_(labels) {} + +ACCEPT_WITH_INPUT(RemoveLabels) + +UniqueCursorPtr RemoveLabels::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::RemoveLabelsOperator); + + return MakeUniqueCursorPtr<RemoveLabelsCursor>(mem, *this, mem); +} + +std::vector<Symbol> RemoveLabels::ModifiedSymbols(const SymbolTable &table) const { + return input_->ModifiedSymbols(table); +} + +RemoveLabels::RemoveLabelsCursor::RemoveLabelsCursor(const RemoveLabels &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} + +bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("RemoveLabels"); + + if (!input_cursor_->Pull(frame, context)) return false; + + TypedValue &vertex_value = frame[self_.input_symbol_]; + // Skip removing labels on Null (can occur in optional match). + if (vertex_value.IsNull()) return true; + ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); + auto &vertex = vertex_value.ValueVertex(); + for (auto label : self_.labels_) { + auto maybe_value = vertex.RemoveLabel(label); + if (maybe_value.HasError()) { + switch (maybe_value.GetError()) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to remove labels from a deleted node."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when removing labels from a node."); + } + } + + context.execution_stats[ExecutionStats::Key::DELETED_LABELS] += 1; + if (context.trigger_context_collector && *maybe_value) { + context.trigger_context_collector->RegisterRemovedVertexLabel(vertex, label); + } + } + + return true; +} + +void RemoveLabels::RemoveLabelsCursor::Shutdown() { input_cursor_->Shutdown(); } + +void RemoveLabels::RemoveLabelsCursor::Reset() { input_cursor_->Reset(); } + +EdgeUniquenessFilter::EdgeUniquenessFilter(const std::shared_ptr<LogicalOperator> &input, Symbol expand_symbol, + const std::vector<Symbol> &previous_symbols) + : input_(input), expand_symbol_(expand_symbol), previous_symbols_(previous_symbols) {} + +ACCEPT_WITH_INPUT(EdgeUniquenessFilter) + +UniqueCursorPtr EdgeUniquenessFilter::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::EdgeUniquenessFilterOperator); + + return MakeUniqueCursorPtr<EdgeUniquenessFilterCursor>(mem, *this, mem); +} + +std::vector<Symbol> EdgeUniquenessFilter::ModifiedSymbols(const SymbolTable &table) const { + return input_->ModifiedSymbols(table); +} + +EdgeUniquenessFilter::EdgeUniquenessFilterCursor::EdgeUniquenessFilterCursor(const EdgeUniquenessFilter &self, + utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} + +namespace { +/** + * Returns true if: + * - a and b are either edge or edge-list values, and there + * is at least one matching edge in the two values + */ +bool ContainsSameEdge(const TypedValue &a, const TypedValue &b) { + auto compare_to_list = [](const TypedValue &list, const TypedValue &other) { + for (const TypedValue &list_elem : list.ValueList()) + if (ContainsSameEdge(list_elem, other)) return true; + return false; + }; + + if (a.type() == TypedValue::Type::List) return compare_to_list(a, b); + if (b.type() == TypedValue::Type::List) return compare_to_list(b, a); + + return a.ValueEdge() == b.ValueEdge(); +} +} // namespace + +bool EdgeUniquenessFilter::EdgeUniquenessFilterCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("EdgeUniquenessFilter"); + + auto expansion_ok = [&]() { + const auto &expand_value = frame[self_.expand_symbol_]; + for (const auto &previous_symbol : self_.previous_symbols_) { + const auto &previous_value = frame[previous_symbol]; + // This shouldn't raise a TypedValueException, because the planner + // makes sure these are all of the expected type. In case they are not + // an error should be raised long before this code is executed. + if (ContainsSameEdge(previous_value, expand_value)) return false; + } + return true; + }; + + while (input_cursor_->Pull(frame, context)) + if (expansion_ok()) return true; + return false; +} + +void EdgeUniquenessFilter::EdgeUniquenessFilterCursor::Shutdown() { input_cursor_->Shutdown(); } + +void EdgeUniquenessFilter::EdgeUniquenessFilterCursor::Reset() { input_cursor_->Reset(); } + +Accumulate::Accumulate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Symbol> &symbols, + bool advance_command) + : input_(input), symbols_(symbols), advance_command_(advance_command) {} + +ACCEPT_WITH_INPUT(Accumulate) + +std::vector<Symbol> Accumulate::ModifiedSymbols(const SymbolTable &) const { return symbols_; } + +class AccumulateCursor : public Cursor { + public: + AccumulateCursor(const Accumulate &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)), cache_(mem) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("Accumulate"); + + auto &dba = *context.db_accessor; + // cache all the input + if (!pulled_all_input_) { + while (input_cursor_->Pull(frame, context)) { + utils::pmr::vector<TypedValue> row(cache_.get_allocator().GetMemoryResource()); + row.reserve(self_.symbols_.size()); + for (const Symbol &symbol : self_.symbols_) row.emplace_back(frame[symbol]); + cache_.emplace_back(std::move(row)); + } + pulled_all_input_ = true; + cache_it_ = cache_.begin(); + + if (self_.advance_command_) dba.AdvanceCommand(); + } + + if (MustAbort(context)) throw HintedAbortError(); + if (cache_it_ == cache_.end()) return false; + auto row_it = (cache_it_++)->begin(); + for (const Symbol &symbol : self_.symbols_) frame[symbol] = *row_it++; + return true; + } + + void Shutdown() override { input_cursor_->Shutdown(); } + + void Reset() override { + input_cursor_->Reset(); + cache_.clear(); + cache_it_ = cache_.begin(); + pulled_all_input_ = false; + } + + private: + const Accumulate &self_; + const UniqueCursorPtr input_cursor_; + utils::pmr::vector<utils::pmr::vector<TypedValue>> cache_; + decltype(cache_.begin()) cache_it_ = cache_.begin(); + bool pulled_all_input_{false}; +}; + +UniqueCursorPtr Accumulate::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::AccumulateOperator); + + return MakeUniqueCursorPtr<AccumulateCursor>(mem, *this, mem); +} + +Aggregate::Aggregate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Aggregate::Element> &aggregations, + const std::vector<Expression *> &group_by, const std::vector<Symbol> &remember) + : input_(input ? input : std::make_shared<Once>()), + aggregations_(aggregations), + group_by_(group_by), + remember_(remember) {} + +ACCEPT_WITH_INPUT(Aggregate) + +std::vector<Symbol> Aggregate::ModifiedSymbols(const SymbolTable &) const { + auto symbols = remember_; + for (const auto &elem : aggregations_) symbols.push_back(elem.output_sym); + return symbols; +} + +namespace { +/** Returns the default TypedValue for an Aggregation element. + * This value is valid both for returning when where are no inputs + * to the aggregation op, and for initializing an aggregation result + * when there are */ +TypedValue DefaultAggregationOpValue(const Aggregate::Element &element, utils::MemoryResource *memory) { + switch (element.op) { + case Aggregation::Op::COUNT: + return TypedValue(0, memory); + case Aggregation::Op::SUM: + case Aggregation::Op::MIN: + case Aggregation::Op::MAX: + case Aggregation::Op::AVG: + return TypedValue(memory); + case Aggregation::Op::COLLECT_LIST: + return TypedValue(TypedValue::TVector(memory)); + case Aggregation::Op::COLLECT_MAP: + return TypedValue(TypedValue::TMap(memory)); + } +} +} // namespace + +class AggregateCursor : public Cursor { + public: + AggregateCursor(const Aggregate &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self_.input_->MakeCursor(mem)), aggregation_(mem) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("Aggregate"); + + if (!pulled_all_input_) { + ProcessAll(&frame, &context); + pulled_all_input_ = true; + aggregation_it_ = aggregation_.begin(); + + // in case there is no input and no group_bys we need to return true + // just this once + if (aggregation_.empty() && self_.group_by_.empty()) { + auto *pull_memory = context.evaluation_context.memory; + // place default aggregation values on the frame + for (const auto &elem : self_.aggregations_) + frame[elem.output_sym] = DefaultAggregationOpValue(elem, pull_memory); + // place null as remember values on the frame + for (const Symbol &remember_sym : self_.remember_) frame[remember_sym] = TypedValue(pull_memory); + return true; + } + } + + if (aggregation_it_ == aggregation_.end()) return false; + + // place aggregation values on the frame + auto aggregation_values_it = aggregation_it_->second.values_.begin(); + for (const auto &aggregation_elem : self_.aggregations_) + frame[aggregation_elem.output_sym] = *aggregation_values_it++; + + // place remember values on the frame + auto remember_values_it = aggregation_it_->second.remember_.begin(); + for (const Symbol &remember_sym : self_.remember_) frame[remember_sym] = *remember_values_it++; + + aggregation_it_++; + return true; + } + + void Shutdown() override { input_cursor_->Shutdown(); } + + void Reset() override { + input_cursor_->Reset(); + aggregation_.clear(); + aggregation_it_ = aggregation_.begin(); + pulled_all_input_ = false; + } + + private: + // Data structure for a single aggregation cache. + // Does NOT include the group-by values since those are a key in the + // aggregation map. The vectors in an AggregationValue contain one element for + // each aggregation in this LogicalOp. + struct AggregationValue { + explicit AggregationValue(utils::MemoryResource *mem) : counts_(mem), values_(mem), remember_(mem) {} + + // how many input rows have been aggregated in respective values_ element so + // far + // TODO: The counting value type should be changed to an unsigned type once + // TypedValue can support signed integer values larger than 64bits so that + // precision isn't lost. + utils::pmr::vector<int64_t> counts_; + // aggregated values. Initially Null (until at least one input row with a + // valid value gets processed) + utils::pmr::vector<TypedValue> values_; + // remember values. + utils::pmr::vector<TypedValue> remember_; + }; + + const Aggregate &self_; + const UniqueCursorPtr input_cursor_; + // storage for aggregated data + // map key is the vector of group-by values + // map value is an AggregationValue struct + utils::pmr::unordered_map<utils::pmr::vector<TypedValue>, AggregationValue, + // use FNV collection hashing specialized for a + // vector of TypedValues + utils::FnvCollection<utils::pmr::vector<TypedValue>, TypedValue, TypedValue::Hash>, + // custom equality + TypedValueVectorEqual> + aggregation_; + // iterator over the accumulated cache + decltype(aggregation_.begin()) aggregation_it_ = aggregation_.begin(); + // this LogicalOp pulls all from the input on it's first pull + // this switch tracks if this has been performed + bool pulled_all_input_{false}; + + /** + * Pulls from the input operator until exhausted and aggregates the + * results. If the input operator is not provided, a single call + * to ProcessOne is issued. + * + * Accumulation automatically groups the results so that `aggregation_` + * cache cardinality depends on number of + * aggregation results, and not on the number of inputs. + */ + void ProcessAll(Frame *frame, ExecutionContext *context) { + ExpressionEvaluator evaluator(frame, context->symbol_table, context->evaluation_context, context->db_accessor, + storage::v3::View::NEW); + while (input_cursor_->Pull(*frame, *context)) { + ProcessOne(*frame, &evaluator); + } + + // calculate AVG aggregations (so far they have only been summed) + for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) { + if (self_.aggregations_[pos].op != Aggregation::Op::AVG) continue; + for (auto &kv : aggregation_) { + AggregationValue &agg_value = kv.second; + auto count = agg_value.counts_[pos]; + auto *pull_memory = context->evaluation_context.memory; + if (count > 0) { + agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory); + } + } + } + } + + /** + * Performs a single accumulation. + */ + void ProcessOne(const Frame &frame, ExpressionEvaluator *evaluator) { + auto *mem = aggregation_.get_allocator().GetMemoryResource(); + utils::pmr::vector<TypedValue> group_by(mem); + group_by.reserve(self_.group_by_.size()); + for (Expression *expression : self_.group_by_) { + group_by.emplace_back(expression->Accept(*evaluator)); + } + auto &agg_value = aggregation_.try_emplace(std::move(group_by), mem).first->second; + EnsureInitialized(frame, &agg_value); + Update(evaluator, &agg_value); + } + + /** Ensures the new AggregationValue has been initialized. This means + * that the value vectors are filled with an appropriate number of Nulls, + * counts are set to 0 and remember values are remembered. + */ + void EnsureInitialized(const Frame &frame, AggregateCursor::AggregationValue *agg_value) const { + if (!agg_value->values_.empty()) return; + + for (const auto &agg_elem : self_.aggregations_) { + auto *mem = agg_value->values_.get_allocator().GetMemoryResource(); + agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem)); + } + agg_value->counts_.resize(self_.aggregations_.size(), 0); + + for (const Symbol &remember_sym : self_.remember_) agg_value->remember_.push_back(frame[remember_sym]); + } + + /** Updates the given AggregationValue with new data. Assumes that + * the AggregationValue has been initialized */ + void Update(ExpressionEvaluator *evaluator, AggregateCursor::AggregationValue *agg_value) { + DMG_ASSERT(self_.aggregations_.size() == agg_value->values_.size(), + "Expected as much AggregationValue.values_ as there are " + "aggregations."); + DMG_ASSERT(self_.aggregations_.size() == agg_value->counts_.size(), + "Expected as much AggregationValue.counts_ as there are " + "aggregations."); + + // we iterate over counts, values and aggregation info at the same time + auto count_it = agg_value->counts_.begin(); + auto value_it = agg_value->values_.begin(); + auto agg_elem_it = self_.aggregations_.begin(); + for (; count_it < agg_value->counts_.end(); count_it++, value_it++, agg_elem_it++) { + // COUNT(*) is the only case where input expression is optional + // handle it here + auto input_expr_ptr = agg_elem_it->value; + if (!input_expr_ptr) { + *count_it += 1; + *value_it = *count_it; + continue; + } + + TypedValue input_value = input_expr_ptr->Accept(*evaluator); + + // Aggregations skip Null input values. + if (input_value.IsNull()) continue; + const auto &agg_op = agg_elem_it->op; + *count_it += 1; + if (*count_it == 1) { + // first value, nothing to aggregate. check type, set and continue. + switch (agg_op) { + case Aggregation::Op::MIN: + case Aggregation::Op::MAX: + *value_it = input_value; + EnsureOkForMinMax(input_value); + break; + case Aggregation::Op::SUM: + case Aggregation::Op::AVG: + *value_it = input_value; + EnsureOkForAvgSum(input_value); + break; + case Aggregation::Op::COUNT: + *value_it = 1; + break; + case Aggregation::Op::COLLECT_LIST: + value_it->ValueList().push_back(input_value); + break; + case Aggregation::Op::COLLECT_MAP: + auto key = agg_elem_it->key->Accept(*evaluator); + if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string."); + value_it->ValueMap().emplace(key.ValueString(), input_value); + break; + } + continue; + } + + // aggregation of existing values + switch (agg_op) { + case Aggregation::Op::COUNT: + *value_it = *count_it; + break; + case Aggregation::Op::MIN: { + EnsureOkForMinMax(input_value); + try { + TypedValue comparison_result = input_value < *value_it; + // since we skip nulls we either have a valid comparison, or + // an exception was just thrown above + // safe to assume a bool TypedValue + if (comparison_result.ValueBool()) *value_it = input_value; + } catch (const TypedValueException &) { + throw QueryRuntimeException("Unable to get MIN of '{}' and '{}'.", input_value.type(), value_it->type()); + } + break; + } + case Aggregation::Op::MAX: { + // all comments as for Op::Min + EnsureOkForMinMax(input_value); + try { + TypedValue comparison_result = input_value > *value_it; + if (comparison_result.ValueBool()) *value_it = input_value; + } catch (const TypedValueException &) { + throw QueryRuntimeException("Unable to get MAX of '{}' and '{}'.", input_value.type(), value_it->type()); + } + break; + } + case Aggregation::Op::AVG: + // for averaging we sum first and divide by count once all + // the input has been processed + case Aggregation::Op::SUM: + EnsureOkForAvgSum(input_value); + *value_it = *value_it + input_value; + break; + case Aggregation::Op::COLLECT_LIST: + value_it->ValueList().push_back(input_value); + break; + case Aggregation::Op::COLLECT_MAP: + auto key = agg_elem_it->key->Accept(*evaluator); + if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string."); + value_it->ValueMap().emplace(key.ValueString(), input_value); + break; + } // end switch over Aggregation::Op enum + } // end loop over all aggregations + } + + /** Checks if the given TypedValue is legal in MIN and MAX. If not + * an appropriate exception is thrown. */ + void EnsureOkForMinMax(const TypedValue &value) const { + switch (value.type()) { + case TypedValue::Type::Bool: + case TypedValue::Type::Int: + case TypedValue::Type::Double: + case TypedValue::Type::String: + return; + default: + throw QueryRuntimeException( + "Only boolean, numeric and string values are allowed in " + "MIN and MAX aggregations."); + } + } + + /** Checks if the given TypedValue is legal in AVG and SUM. If not + * an appropriate exception is thrown. */ + void EnsureOkForAvgSum(const TypedValue &value) const { + switch (value.type()) { + case TypedValue::Type::Int: + case TypedValue::Type::Double: + return; + default: + throw QueryRuntimeException("Only numeric values allowed in SUM and AVG aggregations."); + } + } +}; + +UniqueCursorPtr Aggregate::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::AggregateOperator); + + return MakeUniqueCursorPtr<AggregateCursor>(mem, *this, mem); +} + +Skip::Skip(const std::shared_ptr<LogicalOperator> &input, Expression *expression) + : input_(input), expression_(expression) {} + +ACCEPT_WITH_INPUT(Skip) + +UniqueCursorPtr Skip::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::SkipOperator); + + return MakeUniqueCursorPtr<SkipCursor>(mem, *this, mem); +} + +std::vector<Symbol> Skip::OutputSymbols(const SymbolTable &symbol_table) const { + // Propagate this to potential Produce. + return input_->OutputSymbols(symbol_table); +} + +std::vector<Symbol> Skip::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } + +Skip::SkipCursor::SkipCursor(const Skip &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {} + +bool Skip::SkipCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("Skip"); + + while (input_cursor_->Pull(frame, context)) { + if (to_skip_ == -1) { + // First successful pull from the input, evaluate the skip expression. + // The skip expression doesn't contain identifiers so graph view + // parameter is not important. + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::OLD); + TypedValue to_skip = self_.expression_->Accept(evaluator); + if (to_skip.type() != TypedValue::Type::Int) + throw QueryRuntimeException("Number of elements to skip must be an integer."); + + to_skip_ = to_skip.ValueInt(); + if (to_skip_ < 0) throw QueryRuntimeException("Number of elements to skip must be non-negative."); + } + + if (skipped_++ < to_skip_) continue; + return true; + } + return false; +} + +void Skip::SkipCursor::Shutdown() { input_cursor_->Shutdown(); } + +void Skip::SkipCursor::Reset() { + input_cursor_->Reset(); + to_skip_ = -1; + skipped_ = 0; +} + +Limit::Limit(const std::shared_ptr<LogicalOperator> &input, Expression *expression) + : input_(input), expression_(expression) {} + +ACCEPT_WITH_INPUT(Limit) + +UniqueCursorPtr Limit::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::LimitOperator); + + return MakeUniqueCursorPtr<LimitCursor>(mem, *this, mem); +} + +std::vector<Symbol> Limit::OutputSymbols(const SymbolTable &symbol_table) const { + // Propagate this to potential Produce. + return input_->OutputSymbols(symbol_table); +} + +std::vector<Symbol> Limit::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } + +Limit::LimitCursor::LimitCursor(const Limit &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {} + +bool Limit::LimitCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("Limit"); + + // We need to evaluate the limit expression before the first input Pull + // because it might be 0 and thereby we shouldn't Pull from input at all. + // We can do this before Pulling from the input because the limit expression + // is not allowed to contain any identifiers. + if (limit_ == -1) { + // Limit expression doesn't contain identifiers so graph view is not + // important. + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::OLD); + TypedValue limit = self_.expression_->Accept(evaluator); + if (limit.type() != TypedValue::Type::Int) + throw QueryRuntimeException("Limit on number of returned elements must be an integer."); + + limit_ = limit.ValueInt(); + if (limit_ < 0) throw QueryRuntimeException("Limit on number of returned elements must be non-negative."); + } + + // check we have not exceeded the limit before pulling + if (pulled_++ >= limit_) return false; + + return input_cursor_->Pull(frame, context); +} + +void Limit::LimitCursor::Shutdown() { input_cursor_->Shutdown(); } + +void Limit::LimitCursor::Reset() { + input_cursor_->Reset(); + limit_ = -1; + pulled_ = 0; +} + +OrderBy::OrderBy(const std::shared_ptr<LogicalOperator> &input, const std::vector<SortItem> &order_by, + const std::vector<Symbol> &output_symbols) + : input_(input), output_symbols_(output_symbols) { + // split the order_by vector into two vectors of orderings and expressions + std::vector<Ordering> ordering; + ordering.reserve(order_by.size()); + order_by_.reserve(order_by.size()); + for (const auto &ordering_expression_pair : order_by) { + ordering.emplace_back(ordering_expression_pair.ordering); + order_by_.emplace_back(ordering_expression_pair.expression); + } + compare_ = TypedValueVectorCompare(ordering); +} + +ACCEPT_WITH_INPUT(OrderBy) + +std::vector<Symbol> OrderBy::OutputSymbols(const SymbolTable &symbol_table) const { + // Propagate this to potential Produce. + return input_->OutputSymbols(symbol_table); +} + +std::vector<Symbol> OrderBy::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } + +class OrderByCursor : public Cursor { + public: + OrderByCursor(const OrderBy &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self_.input_->MakeCursor(mem)), cache_(mem) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("OrderBy"); + + if (!did_pull_all_) { + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::OLD); + auto *mem = cache_.get_allocator().GetMemoryResource(); + while (input_cursor_->Pull(frame, context)) { + // collect the order_by elements + utils::pmr::vector<TypedValue> order_by(mem); + order_by.reserve(self_.order_by_.size()); + for (auto expression_ptr : self_.order_by_) { + order_by.emplace_back(expression_ptr->Accept(evaluator)); + } + + // collect the output elements + utils::pmr::vector<TypedValue> output(mem); + output.reserve(self_.output_symbols_.size()); + for (const Symbol &output_sym : self_.output_symbols_) output.emplace_back(frame[output_sym]); + + cache_.push_back(Element{std::move(order_by), std::move(output)}); + } + + std::sort(cache_.begin(), cache_.end(), [this](const auto &pair1, const auto &pair2) { + return self_.compare_(pair1.order_by, pair2.order_by); + }); + + did_pull_all_ = true; + cache_it_ = cache_.begin(); + } + + if (cache_it_ == cache_.end()) return false; + + if (MustAbort(context)) throw HintedAbortError(); + + // place the output values on the frame + DMG_ASSERT(self_.output_symbols_.size() == cache_it_->remember.size(), + "Number of values does not match the number of output symbols " + "in OrderBy"); + auto output_sym_it = self_.output_symbols_.begin(); + for (const TypedValue &output : cache_it_->remember) frame[*output_sym_it++] = output; + + cache_it_++; + return true; + } + void Shutdown() override { input_cursor_->Shutdown(); } + + void Reset() override { + input_cursor_->Reset(); + did_pull_all_ = false; + cache_.clear(); + cache_it_ = cache_.begin(); + } + + private: + struct Element { + utils::pmr::vector<TypedValue> order_by; + utils::pmr::vector<TypedValue> remember; + }; + + const OrderBy &self_; + const UniqueCursorPtr input_cursor_; + bool did_pull_all_{false}; + // a cache of elements pulled from the input + // the cache is filled and sorted (only on first elem) on first Pull + utils::pmr::vector<Element> cache_; + // iterator over the cache_, maintains state between Pulls + decltype(cache_.begin()) cache_it_ = cache_.begin(); +}; + +UniqueCursorPtr OrderBy::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::OrderByOperator); + + return MakeUniqueCursorPtr<OrderByCursor>(mem, *this, mem); +} + +Merge::Merge(const std::shared_ptr<LogicalOperator> &input, const std::shared_ptr<LogicalOperator> &merge_match, + const std::shared_ptr<LogicalOperator> &merge_create) + : input_(input ? input : std::make_shared<Once>()), merge_match_(merge_match), merge_create_(merge_create) {} + +bool Merge::Accept(HierarchicalLogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + input_->Accept(visitor) && merge_match_->Accept(visitor) && merge_create_->Accept(visitor); + } + return visitor.PostVisit(*this); +} + +UniqueCursorPtr Merge::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::MergeOperator); + + return MakeUniqueCursorPtr<MergeCursor>(mem, *this, mem); +} + +std::vector<Symbol> Merge::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + // Match and create branches should have the same symbols, so just take one + // of them. + auto my_symbols = merge_match_->OutputSymbols(table); + symbols.insert(symbols.end(), my_symbols.begin(), my_symbols.end()); + return symbols; +} + +Merge::MergeCursor::MergeCursor(const Merge &self, utils::MemoryResource *mem) + : input_cursor_(self.input_->MakeCursor(mem)), + merge_match_cursor_(self.merge_match_->MakeCursor(mem)), + merge_create_cursor_(self.merge_create_->MakeCursor(mem)) {} + +bool Merge::MergeCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("Merge"); + + while (true) { + if (pull_input_) { + if (input_cursor_->Pull(frame, context)) { + // after a successful input from the input + // reset merge_match (it's expand iterators maintain state) + // and merge_create (could have a Once at the beginning) + merge_match_cursor_->Reset(); + merge_create_cursor_->Reset(); + } else + // input is exhausted, we're done + return false; + } + + // pull from the merge_match cursor + if (merge_match_cursor_->Pull(frame, context)) { + // if successful, next Pull from this should not pull_input_ + pull_input_ = false; + return true; + } else { + // failed to Pull from the merge_match cursor + if (pull_input_) { + // if we have just now pulled from the input + // and failed to pull from merge_match, we should create + __attribute__((unused)) bool merge_create_pull_result = merge_create_cursor_->Pull(frame, context); + DMG_ASSERT(merge_create_pull_result, "MergeCreate must never fail"); + return true; + } + // We have exhausted merge_match_cursor_ after 1 or more successful + // Pulls. Attempt next input_cursor_ pull + pull_input_ = true; + continue; + } + } +} + +void Merge::MergeCursor::Shutdown() { + input_cursor_->Shutdown(); + merge_match_cursor_->Shutdown(); + merge_create_cursor_->Shutdown(); +} + +void Merge::MergeCursor::Reset() { + input_cursor_->Reset(); + merge_match_cursor_->Reset(); + merge_create_cursor_->Reset(); + pull_input_ = true; +} + +Optional::Optional(const std::shared_ptr<LogicalOperator> &input, const std::shared_ptr<LogicalOperator> &optional, + const std::vector<Symbol> &optional_symbols) + : input_(input ? input : std::make_shared<Once>()), optional_(optional), optional_symbols_(optional_symbols) {} + +bool Optional::Accept(HierarchicalLogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + input_->Accept(visitor) && optional_->Accept(visitor); + } + return visitor.PostVisit(*this); +} + +UniqueCursorPtr Optional::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::OptionalOperator); + + return MakeUniqueCursorPtr<OptionalCursor>(mem, *this, mem); +} + +std::vector<Symbol> Optional::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + auto my_symbols = optional_->ModifiedSymbols(table); + symbols.insert(symbols.end(), my_symbols.begin(), my_symbols.end()); + return symbols; +} + +Optional::OptionalCursor::OptionalCursor(const Optional &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)), optional_cursor_(self.optional_->MakeCursor(mem)) {} + +bool Optional::OptionalCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("Optional"); + + while (true) { + if (pull_input_) { + if (input_cursor_->Pull(frame, context)) { + // after a successful input from the input + // reset optional_ (it's expand iterators maintain state) + optional_cursor_->Reset(); + } else + // input is exhausted, we're done + return false; + } + + // pull from the optional_ cursor + if (optional_cursor_->Pull(frame, context)) { + // if successful, next Pull from this should not pull_input_ + pull_input_ = false; + return true; + } else { + // failed to Pull from the merge_match cursor + if (pull_input_) { + // if we have just now pulled from the input + // and failed to pull from optional_ so set the + // optional symbols to Null, ensure next time the + // input gets pulled and return true + for (const Symbol &sym : self_.optional_symbols_) frame[sym] = TypedValue(context.evaluation_context.memory); + pull_input_ = true; + return true; + } + // we have exhausted optional_cursor_ after 1 or more successful Pulls + // attempt next input_cursor_ pull + pull_input_ = true; + continue; + } + } +} + +void Optional::OptionalCursor::Shutdown() { + input_cursor_->Shutdown(); + optional_cursor_->Shutdown(); +} + +void Optional::OptionalCursor::Reset() { + input_cursor_->Reset(); + optional_cursor_->Reset(); + pull_input_ = true; +} + +Unwind::Unwind(const std::shared_ptr<LogicalOperator> &input, Expression *input_expression, Symbol output_symbol) + : input_(input ? input : std::make_shared<Once>()), + input_expression_(input_expression), + output_symbol_(output_symbol) {} + +ACCEPT_WITH_INPUT(Unwind) + +std::vector<Symbol> Unwind::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + symbols.emplace_back(output_symbol_); + return symbols; +} + +class UnwindCursor : public Cursor { + public: + UnwindCursor(const Unwind &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)), input_value_(mem) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("Unwind"); + while (true) { + if (MustAbort(context)) throw HintedAbortError(); + // if we reached the end of our list of values + // pull from the input + if (input_value_it_ == input_value_.end()) { + if (!input_cursor_->Pull(frame, context)) return false; + + // successful pull from input, initialize value and iterator + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::OLD); + TypedValue input_value = self_.input_expression_->Accept(evaluator); + if (input_value.type() != TypedValue::Type::List) + throw QueryRuntimeException("Argument of UNWIND must be a list, but '{}' was provided.", input_value.type()); + // Copy the evaluted input_value_list to our vector. + input_value_ = input_value.ValueList(); + input_value_it_ = input_value_.begin(); + } + + // if we reached the end of our list of values goto back to top + if (input_value_it_ == input_value_.end()) continue; + + frame[self_.output_symbol_] = *input_value_it_++; + return true; + } + } + + void Shutdown() override { input_cursor_->Shutdown(); } + + void Reset() override { + input_cursor_->Reset(); + input_value_.clear(); + input_value_it_ = input_value_.end(); + } + + private: + const Unwind &self_; + const UniqueCursorPtr input_cursor_; + // typed values we are unwinding and yielding + utils::pmr::vector<TypedValue> input_value_; + // current position in input_value_ + decltype(input_value_)::iterator input_value_it_ = input_value_.end(); +}; + +UniqueCursorPtr Unwind::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::UnwindOperator); + + return MakeUniqueCursorPtr<UnwindCursor>(mem, *this, mem); +} + +class DistinctCursor : public Cursor { + public: + DistinctCursor(const Distinct &self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self.input_->MakeCursor(mem)), seen_rows_(mem) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("Distinct"); + + while (true) { + if (!input_cursor_->Pull(frame, context)) return false; + + utils::pmr::vector<TypedValue> row(seen_rows_.get_allocator().GetMemoryResource()); + row.reserve(self_.value_symbols_.size()); + for (const auto &symbol : self_.value_symbols_) row.emplace_back(frame[symbol]); + if (seen_rows_.insert(std::move(row)).second) return true; + } + } + + void Shutdown() override { input_cursor_->Shutdown(); } + + void Reset() override { + input_cursor_->Reset(); + seen_rows_.clear(); + } + + private: + const Distinct &self_; + const UniqueCursorPtr input_cursor_; + // a set of already seen rows + utils::pmr::unordered_set<utils::pmr::vector<TypedValue>, + // use FNV collection hashing specialized for a + // vector of TypedValue + utils::FnvCollection<utils::pmr::vector<TypedValue>, TypedValue, TypedValue::Hash>, + TypedValueVectorEqual> + seen_rows_; +}; + +Distinct::Distinct(const std::shared_ptr<LogicalOperator> &input, const std::vector<Symbol> &value_symbols) + : input_(input ? input : std::make_shared<Once>()), value_symbols_(value_symbols) {} + +ACCEPT_WITH_INPUT(Distinct) + +UniqueCursorPtr Distinct::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::DistinctOperator); + + return MakeUniqueCursorPtr<DistinctCursor>(mem, *this, mem); +} + +std::vector<Symbol> Distinct::OutputSymbols(const SymbolTable &symbol_table) const { + // Propagate this to potential Produce. + return input_->OutputSymbols(symbol_table); +} + +std::vector<Symbol> Distinct::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } + +Union::Union(const std::shared_ptr<LogicalOperator> &left_op, const std::shared_ptr<LogicalOperator> &right_op, + const std::vector<Symbol> &union_symbols, const std::vector<Symbol> &left_symbols, + const std::vector<Symbol> &right_symbols) + : left_op_(left_op), + right_op_(right_op), + union_symbols_(union_symbols), + left_symbols_(left_symbols), + right_symbols_(right_symbols) {} + +UniqueCursorPtr Union::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::UnionOperator); + + return MakeUniqueCursorPtr<Union::UnionCursor>(mem, *this, mem); +} + +bool Union::Accept(HierarchicalLogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + if (left_op_->Accept(visitor)) { + right_op_->Accept(visitor); + } + } + return visitor.PostVisit(*this); +} + +std::vector<Symbol> Union::OutputSymbols(const SymbolTable &) const { return union_symbols_; } + +std::vector<Symbol> Union::ModifiedSymbols(const SymbolTable &) const { return union_symbols_; } + +WITHOUT_SINGLE_INPUT(Union); + +Union::UnionCursor::UnionCursor(const Union &self, utils::MemoryResource *mem) + : self_(self), left_cursor_(self.left_op_->MakeCursor(mem)), right_cursor_(self.right_op_->MakeCursor(mem)) {} + +bool Union::UnionCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("Union"); + + utils::pmr::unordered_map<std::string, TypedValue> results(context.evaluation_context.memory); + if (left_cursor_->Pull(frame, context)) { + // collect values from the left child + for (const auto &output_symbol : self_.left_symbols_) { + results[output_symbol.name()] = frame[output_symbol]; + } + } else if (right_cursor_->Pull(frame, context)) { + // collect values from the right child + for (const auto &output_symbol : self_.right_symbols_) { + results[output_symbol.name()] = frame[output_symbol]; + } + } else { + return false; + } + + // put collected values on frame under union symbols + for (const auto &symbol : self_.union_symbols_) { + frame[symbol] = results[symbol.name()]; + } + return true; +} + +void Union::UnionCursor::Shutdown() { + left_cursor_->Shutdown(); + right_cursor_->Shutdown(); +} + +void Union::UnionCursor::Reset() { + left_cursor_->Reset(); + right_cursor_->Reset(); +} + +std::vector<Symbol> Cartesian::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = left_op_->ModifiedSymbols(table); + auto right = right_op_->ModifiedSymbols(table); + symbols.insert(symbols.end(), right.begin(), right.end()); + return symbols; +} + +bool Cartesian::Accept(HierarchicalLogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + left_op_->Accept(visitor) && right_op_->Accept(visitor); + } + return visitor.PostVisit(*this); +} + +WITHOUT_SINGLE_INPUT(Cartesian); + +namespace { + +class CartesianCursor : public Cursor { + public: + CartesianCursor(const Cartesian &self, utils::MemoryResource *mem) + : self_(self), + left_op_frames_(mem), + right_op_frame_(mem), + left_op_cursor_(self.left_op_->MakeCursor(mem)), + right_op_cursor_(self_.right_op_->MakeCursor(mem)) { + MG_ASSERT(left_op_cursor_ != nullptr, "CartesianCursor: Missing left operator cursor."); + MG_ASSERT(right_op_cursor_ != nullptr, "CartesianCursor: Missing right operator cursor."); + } + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("Cartesian"); + + if (!cartesian_pull_initialized_) { + // Pull all left_op frames. + while (left_op_cursor_->Pull(frame, context)) { + left_op_frames_.emplace_back(frame.elems().begin(), frame.elems().end()); + } + + // We're setting the iterator to 'end' here so it pulls the right + // cursor. + left_op_frames_it_ = left_op_frames_.end(); + cartesian_pull_initialized_ = true; + } + + // If left operator yielded zero results there is no cartesian product. + if (left_op_frames_.empty()) { + return false; + } + + auto restore_frame = [&frame](const auto &symbols, const auto &restore_from) { + for (const auto &symbol : symbols) { + frame[symbol] = restore_from[symbol.position()]; + } + }; + + if (left_op_frames_it_ == left_op_frames_.end()) { + // Advance right_op_cursor_. + if (!right_op_cursor_->Pull(frame, context)) return false; + + right_op_frame_.assign(frame.elems().begin(), frame.elems().end()); + left_op_frames_it_ = left_op_frames_.begin(); + } else { + // Make sure right_op_cursor last pulled results are on frame. + restore_frame(self_.right_symbols_, right_op_frame_); + } + + if (MustAbort(context)) throw HintedAbortError(); + + restore_frame(self_.left_symbols_, *left_op_frames_it_); + left_op_frames_it_++; + return true; + } + + void Shutdown() override { + left_op_cursor_->Shutdown(); + right_op_cursor_->Shutdown(); + } + + void Reset() override { + left_op_cursor_->Reset(); + right_op_cursor_->Reset(); + right_op_frame_.clear(); + left_op_frames_.clear(); + left_op_frames_it_ = left_op_frames_.end(); + cartesian_pull_initialized_ = false; + } + + private: + const Cartesian &self_; + utils::pmr::vector<utils::pmr::vector<TypedValue>> left_op_frames_; + utils::pmr::vector<TypedValue> right_op_frame_; + const UniqueCursorPtr left_op_cursor_; + const UniqueCursorPtr right_op_cursor_; + utils::pmr::vector<utils::pmr::vector<TypedValue>>::iterator left_op_frames_it_; + bool cartesian_pull_initialized_{false}; +}; + +} // namespace + +UniqueCursorPtr Cartesian::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::CartesianOperator); + + return MakeUniqueCursorPtr<CartesianCursor>(mem, *this, mem); +} + +OutputTable::OutputTable(std::vector<Symbol> output_symbols, std::vector<std::vector<TypedValue>> rows) + : output_symbols_(std::move(output_symbols)), callback_([rows](Frame *, ExecutionContext *) { return rows; }) {} + +OutputTable::OutputTable(std::vector<Symbol> output_symbols, + std::function<std::vector<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback) + : output_symbols_(std::move(output_symbols)), callback_(std::move(callback)) {} + +WITHOUT_SINGLE_INPUT(OutputTable); + +class OutputTableCursor : public Cursor { + public: + OutputTableCursor(const OutputTable &self) : self_(self) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + if (!pulled_) { + rows_ = self_.callback_(&frame, &context); + for (const auto &row : rows_) { + MG_ASSERT(row.size() == self_.output_symbols_.size(), "Wrong number of columns in row!"); + } + pulled_ = true; + } + if (current_row_ < rows_.size()) { + for (size_t i = 0; i < self_.output_symbols_.size(); ++i) { + frame[self_.output_symbols_[i]] = rows_[current_row_][i]; + } + current_row_++; + return true; + } + return false; + } + + void Reset() override { + pulled_ = false; + current_row_ = 0; + rows_.clear(); + } + + void Shutdown() override {} + + private: + const OutputTable &self_; + size_t current_row_{0}; + std::vector<std::vector<TypedValue>> rows_; + bool pulled_{false}; +}; + +UniqueCursorPtr OutputTable::MakeCursor(utils::MemoryResource *mem) const { + return MakeUniqueCursorPtr<OutputTableCursor>(mem, *this); +} + +OutputTableStream::OutputTableStream( + std::vector<Symbol> output_symbols, + std::function<std::optional<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback) + : output_symbols_(std::move(output_symbols)), callback_(std::move(callback)) {} + +WITHOUT_SINGLE_INPUT(OutputTableStream); + +class OutputTableStreamCursor : public Cursor { + public: + explicit OutputTableStreamCursor(const OutputTableStream *self) : self_(self) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + const auto row = self_->callback_(&frame, &context); + if (row) { + MG_ASSERT(row->size() == self_->output_symbols_.size(), "Wrong number of columns in row!"); + for (size_t i = 0; i < self_->output_symbols_.size(); ++i) { + frame[self_->output_symbols_[i]] = row->at(i); + } + return true; + } + return false; + } + + // TODO(tsabolcec): Come up with better approach for handling `Reset()`. + // One possibility is to implement a custom closure utility class with + // `Reset()` method. + void Reset() override { throw utils::NotYetImplemented("OutputTableStreamCursor::Reset"); } + + void Shutdown() override {} + + private: + const OutputTableStream *self_; +}; + +UniqueCursorPtr OutputTableStream::MakeCursor(utils::MemoryResource *mem) const { + return MakeUniqueCursorPtr<OutputTableStreamCursor>(mem, this); +} + +CallProcedure::CallProcedure(std::shared_ptr<LogicalOperator> input, std::string name, std::vector<Expression *> args, + std::vector<std::string> fields, std::vector<Symbol> symbols, Expression *memory_limit, + size_t memory_scale, bool is_write) + : input_(input ? input : std::make_shared<Once>()), + procedure_name_(name), + arguments_(args), + result_fields_(fields), + result_symbols_(symbols), + memory_limit_(memory_limit), + memory_scale_(memory_scale), + is_write_(is_write) {} + +ACCEPT_WITH_INPUT(CallProcedure); + +std::vector<Symbol> CallProcedure::OutputSymbols(const SymbolTable &) const { return result_symbols_; } + +std::vector<Symbol> CallProcedure::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + symbols.insert(symbols.end(), result_symbols_.begin(), result_symbols_.end()); + return symbols; +} + +void CallProcedure::IncrementCounter(const std::string &procedure_name) { + procedure_counters_.WithLock([&](auto &counters) { ++counters[procedure_name]; }); +} + +std::unordered_map<std::string, int64_t> CallProcedure::GetAndResetCounters() { + auto counters = procedure_counters_.Lock(); + auto ret = std::move(*counters); + counters->clear(); + return ret; +} + +namespace { + +void CallCustomProcedure(const std::string_view fully_qualified_procedure_name, const mgp_proc &proc, + const std::vector<Expression *> &args, mgp_graph &graph, ExpressionEvaluator *evaluator, + utils::MemoryResource *memory, std::optional<size_t> memory_limit, mgp_result *result) { + static_assert(std::uses_allocator_v<mgp_value, utils::Allocator<mgp_value>>, + "Expected mgp_value to use custom allocator and makes STL " + "containers aware of that"); + // Build and type check procedure arguments. + mgp_list proc_args(memory); + std::vector<TypedValue> args_list; + args_list.reserve(args.size()); + for (auto *expression : args) { + args_list.emplace_back(expression->Accept(*evaluator)); + } + procedure::ConstructArguments(args_list, proc, fully_qualified_procedure_name, proc_args, graph); + if (memory_limit) { + SPDLOG_INFO("Running '{}' with memory limit of {}", fully_qualified_procedure_name, + utils::GetReadableSize(*memory_limit)); + utils::LimitedMemoryResource limited_mem(memory, *memory_limit); + mgp_memory proc_memory{&limited_mem}; + MG_ASSERT(result->signature == &proc.results); + // TODO: What about cross library boundary exceptions? OMG C++?! + proc.cb(&proc_args, &graph, result, &proc_memory); + size_t leaked_bytes = limited_mem.GetAllocatedBytes(); + if (leaked_bytes > 0U) { + spdlog::warn("Query procedure '{}' leaked {} *tracked* bytes", fully_qualified_procedure_name, leaked_bytes); + } + } else { + // TODO: Add a tracking MemoryResource without limits, so that we report + // memory leaks in procedure. + mgp_memory proc_memory{memory}; + MG_ASSERT(result->signature == &proc.results); + // TODO: What about cross library boundary exceptions? OMG C++?! + proc.cb(&proc_args, &graph, result, &proc_memory); + } +} + +} // namespace + +class CallProcedureCursor : public Cursor { + const CallProcedure *self_; + UniqueCursorPtr input_cursor_; + mgp_result result_; + decltype(result_.rows.end()) result_row_it_{result_.rows.end()}; + size_t result_signature_size_{0}; + + public: + CallProcedureCursor(const CallProcedure *self, utils::MemoryResource *mem) + : self_(self), + input_cursor_(self_->input_->MakeCursor(mem)), + // result_ needs to live throughout multiple Pull evaluations, until all + // rows are produced. Therefore, we use the memory dedicated for the + // whole execution. + result_(nullptr, mem) { + MG_ASSERT(self_->result_fields_.size() == self_->result_symbols_.size(), "Incorrectly constructed CallProcedure"); + } + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("CallProcedure"); + + if (MustAbort(context)) throw HintedAbortError(); + + // We need to fetch new procedure results after pulling from input. + // TODO: Look into openCypher's distinction between procedures returning an + // empty result set vs procedures which return `void`. We currently don't + // have procedures registering what they return. + // This `while` loop will skip over empty results. + while (result_row_it_ == result_.rows.end()) { + if (!input_cursor_->Pull(frame, context)) return false; + result_.signature = nullptr; + result_.rows.clear(); + result_.error_msg.reset(); + // It might be a good idea to resolve the procedure name once, at the + // start. Unfortunately, this could deadlock if we tried to invoke a + // procedure from a module (read lock) and reload a module (write lock) + // inside the same execution thread. Also, our RWLock is setup so that + // it's not possible for a single thread to request multiple read locks. + // Builtin module registration in query/procedure/module.cpp depends on + // this locking scheme. + const auto &maybe_found = procedure::FindProcedure(procedure::gModuleRegistry, self_->procedure_name_, + context.evaluation_context.memory); + if (!maybe_found) { + throw QueryRuntimeException("There is no procedure named '{}'.", self_->procedure_name_); + } + const auto &[module, proc] = *maybe_found; + if (proc->info.is_write != self_->is_write_) { + auto get_proc_type_str = [](bool is_write) { return is_write ? "write" : "read"; }; + throw QueryRuntimeException("The procedure named '{}' was a {} procedure, but changed to be a {} procedure.", + self_->procedure_name_, get_proc_type_str(self_->is_write_), + get_proc_type_str(proc->info.is_write)); + } + const auto graph_view = proc->info.is_write ? storage::v3::View::NEW : storage::v3::View::OLD; + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + graph_view); + + result_.signature = &proc->results; + // Use evaluation memory, as invoking a procedure is akin to a simple + // evaluation of an expression. + // TODO: This will probably need to be changed when we add support for + // generator like procedures which yield a new result on each invocation. + auto *memory = context.evaluation_context.memory; + auto memory_limit = EvaluateMemoryLimit(&evaluator, self_->memory_limit_, self_->memory_scale_); + auto graph = mgp_graph::WritableGraph(*context.db_accessor, graph_view, context); + CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit, + &result_); + + // Reset result_.signature to nullptr, because outside of this scope we + // will no longer hold a lock on the `module`. If someone were to reload + // it, the pointer would be invalid. + result_signature_size_ = result_.signature->size(); + result_.signature = nullptr; + if (result_.error_msg) { + throw QueryRuntimeException("{}: {}", self_->procedure_name_, *result_.error_msg); + } + result_row_it_ = result_.rows.begin(); + } + + const auto &values = result_row_it_->values; + // Check that the row has all fields as required by the result signature. + // C API guarantees that it's impossible to set fields which are not part of + // the result record, but it does not gurantee that some may be missing. See + // `mgp_result_record_insert`. + if (values.size() != result_signature_size_) { + throw QueryRuntimeException( + "Procedure '{}' did not yield all fields as required by its " + "signature.", + self_->procedure_name_); + } + for (size_t i = 0; i < self_->result_fields_.size(); ++i) { + std::string_view field_name(self_->result_fields_[i]); + auto result_it = values.find(field_name); + if (result_it == values.end()) { + throw QueryRuntimeException("Procedure '{}' did not yield a record with '{}' field.", self_->procedure_name_, + field_name); + } + frame[self_->result_symbols_[i]] = result_it->second; + } + ++result_row_it_; + + return true; + } + + void Reset() override { + result_.rows.clear(); + result_.error_msg.reset(); + input_cursor_->Reset(); + } + + void Shutdown() override {} +}; + +UniqueCursorPtr CallProcedure::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::CallProcedureOperator); + CallProcedure::IncrementCounter(procedure_name_); + + return MakeUniqueCursorPtr<CallProcedureCursor>(mem, this, mem); +} + +LoadCsv::LoadCsv(std::shared_ptr<LogicalOperator> input, Expression *file, bool with_header, bool ignore_bad, + Expression *delimiter, Expression *quote, Symbol row_var) + : input_(input ? input : (std::make_shared<Once>())), + file_(file), + with_header_(with_header), + ignore_bad_(ignore_bad), + delimiter_(delimiter), + quote_(quote), + row_var_(row_var) { + MG_ASSERT(file_, "Something went wrong - '{}' member file_ shouldn't be a nullptr", __func__); +} + +bool LoadCsv::Accept(HierarchicalLogicalOperatorVisitor &visitor) { return false; }; + +class LoadCsvCursor; + +std::vector<Symbol> LoadCsv::OutputSymbols(const SymbolTable &sym_table) const { return {row_var_}; }; + +std::vector<Symbol> LoadCsv::ModifiedSymbols(const SymbolTable &sym_table) const { + auto symbols = input_->ModifiedSymbols(sym_table); + symbols.push_back(row_var_); + return symbols; +}; + +namespace { +// copy-pasted from interpreter.cpp +TypedValue EvaluateOptionalExpression(Expression *expression, ExpressionEvaluator *eval) { + return expression ? expression->Accept(*eval) : TypedValue(); +} + +auto ToOptionalString(ExpressionEvaluator *evaluator, Expression *expression) -> std::optional<utils::pmr::string> { + const auto evaluated_expr = EvaluateOptionalExpression(expression, evaluator); + if (evaluated_expr.IsString()) { + return utils::pmr::string(evaluated_expr.ValueString(), utils::NewDeleteResource()); + } + return std::nullopt; +}; + +TypedValue CsvRowToTypedList(csv::Reader::Row row) { + auto *mem = row.get_allocator().GetMemoryResource(); + auto typed_columns = utils::pmr::vector<TypedValue>(mem); + typed_columns.reserve(row.size()); + for (auto &column : row) { + typed_columns.emplace_back(std::move(column)); + } + return TypedValue(typed_columns, mem); +} + +TypedValue CsvRowToTypedMap(csv::Reader::Row row, csv::Reader::Header header) { + // a valid row has the same number of elements as the header + auto *mem = row.get_allocator().GetMemoryResource(); + utils::pmr::map<utils::pmr::string, TypedValue> m(mem); + for (auto i = 0; i < row.size(); ++i) { + m.emplace(std::move(header[i]), std::move(row[i])); + } + return TypedValue(m, mem); +} + +} // namespace + +class LoadCsvCursor : public Cursor { + const LoadCsv *self_; + const UniqueCursorPtr input_cursor_; + bool input_is_once_; + std::optional<csv::Reader> reader_{}; + + public: + LoadCsvCursor(const LoadCsv *self, utils::MemoryResource *mem) + : self_(self), input_cursor_(self_->input_->MakeCursor(mem)) { + input_is_once_ = dynamic_cast<Once *>(self_->input_.get()); + } + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("LoadCsv"); + + if (MustAbort(context)) throw HintedAbortError(); + + // ToDo(the-joksim): + // - this is an ungodly hack because the pipeline of creating a plan + // doesn't allow evaluating the expressions contained in self_->file_, + // self_->delimiter_, and self_->quote_ earlier (say, in the interpreter.cpp) + // without massacring the code even worse than I did here + if (UNLIKELY(!reader_)) { + reader_ = MakeReader(&context.evaluation_context); + } + + bool input_pulled = input_cursor_->Pull(frame, context); + + // If the input is Once, we have to keep going until we read all the rows, + // regardless of whether the pull on Once returned false. + // If we have e.g. MATCH(n) LOAD CSV ... AS x SET n.name = x.name, then we + // have to read at most cardinality(n) rows (but we can read less and stop + // pulling MATCH). + if (!input_is_once_ && !input_pulled) return false; + + if (auto row = reader_->GetNextRow(context.evaluation_context.memory)) { + if (!reader_->HasHeader()) { + frame[self_->row_var_] = CsvRowToTypedList(std::move(*row)); + } else { + frame[self_->row_var_] = CsvRowToTypedMap( + std::move(*row), csv::Reader::Header(reader_->GetHeader(), context.evaluation_context.memory)); + } + return true; + } + + return false; + } + + void Reset() override { input_cursor_->Reset(); } + void Shutdown() override { input_cursor_->Shutdown(); } + + private: + csv::Reader MakeReader(EvaluationContext *eval_context) { + Frame frame(0); + SymbolTable symbol_table; + DbAccessor *dba = nullptr; + auto evaluator = ExpressionEvaluator(&frame, symbol_table, *eval_context, dba, storage::v3::View::OLD); + + auto maybe_file = ToOptionalString(&evaluator, self_->file_); + auto maybe_delim = ToOptionalString(&evaluator, self_->delimiter_); + auto maybe_quote = ToOptionalString(&evaluator, self_->quote_); + + // No need to check if maybe_file is std::nullopt, as the parser makes sure + // we can't get a nullptr for the 'file_' member in the LoadCsv clause. + // Note that the reader has to be given its own memory resource, as it + // persists between pulls, so it can't use the evalutation context memory + // resource. + return csv::Reader( + *maybe_file, + csv::Reader::Config(self_->with_header_, self_->ignore_bad_, std::move(maybe_delim), std::move(maybe_quote)), + utils::NewDeleteResource()); + } +}; + +UniqueCursorPtr LoadCsv::MakeCursor(utils::MemoryResource *mem) const { + return MakeUniqueCursorPtr<LoadCsvCursor>(mem, this, mem); +}; + +class ForeachCursor : public Cursor { + public: + explicit ForeachCursor(const Foreach &foreach, utils::MemoryResource *mem) + : loop_variable_symbol_(foreach.loop_variable_symbol_), + input_(foreach.input_->MakeCursor(mem)), + updates_(foreach.update_clauses_->MakeCursor(mem)), + expression(foreach.expression_) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP(op_name_); + + if (!input_->Pull(frame, context)) { + return false; + } + + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::v3::View::NEW); + TypedValue expr_result = expression->Accept(evaluator); + + if (expr_result.IsNull()) { + return true; + } + + if (!expr_result.IsList()) { + throw QueryRuntimeException("FOREACH expression must resolve to a list, but got '{}'.", expr_result.type()); + } + + const auto &cache_ = expr_result.ValueList(); + for (const auto &index : cache_) { + frame[loop_variable_symbol_] = index; + while (updates_->Pull(frame, context)) { + } + ResetUpdates(); + } + + return true; + } + + void Shutdown() override { input_->Shutdown(); } + + void ResetUpdates() { updates_->Reset(); } + + void Reset() override { + input_->Reset(); + ResetUpdates(); + } + + private: + const Symbol loop_variable_symbol_; + const UniqueCursorPtr input_; + const UniqueCursorPtr updates_; + Expression *expression; + const char *op_name_{"Foreach"}; +}; + +Foreach::Foreach(std::shared_ptr<LogicalOperator> input, std::shared_ptr<LogicalOperator> updates, Expression *expr, + Symbol loop_variable_symbol) + : input_(input ? std::move(input) : std::make_shared<Once>()), + update_clauses_(std::move(updates)), + expression_(expr), + loop_variable_symbol_(loop_variable_symbol) {} + +UniqueCursorPtr Foreach::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ForeachOperator); + return MakeUniqueCursorPtr<ForeachCursor>(mem, *this, mem); +} + +std::vector<Symbol> Foreach::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + symbols.emplace_back(loop_variable_symbol_); + return symbols; +} + +bool Foreach::Accept(HierarchicalLogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + input_->Accept(visitor); + update_clauses_->Accept(visitor); + } + return visitor.PostVisit(*this); +} + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/operator.lcp b/src/query/v2/plan/operator.lcp new file mode 100644 index 000000000..393529b77 --- /dev/null +++ b/src/query/v2/plan/operator.lcp @@ -0,0 +1,2305 @@ +;; Copyright 2022 Memgraph Ltd. +;; +;; Use of this software is governed by the Business Source License +;; included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +;; License, and you may not use this file except in compliance with the Business Source License. +;; +;; As of the Change Date specified in that file, in accordance with +;; the Business Source License, use of this software will be governed +;; by the Apache License, Version 2.0, included in the file +;; licenses/APL.txt. + +#>cpp +/** @file */ + +#pragma once + +#include <memory> +#include <optional> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <variant> +#include <vector> + +#include "query/v2/common.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/semantic/symbol.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/id_types.hpp" +#include "utils/bound.hpp" +#include "utils/fnv.hpp" +#include "utils/memory.hpp" +#include "utils/visitor.hpp" +#include "utils/logging.hpp" +cpp<# + +(lcp:namespace memgraph) +(lcp:namespace query) +(lcp:namespace v2) + +#>cpp +struct ExecutionContext; +class ExpressionEvaluator; +class Frame; +class SymbolTable; +cpp<# + +(lcp:namespace plan) + +#>cpp +/// Base class for iteration cursors of @c LogicalOperator classes. +/// +/// Each @c LogicalOperator must produce a concrete @c Cursor, which provides +/// the iteration mechanism. +class Cursor { + public: + /// Run an iteration of a @c LogicalOperator. + /// + /// Since operators may be chained, the iteration may pull results from + /// multiple operators. + /// + /// @param Frame May be read from or written to while performing the + /// iteration. + /// @param ExecutionContext Used to get the position of symbols in frame and + /// other information. + /// + /// @throws QueryRuntimeException if something went wrong with execution + virtual bool Pull(Frame &, ExecutionContext &) = 0; + + /// Resets the Cursor to its initial state. + virtual void Reset() = 0; + + /// Perform cleanup which may throw an exception + virtual void Shutdown() = 0; + + virtual ~Cursor() {} +}; + +/// unique_ptr to Cursor managed with a custom deleter. +/// This allows us to use utils::MemoryResource for allocation. +using UniqueCursorPtr = std::unique_ptr<Cursor, std::function<void(Cursor *)>>; + +template <class TCursor, class... TArgs> +std::unique_ptr<Cursor, std::function<void(Cursor *)>> MakeUniqueCursorPtr( + utils::Allocator<TCursor> allocator, TArgs &&... args) { + auto *ptr = allocator.allocate(1); + try { + auto *cursor = new (ptr) TCursor(std::forward<TArgs>(args)...); + return std::unique_ptr<Cursor, std::function<void(Cursor *)>>( + cursor, [allocator](Cursor *base_ptr) mutable { + auto *p = static_cast<TCursor *>(base_ptr); + p->~TCursor(); + allocator.deallocate(p, 1); + }); + } catch (...) { + allocator.deallocate(ptr, 1); + throw; + } +} + +class Once; +class CreateNode; +class CreateExpand; +class ScanAll; +class ScanAllByLabel; +class ScanAllByLabelPropertyRange; +class ScanAllByLabelPropertyValue; +class ScanAllByLabelProperty; +class ScanAllById; +class Expand; +class ExpandVariable; +class ConstructNamedPath; +class Filter; +class Produce; +class Delete; +class SetProperty; +class SetProperties; +class SetLabels; +class RemoveProperty; +class RemoveLabels; +class EdgeUniquenessFilter; +class Accumulate; +class Aggregate; +class Skip; +class Limit; +class OrderBy; +class Merge; +class Optional; +class Unwind; +class Distinct; +class Union; +class Cartesian; +class CallProcedure; +class LoadCsv; +class Foreach; + +using LogicalOperatorCompositeVisitor = utils::CompositeVisitor< + Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, + ScanAllByLabelPropertyRange, ScanAllByLabelPropertyValue, + ScanAllByLabelProperty, ScanAllById, + Expand, ExpandVariable, ConstructNamedPath, Filter, Produce, Delete, + SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, + EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit, OrderBy, Merge, + Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv, Foreach>; + +using LogicalOperatorLeafVisitor = utils::LeafVisitor<Once>; + +/** + * @brief Base class for hierarhical visitors of @c LogicalOperator class + * hierarchy. + */ +class HierarchicalLogicalOperatorVisitor + : public LogicalOperatorCompositeVisitor, + public LogicalOperatorLeafVisitor { + public: + using LogicalOperatorCompositeVisitor::PostVisit; + using LogicalOperatorCompositeVisitor::PreVisit; + using LogicalOperatorLeafVisitor::Visit; + using typename LogicalOperatorLeafVisitor::ReturnType; +}; +cpp<# + +(lcp:define-class logical-operator ("utils::Visitable<HierarchicalLogicalOperatorVisitor>") + () + (:abstractp t) + (:documentation + "Base class for logical operators. + +Each operator describes an operation, which is to be performed on the +database. Operators are iterated over using a @c Cursor. Various operators +can serve as inputs to others and thus a sequence of operations is formed.") + (:public + #>cpp + virtual ~LogicalOperator() {} + + /** Construct a @c Cursor which is used to run this operator. + * + * @param utils::MemoryResource Memory resource used for allocations during + * the lifetime of the returned Cursor. + */ + virtual UniqueCursorPtr MakeCursor(utils::MemoryResource *) const = 0; + + /** Return @c Symbol vector where the query results will be stored. + * + * Currently, output symbols are generated in @c Produce @c Union and + * @c CallProcedure operators. @c Skip, @c Limit, @c OrderBy and @c Distinct + * propagate the symbols from @c Produce (if it exists as input operator). + * + * @param SymbolTable used to find symbols for expressions. + * @return std::vector<Symbol> used for results. + */ + virtual std::vector<Symbol> OutputSymbols(const SymbolTable &) const { + return std::vector<Symbol>(); + } + + /** + * Symbol vector whose values are modified by this operator sub-tree. + * + * This is different than @c OutputSymbols, because it returns all of the + * modified symbols, including those that may not be returned as the + * result of the query. Note that the modified symbols will not contain + * those that should not be read after the operator is processed. + * + * For example, `MATCH (n)-[e]-(m) RETURN n AS l` will generate `ScanAll (n) > + * Expand (e, m) > Produce (l)`. The modified symbols on Produce sub-tree will + * be `l`, the same as output symbols, because it isn't valid to read `n`, `e` + * nor `m` after Produce. On the other hand, modified symbols from Expand + * contain `e` and `m`, as well as `n`, while output symbols are empty. + * Modified symbols from ScanAll contain only `n`, while output symbols are + * also empty. + */ + virtual std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const = 0; + + /** + * Returns true if the operator takes only one input operator. + * NOTE: When this method returns true, you may use `input` and `set_input` + * methods. + */ + virtual bool HasSingleInput() const = 0; + + /** + * Returns the input operator if it has any. + * NOTE: This should only be called if `HasSingleInput() == true`. + */ + virtual std::shared_ptr<LogicalOperator> input() const = 0; + /** + * Set a different input on this operator. + * NOTE: This should only be called if `HasSingleInput() == true`. + */ + virtual void set_input(std::shared_ptr<LogicalOperator>) = 0; + + struct SaveHelper { + std::vector<LogicalOperator *> saved_ops; + }; + + struct LoadHelper { + AstStorage ast_storage; + std::vector<std::pair<uint64_t, std::shared_ptr<LogicalOperator>>> + loaded_ops; + }; + + struct SlkLoadHelper { + AstStorage ast_storage; + std::vector<std::shared_ptr<LogicalOperator>> loaded_ops; + }; + cpp<#) + (:serialize + (:slk :base t + :save-args '((helper "query::v2::plan::LogicalOperator::SaveHelper *")) + :load-args '((helper "query::v2::plan::LogicalOperator::SlkLoadHelper *")))) + (:type-info :base t) + (:clone :args '((storage "AstStorage *")) + :base t)) + +(defun slk-save-ast-pointer (member) + #>cpp + query::v2::SaveAstPointer(self.${member}, builder); + cpp<#) + +(defun slk-load-ast-pointer (type) + (lambda (member) + #>cpp + self->${member} = query::v2::LoadAstPointer<query::v2::${type}>( + &helper->ast_storage, reader); + cpp<#)) + +(defun slk-save-ast-vector (member) + #>cpp + size_t size = self.${member}.size(); + slk::Save(size, builder); + for (const auto *val : self.${member}) { + query::v2::SaveAstPointer(val, builder); + } + cpp<#) + +(defun slk-load-ast-vector (type) + (lambda (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + self->${member}.resize(size); + for (size_t i = 0; + i < size; + ++i) { + self->${member}[i] = query::v2::LoadAstPointer<query::v2::${type}>( + &helper->ast_storage, reader); + } + cpp<#)) + +(defun slk-save-operator-pointer (member) + #>cpp + slk::Save<query::v2::plan::LogicalOperator>(self.${member}, builder, + &helper->saved_ops, + [&helper](const auto &val, + auto *builder) { + slk::Save(val, builder, helper); + }); + cpp<#) + +(defun slk-load-operator-pointer (member) + #>cpp + slk::Load<query::v2::plan::LogicalOperator>(&self->${member}, reader, &helper->loaded_ops, + [&helper](auto *op, auto *reader) { + slk::ConstructAndLoad(op, reader, helper); + }); + cpp<#) + +(lcp:define-class once (logical-operator) + ((symbols "std::vector<Symbol>" :scope :public)) + (:documentation + "A logical operator whose Cursor returns true on the first Pull +and false on every following Pull.") + (:public + #>cpp + Once(std::vector<Symbol> symbols = {}) : symbols_{std::move(symbols)} {} + DEFVISITABLE(HierarchicalLogicalOperatorVisitor); + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override { + return symbols_; + } + + bool HasSingleInput() const override; + std::shared_ptr<LogicalOperator> input() const override; + void set_input(std::shared_ptr<LogicalOperator>) override; + cpp<#) + (:private + #>cpp + class OnceCursor : public Cursor { + public: + OnceCursor() {} + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + bool did_pull_{false}; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(defun slk-save-properties (member) + #>cpp + size_t size = self.${member}.size(); + slk::Save(size, builder); + for (const auto &kv : self.${member}) { + slk::Save(kv.first, builder); + query::v2::SaveAstPointer(kv.second, builder); + } + cpp<#) + +(defun slk-load-properties (member) + #>cpp + size_t size = 0; + slk::Load(&size, reader); + self->${member}.resize(size); + for (size_t i = 0; i < size; ++i) { + storage::v3::PropertyId prop; + slk::Load(&prop, reader); + auto *expr = query::v2::LoadAstPointer<query::v2::Expression>( + &helper->ast_storage, reader); + self->${member}[i] = {prop, expr}; + } + cpp<#) + +(defun clone-variant-properties (source destination) + #>cpp + if (const auto *props = std::get_if<PropertiesMapList>(&${source})) { + auto &destination_props = std::get<PropertiesMapList>(${destination}); + destination_props.resize(props->size()); + for (auto i0 = 0; i0 < props->size(); ++i0) { + { + storage::v3::PropertyId first1 = (*props)[i0].first; + Expression *second2; + second2 = (*props)[i0].second ? (*props)[i0].second->Clone(storage) : nullptr; + destination_props[i0] = std::make_pair(std::move(first1), std::move(second2)); + } + } + } else { + ${destination} = std::get<ParameterLookup *>(${source})->Clone(storage); + } + cpp<#) + +#>cpp +using PropertiesMapList = std::vector<std::pair<storage::v3::PropertyId, Expression *>>; +cpp<# + +(lcp:define-struct node-creation-info () + ((symbol "Symbol") + (labels "std::vector<storage::v3::LabelId>") + (properties "std::variant<PropertiesMapList, ParameterLookup *>" + :slk-save #'slk-save-properties + :slk-load #'slk-load-properties + :clone #'clone-variant-properties)) + (:serialize (:slk :save-args '((helper "query::v2::plan::LogicalOperator::SaveHelper *")) + :load-args '((helper "query::v2::plan::LogicalOperator::SlkLoadHelper *")))) + (:clone :args '((storage "AstStorage *"))) + (:public + #>cpp + NodeCreationInfo() = default; + + NodeCreationInfo( + Symbol symbol, std::vector<storage::v3::LabelId> labels, + std::variant<PropertiesMapList, ParameterLookup *> properties) + : symbol{std::move(symbol)}, labels{std::move(labels)}, properties{std::move(properties)} {}; + + NodeCreationInfo(Symbol symbol, std::vector<storage::v3::LabelId> labels, + PropertiesMapList properties) + : symbol{std::move(symbol)}, labels{std::move(labels)}, properties{std::move(properties)} {}; + + NodeCreationInfo(Symbol symbol, std::vector<storage::v3::LabelId> labels, ParameterLookup* properties) + : symbol{std::move(symbol)}, labels{std::move(labels)}, properties{properties} {}; + cpp<#)) + +(lcp:define-class create-node (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (node-info "NodeCreationInfo" :scope :public + :slk-save (lambda (m) + #>cpp + slk::Save(self.${m}, builder, helper); + cpp<#) + :slk-load (lambda (m) + #>cpp + slk::Load(&self->${m}, reader, helper); + cpp<#))) + (:documentation + "Operator for creating a node. + +This op is used both for creating a single node (`CREATE` statement without +a preceding `MATCH`), or multiple nodes (`MATCH ... CREATE` or +`CREATE (), () ...`). + +@sa CreateExpand") + (:public + #>cpp + CreateNode() {} + + /** + * @param input Optional. If @c nullptr, then a single node will be + * created (a single successful @c Cursor::Pull from this op's @c Cursor). + * If a valid input, then a node will be created for each + * successful pull from the given input. + * @param node_info @c NodeCreationInfo + */ + CreateNode(const std::shared_ptr<LogicalOperator> &input, + const NodeCreationInfo &node_info); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class CreateNodeCursor : public Cursor { + public: + CreateNodeCursor(const CreateNode &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const CreateNode &self_; + const UniqueCursorPtr input_cursor_; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-struct edge-creation-info () + ((symbol "Symbol") + (properties "std::variant<PropertiesMapList, ParameterLookup *>" + :slk-save #'slk-save-properties + :slk-load #'slk-load-properties + :clone #'clone-variant-properties) + (edge-type "::storage::v3::EdgeTypeId") + (direction "::EdgeAtom::Direction" :initval "EdgeAtom::Direction::BOTH")) + (:serialize (:slk :save-args '((helper "query::v2::plan::LogicalOperator::SaveHelper *")) + :load-args '((helper "query::v2::plan::LogicalOperator::SlkLoadHelper *")))) + (:clone :args '((storage "AstStorage *"))) + (:public + #>cpp + EdgeCreationInfo() = default; + + EdgeCreationInfo(Symbol symbol, std::variant<PropertiesMapList, ParameterLookup *> properties, + storage::v3::EdgeTypeId edge_type, EdgeAtom::Direction direction) + : symbol{std::move(symbol)}, properties{std::move(properties)}, edge_type{edge_type}, direction{direction} {}; + + EdgeCreationInfo(Symbol symbol, PropertiesMapList properties, + storage::v3::EdgeTypeId edge_type, EdgeAtom::Direction direction) + : symbol{std::move(symbol)}, properties{std::move(properties)}, edge_type{edge_type}, direction{direction} {}; + + EdgeCreationInfo(Symbol symbol, ParameterLookup* properties, + storage::v3::EdgeTypeId edge_type, EdgeAtom::Direction direction) + : symbol{std::move(symbol)}, properties{properties}, edge_type{edge_type}, direction{direction} {}; + cpp<#)) + +(lcp:define-class create-expand (logical-operator) + ((node-info "NodeCreationInfo" :scope :public + :slk-save (lambda (m) + #>cpp + slk::Save(self.${m}, builder, helper); + cpp<#) + :slk-load (lambda (m) + #>cpp + slk::Load(&self->${m}, reader, helper); + cpp<#)) + (edge-info "EdgeCreationInfo" :scope :public + :slk-save (lambda (m) + #>cpp + slk::Save(self.${m}, builder, helper); + cpp<#) + :slk-load (lambda (m) + #>cpp + slk::Load(&self->${m}, reader, helper); + cpp<#)) + ;; the input op and the symbol under which the op's result + ;; can be found in the frame + (input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (input-symbol "Symbol" :scope :public) + (existing-node :bool :scope :public :documentation + "if the given node atom refers to an existing node (either matched or created)")) + (:documentation + "Operator for creating edges and destination nodes. + +This operator extends already created nodes with an edge. If the other node +on the edge does not exist, it will be created. For example, in `MATCH (n) +CREATE (n) -[r:r]-> (n)` query, this operator will create just the edge `r`. +In `MATCH (n) CREATE (n) -[r:r]-> (m)` query, the operator will create both +the edge `r` and the node `m`. In case of `CREATE (n) -[r:r]-> (m)` the +first node `n` is created by @c CreateNode operator, while @c CreateExpand +will create the edge `r` and `m`. Similarly, multiple @c CreateExpand are +chained in cases when longer paths need creating. + +@sa CreateNode") + (:public + #>cpp + CreateExpand() {} + + /** @brief Construct @c CreateExpand. + * + * @param node_info @c NodeCreationInfo at the end of the edge. + * Used to create a node, unless it refers to an existing one. + * @param edge_info @c EdgeCreationInfo for the edge to be created. + * @param input Optional. Previous @c LogicalOperator which will be pulled. + * For each successful @c Cursor::Pull, this operator will create an + * expansion. + * @param input_symbol @c Symbol for the node at the start of the edge. + * @param existing_node @c bool indicating whether the @c node_atom refers to + * an existing node. If @c false, the operator will also create the node. + */ + CreateExpand(const NodeCreationInfo &node_info, + const EdgeCreationInfo &edge_info, + const std::shared_ptr<LogicalOperator> &input, + Symbol input_symbol, bool existing_node); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class CreateExpandCursor : public Cursor { + public: + CreateExpandCursor(const CreateExpand &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const CreateExpand &self_; + const UniqueCursorPtr input_cursor_; + + // Get the existing node (if existing_node_ == true), or create a new node + VertexAccessor &OtherVertex(Frame &frame, ExecutionContext &context); + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class scan-all (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (output-symbol "Symbol" :scope :public) + (view "::storage::v3::View" :scope :public + :documentation + "Controls which graph state is used to produce vertices. + +If @c storage::v3::View::OLD, @c ScanAll will produce vertices visible in the +previous graph state, before modifications done by current transaction & +command. With @c storage::v3::View::NEW, all vertices will be produced the current +transaction sees along with their modifications.")) + + (:documentation + "Operator which iterates over all the nodes currently in the database. +When given an input (optional), does a cartesian product. + +It accepts an optional input. If provided then this op scans all the nodes +currently in the database for each successful Pull from it's input, thereby +producing a cartesian product of input Pulls and database elements. + +ScanAll can either iterate over the previous graph state (state before +the current transacton+command) or over current state. This is controlled +with a constructor argument. + +@sa ScanAllByLabel +@sa ScanAllByLabelPropertyRange +@sa ScanAllByLabelPropertyValue") + (:public + #>cpp + ScanAll() {} + ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, + storage::v3::View view = storage::v3::View::OLD); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class scan-all-by-label (scan-all) + ((label "::storage::v3::LabelId" :scope :public)) + (:documentation + "Behaves like @c ScanAll, but this operator produces only vertices with +given label. + +@sa ScanAll +@sa ScanAllByLabelPropertyRange +@sa ScanAllByLabelPropertyValue") + (:public + #>cpp + ScanAllByLabel() {} + ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, storage::v3::LabelId label, + storage::v3::View view = storage::v3::View::OLD); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(defun slk-save-optional-bound (member) + #>cpp + slk::Save(static_cast<bool>(self.${member}), builder); + if (!self.${member}) { + return; + } + uint8_t bound_type; + const auto &bound = *self.${member}; + switch (bound.type()) { + case utils::BoundType::INCLUSIVE: + bound_type = 0; + break; + case utils::BoundType::EXCLUSIVE: + bound_type = 1; + break; + } + slk::Save(bound_type, builder); + query::v2::SaveAstPointer(bound.value(), builder); + cpp<#) + +(defun slk-load-optional-bound (member) + #>cpp + bool has_bound; + slk::Load(&has_bound, reader); + if (!has_bound) { + self->${member} = std::nullopt; + return; + } + uint8_t bound_type_value; + slk::Load(&bound_type_value, reader); + utils::BoundType bound_type; + switch (bound_type_value) { + case static_cast<uint8_t>(0): + bound_type = utils::BoundType::INCLUSIVE; + break; + case static_cast<uint8_t>(1): + bound_type = utils::BoundType::EXCLUSIVE; + break; + default: + throw slk::SlkDecodeException("Loading unknown BoundType"); + } + auto *value = query::v2::LoadAstPointer<query::v2::Expression>( + &helper->ast_storage, reader); + self->${member}.emplace(utils::Bound<query::v2::Expression *>(value, bound_type)); + cpp<#) + +(defun clone-optional-bound (source dest) + #>cpp + if (${source}) { + ${dest}.emplace(utils::Bound<Expression *>( + ${source}->value()->Clone(storage), + ${source}->type())); + } else { + ${dest} = std::nullopt; + } + cpp<#) + +(lcp:define-class scan-all-by-label-property-range (scan-all) + ((label "::storage::v3::LabelId" :scope :public) + (property "::storage::v3::PropertyId" :scope :public) + (property-name "std::string" :scope :public) + (lower-bound "std::optional<Bound>" :scope :public + :slk-save #'slk-save-optional-bound + :slk-load #'slk-load-optional-bound + :clone #'clone-optional-bound) + (upper-bound "std::optional<Bound>" :scope :public + :slk-save #'slk-save-optional-bound + :slk-load #'slk-load-optional-bound + :clone #'clone-optional-bound)) + (:documentation + "Behaves like @c ScanAll, but produces only vertices with given label and +property value which is inside a range (inclusive or exlusive). + +@sa ScanAll +@sa ScanAllByLabel +@sa ScanAllByLabelPropertyValue") + (:public + #>cpp + /** Bound with expression which when evaluated produces the bound value. */ + using Bound = utils::Bound<Expression *>; + ScanAllByLabelPropertyRange() {} + /** + * Constructs the operator for given label and property value in range + * (inclusive). + * + * Range bounds are optional, but only one bound can be left out. + * + * @param input Preceding operator which will serve as the input. + * @param output_symbol Symbol where the vertices will be stored. + * @param label Label which the vertex must have. + * @param property Property from which the value will be looked up from. + * @param lower_bound Optional lower @c Bound. + * @param upper_bound Optional upper @c Bound. + * @param view storage::v3::View used when obtaining vertices. + */ + ScanAllByLabelPropertyRange(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, storage::v3::LabelId label, + storage::v3::PropertyId property, + const std::string &property_name, + std::optional<Bound> lower_bound, + std::optional<Bound> upper_bound, + storage::v3::View view = storage::v3::View::OLD); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class scan-all-by-label-property-value (scan-all) + ((label "::storage::v3::LabelId" :scope :public) + (property "::storage::v3::PropertyId" :scope :public) + (property-name "std::string" :scope :public) + (expression "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:documentation + "Behaves like @c ScanAll, but produces only vertices with given label and +property value. + +@sa ScanAll +@sa ScanAllByLabel +@sa ScanAllByLabelPropertyRange") + (:public + #>cpp + ScanAllByLabelPropertyValue() {} + /** + * Constructs the operator for given label and property value. + * + * @param input Preceding operator which will serve as the input. + * @param output_symbol Symbol where the vertices will be stored. + * @param label Label which the vertex must have. + * @param property Property from which the value will be looked up from. + * @param expression Expression producing the value of the vertex property. + * @param view storage::v3::View used when obtaining vertices. + */ + ScanAllByLabelPropertyValue(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, storage::v3::LabelId label, + storage::v3::PropertyId property, + const std::string &property_name, + Expression *expression, + storage::v3::View view = storage::v3::View::OLD); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class scan-all-by-label-property (scan-all) + ((label "::storage::v3::LabelId" :scope :public) + (property "::storage::v3::PropertyId" :scope :public) + (property-name "std::string" :scope :public) + (expression "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + + (:documentation + "Behaves like @c ScanAll, but this operator produces only vertices with +given label and property. + +@sa ScanAll +@sa ScanAllByLabelPropertyRange +@sa ScanAllByLabelPropertyValue") + (:public + #>cpp + ScanAllByLabelProperty() {} + ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, storage::v3::LabelId label, + storage::v3::PropertyId property, + const std::string &property_name, + storage::v3::View view = storage::v3::View::OLD); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + cpp<#) + (:serialize (:slk)) + (:clone)) + + + +(lcp:define-class scan-all-by-id (scan-all) + ((expression "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:documentation + "ScanAll producing a single node with ID equal to evaluated expression") + (:public + #>cpp + ScanAllById() {} + ScanAllById(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, Expression *expression, + storage::v3::View view = storage::v3::View::OLD); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-struct expand-common () + ( + ;; info on what's getting expanded + (node-symbol "Symbol" + :documentation "Symbol pointing to the node to be expanded. +This is where the new node will be stored.") + (edge-symbol "Symbol" + :documentation "Symbol for the edges to be expanded. +This is where a TypedValue containing a list of expanded edges will be stored.") + (direction "::EdgeAtom::Direction" + :documentation "EdgeAtom::Direction determining the direction of edge +expansion. The direction is relative to the starting vertex for each expansion.") + (edge-types "std::vector<storage::v3::EdgeTypeId>" + :documentation "storage::v3::EdgeTypeId specifying which edges we want +to expand. If empty, all edges are valid. If not empty, only edges with one of +the given types are valid.") + (existing-node :bool :documentation "If the given node atom refer to a symbol +that has already been expanded and should be just validated in the frame.")) + (:serialize (:slk))) + +(lcp:define-class expand (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (input-symbol "Symbol" :scope :public) + (common "ExpandCommon" :scope :public) + (view "::storage::v3::View" :scope :public + :documentation + "State from which the input node should get expanded.")) + (:documentation + "Expansion operator. For a node existing in the frame it +expands one edge and one node and places them on the frame. + +This class does not handle node/edge filtering based on +properties, labels and edge types. However, it does handle +filtering on existing node / edge. + +Filtering on existing means that for a pattern that references +an already declared node or edge (for example in +MATCH (a) MATCH (a)--(b)), +only expansions that match defined equalities are successfully +pulled.") + (:public + #>cpp + /** + * Creates an expansion. All parameters except input and input_symbol are + * forwarded to @c ExpandCommon and are documented there. + * + * @param input Optional logical operator that preceeds this one. + * @param input_symbol Symbol that points to a VertexAccessor in the frame + * that expansion should emanate from. + */ + Expand(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, + Symbol node_symbol, Symbol edge_symbol, EdgeAtom::Direction direction, + const std::vector<storage::v3::EdgeTypeId> &edge_types, bool existing_node, + storage::v3::View view); + + Expand() {} + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + + class ExpandCursor : public Cursor { + public: + ExpandCursor(const Expand &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + using InEdgeT = std::remove_reference_t<decltype( + *std::declval<VertexAccessor>().InEdges(storage::v3::View::OLD))>; + using InEdgeIteratorT = decltype(std::declval<InEdgeT>().begin()); + using OutEdgeT = std::remove_reference_t<decltype( + *std::declval<VertexAccessor>().OutEdges(storage::v3::View::OLD))>; + using OutEdgeIteratorT = decltype(std::declval<OutEdgeT>().begin()); + + const Expand &self_; + const UniqueCursorPtr input_cursor_; + + // The iterable over edges and the current edge iterator are referenced via + // optional because they can not be initialized in the constructor of + // this class. They are initialized once for each pull from the input. + std::optional<InEdgeT> in_edges_; + std::optional<InEdgeIteratorT> in_edges_it_; + std::optional<OutEdgeT> out_edges_; + std::optional<OutEdgeIteratorT> out_edges_it_; + + bool InitEdges(Frame &, ExecutionContext &); + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-struct expansion-lambda () + ((inner-edge-symbol "Symbol" :documentation "Currently expanded edge symbol.") + (inner-node-symbol "Symbol" :documentation "Currently expanded node symbol.") + (expression "Expression *" :documentation "Expression used in lambda during expansion." + :slk-save #'slk-save-ast-pointer + :slk-load (lambda (member) + #>cpp + self->${member} = query::v2::LoadAstPointer<query::v2::Expression>( + ast_storage, reader); + cpp<#))) + (:serialize (:slk :load-args '((ast-storage "query::v2::AstStorage *")))) + (:clone :args '((storage "AstStorage *")))) + +(lcp:define-class expand-variable (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (input-symbol "Symbol" :scope :public) + (common "ExpandCommon" :scope :public) + (type "::EdgeAtom::Type" :scope :public) + (is-reverse :bool :scope :public :documentation + "True if the path should be written as expanding from node_symbol to input_symbol.") + (lower-bound "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "Optional lower bound of the variable length expansion, defaults are (1, inf)") + (upper-bound "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression") + :documentation "Optional upper bound of the variable length expansion, defaults are (1, inf)") + (filter-lambda "ExpansionLambda" + :scope :public + :slk-load (lambda (member) + #>cpp + slk::Load(&self->${member}, reader, &helper->ast_storage); + cpp<#)) + (weight-lambda "std::optional<ExpansionLambda>" :scope :public + :slk-load (lambda (member) + #>cpp + bool has_value; + slk::Load(&has_value, reader); + if (!has_value) { + self->${member} = std::nullopt; + return; + } + query::v2::plan::ExpansionLambda lambda; + slk::Load(&lambda, reader, &helper->ast_storage); + self->${member}.emplace(lambda); + cpp<#)) + (total-weight "std::optional<Symbol>" :scope :public)) + (:documentation + "Variable-length expansion operator. For a node existing in +the frame it expands a variable number of edges and places them +(in a list-type TypedValue), as well as the final destination node, +on the frame. + +This class does not handle node/edge filtering based on +properties, labels and edge types. However, it does handle +filtering on existing node / edge. Additionally it handles's +edge-uniquess (cyphermorphism) because it's not feasable to do +later. + +Filtering on existing means that for a pattern that references +an already declared node or edge (for example in +MATCH (a) MATCH (a)--(b)), +only expansions that match defined equalities are succesfully +pulled.") + (:public + #>cpp + ExpandVariable() {} + + /** + * Creates a variable-length expansion. Most params are forwarded + * to the @c ExpandCommon constructor, and are documented there. + * + * Expansion length bounds are both inclusive (as in Neo's Cypher + * implementation). + * + * @param input Optional logical operator that preceeds this one. + * @param input_symbol Symbol that points to a VertexAccessor in the frame + * that expansion should emanate from. + * @param type - Either Type::DEPTH_FIRST (default variable-length expansion), + * or Type::BREADTH_FIRST. + * @param is_reverse Set to `true` if the edges written on frame should expand + * from `node_symbol` to `input_symbol`. Opposed to the usual expanding + * from `input_symbol` to `node_symbol`. + * @param lower_bound An optional indicator of the minimum number of edges + * that get expanded (inclusive). + * @param upper_bound An optional indicator of the maximum number of edges + * that get expanded (inclusive). + * @param inner_edge_symbol Like `inner_node_symbol` + * @param inner_node_symbol For each expansion the node expanded into is + * assigned to this symbol so it can be evaulated by the 'where' + * expression. + * @param filter_ The filter that must be satisfied for an expansion to + * succeed. Can use inner(node/edge) symbols. If nullptr, it is ignored. + */ + ExpandVariable(const std::shared_ptr<LogicalOperator> &input, + Symbol input_symbol, Symbol node_symbol, Symbol edge_symbol, + EdgeAtom::Type type, EdgeAtom::Direction direction, + const std::vector<storage::v3::EdgeTypeId> &edge_types, + bool is_reverse, Expression *lower_bound, + Expression *upper_bound, bool existing_node, + ExpansionLambda filter_lambda, + std::optional<ExpansionLambda> weight_lambda, + std::optional<Symbol> total_weight); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + // the Cursors are not declared in the header because + // it's edges_ and edges_it_ are decltyped using a helper function + // that should be inaccessible (private class function won't compile) + friend class ExpandVariableCursor; + friend class ExpandWeightedShortestPathCursor; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class construct-named-path (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (path-symbol "Symbol" :scope :public) + (path-elements "std::vector<Symbol>" :scope :public)) + (:documentation + "Constructs a named path from its elements and places it on the frame.") + (:public + #>cpp + ConstructNamedPath() {} + ConstructNamedPath(const std::shared_ptr<LogicalOperator> &input, + Symbol path_symbol, + const std::vector<Symbol> &path_elements) + : input_(input), + path_symbol_(path_symbol), + path_elements_(path_elements) {} + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class filter (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (expression "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:documentation + "Filter whose Pull returns true only when the given expression +evaluates into true. + +The given expression is assumed to return either NULL (treated as false) or +a boolean value.") + (:public + #>cpp + Filter() {} + + Filter(const std::shared_ptr<LogicalOperator> &input_, + Expression *expression_); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class FilterCursor : public Cursor { + public: + FilterCursor(const Filter &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Filter &self_; + const UniqueCursorPtr input_cursor_; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class produce (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (named-expressions "std::vector<NamedExpression *>" :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "NamedExpression"))) + (:documentation + "A logical operator that places an arbitrary number +of named expressions on the frame (the logical operator +for the RETURN clause). + +Supports optional input. When the input is provided, +it is Pulled from and the Produce succeeds once for +every input Pull (typically a MATCH/RETURN query). +When the input is not provided (typically a standalone +RETURN clause) the Produce's pull succeeds exactly once.") + (:public + #>cpp + Produce() {} + + Produce(const std::shared_ptr<LogicalOperator> &input, + const std::vector<NamedExpression *> &named_expressions); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class ProduceCursor : public Cursor { + public: + ProduceCursor(const Produce &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Produce &self_; + const UniqueCursorPtr input_cursor_; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class delete (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (expressions "std::vector<Expression *>" :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Expression")) + (detach :bool :scope :public :documentation + "Whether the vertex should be detached before deletion. If not detached, + and has connections, an error is raised when deleting edges.")) + (:documentation + "Operator for deleting vertices and edges. + +Has a flag for using DETACH DELETE when deleting vertices.") + (:public + #>cpp + Delete() {} + + Delete(const std::shared_ptr<LogicalOperator> &input_, + const std::vector<Expression *> &expressions, bool detach_); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class DeleteCursor : public Cursor { + public: + DeleteCursor(const Delete &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Delete &self_; + const UniqueCursorPtr input_cursor_; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class set-property (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (property "::storage::v3::PropertyId" :scope :public) + (lhs "PropertyLookup *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "PropertyLookup")) + (rhs "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:documentation + "Logical operator for setting a single property on a single vertex or edge. + +The property value is an expression that must evaluate to some type that +can be stored (a TypedValue that can be converted to PropertyValue).") + (:public + #>cpp + SetProperty() {} + + SetProperty(const std::shared_ptr<LogicalOperator> &input, + storage::v3::PropertyId property, PropertyLookup *lhs, + Expression *rhs); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class SetPropertyCursor : public Cursor { + public: + SetPropertyCursor(const SetProperty &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const SetProperty &self_; + const UniqueCursorPtr input_cursor_; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class set-properties (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (input-symbol "Symbol" :scope :public) + (rhs "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (op "Op" :scope :public)) + (:documentation + "Logical operator for setting the whole property set on a vertex or an edge. + +The value being set is an expression that must evaluate to a vertex, edge or +map (literal or parameter). + +Supports setting (replacing the whole properties set with another) and +updating.") + (:public + (lcp:define-enum op + (update replace) + (:documentation "Defines how setting the properties works. + +@c UPDATE means that the current property set is augmented with additional +ones (existing props of the same name are replaced), while @c REPLACE means +that the old properties are discarded and replaced with new ones.") + (:serialize)) + + #>cpp + SetProperties() {} + + SetProperties(const std::shared_ptr<LogicalOperator> &input, + Symbol input_symbol, Expression *rhs, Op op); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class SetPropertiesCursor : public Cursor { + public: + SetPropertiesCursor(const SetProperties &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const SetProperties &self_; + const UniqueCursorPtr input_cursor_; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class set-labels (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (input-symbol "Symbol" :scope :public) + (labels "std::vector<storage::v3::LabelId>" :scope :public)) + (:documentation + "Logical operator for setting an arbitrary number of labels on a Vertex. + +It does NOT remove labels that are already set on that Vertex.") + (:public + #>cpp + SetLabels() {} + + SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, + const std::vector<storage::v3::LabelId> &labels); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class SetLabelsCursor : public Cursor { + public: + SetLabelsCursor(const SetLabels &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const SetLabels &self_; + const UniqueCursorPtr input_cursor_; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class remove-property (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (property "::storage::v3::PropertyId" :scope :public) + (lhs "PropertyLookup *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "PropertyLookup"))) + (:documentation + "Logical operator for removing a property from an edge or a vertex.") + (:public + #>cpp + RemoveProperty() {} + + RemoveProperty(const std::shared_ptr<LogicalOperator> &input, + storage::v3::PropertyId property, PropertyLookup *lhs); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class RemovePropertyCursor : public Cursor { + public: + RemovePropertyCursor(const RemoveProperty &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const RemoveProperty &self_; + const UniqueCursorPtr input_cursor_; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class remove-labels (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (input-symbol "Symbol" :scope :public) + (labels "std::vector<storage::v3::LabelId>" :scope :public)) + (:documentation + "Logical operator for removing an arbitrary number of labels on a Vertex. + +If a label does not exist on a Vertex, nothing happens.") + (:public + #>cpp + RemoveLabels() {} + + RemoveLabels(const std::shared_ptr<LogicalOperator> &input, + Symbol input_symbol, const std::vector<storage::v3::LabelId> &labels); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class RemoveLabelsCursor : public Cursor { + public: + RemoveLabelsCursor(const RemoveLabels &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const RemoveLabels &self_; + const UniqueCursorPtr input_cursor_; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class edge-uniqueness-filter (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (expand-symbol "Symbol" :scope :public) + (previous-symbols "std::vector<Symbol>" :scope :public)) + (:documentation + "Filter whose Pull returns true only when the given expand_symbol frame +value (the latest expansion) is not equal to any of the previous_symbols frame +values. + +Used for implementing Cyphermorphism. +Isomorphism is vertex-uniqueness. It means that two different vertices in a +pattern can not map to the same data vertex. +Cyphermorphism is edge-uniqueness (the above explanation applies). By default +Neo4j uses Cyphermorphism (that's where the name stems from, it is not a valid +graph-theory term). + +Supports variable-length-edges (uniqueness comparisons between edges and an +edge lists).") + (:public + #>cpp + EdgeUniquenessFilter() {} + + EdgeUniquenessFilter(const std::shared_ptr<LogicalOperator> &input, + Symbol expand_symbol, + const std::vector<Symbol> &previous_symbols); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class EdgeUniquenessFilterCursor : public Cursor { + public: + EdgeUniquenessFilterCursor(const EdgeUniquenessFilter &, + utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const EdgeUniquenessFilter &self_; + const UniqueCursorPtr input_cursor_; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class accumulate (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (symbols "std::vector<Symbol>" :scope :public) + (advance-command :bool :scope :public)) + (:documentation + "Pulls everything from the input before passing it through. +Optionally advances the command after accumulation and before emitting. + +On the first Pull from this operator's Cursor the input Cursor will be Pulled +until it is empty. The results will be accumulated in the temporary cache. Once +the input Cursor is empty, this operator's Cursor will start returning cached +stuff from its Pull. + +This technique is used for ensuring all the operations from the +previous logical operator have been performed before exposing data +to the next. A typical use case is a `MATCH--SET--RETURN` +query in which every SET iteration must be performed before +RETURN starts iterating (see Memgraph Wiki for detailed reasoning). + +IMPORTANT: This operator does not cache all the results but only those +elements from the frame whose symbols (frame positions) it was given. +All other frame positions will contain undefined junk after this +operator has executed, and should not be used. + +This operator can also advance the command after the accumulation and +before emitting. If the command gets advanced, every value that +has been cached will be reconstructed before Pull returns. + +@param input Input @c LogicalOperator. +@param symbols A vector of Symbols that need to be accumulated + and exposed to the next op.") + (:public + #>cpp + Accumulate() {} + + Accumulate(const std::shared_ptr<LogicalOperator> &input, + const std::vector<Symbol> &symbols, bool advance_command = false); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class aggregate (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (aggregations "std::vector<Element>" :scope :public + :slk-save (lambda (member) + #>cpp + size_t size = self.${member}.size(); + slk::Save(size, builder); + for (const auto &v : self.${member}) { + slk::Save(v, builder, helper); + } + cpp<#) + :slk-load (lambda (member) + #>cpp + size_t size; + slk::Load(&size, reader); + self->${member}.resize(size); + for (size_t i = 0; + i < size; + ++i) { + slk::Load(&self->${member}[i], reader, helper); + } + cpp<#)) + (group-by "std::vector<Expression *>" :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Expression")) + (remember "std::vector<Symbol>" :scope :public)) + (:documentation + "Performs an arbitrary number of aggregations of data +from the given input grouped by the given criteria. + +Aggregations are defined by triples that define +(input data expression, type of aggregation, output symbol). +Input data is grouped based on the given set of named +expressions. Grouping is done on unique values. + +IMPORTANT: +Operators taking their input from an aggregation are only +allowed to use frame values that are either aggregation +outputs or group-by named-expressions. All other frame +elements are in an undefined state after aggregation.") + (:public + (lcp:define-struct element () + ((value "Expression *" + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (key "Expression *" + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (op "::Aggregation::Op") + (output-sym "Symbol")) + (:documentation + "An aggregation element, contains: + (input data expression, key expression - only used in COLLECT_MAP, type of + aggregation, output symbol).") + (:serialize (:slk :save-args '((helper "query::v2::plan::LogicalOperator::SaveHelper *")) + :load-args '((helper "query::v2::plan::LogicalOperator::SlkLoadHelper *")))) + (:clone :args '((storage "AstStorage *")))) + #>cpp + Aggregate() = default; + Aggregate(const std::shared_ptr<LogicalOperator> &input, + const std::vector<Element> &aggregations, + const std::vector<Expression *> &group_by, + const std::vector<Symbol> &remember); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class skip (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (expression "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:documentation + "Skips a number of Pulls from the input op. + +The given expression determines how many Pulls from the input +should be skipped (ignored). +All other successful Pulls from the +input are simply passed through. + +The given expression is evaluated after the first Pull from +the input, and only once. Neo does not allow this expression +to contain identifiers, and neither does Memgraph, but this +operator's implementation does not expect this.") + (:public + #>cpp + Skip() {} + + Skip(const std::shared_ptr<LogicalOperator> &input, Expression *expression); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class SkipCursor : public Cursor { + public: + SkipCursor(const Skip &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Skip &self_; + const UniqueCursorPtr input_cursor_; + // init to_skip_ to -1, indicating + // that it's still unknown (input has not been Pulled yet) + int64_t to_skip_{-1}; + int64_t skipped_{0}; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class limit (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (expression "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + (:documentation + "Limits the number of Pulls from the input op. + +The given expression determines how many +input Pulls should be passed through. The input is not +Pulled once this limit is reached. Note that this has +implications: the out-of-bounds input Pulls are never +evaluated. + +The limit expression must NOT use anything from the +Frame. It is evaluated before the first Pull from the +input. This is consistent with Neo (they don't allow +identifiers in limit expressions), and it's necessary +when limit evaluates to 0 (because 0 Pulls from the +input should be performed).") + (:public + #>cpp + Limit() {} + + Limit(const std::shared_ptr<LogicalOperator> &input, Expression *expression); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class LimitCursor : public Cursor { + public: + LimitCursor(const Limit &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Limit &self_; + UniqueCursorPtr input_cursor_; + // init limit_ to -1, indicating + // that it's still unknown (Cursor has not been Pulled yet) + int64_t limit_{-1}; + int64_t pulled_{0}; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class order-by (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (compare "TypedValueVectorCompare" :scope :public) + (order-by "std::vector<Expression *>" :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Expression")) + (output-symbols "std::vector<Symbol>" :scope :public)) + (:documentation + "Logical operator for ordering (sorting) results. + +Sorts the input rows based on an arbitrary number of +Expressions. Ascending or descending ordering can be chosen +for each independently (not providing enough orderings +results in a runtime error). + +For each row an arbitrary number of Frame elements can be +remembered. Only these elements (defined by their Symbols) +are valid for usage after the OrderBy operator.") + (:public + #>cpp + OrderBy() {} + + OrderBy(const std::shared_ptr<LogicalOperator> &input, + const std::vector<SortItem> &order_by, + const std::vector<Symbol> &output_symbols); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class merge (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (merge-match "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (merge-create "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer)) + (:documentation + "Merge operator. For every sucessful Pull from the +input operator a Pull from the merge_match is attempted. All +successfull Pulls from the merge_match are passed on as output. +If merge_match Pull does not yield any elements, a single Pull +from the merge_create op is performed. + +The input logical op is optional. If false (nullptr) +it will be replaced by a Once op. + +For an argumentation of this implementation see the wiki +documentation.") + (:public + #>cpp + Merge() {} + + Merge(const std::shared_ptr<LogicalOperator> &input, + const std::shared_ptr<LogicalOperator> &merge_match, + const std::shared_ptr<LogicalOperator> &merge_create); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + // TODO: Consider whether we want to treat Merge as having single input. It + // makes sense that we do, because other branches are executed depending on + // the input. + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class MergeCursor : public Cursor { + public: + MergeCursor(const Merge &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const UniqueCursorPtr input_cursor_; + const UniqueCursorPtr merge_match_cursor_; + const UniqueCursorPtr merge_create_cursor_; + + // indicates if the next Pull from this cursor + // should perform a pull from input_cursor_ + // this is true when: + // - first Pulling from this cursor + // - previous Pull from this cursor exhausted the merge_match_cursor + bool pull_input_{true}; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class optional (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (optional "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (optional-symbols "std::vector<Symbol>" :scope :public)) + (:documentation + "Optional operator. Used for optional match. For every +successful Pull from the input branch a Pull from the optional +branch is attempted (and Pulled from till exhausted). If zero +Pulls succeed from the optional branch, the Optional operator +sets the optional symbols to TypedValue::Null on the Frame +and returns true, once.") + (:public + #>cpp + Optional() {} + + Optional(const std::shared_ptr<LogicalOperator> &input, + const std::shared_ptr<LogicalOperator> &optional, + const std::vector<Symbol> &optional_symbols); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class OptionalCursor : public Cursor { + public: + OptionalCursor(const Optional &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Optional &self_; + const UniqueCursorPtr input_cursor_; + const UniqueCursorPtr optional_cursor_; + // indicates if the next Pull from this cursor should + // perform a Pull from the input_cursor_ + // this is true when: + // - first pulling from this Cursor + // - previous Pull from this cursor exhausted the optional_cursor_ + bool pull_input_{true}; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class unwind (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (input-expression "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (output-symbol "Symbol" :scope :public)) + (:documentation + "Takes a list TypedValue as it's input and yields each +element as it's output. + +Input is optional (unwind can be the first clause in a query).") + (:public + #>cpp + Unwind() {} + + Unwind(const std::shared_ptr<LogicalOperator> &input, + Expression *input_expression_, Symbol output_symbol); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { + return true; } + std::shared_ptr<LogicalOperator> input() const override { + return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class distinct (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (value-symbols "std::vector<Symbol>" :scope :public)) + (:documentation + "Ensures that only distinct rows are yielded. +This implementation accepts a vector of Symbols +which define a row. Only those Symbols are valid +for use in operators following Distinct. + +This implementation maintains input ordering.") + (:public + #>cpp + Distinct() {} + + Distinct(const std::shared_ptr<LogicalOperator> &input, + const std::vector<Symbol> &value_symbols); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class union (logical-operator) + ((left-op "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (right-op "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (union-symbols "std::vector<Symbol>" :scope :public) + (left-symbols "std::vector<Symbol>" :scope :public) + (right-symbols "std::vector<Symbol>" :scope :public)) + (:documentation + "A logical operator that applies UNION operator on inputs and places the +result on the frame. + +This operator takes two inputs, a vector of symbols for the result, and vectors +of symbols used by each of the inputs.") + (:public + #>cpp + Union() {} + + Union(const std::shared_ptr<LogicalOperator> &left_op, + const std::shared_ptr<LogicalOperator> &right_op, + const std::vector<Symbol> &union_symbols, + const std::vector<Symbol> &left_symbols, + const std::vector<Symbol> &right_symbols); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override; + std::shared_ptr<LogicalOperator> input() const override; + void set_input(std::shared_ptr<LogicalOperator>) override; + cpp<#) + (:private + #>cpp + class UnionCursor : public Cursor { + public: + UnionCursor(const Union &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Union &self_; + const UniqueCursorPtr left_cursor_, right_cursor_; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + +;; TODO: We should probably output this operator in regular planner, not just +;; distributed planner. +(lcp:define-class cartesian (logical-operator) + ((left-op "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (left-symbols "std::vector<Symbol>" :scope :public) + (right-op "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (right-symbols "std::vector<Symbol>" :scope :public)) + (:documentation + "Operator for producing a Cartesian product from 2 input branches") + (:public + #>cpp + Cartesian() {} + /** Construct the operator with left input branch and right input branch. */ + Cartesian(const std::shared_ptr<LogicalOperator> &left_op, + const std::vector<Symbol> &left_symbols, + const std::shared_ptr<LogicalOperator> &right_op, + const std::vector<Symbol> &right_symbols) + : left_op_(left_op), + left_symbols_(left_symbols), + right_op_(right_op), + right_symbols_(right_symbols) {} + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override; + std::shared_ptr<LogicalOperator> input() const override; + void set_input(std::shared_ptr<LogicalOperator>) override; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class output-table (logical-operator) + ((output-symbols "std::vector<Symbol>" :scope :public :dont-save t) + (callback "std::function<std::vector<std::vector<TypedValue>>(Frame *, ExecutionContext *)>" + :scope :public :dont-save t :clone :copy)) + (:documentation "An operator that outputs a table, producing a single row on each pull") + (:public + #>cpp + OutputTable() {} + OutputTable( + std::vector<Symbol> output_symbols, + std::function<std::vector<std::vector<TypedValue>>(Frame *, ExecutionContext *)> + callback); + OutputTable(std::vector<Symbol> output_symbols, + std::vector<std::vector<TypedValue>> rows); + + bool Accept(HierarchicalLogicalOperatorVisitor &) override { + LOG_FATAL("OutputTable operator should not be visited!"); + } + + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override { + return output_symbols_; + } + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override { + return output_symbols_; + } + + bool HasSingleInput() const override; + std::shared_ptr<LogicalOperator> input() const override; + void set_input(std::shared_ptr<LogicalOperator> input) override; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class output-table-stream (logical-operator) + ((output-symbols "std::vector<Symbol>" :scope :public :dont-save t) + (callback "std::function<std::optional<std::vector<TypedValue>>(Frame *, ExecutionContext *)>" + :scope :public :dont-save t :clone :copy)) + (:documentation "An operator that outputs a table, producing a single row on each pull. +This class is different from @c OutputTable in that its callback doesn't fetch all rows +at once. Instead, each call of the callback should return a single row of the table.") + (:public + #>cpp + OutputTableStream() {} + OutputTableStream( + std::vector<Symbol> output_symbols, + std::function<std::optional<std::vector<TypedValue>>(Frame *, ExecutionContext *)> + callback); + + bool Accept(HierarchicalLogicalOperatorVisitor &) override { + LOG_FATAL("OutputTableStream operator should not be visited!"); + } + + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override { + return output_symbols_; + } + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override { + return output_symbols_; + } + + bool HasSingleInput() const override; + std::shared_ptr<LogicalOperator> input() const override; + void set_input(std::shared_ptr<LogicalOperator> input) override; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class call-procedure (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (procedure-name "std::string" :scope :public) + (arguments "std::vector<Expression *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Expression")) + (result-fields "std::vector<std::string>" :scope :public) + (result-symbols "std::vector<Symbol>" :scope :public) + (memory-limit "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (memory-scale "size_t" :initval "1024U" :scope :public) + (is_write :bool :scope :public)) + (:public + #>cpp + CallProcedure() = default; + CallProcedure(std::shared_ptr<LogicalOperator> input, std::string name, + std::vector<Expression *> arguments, + std::vector<std::string> fields, std::vector<Symbol> symbols, + Expression *memory_limit, size_t memory_scale, bool is_write); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + + static void IncrementCounter(const std::string &procedure_name); + static std::unordered_map<std::string, int64_t> GetAndResetCounters(); + cpp<#) + (:private + #>cpp + inline static utils::Synchronized<std::unordered_map<std::string, int64_t>, utils::SpinLock> procedure_counters_; + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class load-csv (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (file "Expression *" :scope :public) + (with_header "bool" :scope :public) + (ignore_bad "bool" :scope :public) + (delimiter "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (quote "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (row_var "Symbol" :scope :public)) + (:public + #>cpp + LoadCsv() = default; + LoadCsv(std::shared_ptr<LogicalOperator> input, Expression *file, bool with_header, bool ignore_bad, + Expression* delimiter, Expression* quote, Symbol row_var); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:define-class foreach (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (update-clauses "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (expression "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (loop-variable-symbol "Symbol" :scope :public)) + + (:documentation + "Iterates over a collection of elements and applies one or more update +clauses. +") + (:public + #>cpp + Foreach() = default; + Foreach(std::shared_ptr<LogicalOperator> input, + std::shared_ptr<LogicalOperator> updates, + Expression *named_expr, + Symbol loop_variable_symbol); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = std::move(input); + } + cpp<#) + (:serialize (:slk)) + (:clone)) + +(lcp:pop-namespace) ;; plan +(lcp:pop-namespace) ;; v2 +(lcp:pop-namespace) ;; query +(lcp:pop-namespace) ;; memgraph diff --git a/src/query/v2/plan/planner.hpp b/src/query/v2/plan/planner.hpp new file mode 100644 index 000000000..fe9e88f32 --- /dev/null +++ b/src/query/v2/plan/planner.hpp @@ -0,0 +1,158 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +/// This file is an entry point for invoking various planners via the following +/// API: +/// * `MakeLogicalPlanForSingleQuery` +/// * `MakeLogicalPlan` + +#pragma once + +#include "query/v2/plan/cost_estimator.hpp" +#include "query/v2/plan/operator.hpp" +#include "query/v2/plan/preprocess.hpp" +#include "query/v2/plan/pretty_print.hpp" +#include "query/v2/plan/rewrite/index_lookup.hpp" +#include "query/v2/plan/rule_based_planner.hpp" +#include "query/v2/plan/variable_start_planner.hpp" +#include "query/v2/plan/vertex_count_cache.hpp" + +namespace memgraph::query::v2 { + +class AstStorage; +class SymbolTable; + +namespace plan { + +class PostProcessor final { + Parameters parameters_; + + public: + using ProcessedPlan = std::unique_ptr<LogicalOperator>; + + explicit PostProcessor(const Parameters ¶meters) : parameters_(parameters) {} + + template <class TPlanningContext> + std::unique_ptr<LogicalOperator> Rewrite(std::unique_ptr<LogicalOperator> plan, TPlanningContext *context) { + return RewriteWithIndexLookup(std::move(plan), context->symbol_table, context->ast_storage, context->db); + } + + template <class TVertexCounts> + double EstimatePlanCost(const std::unique_ptr<LogicalOperator> &plan, TVertexCounts *vertex_counts) { + return query::v2::plan::EstimatePlanCost(vertex_counts, parameters_, *plan); + } + + template <class TPlanningContext> + std::unique_ptr<LogicalOperator> MergeWithCombinator(std::unique_ptr<LogicalOperator> curr_op, + std::unique_ptr<LogicalOperator> last_op, const Tree &combinator, + TPlanningContext *context) { + if (const auto *union_ = utils::Downcast<const CypherUnion>(&combinator)) { + return std::unique_ptr<LogicalOperator>( + impl::GenUnion(*union_, std::move(last_op), std::move(curr_op), *context->symbol_table)); + } + throw utils::NotYetImplemented("query combinator"); + } + + template <class TPlanningContext> + std::unique_ptr<LogicalOperator> MakeDistinct(std::unique_ptr<LogicalOperator> last_op, TPlanningContext *context) { + auto output_symbols = last_op->OutputSymbols(*context->symbol_table); + return std::make_unique<Distinct>(std::move(last_op), output_symbols); + } +}; + +/// @brief Generates the LogicalOperator tree for a single query and returns the +/// resulting plan. +/// +/// @tparam TPlanner Type of the planner used for generation. +/// @tparam TDbAccessor Type of the database accessor used for generation. +/// @param vector of @c SingleQueryPart from the single query +/// @param context PlanningContext used for generating plans. +/// @return @c PlanResult which depends on the @c TPlanner used. +/// +/// @sa PlanningContext +/// @sa RuleBasedPlanner +/// @sa VariableStartPlanner +template <template <class> class TPlanner, class TDbAccessor> +auto MakeLogicalPlanForSingleQuery(std::vector<SingleQueryPart> single_query_parts, + PlanningContext<TDbAccessor> *context) { + context->bound_symbols.clear(); + return TPlanner<PlanningContext<TDbAccessor>>(context).Plan(single_query_parts); +} + +/// Generates the LogicalOperator tree and returns the resulting plan. +/// +/// @tparam TPlanningContext Type of the context used. +/// @tparam TPlanPostProcess Type of the plan post processor used. +/// +/// @param context PlanningContext used for generating plans. +/// @param post_process performs plan rewrites and cost estimation. +/// @param use_variable_planner boolean flag to choose which planner to use. +/// +/// @return pair consisting of the final `TPlanPostProcess::ProcessedPlan` and +/// the estimated cost of that plan as a `double`. +template <class TPlanningContext, class TPlanPostProcess> +auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process, bool use_variable_planner) { + auto query_parts = CollectQueryParts(*context->symbol_table, *context->ast_storage, context->query); + auto &vertex_counts = *context->db; + double total_cost = 0; + + using ProcessedPlan = typename TPlanPostProcess::ProcessedPlan; + ProcessedPlan last_plan; + + for (const auto &query_part : query_parts.query_parts) { + std::optional<ProcessedPlan> curr_plan; + double min_cost = std::numeric_limits<double>::max(); + + if (use_variable_planner) { + auto plans = MakeLogicalPlanForSingleQuery<VariableStartPlanner>(query_part.single_query_parts, context); + for (auto plan : plans) { + // Plans are generated lazily and the current plan will disappear, so + // it's ok to move it. + auto rewritten_plan = post_process->Rewrite(std::move(plan), context); + double cost = post_process->EstimatePlanCost(rewritten_plan, &vertex_counts); + if (!curr_plan || cost < min_cost) { + curr_plan.emplace(std::move(rewritten_plan)); + min_cost = cost; + } + } + } else { + auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(query_part.single_query_parts, context); + auto rewritten_plan = post_process->Rewrite(std::move(plan), context); + min_cost = post_process->EstimatePlanCost(rewritten_plan, &vertex_counts); + curr_plan.emplace(std::move(rewritten_plan)); + } + + total_cost += min_cost; + if (query_part.query_combinator) { + last_plan = post_process->MergeWithCombinator(std::move(*curr_plan), std::move(last_plan), + *query_part.query_combinator, context); + } else { + last_plan = std::move(*curr_plan); + } + } + + if (query_parts.distinct) { + last_plan = post_process->MakeDistinct(std::move(last_plan), context); + } + + return std::make_pair(std::move(last_plan), total_cost); +} + +template <class TPlanningContext> +auto MakeLogicalPlan(TPlanningContext *context, const Parameters ¶meters, bool use_variable_planner) { + PostProcessor post_processor(parameters); + return MakeLogicalPlan(context, &post_processor, use_variable_planner); +} + +} // namespace plan + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/plan/preprocess.cpp b/src/query/v2/plan/preprocess.cpp new file mode 100644 index 000000000..80f1935da --- /dev/null +++ b/src/query/v2/plan/preprocess.cpp @@ -0,0 +1,599 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <algorithm> +#include <functional> +#include <stack> +#include <type_traits> +#include <unordered_map> +#include <variant> + +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/ast_visitor.hpp" +#include "query/v2/plan/preprocess.hpp" +#include "utils/typeinfo.hpp" + +namespace memgraph::query::v2::plan { + +namespace { + +void ForEachPattern(Pattern &pattern, std::function<void(NodeAtom *)> base, + std::function<void(NodeAtom *, EdgeAtom *, NodeAtom *)> collect) { + DMG_ASSERT(!pattern.atoms_.empty(), "Missing atoms in pattern"); + auto atoms_it = pattern.atoms_.begin(); + auto current_node = utils::Downcast<NodeAtom>(*atoms_it++); + DMG_ASSERT(current_node, "First pattern atom is not a node"); + base(current_node); + // Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)* + while (atoms_it != pattern.atoms_.end()) { + auto edge = utils::Downcast<EdgeAtom>(*atoms_it++); + DMG_ASSERT(edge, "Expected an edge atom in pattern."); + DMG_ASSERT(atoms_it != pattern.atoms_.end(), "Edge atom should not end the pattern."); + auto prev_node = current_node; + current_node = utils::Downcast<NodeAtom>(*atoms_it++); + DMG_ASSERT(current_node, "Expected a node atom in pattern."); + collect(prev_node, edge, current_node); + } +} + +// Converts multiple Patterns to Expansions. Each Pattern can contain an +// arbitrarily long chain of nodes and edges. The conversion to an Expansion is +// done by splitting a pattern into triplets (node1, edge, node2). The triplets +// conserve the semantics of the pattern. For example, in a pattern: +// (m) -[e]- (n) -[f]- (o) the same can be achieved with: +// (m) -[e]- (n), (n) -[f]- (o). +// This representation makes it easier to permute from which node or edge we +// want to start expanding. +std::vector<Expansion> NormalizePatterns(const SymbolTable &symbol_table, const std::vector<Pattern *> &patterns) { + std::vector<Expansion> expansions; + auto ignore_node = [&](auto *) {}; + auto collect_expansion = [&](auto *prev_node, auto *edge, auto *current_node) { + UsedSymbolsCollector collector(symbol_table); + if (edge->IsVariable()) { + if (edge->lower_bound_) edge->lower_bound_->Accept(collector); + if (edge->upper_bound_) edge->upper_bound_->Accept(collector); + if (edge->filter_lambda_.expression) edge->filter_lambda_.expression->Accept(collector); + // Remove symbols which are bound by lambda arguments. + collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_edge)); + collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_node)); + if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) { + collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_edge)); + collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_node)); + } + } + expansions.emplace_back(Expansion{prev_node, edge, edge->direction_, false, collector.symbols_, current_node}); + }; + for (const auto &pattern : patterns) { + if (pattern->atoms_.size() == 1U) { + auto *node = utils::Downcast<NodeAtom>(pattern->atoms_[0]); + DMG_ASSERT(node, "First pattern atom is not a node"); + expansions.emplace_back(Expansion{node}); + } else { + ForEachPattern(*pattern, ignore_node, collect_expansion); + } + } + return expansions; +} + +// Fills the given Matching, by converting the Match patterns to normalized +// representation as Expansions. Filters used in the Match are also collected, +// as well as edge symbols which determine Cyphermorphism. Collecting filters +// will lift them out of a pattern and generate new expressions (just like they +// were in a Where clause). +void AddMatching(const std::vector<Pattern *> &patterns, Where *where, SymbolTable &symbol_table, AstStorage &storage, + Matching &matching) { + auto expansions = NormalizePatterns(symbol_table, patterns); + std::unordered_set<Symbol> edge_symbols; + for (const auto &expansion : expansions) { + // Matching may already have some expansions, so offset our index. + const size_t expansion_ix = matching.expansions.size(); + // Map node1 symbol to expansion + const auto &node1_sym = symbol_table.at(*expansion.node1->identifier_); + matching.node_symbol_to_expansions[node1_sym].insert(expansion_ix); + // Add node1 to all symbols. + matching.expansion_symbols.insert(node1_sym); + if (expansion.edge) { + const auto &edge_sym = symbol_table.at(*expansion.edge->identifier_); + // Fill edge symbols for Cyphermorphism. + edge_symbols.insert(edge_sym); + // Map node2 symbol to expansion + const auto &node2_sym = symbol_table.at(*expansion.node2->identifier_); + matching.node_symbol_to_expansions[node2_sym].insert(expansion_ix); + // Add edge and node2 to all symbols + matching.expansion_symbols.insert(edge_sym); + matching.expansion_symbols.insert(node2_sym); + } + matching.expansions.push_back(expansion); + } + if (!edge_symbols.empty()) { + matching.edge_symbols.emplace_back(edge_symbols); + } + for (auto *pattern : patterns) { + matching.filters.CollectPatternFilters(*pattern, symbol_table, storage); + if (pattern->identifier_->user_declared_) { + std::vector<Symbol> path_elements; + for (auto *pattern_atom : pattern->atoms_) + path_elements.emplace_back(symbol_table.at(*pattern_atom->identifier_)); + matching.named_paths.emplace(symbol_table.at(*pattern->identifier_), std::move(path_elements)); + } + } + if (where) { + matching.filters.CollectWhereFilter(*where, symbol_table); + } +} +void AddMatching(const Match &match, SymbolTable &symbol_table, AstStorage &storage, Matching &matching) { + return AddMatching(match.patterns_, match.where_, symbol_table, storage, matching); +} + +auto SplitExpressionOnAnd(Expression *expression) { + // TODO: Think about converting all filtering expression into CNF to improve + // the granularity of filters which can be stand alone. + std::vector<Expression *> expressions; + std::stack<Expression *> pending_expressions; + pending_expressions.push(expression); + while (!pending_expressions.empty()) { + auto *current_expression = pending_expressions.top(); + pending_expressions.pop(); + if (auto *and_op = utils::Downcast<AndOperator>(current_expression)) { + pending_expressions.push(and_op->expression1_); + pending_expressions.push(and_op->expression2_); + } else { + expressions.push_back(current_expression); + } + } + return expressions; +} + +} // namespace + +PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property, + Expression *value, Type type) + : symbol_(symbol), property_(property), type_(type), value_(value) { + MG_ASSERT(type != Type::RANGE); + UsedSymbolsCollector collector(symbol_table); + value->Accept(collector); + is_symbol_in_value_ = utils::Contains(collector.symbols_, symbol); +} + +PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property, + const std::optional<PropertyFilter::Bound> &lower_bound, + const std::optional<PropertyFilter::Bound> &upper_bound) + : symbol_(symbol), property_(property), type_(Type::RANGE), lower_bound_(lower_bound), upper_bound_(upper_bound) { + UsedSymbolsCollector collector(symbol_table); + if (lower_bound) { + lower_bound->value()->Accept(collector); + } + if (upper_bound) { + upper_bound->value()->Accept(collector); + } + is_symbol_in_value_ = utils::Contains(collector.symbols_, symbol); +} + +PropertyFilter::PropertyFilter(const Symbol &symbol, PropertyIx property, Type type) + : symbol_(symbol), property_(property), type_(type) { + // As this constructor is used for property filters where + // we don't have to evaluate the filter expression, we set + // the is_symbol_in_value_ to false, although the filter + // expression may actually contain the symbol whose property + // we may be looking up. +} + +IdFilter::IdFilter(const SymbolTable &symbol_table, const Symbol &symbol, Expression *value) + : symbol_(symbol), value_(value) { + MG_ASSERT(value); + UsedSymbolsCollector collector(symbol_table); + value->Accept(collector); + is_symbol_in_value_ = utils::Contains(collector.symbols_, symbol); +} + +void Filters::EraseFilter(const FilterInfo &filter) { + // TODO: Ideally, we want to determine the equality of both expression trees, + // instead of a simple pointer compare. + all_filters_.erase(std::remove_if(all_filters_.begin(), all_filters_.end(), + [&filter](const auto &f) { return f.expression == filter.expression; }), + all_filters_.end()); +} + +void Filters::EraseLabelFilter(const Symbol &symbol, LabelIx label, std::vector<Expression *> *removed_filters) { + for (auto filter_it = all_filters_.begin(); filter_it != all_filters_.end();) { + if (filter_it->type != FilterInfo::Type::Label) { + ++filter_it; + continue; + } + if (!utils::Contains(filter_it->used_symbols, symbol)) { + ++filter_it; + continue; + } + auto label_it = std::find(filter_it->labels.begin(), filter_it->labels.end(), label); + if (label_it == filter_it->labels.end()) { + ++filter_it; + continue; + } + filter_it->labels.erase(label_it); + DMG_ASSERT(!utils::Contains(filter_it->labels, label), "Didn't expect duplicated labels"); + if (filter_it->labels.empty()) { + // If there are no labels to filter, then erase the whole FilterInfo. + if (removed_filters) { + removed_filters->push_back(filter_it->expression); + } + filter_it = all_filters_.erase(filter_it); + } else { + ++filter_it; + } + } +} + +void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, AstStorage &storage) { + UsedSymbolsCollector collector(symbol_table); + auto add_properties_variable = [&](EdgeAtom *atom) { + const auto &symbol = symbol_table.at(*atom->identifier_); + if (auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&atom->properties_)) { + for (auto &prop_pair : *properties) { + // We need to store two property-lookup filters in all_filters. One is + // used for inlining property filters into variable expansion, and + // utilizes the inner_edge symbol. The other is used for post-expansion + // filtering and does not use the inner_edge symbol, but the edge symbol + // (a list of edges). + { + collector.symbols_.clear(); + prop_pair.second->Accept(collector); + collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.inner_node)); + collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.inner_edge)); + // First handle the inline property filter. + auto *property_lookup = storage.Create<PropertyLookup>(atom->filter_lambda_.inner_edge, prop_pair.first); + auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second); + // Currently, variable expand has no gains if we set PropertyFilter. + all_filters_.emplace_back(FilterInfo{FilterInfo::Type::Generic, prop_equal, collector.symbols_}); + } + { + collector.symbols_.clear(); + prop_pair.second->Accept(collector); + collector.symbols_.insert(symbol); // PropertyLookup uses the symbol. + // Now handle the post-expansion filter. + // Create a new identifier and a symbol which will be filled in All. + auto *identifier = storage.Create<Identifier>(atom->identifier_->name_, atom->identifier_->user_declared_) + ->MapTo(symbol_table.CreateSymbol(atom->identifier_->name_, false)); + // Create an equality expression and store it in all_filters_. + auto *property_lookup = storage.Create<PropertyLookup>(identifier, prop_pair.first); + auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second); + // Currently, variable expand has no gains if we set PropertyFilter. + all_filters_.emplace_back( + FilterInfo{FilterInfo::Type::Generic, + storage.Create<All>(identifier, atom->identifier_, storage.Create<Where>(prop_equal)), + collector.symbols_}); + } + } + return; + } + throw SemanticException("Property map matching not supported in MATCH/MERGE clause!"); + }; + auto add_properties = [&](auto *atom) { + const auto &symbol = symbol_table.at(*atom->identifier_); + if (auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&atom->properties_)) { + for (auto &prop_pair : *properties) { + // Create an equality expression and store it in all_filters_. + auto *property_lookup = storage.Create<PropertyLookup>(atom->identifier_, prop_pair.first); + auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second); + collector.symbols_.clear(); + prop_equal->Accept(collector); + FilterInfo filter_info{FilterInfo::Type::Property, prop_equal, collector.symbols_}; + // Store a PropertyFilter on the value of the property. + filter_info.property_filter.emplace(symbol_table, symbol, prop_pair.first, prop_pair.second, + PropertyFilter::Type::EQUAL); + all_filters_.emplace_back(filter_info); + } + return; + } + throw SemanticException("Property map matching not supported in MATCH/MERGE clause!"); + }; + auto add_node_filter = [&](NodeAtom *node) { + const auto &node_symbol = symbol_table.at(*node->identifier_); + if (!node->labels_.empty()) { + // Create a LabelsTest and store it. + auto *labels_test = storage.Create<LabelsTest>(node->identifier_, node->labels_); + auto label_filter = FilterInfo{FilterInfo::Type::Label, labels_test, std::unordered_set<Symbol>{node_symbol}}; + label_filter.labels = node->labels_; + all_filters_.emplace_back(label_filter); + } + add_properties(node); + }; + auto add_expand_filter = [&](NodeAtom *, EdgeAtom *edge, NodeAtom *node) { + if (edge->IsVariable()) + add_properties_variable(edge); + else + add_properties(edge); + add_node_filter(node); + }; + ForEachPattern(pattern, add_node_filter, add_expand_filter); +} + +// Adds the where filter expression to `all_filters_` and collects additional +// information for potential property and label indexing. +void Filters::CollectWhereFilter(Where &where, const SymbolTable &symbol_table) { + CollectFilterExpression(where.expression_, symbol_table); +} + +// Adds the expression to `all_filters_` and collects additional +// information for potential property and label indexing. +void Filters::CollectFilterExpression(Expression *expr, const SymbolTable &symbol_table) { + auto filters = SplitExpressionOnAnd(expr); + for (const auto &filter : filters) { + AnalyzeAndStoreFilter(filter, symbol_table); + } +} + +// Analyzes the filter expression by collecting information on filtering labels +// and properties to be used with indexing. +void Filters::AnalyzeAndStoreFilter(Expression *expr, const SymbolTable &symbol_table) { + using Bound = PropertyFilter::Bound; + UsedSymbolsCollector collector(symbol_table); + expr->Accept(collector); + auto make_filter = [&collector, &expr](FilterInfo::Type type) { return FilterInfo{type, expr, collector.symbols_}; }; + auto get_property_lookup = [](auto *maybe_lookup, auto *&prop_lookup, auto *&ident) -> bool { + return (prop_lookup = utils::Downcast<PropertyLookup>(maybe_lookup)) && + (ident = utils::Downcast<Identifier>(prop_lookup->expression_)); + }; + // Checks if maybe_lookup is a property lookup, stores it as a + // PropertyFilter and returns true. If it isn't, returns false. + auto add_prop_equal = [&](auto *maybe_lookup, auto *val_expr) -> bool { + PropertyLookup *prop_lookup = nullptr; + Identifier *ident = nullptr; + if (get_property_lookup(maybe_lookup, prop_lookup, ident)) { + auto filter = make_filter(FilterInfo::Type::Property); + filter.property_filter = PropertyFilter(symbol_table, symbol_table.at(*ident), prop_lookup->property_, val_expr, + PropertyFilter::Type::EQUAL); + all_filters_.emplace_back(filter); + return true; + } + return false; + }; + // Like add_prop_equal, but for adding regex match property filter. + auto add_prop_regex_match = [&](auto *maybe_lookup, auto *val_expr) -> bool { + PropertyLookup *prop_lookup = nullptr; + Identifier *ident = nullptr; + if (get_property_lookup(maybe_lookup, prop_lookup, ident)) { + auto filter = make_filter(FilterInfo::Type::Property); + filter.property_filter = PropertyFilter(symbol_table, symbol_table.at(*ident), prop_lookup->property_, val_expr, + PropertyFilter::Type::REGEX_MATCH); + all_filters_.emplace_back(filter); + return true; + } + return false; + }; + // Checks if either the expr1 and expr2 are property lookups, adds them as + // PropertyFilter and returns true. Otherwise, returns false. + auto add_prop_greater = [&](auto *expr1, auto *expr2, auto bound_type) -> bool { + PropertyLookup *prop_lookup = nullptr; + Identifier *ident = nullptr; + bool is_prop_filter = false; + if (get_property_lookup(expr1, prop_lookup, ident)) { + // n.prop > value + auto filter = make_filter(FilterInfo::Type::Property); + filter.property_filter.emplace(symbol_table, symbol_table.at(*ident), prop_lookup->property_, + Bound(expr2, bound_type), std::nullopt); + all_filters_.emplace_back(filter); + is_prop_filter = true; + } + if (get_property_lookup(expr2, prop_lookup, ident)) { + // value > n.prop + auto filter = make_filter(FilterInfo::Type::Property); + filter.property_filter.emplace(symbol_table, symbol_table.at(*ident), prop_lookup->property_, std::nullopt, + Bound(expr1, bound_type)); + all_filters_.emplace_back(filter); + is_prop_filter = true; + } + return is_prop_filter; + }; + // Check if maybe_id_fun is ID invocation on an indentifier and add it as + // IdFilter. + auto add_id_equal = [&](auto *maybe_id_fun, auto *val_expr) -> bool { + auto *id_fun = utils::Downcast<Function>(maybe_id_fun); + if (!id_fun) return false; + if (id_fun->function_name_ != kId) return false; + if (id_fun->arguments_.size() != 1U) return false; + auto *ident = utils::Downcast<Identifier>(id_fun->arguments_.front()); + if (!ident) return false; + auto filter = make_filter(FilterInfo::Type::Id); + filter.id_filter.emplace(symbol_table, symbol_table.at(*ident), val_expr); + all_filters_.emplace_back(filter); + return true; + }; + // Checks if maybe_lookup is a property lookup, stores it as a + // PropertyFilter and returns true. If it isn't, returns false. + auto add_prop_in_list = [&](auto *maybe_lookup, auto *val_expr) -> bool { + if (!utils::Downcast<ListLiteral>(val_expr)) return false; + PropertyLookup *prop_lookup = nullptr; + Identifier *ident = nullptr; + if (get_property_lookup(maybe_lookup, prop_lookup, ident)) { + auto filter = make_filter(FilterInfo::Type::Property); + filter.property_filter = PropertyFilter(symbol_table, symbol_table.at(*ident), prop_lookup->property_, val_expr, + PropertyFilter::Type::IN); + all_filters_.emplace_back(filter); + return true; + } + return false; + }; + + // Checks whether maybe_prop_not_null_check is the null check on a property, + // ("prop IS NOT NULL"), stores it as a PropertyFilter if it is, and returns + // true. If it isn't returns false. + auto add_prop_is_not_null_check = [&](auto *maybe_is_not_null_check) -> bool { + // Strip away the outer NOT operator, and figure out + // whether the inner expression is of the form "prop IS NULL" + if (!maybe_is_not_null_check) { + return false; + } + + auto *maybe_is_null_check = utils::Downcast<IsNullOperator>(maybe_is_not_null_check->expression_); + if (!maybe_is_null_check) { + return false; + } + PropertyLookup *prop_lookup = nullptr; + Identifier *ident = nullptr; + + if (!get_property_lookup(maybe_is_null_check->expression_, prop_lookup, ident)) { + return false; + } + + auto filter = make_filter(FilterInfo::Type::Property); + filter.property_filter = + PropertyFilter(symbol_table.at(*ident), prop_lookup->property_, PropertyFilter::Type::IS_NOT_NULL); + all_filters_.emplace_back(filter); + return true; + }; + // We are only interested to see the insides of And, because Or prevents + // indexing since any labels and properties found there may be optional. + DMG_ASSERT(!utils::IsSubtype(*expr, AndOperator::kType), "Expected AndOperators have been split."); + if (auto *labels_test = utils::Downcast<LabelsTest>(expr)) { + // Since LabelsTest may contain any expression, we can only use the + // simplest test on an identifier. + if (utils::Downcast<Identifier>(labels_test->expression_)) { + auto filter = make_filter(FilterInfo::Type::Label); + filter.labels = labels_test->labels_; + all_filters_.emplace_back(filter); + } else { + all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic)); + } + } else if (auto *eq = utils::Downcast<EqualOperator>(expr)) { + // Try to get property equality test from the top expressions. + // Unfortunately, we cannot go deeper inside Equal, because chained equals + // need not correspond to And. For example, `(n.prop = value) = false)`: + // EQ + // / \ + // EQ false -- top expressions + // / \ + // n.prop value + // Here the `prop` may be different than `value` resulting in `false`. This + // would compare with the top level `false`, producing `true`. Therefore, it + // is incorrect to pick up `n.prop = value` for scanning by property index. + bool is_prop_filter = add_prop_equal(eq->expression1_, eq->expression2_); + // And reversed. + is_prop_filter |= add_prop_equal(eq->expression2_, eq->expression1_); + // Try to get ID equality filter. + bool is_id_filter = add_id_equal(eq->expression1_, eq->expression2_); + is_id_filter |= add_id_equal(eq->expression2_, eq->expression1_); + if (!is_prop_filter && !is_id_filter) { + // No special filter was added, so just store a generic filter. + all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic)); + } + } else if (auto *regex_match = utils::Downcast<RegexMatch>(expr)) { + if (!add_prop_regex_match(regex_match->string_expr_, regex_match->regex_)) { + all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic)); + } + } else if (auto *gt = utils::Downcast<GreaterOperator>(expr)) { + if (!add_prop_greater(gt->expression1_, gt->expression2_, Bound::Type::EXCLUSIVE)) { + all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic)); + } + } else if (auto *ge = utils::Downcast<GreaterEqualOperator>(expr)) { + if (!add_prop_greater(ge->expression1_, ge->expression2_, Bound::Type::INCLUSIVE)) { + all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic)); + } + } else if (auto *lt = utils::Downcast<LessOperator>(expr)) { + // Like greater, but in reverse. + if (!add_prop_greater(lt->expression2_, lt->expression1_, Bound::Type::EXCLUSIVE)) { + all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic)); + } + } else if (auto *le = utils::Downcast<LessEqualOperator>(expr)) { + // Like greater equal, but in reverse. + if (!add_prop_greater(le->expression2_, le->expression1_, Bound::Type::INCLUSIVE)) { + all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic)); + } + } else if (auto *in = utils::Downcast<InListOperator>(expr)) { + // IN isn't equivalent to Equal because IN isn't a symmetric operator. The + // IN filter is captured here only if the property lookup occurs on the + // left side of the operator. In that case, it's valid to do the IN list + // optimization during the index lookup rewrite phase. + if (!add_prop_in_list(in->expression1_, in->expression2_)) { + all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic)); + } + } else if (auto *is_not_null = utils::Downcast<NotOperator>(expr)) { + if (!add_prop_is_not_null_check(is_not_null)) { + all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic)); + } + } else { + all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic)); + } + // TODO: Collect comparisons like `expr1 < n.prop < expr2` for potential + // indexing by range. Note, that the generated Ast uses AND for chained + // relation operators. Therefore, `expr1 < n.prop < expr2` will be represented + // as `expr1 < n.prop AND n.prop < expr2`. +} + +static void ParseForeach(query::v2::Foreach &foreach, SingleQueryPart &query_part, AstStorage &storage, + SymbolTable &symbol_table) { + for (auto *clause : foreach.clauses_) { + if (auto *merge = utils::Downcast<query::v2::Merge>(clause)) { + query_part.merge_matching.emplace_back(Matching{}); + AddMatching({merge->pattern_}, nullptr, symbol_table, storage, query_part.merge_matching.back()); + } else if (auto *nested = utils::Downcast<query::v2::Foreach>(clause)) { + ParseForeach(*nested, query_part, storage, symbol_table); + } + } +} + +// Converts a Query to multiple QueryParts. In the process new Ast nodes may be +// created, e.g. filter expressions. +std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table, AstStorage &storage, + SingleQuery *single_query) { + std::vector<SingleQueryPart> query_parts(1); + auto *query_part = &query_parts.back(); + for (auto &clause : single_query->clauses_) { + if (auto *match = utils::Downcast<Match>(clause)) { + if (match->optional_) { + query_part->optional_matching.emplace_back(Matching{}); + AddMatching(*match, symbol_table, storage, query_part->optional_matching.back()); + } else { + DMG_ASSERT(query_part->optional_matching.empty(), "Match clause cannot follow optional match."); + AddMatching(*match, symbol_table, storage, query_part->matching); + } + } else { + query_part->remaining_clauses.push_back(clause); + if (auto *merge = utils::Downcast<query::v2::Merge>(clause)) { + query_part->merge_matching.emplace_back(Matching{}); + AddMatching({merge->pattern_}, nullptr, symbol_table, storage, query_part->merge_matching.back()); + } else if (auto *foreach = utils::Downcast<query::v2::Foreach>(clause)) { + ParseForeach(*foreach, *query_part, storage, symbol_table); + } else if (utils::IsSubtype(*clause, With::kType) || utils::IsSubtype(*clause, query::v2::Unwind::kType) || + utils::IsSubtype(*clause, query::v2::CallProcedure::kType) || + utils::IsSubtype(*clause, query::v2::LoadCsv::kType)) { + // This query part is done, continue with a new one. + query_parts.emplace_back(SingleQueryPart{}); + query_part = &query_parts.back(); + } else if (utils::IsSubtype(*clause, Return::kType)) { + return query_parts; + } + } + } + return query_parts; +} + +QueryParts CollectQueryParts(SymbolTable &symbol_table, AstStorage &storage, CypherQuery *query) { + std::vector<QueryPart> query_parts; + + auto *single_query = query->single_query_; + MG_ASSERT(single_query, "Expected at least a single query"); + query_parts.push_back(QueryPart{CollectSingleQueryParts(symbol_table, storage, single_query)}); + + bool distinct = false; + for (auto *cypher_union : query->cypher_unions_) { + if (cypher_union->distinct_) { + distinct = true; + } + + auto *single_query = cypher_union->single_query_; + MG_ASSERT(single_query, "Expected UNION to have a query"); + query_parts.push_back(QueryPart{CollectSingleQueryParts(symbol_table, storage, single_query), cypher_union}); + } + return QueryParts{query_parts, distinct}; +} + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/preprocess.hpp b/src/query/v2/plan/preprocess.hpp new file mode 100644 index 000000000..619f27f58 --- /dev/null +++ b/src/query/v2/plan/preprocess.hpp @@ -0,0 +1,360 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +#pragma once + +#include <optional> +#include <set> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/semantic/symbol_table.hpp" +#include "query/v2/plan/operator.hpp" + +namespace memgraph::query::v2::plan { + +/// Collects symbols from identifiers found in visited AST nodes. +class UsedSymbolsCollector : public HierarchicalTreeVisitor { + public: + explicit UsedSymbolsCollector(const SymbolTable &symbol_table) : symbol_table_(symbol_table) {} + + using HierarchicalTreeVisitor::PostVisit; + using HierarchicalTreeVisitor::PreVisit; + using HierarchicalTreeVisitor::Visit; + + bool PostVisit(All &all) override { + // Remove the symbol which is bound by all, because we are only interested + // in free (unbound) symbols. + symbols_.erase(symbol_table_.at(*all.identifier_)); + return true; + } + + bool PostVisit(Single &single) override { + // Remove the symbol which is bound by single, because we are only + // interested in free (unbound) symbols. + symbols_.erase(symbol_table_.at(*single.identifier_)); + return true; + } + + bool PostVisit(Any &any) override { + // Remove the symbol which is bound by any, because we are only interested + // in free (unbound) symbols. + symbols_.erase(symbol_table_.at(*any.identifier_)); + return true; + } + + bool PostVisit(None &none) override { + // Remove the symbol which is bound by none, because we are only interested + // in free (unbound) symbols. + symbols_.erase(symbol_table_.at(*none.identifier_)); + return true; + } + + bool PostVisit(Reduce &reduce) override { + // Remove the symbols bound by reduce, because we are only interested + // in free (unbound) symbols. + symbols_.erase(symbol_table_.at(*reduce.accumulator_)); + symbols_.erase(symbol_table_.at(*reduce.identifier_)); + return true; + } + + bool Visit(Identifier &ident) override { + symbols_.insert(symbol_table_.at(ident)); + return true; + } + + bool Visit(PrimitiveLiteral &) override { return true; } + bool Visit(ParameterLookup &) override { return true; } + + std::unordered_set<Symbol> symbols_; + const SymbolTable &symbol_table_; +}; + +/// Normalized representation of a pattern that needs to be matched. +struct Expansion { + /// The first node in the expansion, it can be a single node. + NodeAtom *node1 = nullptr; + /// Optional edge which connects the 2 nodes. + EdgeAtom *edge = nullptr; + /// Direction of the edge, it may be flipped compared to original + /// @c EdgeAtom during plan generation. + EdgeAtom::Direction direction = EdgeAtom::Direction::BOTH; + /// True if the direction and nodes were flipped. + bool is_flipped = false; + /// Set of symbols found inside the range expressions of a variable path edge. + std::unordered_set<Symbol> symbols_in_range{}; + /// Optional node at the other end of an edge. If the expansion + /// contains an edge, then this node is required. + NodeAtom *node2 = nullptr; +}; + +/// Stores the symbols and expression used to filter a property. +class PropertyFilter { + public: + using Bound = ScanAllByLabelPropertyRange::Bound; + + /// Depending on type, this PropertyFilter may be a value equality, regex + /// matched value or a range with lower and (or) upper bounds, IN list filter. + enum class Type { EQUAL, REGEX_MATCH, RANGE, IN, IS_NOT_NULL }; + + /// Construct with Expression being the equality or regex match check. + PropertyFilter(const SymbolTable &, const Symbol &, PropertyIx, Expression *, Type); + /// Construct the range based filter. + PropertyFilter(const SymbolTable &, const Symbol &, PropertyIx, const std::optional<Bound> &, + const std::optional<Bound> &); + /// Construct a filter without an expression that produces a value. + /// Used for the "PROP IS NOT NULL" filter, and can be used for any + /// property filter that doesn't need to use an expression to produce + /// values that should be filtered further. + PropertyFilter(const Symbol &, PropertyIx, Type); + + /// Symbol whose property is looked up. + Symbol symbol_; + PropertyIx property_; + Type type_; + /// True if the same symbol is used in expressions for value or bounds. + bool is_symbol_in_value_ = false; + /// Expression which when evaluated produces the value a property must + /// equal or regex match depending on type_. + Expression *value_ = nullptr; + /// Expressions which produce lower and upper bounds for a property. + std::optional<Bound> lower_bound_{}; + std::optional<Bound> upper_bound_{}; +}; + +/// Filtering by ID, for example `MATCH (n) WHERE id(n) = 42 ...` +class IdFilter { + public: + /// Construct with Expression being the required value for ID. + IdFilter(const SymbolTable &, const Symbol &, Expression *); + + /// Symbol whose id is looked up. + Symbol symbol_; + /// Expression which when evaluted produces the value an ID must satisfy. + Expression *value_; + /// True if the same symbol is used in expressions for value. + bool is_symbol_in_value_{false}; +}; + +/// Stores additional information for a filter expression. +struct FilterInfo { + /// A FilterInfo can be a generic filter expression or a specific filtering + /// applied for labels or a property. Non generic types contain extra + /// information which can be used to produce indexed scans of graph + /// elements. + enum class Type { Generic, Label, Property, Id }; + + Type type; + /// The original filter expression which must be satisfied. + Expression *expression; + /// Set of used symbols by the filter @c expression. + std::unordered_set<Symbol> used_symbols; + /// Labels for Type::Label filtering. + std::vector<LabelIx> labels; + /// Property information for Type::Property filtering. + std::optional<PropertyFilter> property_filter; + /// Information for Type::Id filtering. + std::optional<IdFilter> id_filter; +}; + +/// Stores information on filters used inside the @c Matching of a @c QueryPart. +/// +/// Info is stored as a list of FilterInfo objects corresponding to all filter +/// expressions that should be generated. +class Filters final { + public: + using iterator = std::vector<FilterInfo>::iterator; + using const_iterator = std::vector<FilterInfo>::const_iterator; + + auto begin() { return all_filters_.begin(); } + auto begin() const { return all_filters_.begin(); } + auto end() { return all_filters_.end(); } + auto end() const { return all_filters_.end(); } + + auto empty() const { return all_filters_.empty(); } + + auto erase(iterator pos) { return all_filters_.erase(pos); } + auto erase(const_iterator pos) { return all_filters_.erase(pos); } + auto erase(iterator first, iterator last) { return all_filters_.erase(first, last); } + auto erase(const_iterator first, const_iterator last) { return all_filters_.erase(first, last); } + + auto FilteredLabels(const Symbol &symbol) const { + std::unordered_set<LabelIx> labels; + for (const auto &filter : all_filters_) { + if (filter.type == FilterInfo::Type::Label && utils::Contains(filter.used_symbols, symbol)) { + MG_ASSERT(filter.used_symbols.size() == 1U, "Expected a single used symbol for label filter"); + labels.insert(filter.labels.begin(), filter.labels.end()); + } + } + return labels; + } + + /// Remove a filter; may invalidate iterators. + /// Removal is done by comparing only the expression, so that multiple + /// FilterInfo objects using the same original expression are removed. + void EraseFilter(const FilterInfo &); + + /// Remove a label filter for symbol; may invalidate iterators. + /// If removed_filters is not nullptr, fills the vector with original + /// `Expression *` which are now completely removed. + void EraseLabelFilter(const Symbol &, LabelIx, std::vector<Expression *> *removed_filters = nullptr); + + /// Returns a vector of FilterInfo for properties. + auto PropertyFilters(const Symbol &symbol) const { + std::vector<FilterInfo> filters; + for (const auto &filter : all_filters_) { + if (filter.type == FilterInfo::Type::Property && filter.property_filter->symbol_ == symbol) { + filters.push_back(filter); + } + } + return filters; + } + + /// Return a vector of FilterInfo for ID equality filtering. + auto IdFilters(const Symbol &symbol) const { + std::vector<FilterInfo> filters; + for (const auto &filter : all_filters_) { + if (filter.type == FilterInfo::Type::Id && filter.id_filter->symbol_ == symbol) { + filters.push_back(filter); + } + } + return filters; + } + + /// Collects filtering information from a pattern. + /// + /// Goes through all the atoms in a pattern and generates filter expressions + /// for found labels, properties and edge types. The generated expressions are + /// stored. + void CollectPatternFilters(Pattern &, SymbolTable &, AstStorage &); + + /// Collects filtering information from a where expression. + /// + /// Takes the where expression and stores it, then analyzes the expression for + /// additional information. The additional information is used to populate + /// label filters and property filters, so that indexed scanning can use it. + void CollectWhereFilter(Where &, const SymbolTable &); + + /// Collects filtering information from an expression. + /// + /// Takes the where expression and stores it, then analyzes the expression for + /// additional information. The additional information is used to populate + /// label filters and property filters, so that indexed scanning can use it. + void CollectFilterExpression(Expression *, const SymbolTable &); + + private: + void AnalyzeAndStoreFilter(Expression *, const SymbolTable &); + + std::vector<FilterInfo> all_filters_; +}; + +/// Normalized representation of a single or multiple Match clauses. +/// +/// For example, `MATCH (a :Label) -[e1]- (b) -[e2]- (c) MATCH (n) -[e3]- (m) +/// WHERE c.prop < 42` will produce the following. +/// Expansions will store `(a) -[e1]-(b)`, `(b) -[e2]- (c)` and +/// `(n) -[e3]- (m)`. +/// Edge symbols for Cyphermorphism will only contain the set `{e1, e2}` for the +/// first `MATCH` and the set `{e3}` for the second. +/// Filters will contain 2 pairs. One for testing `:Label` on symbol `a` and the +/// other obtained from `WHERE` on symbol `c`. +struct Matching { + /// All expansions that need to be performed across @c Match clauses. + std::vector<Expansion> expansions; + /// Symbols for edges established in match, used to ensure Cyphermorphism. + /// + /// There are multiple sets, because each Match clause determines a single + /// set. + std::vector<std::unordered_set<Symbol>> edge_symbols; + /// Information on used filter expressions while matching. + Filters filters; + /// Maps node symbols to expansions which bind them. + std::unordered_map<Symbol, std::set<size_t>> node_symbol_to_expansions{}; + /// Maps named path symbols to a vector of Symbols that define its pattern. + std::unordered_map<Symbol, std::vector<Symbol>> named_paths{}; + /// All node and edge symbols across all expansions (from all matches). + std::unordered_set<Symbol> expansion_symbols{}; +}; + +/// @brief Represents a read (+ write) part of a query. Parts are split on +/// `WITH` clauses. +/// +/// Each part ends with either: +/// +/// * `RETURN` clause; +/// * `WITH` clause; +/// * `UNWIND` clause; +/// * `CALL` clause or +/// * any of the write clauses. +/// +/// For a query `MATCH (n) MERGE (n) -[e]- (m) SET n.x = 42 MERGE (l)` the +/// generated SingleQueryPart will have `matching` generated for the `MATCH`. +/// `remaining_clauses` will contain `Merge`, `SetProperty` and `Merge` clauses +/// in that exact order. The pattern inside the first `MERGE` will be used to +/// generate the first `merge_matching` element, and the second `MERGE` pattern +/// will produce the second `merge_matching` element. This way, if someone +/// traverses `remaining_clauses`, the order of appearance of `Merge` clauses is +/// in the same order as their respective `merge_matching` elements. +/// An exception to the above rule is Foreach. Its update clauses will not be contained in +/// the `remaining_clauses`, but rather inside the foreach itself. The order guarantee is not +/// violated because the update clauses of the foreach are immediately processed in +/// the `RuleBasedPlanner` as if as they were pushed into the `remaining_clauses`. +struct SingleQueryPart { + /// @brief All `MATCH` clauses merged into one @c Matching. + Matching matching; + /// @brief Each `OPTIONAL MATCH` converted to @c Matching. + std::vector<Matching> optional_matching{}; + /// @brief @c Matching for each `MERGE` clause. + /// + /// Storing the normalized pattern of a @c Merge does not preclude storing the + /// @c Merge clause itself inside `remaining_clauses`. The reason is that we + /// need to have access to other parts of the clause, such as `SET` clauses + /// which need to be run. + /// + /// Since @c Merge is contained in `remaining_clauses`, this vector contains + /// matching in the same order as @c Merge appears. + // + /// Foreach @c does not violate this gurantee. However, update clauses are not stored + /// in the `remaining_clauses` but rather in the `Foreach` itself and are guranteed + /// to be processed in the same order by the semantics of the `RuleBasedPlanner`. + std::vector<Matching> merge_matching{}; + /// @brief All the remaining clauses (without @c Match). + std::vector<Clause *> remaining_clauses{}; +}; + +/// Holds query parts of a single query together with the optional information +/// about the combinator used between this single query and the previous one. +struct QueryPart { + std::vector<SingleQueryPart> single_query_parts = {}; + /// Optional AST query combinator node + Tree *query_combinator = nullptr; +}; + +/// Holds query parts of all single queries together with the information +/// whether or not the resulting set should contain distinct elements. +struct QueryParts { + std::vector<QueryPart> query_parts = {}; + /// Distinct flag, determined by the query combinator + bool distinct = false; +}; + +/// @brief Convert the AST to multiple @c QueryParts. +/// +/// This function will normalize patterns inside @c Match and @c Merge clauses +/// and do some other preprocessing in order to generate multiple @c QueryPart +/// structures. @c AstStorage and @c SymbolTable may be used to create new +/// AST nodes. +QueryParts CollectQueryParts(SymbolTable &, AstStorage &, CypherQuery *); + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/pretty_print.cpp b/src/query/v2/plan/pretty_print.cpp new file mode 100644 index 000000000..361cd89f1 --- /dev/null +++ b/src/query/v2/plan/pretty_print.cpp @@ -0,0 +1,910 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/plan/pretty_print.hpp" +#include <variant> + +#include "query/v2/db_accessor.hpp" +#include "query/v2/frontend/ast/pretty_print.hpp" +#include "utils/string.hpp" + +namespace memgraph::query::v2::plan { + +PlanPrinter::PlanPrinter(const DbAccessor *dba, std::ostream *out) : dba_(dba), out_(out) {} + +#define PRE_VISIT(TOp) \ + bool PlanPrinter::PreVisit(TOp &) { \ + WithPrintLn([](auto &out) { out << "* " << #TOp; }); \ + return true; \ + } + +PRE_VISIT(CreateNode); + +bool PlanPrinter::PreVisit(CreateExpand &op) { + WithPrintLn([&](auto &out) { + out << "* CreateExpand (" << op.input_symbol_.name() << ")" + << (op.edge_info_.direction == query::v2::EdgeAtom::Direction::IN ? "<-" : "-") << "[" + << op.edge_info_.symbol.name() << ":" << dba_->EdgeTypeToName(op.edge_info_.edge_type) << "]" + << (op.edge_info_.direction == query::v2::EdgeAtom::Direction::OUT ? "->" : "-") << "(" + << op.node_info_.symbol.name() << ")"; + }); + return true; +} + +PRE_VISIT(Delete); + +bool PlanPrinter::PreVisit(query::v2::plan::ScanAll &op) { + WithPrintLn([&](auto &out) { + out << "* ScanAll" + << " (" << op.output_symbol_.name() << ")"; + }); + return true; +} + +bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabel &op) { + WithPrintLn([&](auto &out) { + out << "* ScanAllByLabel" + << " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << ")"; + }); + return true; +} + +bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelPropertyValue &op) { + WithPrintLn([&](auto &out) { + out << "* ScanAllByLabelPropertyValue" + << " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {" + << dba_->PropertyToName(op.property_) << "})"; + }); + return true; +} + +bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelPropertyRange &op) { + WithPrintLn([&](auto &out) { + out << "* ScanAllByLabelPropertyRange" + << " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {" + << dba_->PropertyToName(op.property_) << "})"; + }); + return true; +} + +bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelProperty &op) { + WithPrintLn([&](auto &out) { + out << "* ScanAllByLabelProperty" + << " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {" + << dba_->PropertyToName(op.property_) << "})"; + }); + return true; +} + +bool PlanPrinter::PreVisit(ScanAllById &op) { + WithPrintLn([&](auto &out) { + out << "* ScanAllById" + << " (" << op.output_symbol_.name() << ")"; + }); + return true; +} + +bool PlanPrinter::PreVisit(query::v2::plan::Expand &op) { + WithPrintLn([&](auto &out) { + *out_ << "* Expand (" << op.input_symbol_.name() << ")" + << (op.common_.direction == query::v2::EdgeAtom::Direction::IN ? "<-" : "-") << "[" + << op.common_.edge_symbol.name(); + utils::PrintIterable(*out_, op.common_.edge_types, "|", [this](auto &stream, const auto &edge_type) { + stream << ":" << dba_->EdgeTypeToName(edge_type); + }); + *out_ << "]" << (op.common_.direction == query::v2::EdgeAtom::Direction::OUT ? "->" : "-") << "(" + << op.common_.node_symbol.name() << ")"; + }); + return true; +} + +bool PlanPrinter::PreVisit(query::v2::plan::ExpandVariable &op) { + using Type = query::v2::EdgeAtom::Type; + WithPrintLn([&](auto &out) { + *out_ << "* "; + switch (op.type_) { + case Type::DEPTH_FIRST: + *out_ << "ExpandVariable"; + break; + case Type::BREADTH_FIRST: + *out_ << (op.common_.existing_node ? "STShortestPath" : "BFSExpand"); + break; + case Type::WEIGHTED_SHORTEST_PATH: + *out_ << "WeightedShortestPath"; + break; + case Type::SINGLE: + LOG_FATAL("Unexpected ExpandVariable::type_"); + } + *out_ << " (" << op.input_symbol_.name() << ")" + << (op.common_.direction == query::v2::EdgeAtom::Direction::IN ? "<-" : "-") << "[" + << op.common_.edge_symbol.name(); + utils::PrintIterable(*out_, op.common_.edge_types, "|", [this](auto &stream, const auto &edge_type) { + stream << ":" << dba_->EdgeTypeToName(edge_type); + }); + *out_ << "]" << (op.common_.direction == query::v2::EdgeAtom::Direction::OUT ? "->" : "-") << "(" + << op.common_.node_symbol.name() << ")"; + }); + return true; +} + +bool PlanPrinter::PreVisit(query::v2::plan::Produce &op) { + WithPrintLn([&](auto &out) { + out << "* Produce {"; + utils::PrintIterable(out, op.named_expressions_, ", ", [](auto &out, const auto &nexpr) { out << nexpr->name_; }); + out << "}"; + }); + return true; +} + +PRE_VISIT(ConstructNamedPath); +PRE_VISIT(Filter); +PRE_VISIT(SetProperty); +PRE_VISIT(SetProperties); +PRE_VISIT(SetLabels); +PRE_VISIT(RemoveProperty); +PRE_VISIT(RemoveLabels); +PRE_VISIT(EdgeUniquenessFilter); +PRE_VISIT(Accumulate); + +bool PlanPrinter::PreVisit(query::v2::plan::Aggregate &op) { + WithPrintLn([&](auto &out) { + out << "* Aggregate {"; + utils::PrintIterable(out, op.aggregations_, ", ", + [](auto &out, const auto &aggr) { out << aggr.output_sym.name(); }); + out << "} {"; + utils::PrintIterable(out, op.remember_, ", ", [](auto &out, const auto &sym) { out << sym.name(); }); + out << "}"; + }); + return true; +} + +PRE_VISIT(Skip); +PRE_VISIT(Limit); + +bool PlanPrinter::PreVisit(query::v2::plan::OrderBy &op) { + WithPrintLn([&op](auto &out) { + out << "* OrderBy {"; + utils::PrintIterable(out, op.output_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); }); + out << "}"; + }); + return true; +} + +bool PlanPrinter::PreVisit(query::v2::plan::Merge &op) { + WithPrintLn([](auto &out) { out << "* Merge"; }); + Branch(*op.merge_match_, "On Match"); + Branch(*op.merge_create_, "On Create"); + op.input_->Accept(*this); + return false; +} + +bool PlanPrinter::PreVisit(query::v2::plan::Optional &op) { + WithPrintLn([](auto &out) { out << "* Optional"; }); + Branch(*op.optional_); + op.input_->Accept(*this); + return false; +} + +PRE_VISIT(Unwind); +PRE_VISIT(Distinct); + +bool PlanPrinter::PreVisit(query::v2::plan::Union &op) { + WithPrintLn([&op](auto &out) { + out << "* Union {"; + utils::PrintIterable(out, op.left_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); }); + out << " : "; + utils::PrintIterable(out, op.right_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); }); + out << "}"; + }); + Branch(*op.right_op_); + op.left_op_->Accept(*this); + return false; +} + +bool PlanPrinter::PreVisit(query::v2::plan::CallProcedure &op) { + WithPrintLn([&op](auto &out) { + out << "* CallProcedure<" << op.procedure_name_ << "> {"; + utils::PrintIterable(out, op.result_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); }); + out << "}"; + }); + return true; +} + +bool PlanPrinter::PreVisit(query::v2::plan::LoadCsv &op) { + WithPrintLn([&op](auto &out) { out << "* LoadCsv {" << op.row_var_.name() << "}"; }); + return true; +} + +bool PlanPrinter::Visit(query::v2::plan::Once & /*op*/) { + WithPrintLn([](auto &out) { out << "* Once"; }); + return true; +} + +bool PlanPrinter::PreVisit(query::v2::plan::Cartesian &op) { + WithPrintLn([&op](auto &out) { + out << "* Cartesian {"; + utils::PrintIterable(out, op.left_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); }); + out << " : "; + utils::PrintIterable(out, op.right_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); }); + out << "}"; + }); + Branch(*op.right_op_); + op.left_op_->Accept(*this); + return false; +} + +bool PlanPrinter::PreVisit(query::v2::plan::Foreach &op) { + WithPrintLn([](auto &out) { out << "* Foreach"; }); + Branch(*op.update_clauses_); + op.input_->Accept(*this); + return false; +} +#undef PRE_VISIT + +bool PlanPrinter::DefaultPreVisit() { + WithPrintLn([](auto &out) { out << "* Unknown operator!"; }); + return true; +} + +void PlanPrinter::Branch(query::v2::plan::LogicalOperator &op, const std::string &branch_name) { + WithPrintLn([&](auto &out) { out << "|\\ " << branch_name; }); + ++depth_; + op.Accept(*this); + --depth_; +} + +void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root, std::ostream *out) { + PlanPrinter printer(&dba, out); + // FIXME(mtomic): We should make visitors that take const arguments. + const_cast<LogicalOperator *>(plan_root)->Accept(printer); +} + +nlohmann::json PlanToJson(const DbAccessor &dba, const LogicalOperator *plan_root) { + impl::PlanToJsonVisitor visitor(&dba); + // FIXME(mtomic): We should make visitors that take const arguments. + const_cast<LogicalOperator *>(plan_root)->Accept(visitor); + return visitor.output(); +} + +namespace impl { + +/////////////////////////////////////////////////////////////////////////////// +// +// PlanToJsonVisitor implementation +// +// The JSON formatted plan is consumed (or will be) by Memgraph Lab, and +// therefore should not be changed before synchronizing with whoever is +// maintaining Memgraph Lab. Hopefully, one day integration tests will exist and +// there will be no need to be super careful. + +using nlohmann::json; + +//////////////////////////// HELPER FUNCTIONS ///////////////////////////////// +// TODO: It would be nice to have enum->string functions auto-generated. +std::string ToString(EdgeAtom::Direction dir) { + switch (dir) { + case EdgeAtom::Direction::BOTH: + return "both"; + case EdgeAtom::Direction::IN: + return "in"; + case EdgeAtom::Direction::OUT: + return "out"; + } +} + +std::string ToString(EdgeAtom::Type type) { + switch (type) { + case EdgeAtom::Type::BREADTH_FIRST: + return "bfs"; + case EdgeAtom::Type::DEPTH_FIRST: + return "dfs"; + case EdgeAtom::Type::WEIGHTED_SHORTEST_PATH: + return "wsp"; + case EdgeAtom::Type::SINGLE: + return "single"; + } +} + +std::string ToString(Ordering ord) { + switch (ord) { + case Ordering::ASC: + return "asc"; + case Ordering::DESC: + return "desc"; + } +} + +json ToJson(Expression *expression) { + std::stringstream sstr; + PrintExpression(expression, &sstr); + return sstr.str(); +} + +json ToJson(const utils::Bound<Expression *> &bound) { + json json; + switch (bound.type()) { + case utils::BoundType::INCLUSIVE: + json["type"] = "inclusive"; + break; + case utils::BoundType::EXCLUSIVE: + json["type"] = "exclusive"; + break; + } + + json["value"] = ToJson(bound.value()); + + return json; +} + +json ToJson(const Symbol &symbol) { return symbol.name(); } + +json ToJson(storage::v3::EdgeTypeId edge_type, const DbAccessor &dba) { return dba.EdgeTypeToName(edge_type); } + +json ToJson(storage::v3::LabelId label, const DbAccessor &dba) { return dba.LabelToName(label); } + +json ToJson(storage::v3::PropertyId property, const DbAccessor &dba) { return dba.PropertyToName(property); } + +json ToJson(NamedExpression *nexpr) { + json json; + json["expression"] = ToJson(nexpr->expression_); + json["name"] = nexpr->name_; + return json; +} + +json ToJson(const std::vector<std::pair<storage::v3::PropertyId, Expression *>> &properties, const DbAccessor &dba) { + json json; + for (const auto &prop_pair : properties) { + json.emplace(ToJson(prop_pair.first, dba), ToJson(prop_pair.second)); + } + return json; +} + +json ToJson(const NodeCreationInfo &node_info, const DbAccessor &dba) { + json self; + self["symbol"] = ToJson(node_info.symbol); + self["labels"] = ToJson(node_info.labels, dba); + const auto *props = std::get_if<PropertiesMapList>(&node_info.properties); + self["properties"] = ToJson(props ? *props : PropertiesMapList{}, dba); + return self; +} + +json ToJson(const EdgeCreationInfo &edge_info, const DbAccessor &dba) { + json self; + self["symbol"] = ToJson(edge_info.symbol); + const auto *props = std::get_if<PropertiesMapList>(&edge_info.properties); + self["properties"] = ToJson(props ? *props : PropertiesMapList{}, dba); + self["edge_type"] = ToJson(edge_info.edge_type, dba); + self["direction"] = ToString(edge_info.direction); + return self; +} + +json ToJson(const Aggregate::Element &elem) { + json json; + if (elem.value) { + json["value"] = ToJson(elem.value); + } + if (elem.key) { + json["key"] = ToJson(elem.key); + } + json["op"] = utils::ToLowerCase(Aggregation::OpToString(elem.op)); + json["output_symbol"] = ToJson(elem.output_sym); + return json; +} +////////////////////////// END HELPER FUNCTIONS //////////////////////////////// + +bool PlanToJsonVisitor::Visit(Once &) { + json self; + self["name"] = "Once"; + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(ScanAll &op) { + json self; + self["name"] = "ScanAll"; + self["output_symbol"] = ToJson(op.output_symbol_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(ScanAllByLabel &op) { + json self; + self["name"] = "ScanAllByLabel"; + self["label"] = ToJson(op.label_, *dba_); + self["output_symbol"] = ToJson(op.output_symbol_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(ScanAllByLabelPropertyRange &op) { + json self; + self["name"] = "ScanAllByLabelPropertyRange"; + self["label"] = ToJson(op.label_, *dba_); + self["property"] = ToJson(op.property_, *dba_); + self["lower_bound"] = op.lower_bound_ ? ToJson(*op.lower_bound_) : json(); + self["upper_bound"] = op.upper_bound_ ? ToJson(*op.upper_bound_) : json(); + self["output_symbol"] = ToJson(op.output_symbol_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(ScanAllByLabelPropertyValue &op) { + json self; + self["name"] = "ScanAllByLabelPropertyValue"; + self["label"] = ToJson(op.label_, *dba_); + self["property"] = ToJson(op.property_, *dba_); + self["expression"] = ToJson(op.expression_); + self["output_symbol"] = ToJson(op.output_symbol_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(ScanAllByLabelProperty &op) { + json self; + self["name"] = "ScanAllByLabelProperty"; + self["label"] = ToJson(op.label_, *dba_); + self["property"] = ToJson(op.property_, *dba_); + self["output_symbol"] = ToJson(op.output_symbol_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(ScanAllById &op) { + json self; + self["name"] = "ScanAllById"; + self["output_symbol"] = ToJson(op.output_symbol_); + op.input_->Accept(*this); + self["input"] = PopOutput(); + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(CreateNode &op) { + json self; + self["name"] = "CreateNode"; + self["node_info"] = ToJson(op.node_info_, *dba_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(CreateExpand &op) { + json self; + self["name"] = "CreateExpand"; + self["input_symbol"] = ToJson(op.input_symbol_); + self["node_info"] = ToJson(op.node_info_, *dba_); + self["edge_info"] = ToJson(op.edge_info_, *dba_); + self["existing_node"] = op.existing_node_; + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Expand &op) { + json self; + self["name"] = "Expand"; + self["input_symbol"] = ToJson(op.input_symbol_); + self["node_symbol"] = ToJson(op.common_.node_symbol); + self["edge_symbol"] = ToJson(op.common_.edge_symbol); + self["edge_types"] = ToJson(op.common_.edge_types, *dba_); + self["direction"] = ToString(op.common_.direction); + self["existing_node"] = op.common_.existing_node; + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(ExpandVariable &op) { + json self; + self["name"] = "ExpandVariable"; + self["input_symbol"] = ToJson(op.input_symbol_); + self["node_symbol"] = ToJson(op.common_.node_symbol); + self["edge_symbol"] = ToJson(op.common_.edge_symbol); + self["edge_types"] = ToJson(op.common_.edge_types, *dba_); + self["direction"] = ToString(op.common_.direction); + self["type"] = ToString(op.type_); + self["is_reverse"] = op.is_reverse_; + self["lower_bound"] = op.lower_bound_ ? ToJson(op.lower_bound_) : json(); + self["upper_bound"] = op.upper_bound_ ? ToJson(op.upper_bound_) : json(); + self["existing_node"] = op.common_.existing_node; + + self["filter_lambda"] = op.filter_lambda_.expression ? ToJson(op.filter_lambda_.expression) : json(); + + if (op.type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) { + self["weight_lambda"] = ToJson(op.weight_lambda_->expression); + self["total_weight_symbol"] = ToJson(*op.total_weight_); + } + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(ConstructNamedPath &op) { + json self; + self["name"] = "ConstructNamedPath"; + self["path_symbol"] = ToJson(op.path_symbol_); + self["path_elements"] = ToJson(op.path_elements_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Filter &op) { + json self; + self["name"] = "Filter"; + self["expression"] = ToJson(op.expression_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Produce &op) { + json self; + self["name"] = "Produce"; + self["named_expressions"] = ToJson(op.named_expressions_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Delete &op) { + json self; + self["name"] = "Delete"; + self["expressions"] = ToJson(op.expressions_); + self["detach"] = op.detach_; + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(SetProperty &op) { + json self; + self["name"] = "SetProperty"; + self["property"] = ToJson(op.property_, *dba_); + self["lhs"] = ToJson(op.lhs_); + self["rhs"] = ToJson(op.rhs_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(SetProperties &op) { + json self; + self["name"] = "SetProperties"; + self["input_symbol"] = ToJson(op.input_symbol_); + self["rhs"] = ToJson(op.rhs_); + + switch (op.op_) { + case SetProperties::Op::UPDATE: + self["op"] = "update"; + break; + case SetProperties::Op::REPLACE: + self["op"] = "replace"; + break; + } + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(SetLabels &op) { + json self; + self["name"] = "SetLabels"; + self["input_symbol"] = ToJson(op.input_symbol_); + self["labels"] = ToJson(op.labels_, *dba_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(RemoveProperty &op) { + json self; + self["name"] = "RemoveProperty"; + self["property"] = ToJson(op.property_, *dba_); + self["lhs"] = ToJson(op.lhs_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(RemoveLabels &op) { + json self; + self["name"] = "RemoveLabels"; + self["input_symbol"] = ToJson(op.input_symbol_); + self["labels"] = ToJson(op.labels_, *dba_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(EdgeUniquenessFilter &op) { + json self; + self["name"] = "EdgeUniquenessFilter"; + self["expand_symbol"] = ToJson(op.expand_symbol_); + self["previous_symbols"] = ToJson(op.previous_symbols_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Accumulate &op) { + json self; + self["name"] = "Accumulate"; + self["symbols"] = ToJson(op.symbols_); + self["advance_command"] = op.advance_command_; + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Aggregate &op) { + json self; + self["name"] = "Aggregate"; + self["aggregations"] = ToJson(op.aggregations_); + self["group_by"] = ToJson(op.group_by_); + self["remember"] = ToJson(op.remember_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Skip &op) { + json self; + self["name"] = "Skip"; + self["expression"] = ToJson(op.expression_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Limit &op) { + json self; + self["name"] = "Limit"; + self["expression"] = ToJson(op.expression_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(OrderBy &op) { + json self; + self["name"] = "OrderBy"; + + for (auto i = 0; i < op.order_by_.size(); ++i) { + json json; + json["ordering"] = ToString(op.compare_.ordering_[i]); + json["expression"] = ToJson(op.order_by_[i]); + self["order_by"].push_back(json); + } + self["output_symbols"] = ToJson(op.output_symbols_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Merge &op) { + json self; + self["name"] = "Merge"; + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + op.merge_match_->Accept(*this); + self["merge_match"] = PopOutput(); + + op.merge_create_->Accept(*this); + self["merge_create"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Optional &op) { + json self; + self["name"] = "Optional"; + self["optional_symbols"] = ToJson(op.optional_symbols_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + op.optional_->Accept(*this); + self["optional"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Unwind &op) { + json self; + self["name"] = "Unwind"; + self["output_symbol"] = ToJson(op.output_symbol_); + self["input_expression"] = ToJson(op.input_expression_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(query::v2::plan::CallProcedure &op) { + json self; + self["name"] = "CallProcedure"; + self["procedure_name"] = op.procedure_name_; + self["arguments"] = ToJson(op.arguments_); + self["result_fields"] = op.result_fields_; + self["result_symbols"] = ToJson(op.result_symbols_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(query::v2::plan::LoadCsv &op) { + json self; + self["name"] = "LoadCsv"; + self["file"] = ToJson(op.file_); + self["with_header"] = op.with_header_; + self["ignore_bad"] = op.ignore_bad_; + self["delimiter"] = ToJson(op.delimiter_); + self["quote"] = ToJson(op.quote_); + self["row_variable"] = ToJson(op.row_var_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Distinct &op) { + json self; + self["name"] = "Distinct"; + self["value_symbols"] = ToJson(op.value_symbols_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Union &op) { + json self; + self["name"] = "Union"; + self["union_symbols"] = ToJson(op.union_symbols_); + self["left_symbols"] = ToJson(op.left_symbols_); + self["right_symbols"] = ToJson(op.right_symbols_); + + op.left_op_->Accept(*this); + self["left_op"] = PopOutput(); + + op.right_op_->Accept(*this); + self["right_op"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +bool PlanToJsonVisitor::PreVisit(Cartesian &op) { + json self; + self["name"] = "Cartesian"; + self["left_symbols"] = ToJson(op.left_symbols_); + self["right_symbols"] = ToJson(op.right_symbols_); + + op.left_op_->Accept(*this); + self["left_op"] = PopOutput(); + + op.right_op_->Accept(*this); + self["right_op"] = PopOutput(); + + output_ = std::move(self); + return false; +} +bool PlanToJsonVisitor::PreVisit(Foreach &op) { + json self; + self["name"] = "Foreach"; + self["loop_variable_symbol"] = ToJson(op.loop_variable_symbol_); + self["expression"] = ToJson(op.expression_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + op.update_clauses_->Accept(*this); + self["update_clauses"] = PopOutput(); + + output_ = std::move(self); + return false; +} + +} // namespace impl + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/pretty_print.hpp b/src/query/v2/plan/pretty_print.hpp new file mode 100644 index 000000000..5708a97c5 --- /dev/null +++ b/src/query/v2/plan/pretty_print.hpp @@ -0,0 +1,230 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +#pragma once + +#include <iostream> + +#include <json/json.hpp> + +#include "query/v2/plan/operator.hpp" + +namespace memgraph::query::v2 { +class DbAccessor; + +namespace plan { + +class LogicalOperator; + +/// Pretty print a `LogicalOperator` plan to a `std::ostream`. +/// DbAccessor is needed for resolving label and property names. +/// Note that `plan_root` isn't modified, but we can't take it as a const +/// because we don't have support for visiting a const LogicalOperator. +void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root, std::ostream *out); + +/// Overload of `PrettyPrint` which defaults the `std::ostream` to `std::cout`. +inline void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root) { + PrettyPrint(dba, plan_root, &std::cout); +} + +/// Convert a `LogicalOperator` plan to a JSON representation. +/// DbAccessor is needed for resolving label and property names. +nlohmann::json PlanToJson(const DbAccessor &dba, const LogicalOperator *plan_root); + +class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor { + public: + using HierarchicalLogicalOperatorVisitor::PostVisit; + using HierarchicalLogicalOperatorVisitor::PreVisit; + using HierarchicalLogicalOperatorVisitor::Visit; + + PlanPrinter(const DbAccessor *dba, std::ostream *out); + + bool DefaultPreVisit() override; + + bool PreVisit(CreateNode &) override; + bool PreVisit(CreateExpand &) override; + bool PreVisit(Delete &) override; + + bool PreVisit(SetProperty &) override; + bool PreVisit(SetProperties &) override; + bool PreVisit(SetLabels &) override; + + bool PreVisit(RemoveProperty &) override; + bool PreVisit(RemoveLabels &) override; + + bool PreVisit(ScanAll &) override; + bool PreVisit(ScanAllByLabel &) override; + bool PreVisit(ScanAllByLabelPropertyValue &) override; + bool PreVisit(ScanAllByLabelPropertyRange &) override; + bool PreVisit(ScanAllByLabelProperty &) override; + bool PreVisit(ScanAllById &) override; + + bool PreVisit(Expand &) override; + bool PreVisit(ExpandVariable &) override; + + bool PreVisit(ConstructNamedPath &) override; + + bool PreVisit(Filter &) override; + bool PreVisit(EdgeUniquenessFilter &) override; + + bool PreVisit(Merge &) override; + bool PreVisit(Optional &) override; + bool PreVisit(Cartesian &) override; + + bool PreVisit(Produce &) override; + bool PreVisit(Accumulate &) override; + bool PreVisit(Aggregate &) override; + bool PreVisit(Skip &) override; + bool PreVisit(Limit &) override; + bool PreVisit(OrderBy &) override; + bool PreVisit(Distinct &) override; + bool PreVisit(Union &) override; + + bool PreVisit(Unwind &) override; + bool PreVisit(CallProcedure &) override; + bool PreVisit(LoadCsv &) override; + bool PreVisit(Foreach &) override; + + bool Visit(Once &) override; + + /// Call fun with output stream. The stream is prefixed with amount of spaces + /// corresponding to the current depth_. + template <class TFun> + void WithPrintLn(TFun fun) { + *out_ << " "; + for (int64_t i = 0; i < depth_; ++i) { + *out_ << "| "; + } + fun(*out_); + *out_ << std::endl; + } + + /// Forward this printer to another operator branch by incrementing the depth + /// and printing the branch name. + void Branch(LogicalOperator &op, const std::string &branch_name = ""); + + int64_t depth_{0}; + const DbAccessor *dba_{nullptr}; + std::ostream *out_{nullptr}; +}; + +namespace impl { + +std::string ToString(EdgeAtom::Direction dir); + +std::string ToString(EdgeAtom::Type type); + +std::string ToString(Ordering ord); + +nlohmann::json ToJson(Expression *expression); + +nlohmann::json ToJson(const utils::Bound<Expression *> &bound); + +nlohmann::json ToJson(const Symbol &symbol); + +nlohmann::json ToJson(storage::v3::EdgeTypeId edge_type, const DbAccessor &dba); + +nlohmann::json ToJson(storage::v3::LabelId label, const DbAccessor &dba); + +nlohmann::json ToJson(storage::v3::PropertyId property, const DbAccessor &dba); + +nlohmann::json ToJson(NamedExpression *nexpr); + +nlohmann::json ToJson(const std::vector<std::pair<storage::v3::PropertyId, Expression *>> &properties, + const DbAccessor &dba); + +nlohmann::json ToJson(const NodeCreationInfo &node_info, const DbAccessor &dba); + +nlohmann::json ToJson(const EdgeCreationInfo &edge_info, const DbAccessor &dba); + +nlohmann::json ToJson(const Aggregate::Element &elem); + +template <class T, class... Args> +nlohmann::json ToJson(const std::vector<T> &items, Args &&...args) { + nlohmann::json json; + for (const auto &item : items) { + json.emplace_back(ToJson(item, std::forward<Args>(args)...)); + } + return json; +} + +class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor { + public: + explicit PlanToJsonVisitor(const DbAccessor *dba) : dba_(dba) {} + + using HierarchicalLogicalOperatorVisitor::PostVisit; + using HierarchicalLogicalOperatorVisitor::PreVisit; + using HierarchicalLogicalOperatorVisitor::Visit; + + bool PreVisit(CreateNode &) override; + bool PreVisit(CreateExpand &) override; + bool PreVisit(Delete &) override; + + bool PreVisit(SetProperty &) override; + bool PreVisit(SetProperties &) override; + bool PreVisit(SetLabels &) override; + + bool PreVisit(RemoveProperty &) override; + bool PreVisit(RemoveLabels &) override; + + bool PreVisit(Expand &) override; + bool PreVisit(ExpandVariable &) override; + + bool PreVisit(ConstructNamedPath &) override; + + bool PreVisit(Merge &) override; + bool PreVisit(Optional &) override; + + bool PreVisit(Filter &) override; + bool PreVisit(EdgeUniquenessFilter &) override; + bool PreVisit(Cartesian &) override; + + bool PreVisit(ScanAll &) override; + bool PreVisit(ScanAllByLabel &) override; + bool PreVisit(ScanAllByLabelPropertyRange &) override; + bool PreVisit(ScanAllByLabelPropertyValue &) override; + bool PreVisit(ScanAllByLabelProperty &) override; + bool PreVisit(ScanAllById &) override; + + bool PreVisit(Produce &) override; + bool PreVisit(Accumulate &) override; + bool PreVisit(Aggregate &) override; + bool PreVisit(Skip &) override; + bool PreVisit(Limit &) override; + bool PreVisit(OrderBy &) override; + bool PreVisit(Distinct &) override; + bool PreVisit(Union &) override; + + bool PreVisit(Unwind &) override; + bool PreVisit(Foreach &) override; + bool PreVisit(CallProcedure &) override; + bool PreVisit(LoadCsv &) override; + + bool Visit(Once &) override; + + nlohmann::json output() { return output_; } + + protected: + nlohmann::json output_; + const DbAccessor *dba_; + + nlohmann::json PopOutput() { + nlohmann::json tmp; + tmp.swap(output_); + return tmp; + } +}; + +} // namespace impl + +} // namespace plan +} // namespace memgraph::query::v2 diff --git a/src/query/v2/plan/profile.cpp b/src/query/v2/plan/profile.cpp new file mode 100644 index 000000000..26c5280d4 --- /dev/null +++ b/src/query/v2/plan/profile.cpp @@ -0,0 +1,166 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/plan/profile.hpp" + +#include <algorithm> +#include <chrono> + +#include <fmt/format.h> +#include <json/json.hpp> + +#include "query/v2/context.hpp" +#include "utils/likely.hpp" + +namespace memgraph::query::v2::plan { + +namespace { + +unsigned long long IndividualCycles(const ProfilingStats &cumulative_stats) { + return cumulative_stats.num_cycles - std::accumulate(cumulative_stats.children.begin(), + cumulative_stats.children.end(), 0ULL, + [](auto acc, auto &stats) { return acc + stats.num_cycles; }); +} + +double RelativeTime(unsigned long long num_cycles, unsigned long long total_cycles) { + return static_cast<double>(num_cycles) / total_cycles; +} + +double AbsoluteTime(unsigned long long num_cycles, unsigned long long total_cycles, + std::chrono::duration<double> total_time) { + return (RelativeTime(num_cycles, total_cycles) * static_cast<std::chrono::duration<double, std::milli>>(total_time)) + .count(); +} + +} // namespace + +////////////////////////////////////////////////////////////////////////////// +// +// ProfilingStatsToTable + +namespace { + +class ProfilingStatsToTableHelper { + public: + ProfilingStatsToTableHelper(unsigned long long total_cycles, std::chrono::duration<double> total_time) + : total_cycles_(total_cycles), total_time_(total_time) {} + + void Output(const ProfilingStats &cumulative_stats) { + auto cycles = IndividualCycles(cumulative_stats); + + rows_.emplace_back(std::vector<TypedValue>{ + TypedValue(FormatOperator(cumulative_stats.name)), TypedValue(cumulative_stats.actual_hits), + TypedValue(FormatRelativeTime(cycles)), TypedValue(FormatAbsoluteTime(cycles))}); + + for (size_t i = 1; i < cumulative_stats.children.size(); ++i) { + Branch(cumulative_stats.children[i]); + } + + if (cumulative_stats.children.size() >= 1) { + Output(cumulative_stats.children[0]); + } + } + + std::vector<std::vector<TypedValue>> rows() { return rows_; } + + private: + void Branch(const ProfilingStats &cumulative_stats) { + rows_.emplace_back(std::vector<TypedValue>{TypedValue("|\\"), TypedValue(""), TypedValue(""), TypedValue("")}); + + ++depth_; + Output(cumulative_stats); + --depth_; + } + + std::string Format(const char *str) { + std::ostringstream ss; + for (int64_t i = 0; i < depth_; ++i) { + ss << "| "; + } + ss << str; + return ss.str(); + } + + std::string Format(const std::string &str) { return Format(str.c_str()); } + + std::string FormatOperator(const char *str) { return Format(std::string("* ") + str); } + + std::string FormatRelativeTime(unsigned long long num_cycles) { + return fmt::format("{: 10.6f} %", RelativeTime(num_cycles, total_cycles_) * 100); + } + + std::string FormatAbsoluteTime(unsigned long long num_cycles) { + return fmt::format("{: 10.6f} ms", AbsoluteTime(num_cycles, total_cycles_, total_time_)); + } + + int64_t depth_{0}; + std::vector<std::vector<TypedValue>> rows_; + unsigned long long total_cycles_; + std::chrono::duration<double> total_time_; +}; + +} // namespace + +std::vector<std::vector<TypedValue>> ProfilingStatsToTable(const ProfilingStatsWithTotalTime &stats) { + ProfilingStatsToTableHelper helper{stats.cumulative_stats.num_cycles, stats.total_time}; + helper.Output(stats.cumulative_stats); + return helper.rows(); +} + +////////////////////////////////////////////////////////////////////////////// +// +// ProfilingStatsToJson + +namespace { + +class ProfilingStatsToJsonHelper { + private: + using json = nlohmann::json; + + public: + ProfilingStatsToJsonHelper(unsigned long long total_cycles, std::chrono::duration<double> total_time) + : total_cycles_(total_cycles), total_time_(total_time) {} + + void Output(const ProfilingStats &cumulative_stats) { return Output(cumulative_stats, &json_); } + + json ToJson() { return json_; } + + private: + void Output(const ProfilingStats &cumulative_stats, json *obj) { + auto cycles = IndividualCycles(cumulative_stats); + + obj->emplace("name", cumulative_stats.name); + obj->emplace("actual_hits", cumulative_stats.actual_hits); + obj->emplace("relative_time", RelativeTime(cycles, total_cycles_)); + obj->emplace("absolute_time", AbsoluteTime(cycles, total_cycles_, total_time_)); + obj->emplace("children", json::array()); + + for (size_t i = 0; i < cumulative_stats.children.size(); ++i) { + json child; + Output(cumulative_stats.children[i], &child); + obj->at("children").emplace_back(std::move(child)); + } + } + + json json_; + unsigned long long total_cycles_; + std::chrono::duration<double> total_time_; +}; + +} // namespace + +nlohmann::json ProfilingStatsToJson(const ProfilingStatsWithTotalTime &stats) { + ProfilingStatsToJsonHelper helper{stats.cumulative_stats.num_cycles, stats.total_time}; + helper.Output(stats.cumulative_stats); + return helper.ToJson(); +} + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/profile.hpp b/src/query/v2/plan/profile.hpp new file mode 100644 index 000000000..a84cc94c6 --- /dev/null +++ b/src/query/v2/plan/profile.hpp @@ -0,0 +1,47 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <cstdint> +#include <vector> + +#include <json/json.hpp> + +#include "query/v2/typed_value.hpp" + +namespace memgraph::query::v2 { + +namespace plan { + +/** + * Stores profiling statistics for a single logical operator. + */ +struct ProfilingStats { + int64_t actual_hits{0}; + unsigned long long num_cycles{0}; + uint64_t key{0}; + const char *name{nullptr}; + // TODO: This should use the allocator for query execution + std::vector<ProfilingStats> children; +}; + +struct ProfilingStatsWithTotalTime { + ProfilingStats cumulative_stats{}; + std::chrono::duration<double> total_time{}; +}; + +std::vector<std::vector<TypedValue>> ProfilingStatsToTable(const ProfilingStatsWithTotalTime &stats); + +nlohmann::json ProfilingStatsToJson(const ProfilingStatsWithTotalTime &stats); + +} // namespace plan +} // namespace memgraph::query::v2 diff --git a/src/query/v2/plan/read_write_type_checker.cpp b/src/query/v2/plan/read_write_type_checker.cpp new file mode 100644 index 000000000..6cc38cedf --- /dev/null +++ b/src/query/v2/plan/read_write_type_checker.cpp @@ -0,0 +1,128 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/plan/read_write_type_checker.hpp" + +#define PRE_VISIT(TOp, RWType, continue_visiting) \ + bool ReadWriteTypeChecker::PreVisit(TOp &op) { \ + UpdateType(RWType); \ + return continue_visiting; \ + } + +namespace memgraph::query::v2::plan { + +PRE_VISIT(CreateNode, RWType::W, true) +PRE_VISIT(CreateExpand, RWType::R, true) +PRE_VISIT(Delete, RWType::W, true) + +PRE_VISIT(SetProperty, RWType::W, true) +PRE_VISIT(SetProperties, RWType::W, true) +PRE_VISIT(SetLabels, RWType::W, true) + +PRE_VISIT(RemoveProperty, RWType::W, true) +PRE_VISIT(RemoveLabels, RWType::W, true) + +PRE_VISIT(ScanAll, RWType::R, true) +PRE_VISIT(ScanAllByLabel, RWType::R, true) +PRE_VISIT(ScanAllByLabelPropertyRange, RWType::R, true) +PRE_VISIT(ScanAllByLabelPropertyValue, RWType::R, true) +PRE_VISIT(ScanAllByLabelProperty, RWType::R, true) +PRE_VISIT(ScanAllById, RWType::R, true) + +PRE_VISIT(Expand, RWType::R, true) +PRE_VISIT(ExpandVariable, RWType::R, true) + +PRE_VISIT(ConstructNamedPath, RWType::R, true) + +PRE_VISIT(Filter, RWType::NONE, true) +PRE_VISIT(EdgeUniquenessFilter, RWType::NONE, true) + +PRE_VISIT(Merge, RWType::RW, false) +PRE_VISIT(Optional, RWType::NONE, true) + +bool ReadWriteTypeChecker::PreVisit(Cartesian &op) { + op.left_op_->Accept(*this); + op.right_op_->Accept(*this); + return false; +} + +PRE_VISIT(Produce, RWType::NONE, true) +PRE_VISIT(Accumulate, RWType::NONE, true) +PRE_VISIT(Aggregate, RWType::NONE, true) +PRE_VISIT(Skip, RWType::NONE, true) +PRE_VISIT(Limit, RWType::NONE, true) +PRE_VISIT(OrderBy, RWType::NONE, true) +PRE_VISIT(Distinct, RWType::NONE, true) + +bool ReadWriteTypeChecker::PreVisit(Union &op) { + op.left_op_->Accept(*this); + op.right_op_->Accept(*this); + return false; +} + +PRE_VISIT(Unwind, RWType::NONE, true) + +bool ReadWriteTypeChecker::PreVisit(CallProcedure &op) { + if (op.is_write_) { + UpdateType(RWType::RW); + return false; + } + UpdateType(RWType::R); + return true; +} + +bool ReadWriteTypeChecker::PreVisit([[maybe_unused]] Foreach &op) { + UpdateType(RWType::RW); + return false; +} + +#undef PRE_VISIT + +bool ReadWriteTypeChecker::Visit(Once &op) { return false; } + +void ReadWriteTypeChecker::UpdateType(RWType op_type) { + // Update type only if it's not the NONE type and the current operator's type + // is different than the one that's currently inferred. + if (type != RWType::NONE && type != op_type) { + type = RWType::RW; + } + // Stop inference because RW is the most "dominant" type, i.e. it isn't + // affected by the type of nodes in the plan appearing after the node for + // which the type is set to RW. + if (type == RWType::RW) { + return; + } + if (type == RWType::NONE && op_type != RWType::NONE) { + type = op_type; + } +} + +void ReadWriteTypeChecker::InferRWType(LogicalOperator &root) { root.Accept(*this); } + +std::string ReadWriteTypeChecker::TypeToString(const RWType type) { + switch (type) { + // Unfortunately, neo4j Java drivers do not allow query types that differ + // from the ones defined by neo4j. We'll keep using the NONE type internally + // but we'll convert it to "rw" to keep in line with the neo4j definition. + // Oddly enough, but not surprisingly, Python drivers don't have any problems + // with non-neo4j query types. + case RWType::NONE: + return "rw"; + case RWType::R: + return "r"; + case RWType::W: + return "w"; + case RWType::RW: + return "rw"; + } +} + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/read_write_type_checker.hpp b/src/query/v2/plan/read_write_type_checker.hpp new file mode 100644 index 000000000..a3c2f1a46 --- /dev/null +++ b/src/query/v2/plan/read_write_type_checker.hpp @@ -0,0 +1,95 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "query/v2/plan/operator.hpp" + +namespace memgraph::query::v2::plan { + +class ReadWriteTypeChecker : public virtual HierarchicalLogicalOperatorVisitor { + public: + ReadWriteTypeChecker() = default; + + ReadWriteTypeChecker(const ReadWriteTypeChecker &) = delete; + ReadWriteTypeChecker(ReadWriteTypeChecker &&) = delete; + + ReadWriteTypeChecker &operator=(const ReadWriteTypeChecker &) = delete; + ReadWriteTypeChecker &operator=(ReadWriteTypeChecker &&) = delete; + + using HierarchicalLogicalOperatorVisitor::PostVisit; + using HierarchicalLogicalOperatorVisitor::PreVisit; + using HierarchicalLogicalOperatorVisitor::Visit; + + // NONE type describes an operator whose action neither reads nor writes from + // the database (e.g. Produce or Once). + // R type describes an operator whose action involves reading from the + // database. + // W type describes an operator whose action involves writing to the + // database. + // RW type describes an operator whose action involves both reading and + // writing to the database. + enum class RWType : uint8_t { NONE, R, W, RW }; + + RWType type{RWType::NONE}; + void InferRWType(LogicalOperator &root); + static std::string TypeToString(const RWType type); + + bool PreVisit(CreateNode &) override; + bool PreVisit(CreateExpand &) override; + bool PreVisit(Delete &) override; + + bool PreVisit(SetProperty &) override; + bool PreVisit(SetProperties &) override; + bool PreVisit(SetLabels &) override; + + bool PreVisit(RemoveProperty &) override; + bool PreVisit(RemoveLabels &) override; + + bool PreVisit(ScanAll &) override; + bool PreVisit(ScanAllByLabel &) override; + bool PreVisit(ScanAllByLabelPropertyValue &) override; + bool PreVisit(ScanAllByLabelPropertyRange &) override; + bool PreVisit(ScanAllByLabelProperty &) override; + bool PreVisit(ScanAllById &) override; + + bool PreVisit(Expand &) override; + bool PreVisit(ExpandVariable &) override; + + bool PreVisit(ConstructNamedPath &) override; + + bool PreVisit(Filter &) override; + bool PreVisit(EdgeUniquenessFilter &) override; + + bool PreVisit(Merge &) override; + bool PreVisit(Optional &) override; + bool PreVisit(Cartesian &) override; + + bool PreVisit(Produce &) override; + bool PreVisit(Accumulate &) override; + bool PreVisit(Aggregate &) override; + bool PreVisit(Skip &) override; + bool PreVisit(Limit &) override; + bool PreVisit(OrderBy &) override; + bool PreVisit(Distinct &) override; + bool PreVisit(Union &) override; + + bool PreVisit(Unwind &) override; + bool PreVisit(CallProcedure &) override; + bool PreVisit(Foreach &) override; + + bool Visit(Once &) override; + + private: + void UpdateType(RWType op_type); +}; + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/rewrite/index_lookup.cpp b/src/query/v2/plan/rewrite/index_lookup.cpp new file mode 100644 index 000000000..795f15fa4 --- /dev/null +++ b/src/query/v2/plan/rewrite/index_lookup.cpp @@ -0,0 +1,50 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/plan/rewrite/index_lookup.hpp" + +#include "utils/flag_validation.hpp" + +DEFINE_VALIDATED_HIDDEN_int64(query_vertex_count_to_expand_existing, 10, + "Maximum count of indexed vertices which provoke " + "indexed lookup and then expand to existing, instead of " + "a regular expand. Default is 10, to turn off use -1.", + FLAG_IN_RANGE(-1, std::numeric_limits<std::int64_t>::max())); + +namespace memgraph::query::v2::plan::impl { + +Expression *RemoveAndExpressions(Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove) { + auto *and_op = utils::Downcast<AndOperator>(expr); + if (!and_op) return expr; + if (utils::Contains(exprs_to_remove, and_op)) { + return nullptr; + } + if (utils::Contains(exprs_to_remove, and_op->expression1_)) { + and_op->expression1_ = nullptr; + } + if (utils::Contains(exprs_to_remove, and_op->expression2_)) { + and_op->expression2_ = nullptr; + } + and_op->expression1_ = RemoveAndExpressions(and_op->expression1_, exprs_to_remove); + and_op->expression2_ = RemoveAndExpressions(and_op->expression2_, exprs_to_remove); + if (!and_op->expression1_ && !and_op->expression2_) { + return nullptr; + } + if (and_op->expression1_ && !and_op->expression2_) { + return and_op->expression1_; + } + if (and_op->expression2_ && !and_op->expression1_) { + return and_op->expression2_; + } + return and_op; +} + +} // namespace memgraph::query::v2::plan::impl diff --git a/src/query/v2/plan/rewrite/index_lookup.hpp b/src/query/v2/plan/rewrite/index_lookup.hpp new file mode 100644 index 000000000..57ddba54e --- /dev/null +++ b/src/query/v2/plan/rewrite/index_lookup.hpp @@ -0,0 +1,668 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +/// This file provides a plan rewriter which replaces `Filter` and `ScanAll` +/// operations with `ScanAllBy<Index>` if possible. The public entrypoint is +/// `RewriteWithIndexLookup`. + +#pragma once + +#include <algorithm> +#include <memory> +#include <optional> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include <gflags/gflags.h> + +#include "query/v2/plan/operator.hpp" +#include "query/v2/plan/preprocess.hpp" + +DECLARE_int64(query_vertex_count_to_expand_existing); + +namespace memgraph::query::v2::plan { + +namespace impl { + +// Return the new root expression after removing the given expressions from the +// given expression tree. +Expression *RemoveAndExpressions(Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove); + +template <class TDbAccessor> +class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { + public: + IndexLookupRewriter(SymbolTable *symbol_table, AstStorage *ast_storage, TDbAccessor *db) + : symbol_table_(symbol_table), ast_storage_(ast_storage), db_(db) {} + + using HierarchicalLogicalOperatorVisitor::PostVisit; + using HierarchicalLogicalOperatorVisitor::PreVisit; + using HierarchicalLogicalOperatorVisitor::Visit; + + bool Visit(Once &) override { return true; } + + bool PreVisit(Filter &op) override { + prev_ops_.push_back(&op); + filters_.CollectFilterExpression(op.expression_, *symbol_table_); + return true; + } + + // Remove no longer needed Filter in PostVisit, this should be the last thing + // Filter::Accept does, so it should be safe to remove the last reference and + // free the memory. + bool PostVisit(Filter &op) override { + prev_ops_.pop_back(); + op.expression_ = RemoveAndExpressions(op.expression_, filter_exprs_for_removal_); + if (!op.expression_ || utils::Contains(filter_exprs_for_removal_, op.expression_)) { + SetOnParent(op.input()); + } + return true; + } + + bool PreVisit(ScanAll &op) override { + prev_ops_.push_back(&op); + return true; + } + + // Replace ScanAll with ScanAllBy<Index> in PostVisit, because removal of + // ScanAll may remove the last reference and thus free the memory. PostVisit + // should be the last thing ScanAll::Accept does, so it should be safe. + bool PostVisit(ScanAll &scan) override { + prev_ops_.pop_back(); + auto indexed_scan = GenScanByIndex(scan); + if (indexed_scan) { + SetOnParent(std::move(indexed_scan)); + } + return true; + } + + bool PreVisit(Expand &op) override { + prev_ops_.push_back(&op); + return true; + } + + // See if it might be better to do ScanAllBy<Index> of the destination and + // then do Expand to existing. + bool PostVisit(Expand &expand) override { + prev_ops_.pop_back(); + if (expand.common_.existing_node) { + return true; + } + ScanAll dst_scan(expand.input(), expand.common_.node_symbol, expand.view_); + auto indexed_scan = GenScanByIndex(dst_scan, FLAGS_query_vertex_count_to_expand_existing); + if (indexed_scan) { + expand.set_input(std::move(indexed_scan)); + expand.common_.existing_node = true; + } + return true; + } + + bool PreVisit(ExpandVariable &op) override { + prev_ops_.push_back(&op); + return true; + } + + // See if it might be better to do ScanAllBy<Index> of the destination and + // then do ExpandVariable to existing. + bool PostVisit(ExpandVariable &expand) override { + prev_ops_.pop_back(); + if (expand.common_.existing_node) { + return true; + } + std::unique_ptr<ScanAll> indexed_scan; + ScanAll dst_scan(expand.input(), expand.common_.node_symbol, storage::v3::View::OLD); + // With expand to existing we only get real gains with BFS, because we use a + // different algorithm then, so prefer expand to existing. + if (expand.type_ == EdgeAtom::Type::BREADTH_FIRST) { + // TODO: Perhaps take average node degree into consideration, instead of + // unconditionally creating an indexed scan. + indexed_scan = GenScanByIndex(dst_scan); + } else { + indexed_scan = GenScanByIndex(dst_scan, FLAGS_query_vertex_count_to_expand_existing); + } + if (indexed_scan) { + expand.set_input(std::move(indexed_scan)); + expand.common_.existing_node = true; + } + return true; + } + + // The following operators may only use index lookup in filters inside of + // their own branches. So we handle them all the same. + // * Input operator is visited with the current visitor. + // * Custom operator branches are visited with a new visitor. + + bool PreVisit(Merge &op) override { + prev_ops_.push_back(&op); + op.input()->Accept(*this); + RewriteBranch(&op.merge_match_); + return false; + } + + bool PostVisit(Merge &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Optional &op) override { + prev_ops_.push_back(&op); + op.input()->Accept(*this); + RewriteBranch(&op.optional_); + return false; + } + + bool PostVisit(Optional &) override { + prev_ops_.pop_back(); + return true; + } + + // Rewriting Cartesian assumes that the input plan will have Filter operations + // as soon as they are possible. Therefore we do not track filters above + // Cartesian because they should be irrelevant. + // + // For example, the following plan is not expected to be an input to + // IndexLookupRewriter. + // + // Filter n.prop = 16 + // | + // Cartesian + // | + // |\ + // | ScanAll (n) + // | + // ScanAll (m) + // + // Instead, the equivalent set of operations should be done this way: + // + // Cartesian + // | + // |\ + // | Filter n.prop = 16 + // | | + // | ScanAll (n) + // | + // ScanAll (m) + bool PreVisit(Cartesian &op) override { + prev_ops_.push_back(&op); + RewriteBranch(&op.left_op_); + RewriteBranch(&op.right_op_); + return false; + } + + bool PostVisit(Cartesian &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Union &op) override { + prev_ops_.push_back(&op); + RewriteBranch(&op.left_op_); + RewriteBranch(&op.right_op_); + return false; + } + + bool PostVisit(Union &) override { + prev_ops_.pop_back(); + return true; + } + + // The remaining operators should work by just traversing into their input. + + bool PreVisit(CreateNode &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(CreateNode &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(CreateExpand &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(CreateExpand &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ScanAllByLabel &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ScanAllByLabel &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ScanAllByLabelPropertyRange &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ScanAllByLabelPropertyRange &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ScanAllByLabelPropertyValue &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ScanAllByLabelPropertyValue &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ScanAllByLabelProperty &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ScanAllByLabelProperty &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ScanAllById &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ScanAllById &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ConstructNamedPath &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ConstructNamedPath &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Produce &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Produce &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Delete &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Delete &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(SetProperty &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(SetProperty &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(SetProperties &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(SetProperties &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(SetLabels &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(SetLabels &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(RemoveProperty &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(RemoveProperty &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(RemoveLabels &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(RemoveLabels &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(EdgeUniquenessFilter &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(EdgeUniquenessFilter &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Accumulate &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Accumulate &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Aggregate &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Aggregate &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Skip &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Skip &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Limit &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Limit &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(OrderBy &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(OrderBy &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Unwind &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Unwind &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Distinct &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Distinct &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(CallProcedure &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(CallProcedure &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Foreach &op) override { + prev_ops_.push_back(&op); + return false; + } + + bool PostVisit(Foreach &) override { + prev_ops_.pop_back(); + return true; + } + + std::shared_ptr<LogicalOperator> new_root_; + + private: + SymbolTable *symbol_table_; + AstStorage *ast_storage_; + TDbAccessor *db_; + // Collected filters, pending for examination if they can be used for advanced + // lookup operations (by index, node ID, ...). + Filters filters_; + // Expressions which no longer need a plain Filter operator. + std::unordered_set<Expression *> filter_exprs_for_removal_; + std::vector<LogicalOperator *> prev_ops_; + + struct LabelPropertyIndex { + LabelIx label; + // FilterInfo with PropertyFilter. + FilterInfo filter; + int64_t vertex_count; + }; + + bool DefaultPreVisit() override { throw utils::NotYetImplemented("optimizing index lookup"); } + + void SetOnParent(const std::shared_ptr<LogicalOperator> &input) { + MG_ASSERT(input); + if (prev_ops_.empty()) { + MG_ASSERT(!new_root_); + new_root_ = input; + return; + } + prev_ops_.back()->set_input(input); + } + + void RewriteBranch(std::shared_ptr<LogicalOperator> *branch) { + IndexLookupRewriter<TDbAccessor> rewriter(symbol_table_, ast_storage_, db_); + (*branch)->Accept(rewriter); + if (rewriter.new_root_) { + *branch = rewriter.new_root_; + } + } + + storage::v3::LabelId GetLabel(LabelIx label) { return db_->NameToLabel(label.name); } + + storage::v3::PropertyId GetProperty(PropertyIx prop) { return db_->NameToProperty(prop.name); } + + std::optional<LabelIx> FindBestLabelIndex(const std::unordered_set<LabelIx> &labels) { + MG_ASSERT(!labels.empty(), "Trying to find the best label without any labels."); + std::optional<LabelIx> best_label; + for (const auto &label : labels) { + if (!db_->LabelIndexExists(GetLabel(label))) continue; + if (!best_label) { + best_label = label; + continue; + } + if (db_->VerticesCount(GetLabel(label)) < db_->VerticesCount(GetLabel(*best_label))) best_label = label; + } + return best_label; + } + + // Finds the label-property combination which has indexed the lowest amount of + // vertices. If the index cannot be found, nullopt is returned. + std::optional<LabelPropertyIndex> FindBestLabelPropertyIndex(const Symbol &symbol, + const std::unordered_set<Symbol> &bound_symbols) { + auto are_bound = [&bound_symbols](const auto &used_symbols) { + for (const auto &used_symbol : used_symbols) { + if (!utils::Contains(bound_symbols, used_symbol)) { + return false; + } + } + return true; + }; + std::optional<LabelPropertyIndex> found; + for (const auto &label : filters_.FilteredLabels(symbol)) { + for (const auto &filter : filters_.PropertyFilters(symbol)) { + if (filter.property_filter->is_symbol_in_value_ || !are_bound(filter.used_symbols)) { + // Skip filter expressions which use the symbol whose property we are + // looking up or aren't bound. We cannot scan by such expressions. For + // example, in `n.a = 2 + n.b` both sides of `=` refer to `n`, so we + // cannot scan `n` by property index. + continue; + } + const auto &property = filter.property_filter->property_; + if (!db_->LabelPropertyIndexExists(GetLabel(label), GetProperty(property))) { + continue; + } + int64_t vertex_count = db_->VerticesCount(GetLabel(label), GetProperty(property)); + auto is_better_type = [&found](PropertyFilter::Type type) { + // Order the types by the most preferred index lookup type. + static const PropertyFilter::Type kFilterTypeOrder[] = { + PropertyFilter::Type::EQUAL, PropertyFilter::Type::RANGE, PropertyFilter::Type::REGEX_MATCH}; + auto *found_sort_ix = std::find(kFilterTypeOrder, kFilterTypeOrder + 3, found->filter.property_filter->type_); + auto *type_sort_ix = std::find(kFilterTypeOrder, kFilterTypeOrder + 3, type); + return type_sort_ix < found_sort_ix; + }; + if (!found || vertex_count < found->vertex_count || + (vertex_count == found->vertex_count && is_better_type(filter.property_filter->type_))) { + found = LabelPropertyIndex{label, filter, vertex_count}; + } + } + } + return found; + } + + // Creates a ScanAll by the best possible index for the `node_symbol`. Best + // index is defined as the index with least number of vertices. If the node + // does not have at least a label, no indexed lookup can be created and + // `nullptr` is returned. The operator is chained after `input`. Optional + // `max_vertex_count` controls, whether no operator should be created if the + // vertex count in the best index exceeds this number. In such a case, + // `nullptr` is returned and `input` is not chained. + std::unique_ptr<ScanAll> GenScanByIndex(const ScanAll &scan, + const std::optional<int64_t> &max_vertex_count = std::nullopt) { + const auto &input = scan.input(); + const auto &node_symbol = scan.output_symbol_; + const auto &view = scan.view_; + const auto &modified_symbols = scan.ModifiedSymbols(*symbol_table_); + std::unordered_set<Symbol> bound_symbols(modified_symbols.begin(), modified_symbols.end()); + auto are_bound = [&bound_symbols](const auto &used_symbols) { + for (const auto &used_symbol : used_symbols) { + if (!utils::Contains(bound_symbols, used_symbol)) { + return false; + } + } + return true; + }; + // First, try to see if we can find a vertex by ID. + if (!max_vertex_count || *max_vertex_count >= 1) { + for (const auto &filter : filters_.IdFilters(node_symbol)) { + if (filter.id_filter->is_symbol_in_value_ || !are_bound(filter.used_symbols)) continue; + auto *value = filter.id_filter->value_; + filter_exprs_for_removal_.insert(filter.expression); + filters_.EraseFilter(filter); + return std::make_unique<ScanAllById>(input, node_symbol, value, view); + } + } + // Now try to see if we can use label+property index. If not, try to use + // just the label index. + const auto labels = filters_.FilteredLabels(node_symbol); + if (labels.empty()) { + // Without labels, we cannot generate any indexed ScanAll. + return nullptr; + } + auto found_index = FindBestLabelPropertyIndex(node_symbol, bound_symbols); + if (found_index && + // Use label+property index if we satisfy max_vertex_count. + (!max_vertex_count || *max_vertex_count >= found_index->vertex_count)) { + // Copy the property filter and then erase it from filters. + const auto prop_filter = *found_index->filter.property_filter; + if (prop_filter.type_ != PropertyFilter::Type::REGEX_MATCH) { + // Remove the original expression from Filter operation only if it's not + // a regex match. In such a case we need to perform the matching even + // after we've scanned the index. + filter_exprs_for_removal_.insert(found_index->filter.expression); + } + filters_.EraseFilter(found_index->filter); + std::vector<Expression *> removed_expressions; + filters_.EraseLabelFilter(node_symbol, found_index->label, &removed_expressions); + filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end()); + if (prop_filter.lower_bound_ || prop_filter.upper_bound_) { + return std::make_unique<ScanAllByLabelPropertyRange>( + input, node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_), + prop_filter.property_.name, prop_filter.lower_bound_, prop_filter.upper_bound_, view); + } else if (prop_filter.type_ == PropertyFilter::Type::REGEX_MATCH) { + // Generate index scan using the empty string as a lower bound. + Expression *empty_string = ast_storage_->Create<PrimitiveLiteral>(""); + auto lower_bound = utils::MakeBoundInclusive(empty_string); + return std::make_unique<ScanAllByLabelPropertyRange>( + input, node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_), + prop_filter.property_.name, std::make_optional(lower_bound), std::nullopt, view); + } else if (prop_filter.type_ == PropertyFilter::Type::IN) { + // TODO(buda): ScanAllByLabelProperty + Filter should be considered + // here once the operator and the right cardinality estimation exist. + auto const &symbol = symbol_table_->CreateAnonymousSymbol(); + auto *expression = ast_storage_->Create<Identifier>(symbol.name_); + expression->MapTo(symbol); + auto unwind_operator = std::make_unique<Unwind>(input, prop_filter.value_, symbol); + return std::make_unique<ScanAllByLabelPropertyValue>( + std::move(unwind_operator), node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_), + prop_filter.property_.name, expression, view); + } else if (prop_filter.type_ == PropertyFilter::Type::IS_NOT_NULL) { + return std::make_unique<ScanAllByLabelProperty>(input, node_symbol, GetLabel(found_index->label), + GetProperty(prop_filter.property_), prop_filter.property_.name, + view); + } else { + MG_ASSERT(prop_filter.value_, "Property filter should either have bounds or a value expression."); + return std::make_unique<ScanAllByLabelPropertyValue>(input, node_symbol, GetLabel(found_index->label), + GetProperty(prop_filter.property_), + prop_filter.property_.name, prop_filter.value_, view); + } + } + auto maybe_label = FindBestLabelIndex(labels); + if (!maybe_label) return nullptr; + const auto &label = *maybe_label; + if (max_vertex_count && db_->VerticesCount(GetLabel(label)) > *max_vertex_count) { + // Don't create an indexed lookup, since we have more labeled vertices + // than the allowed count. + return nullptr; + } + std::vector<Expression *> removed_expressions; + filters_.EraseLabelFilter(node_symbol, label, &removed_expressions); + filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end()); + return std::make_unique<ScanAllByLabel>(input, node_symbol, GetLabel(label), view); + } +}; + +} // namespace impl + +template <class TDbAccessor> +std::unique_ptr<LogicalOperator> RewriteWithIndexLookup(std::unique_ptr<LogicalOperator> root_op, + SymbolTable *symbol_table, AstStorage *ast_storage, + TDbAccessor *db) { + impl::IndexLookupRewriter<TDbAccessor> rewriter(symbol_table, ast_storage, db); + root_op->Accept(rewriter); + if (rewriter.new_root_) { + // This shouldn't happen in real use case, because IndexLookupRewriter + // removes Filter operations and they cannot be the root op. In case we + // somehow missed this, raise NotYetImplemented instead of MG_ASSERT + // crashing the application. + throw utils::NotYetImplemented("optimizing index lookup"); + } + return root_op; +} + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/rule_based_planner.cpp b/src/query/v2/plan/rule_based_planner.cpp new file mode 100644 index 000000000..e423966d9 --- /dev/null +++ b/src/query/v2/plan/rule_based_planner.cpp @@ -0,0 +1,594 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/plan/rule_based_planner.hpp" + +#include <algorithm> +#include <functional> +#include <limits> +#include <stack> +#include <unordered_set> + +#include "utils/algorithm.hpp" +#include "utils/exceptions.hpp" +#include "utils/logging.hpp" + +namespace memgraph::query::v2::plan { + +namespace { + +bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols, const FilterInfo &filter) { + for (const auto &symbol : filter.used_symbols) { + if (bound_symbols.find(symbol) == bound_symbols.end()) { + return false; + } + } + return true; +} + +// Ast tree visitor which collects the context for a return body. +// The return body of WITH and RETURN clauses consists of: +// +// * named expressions (used to produce results); +// * flag whether the results need to be DISTINCT; +// * optional SKIP expression; +// * optional LIMIT expression and +// * optional ORDER BY expressions. +// +// In addition to the above, we collect information on used symbols, +// aggregations and expressions used for group by. +class ReturnBodyContext : public HierarchicalTreeVisitor { + public: + ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table, const std::unordered_set<Symbol> &bound_symbols, + AstStorage &storage, Where *where = nullptr) + : body_(body), symbol_table_(symbol_table), bound_symbols_(bound_symbols), storage_(storage), where_(where) { + // Collect symbols from named expressions. + output_symbols_.reserve(body_.named_expressions.size()); + if (body.all_identifiers) { + // Expand '*' to expressions and symbols first, so that their results come + // before regular named expressions. + ExpandUserSymbols(); + } + for (auto &named_expr : body_.named_expressions) { + output_symbols_.emplace_back(symbol_table_.at(*named_expr)); + named_expr->Accept(*this); + named_expressions_.emplace_back(named_expr); + } + // Collect symbols used in group by expressions. + if (!aggregations_.empty()) { + UsedSymbolsCollector collector(symbol_table_); + for (auto &group_by : group_by_) { + group_by->Accept(collector); + } + group_by_used_symbols_ = collector.symbols_; + } + if (aggregations_.empty()) { + // Visit order_by and where if we do not have aggregations. This way we + // prevent collecting group_by expressions from order_by and where, which + // would be very wrong. When we have aggregation, order_by and where can + // only use new symbols (ensured in semantic analysis), so we don't care + // about collecting used_symbols. Also, semantic analysis should + // have prevented any aggregations from appearing here. + for (const auto &order_pair : body.order_by) { + order_pair.expression->Accept(*this); + } + if (where) { + where->Accept(*this); + } + MG_ASSERT(aggregations_.empty(), "Unexpected aggregations in ORDER BY or WHERE"); + } + } + + using HierarchicalTreeVisitor::PostVisit; + using HierarchicalTreeVisitor::PreVisit; + using HierarchicalTreeVisitor::Visit; + + bool Visit(PrimitiveLiteral &) override { + has_aggregation_.emplace_back(false); + return true; + } + + private: + template <typename TLiteral, typename TIteratorToExpression> + void PostVisitCollectionLiteral(TLiteral &literal, TIteratorToExpression iterator_to_expression) { + // If there is an aggregation in the list, and there are group-bys, then we + // need to add the group-bys manually. If there are no aggregations, the + // whole list will be added as a group-by. + std::vector<Expression *> literal_group_by; + bool has_aggr = false; + auto it = has_aggregation_.end(); + auto elements_it = literal.elements_.begin(); + std::advance(it, -literal.elements_.size()); + while (it != has_aggregation_.end()) { + if (*it) { + has_aggr = true; + } else { + literal_group_by.emplace_back(iterator_to_expression(elements_it)); + } + elements_it++; + it = has_aggregation_.erase(it); + } + has_aggregation_.emplace_back(has_aggr); + if (has_aggr) { + for (auto expression_ptr : literal_group_by) group_by_.emplace_back(expression_ptr); + } + } + + public: + bool PostVisit(ListLiteral &list_literal) override { + MG_ASSERT(list_literal.elements_.size() <= has_aggregation_.size(), + "Expected as many has_aggregation_ flags as there are list" + "elements."); + PostVisitCollectionLiteral(list_literal, [](auto it) { return *it; }); + return true; + } + + bool PostVisit(MapLiteral &map_literal) override { + MG_ASSERT(map_literal.elements_.size() <= has_aggregation_.size(), + "Expected has_aggregation_ flags as much as there are map elements."); + PostVisitCollectionLiteral(map_literal, [](auto it) { return it->second; }); + return true; + } + + bool PostVisit(All &all) override { + // Remove the symbol which is bound by all, because we are only interested + // in free (unbound) symbols. + used_symbols_.erase(symbol_table_.at(*all.identifier_)); + MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for ALL arguments"); + bool has_aggr = false; + for (int i = 0; i < 3; ++i) { + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + has_aggregation_.emplace_back(has_aggr); + return true; + } + + bool PostVisit(Single &single) override { + // Remove the symbol which is bound by single, because we are only + // interested in free (unbound) symbols. + used_symbols_.erase(symbol_table_.at(*single.identifier_)); + MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for SINGLE arguments"); + bool has_aggr = false; + for (int i = 0; i < 3; ++i) { + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + has_aggregation_.emplace_back(has_aggr); + return true; + } + + bool PostVisit(Any &any) override { + // Remove the symbol which is bound by any, because we are only interested + // in free (unbound) symbols. + used_symbols_.erase(symbol_table_.at(*any.identifier_)); + MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for ANY arguments"); + bool has_aggr = false; + for (int i = 0; i < 3; ++i) { + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + has_aggregation_.emplace_back(has_aggr); + return true; + } + + bool PostVisit(None &none) override { + // Remove the symbol which is bound by none, because we are only interested + // in free (unbound) symbols. + used_symbols_.erase(symbol_table_.at(*none.identifier_)); + MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for NONE arguments"); + bool has_aggr = false; + for (int i = 0; i < 3; ++i) { + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + has_aggregation_.emplace_back(has_aggr); + return true; + } + + bool PostVisit(Reduce &reduce) override { + // Remove the symbols bound by reduce, because we are only interested + // in free (unbound) symbols. + used_symbols_.erase(symbol_table_.at(*reduce.accumulator_)); + used_symbols_.erase(symbol_table_.at(*reduce.identifier_)); + MG_ASSERT(has_aggregation_.size() >= 5U, "Expected 5 has_aggregation_ flags for REDUCE arguments"); + bool has_aggr = false; + for (int i = 0; i < 5; ++i) { + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + has_aggregation_.emplace_back(has_aggr); + return true; + } + + bool PostVisit(Coalesce &coalesce) override { + MG_ASSERT(has_aggregation_.size() >= coalesce.expressions_.size(), + "Expected >= {} has_aggregation_ flags for COALESCE arguments", has_aggregation_.size()); + bool has_aggr = false; + for (size_t i = 0; i < coalesce.expressions_.size(); ++i) { + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + has_aggregation_.emplace_back(has_aggr); + return true; + } + + bool PostVisit(Extract &extract) override { + // Remove the symbol bound by extract, because we are only interested + // in free (unbound) symbols. + used_symbols_.erase(symbol_table_.at(*extract.identifier_)); + MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for EXTRACT arguments"); + bool has_aggr = false; + for (int i = 0; i < 3; ++i) { + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + has_aggregation_.emplace_back(has_aggr); + return true; + } + + bool Visit(Identifier &ident) override { + const auto &symbol = symbol_table_.at(ident); + if (!utils::Contains(output_symbols_, symbol)) { + // Don't pick up new symbols, even though they may be used in ORDER BY or + // WHERE. + used_symbols_.insert(symbol); + } + has_aggregation_.emplace_back(false); + return true; + } + + bool PreVisit(ListSlicingOperator &list_slicing) override { + list_slicing.list_->Accept(*this); + bool list_has_aggr = has_aggregation_.back(); + has_aggregation_.pop_back(); + bool has_aggr = list_has_aggr; + if (list_slicing.lower_bound_) { + list_slicing.lower_bound_->Accept(*this); + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + if (list_slicing.upper_bound_) { + list_slicing.upper_bound_->Accept(*this); + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + if (has_aggr && !list_has_aggr) { + // We need to group by the list expression, because it didn't have an + // aggregation inside. + group_by_.emplace_back(list_slicing.list_); + } + has_aggregation_.emplace_back(has_aggr); + return false; + } + + bool PreVisit(IfOperator &if_operator) override { + if_operator.condition_->Accept(*this); + bool has_aggr = has_aggregation_.back(); + has_aggregation_.pop_back(); + if_operator.then_expression_->Accept(*this); + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + if_operator.else_expression_->Accept(*this); + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + has_aggregation_.emplace_back(has_aggr); + // TODO: Once we allow aggregations here, insert appropriate stuff in + // group_by. + MG_ASSERT(!has_aggr, "Currently aggregations in CASE are not allowed"); + return false; + } + + bool PostVisit(Function &function) override { + MG_ASSERT(function.arguments_.size() <= has_aggregation_.size(), + "Expected as many has_aggregation_ flags as there are" + "function arguments."); + bool has_aggr = false; + auto it = has_aggregation_.end(); + std::advance(it, -function.arguments_.size()); + while (it != has_aggregation_.end()) { + has_aggr = has_aggr || *it; + it = has_aggregation_.erase(it); + } + has_aggregation_.emplace_back(has_aggr); + return true; + } + +#define VISIT_BINARY_OPERATOR(BinaryOperator) \ + bool PostVisit(BinaryOperator &op) override { \ + MG_ASSERT(has_aggregation_.size() >= 2U, "Expected at least 2 has_aggregation_ flags."); \ + /* has_aggregation_ stack is reversed, last result is from the 2nd */ \ + /* expression. */ \ + bool aggr2 = has_aggregation_.back(); \ + has_aggregation_.pop_back(); \ + bool aggr1 = has_aggregation_.back(); \ + has_aggregation_.pop_back(); \ + bool has_aggr = aggr1 || aggr2; \ + if (has_aggr && !(aggr1 && aggr2)) { \ + /* Group by the expression which does not contain aggregation. */ \ + /* Possible optimization is to ignore constant value expressions */ \ + group_by_.emplace_back(aggr1 ? op.expression2_ : op.expression1_); \ + } \ + /* Propagate that this whole expression may contain an aggregation. */ \ + has_aggregation_.emplace_back(has_aggr); \ + return true; \ + } + + VISIT_BINARY_OPERATOR(OrOperator) + VISIT_BINARY_OPERATOR(XorOperator) + VISIT_BINARY_OPERATOR(AndOperator) + VISIT_BINARY_OPERATOR(AdditionOperator) + VISIT_BINARY_OPERATOR(SubtractionOperator) + VISIT_BINARY_OPERATOR(MultiplicationOperator) + VISIT_BINARY_OPERATOR(DivisionOperator) + VISIT_BINARY_OPERATOR(ModOperator) + VISIT_BINARY_OPERATOR(NotEqualOperator) + VISIT_BINARY_OPERATOR(EqualOperator) + VISIT_BINARY_OPERATOR(LessOperator) + VISIT_BINARY_OPERATOR(GreaterOperator) + VISIT_BINARY_OPERATOR(LessEqualOperator) + VISIT_BINARY_OPERATOR(GreaterEqualOperator) + VISIT_BINARY_OPERATOR(InListOperator) + VISIT_BINARY_OPERATOR(SubscriptOperator) + +#undef VISIT_BINARY_OPERATOR + + bool PostVisit(Aggregation &aggr) override { + // Aggregation contains a virtual symbol, where the result will be stored. + const auto &symbol = symbol_table_.at(aggr); + aggregations_.emplace_back(Aggregate::Element{aggr.expression1_, aggr.expression2_, aggr.op_, symbol}); + // Aggregation expression1_ is optional in COUNT(*), and COLLECT_MAP uses + // two expressions, so we can have 0, 1 or 2 elements on the + // has_aggregation_stack for this Aggregation expression. + if (aggr.op_ == Aggregation::Op::COLLECT_MAP) has_aggregation_.pop_back(); + if (aggr.expression1_) + has_aggregation_.back() = true; + else + has_aggregation_.emplace_back(true); + // Possible optimization is to skip remembering symbols inside aggregation. + // If and when implementing this, don't forget that Accumulate needs *all* + // the symbols, including those inside aggregation. + return true; + } + + bool PostVisit(NamedExpression &named_expr) override { + MG_ASSERT(has_aggregation_.size() == 1U, "Expected to reduce has_aggregation_ to single boolean."); + if (!has_aggregation_.back()) { + group_by_.emplace_back(named_expr.expression_); + } + has_aggregation_.pop_back(); + return true; + } + + bool Visit(ParameterLookup &) override { + has_aggregation_.emplace_back(false); + return true; + } + + bool PostVisit(RegexMatch ®ex_match) override { + MG_ASSERT(has_aggregation_.size() >= 2U, "Expected 2 has_aggregation_ flags for RegexMatch arguments"); + bool has_aggr = has_aggregation_.back(); + has_aggregation_.pop_back(); + has_aggregation_.back() |= has_aggr; + return true; + } + + // Creates NamedExpression with an Identifier for each user declared symbol. + // This should be used when body.all_identifiers is true, to generate + // expressions for Produce operator. + void ExpandUserSymbols() { + MG_ASSERT(named_expressions_.empty(), "ExpandUserSymbols should be first to fill named_expressions_"); + MG_ASSERT(output_symbols_.empty(), "ExpandUserSymbols should be first to fill output_symbols_"); + for (const auto &symbol : bound_symbols_) { + if (!symbol.user_declared()) { + continue; + } + auto *ident = storage_.Create<Identifier>(symbol.name())->MapTo(symbol); + auto *named_expr = storage_.Create<NamedExpression>(symbol.name(), ident)->MapTo(symbol); + // Fill output expressions and symbols with expanded identifiers. + named_expressions_.emplace_back(named_expr); + output_symbols_.emplace_back(symbol); + used_symbols_.insert(symbol); + // Don't forget to group by expanded identifiers. + group_by_.emplace_back(ident); + } + // Cypher RETURN/WITH * expects to expand '*' sorted by name. + std::sort(output_symbols_.begin(), output_symbols_.end(), + [](const auto &a, const auto &b) { return a.name() < b.name(); }); + std::sort(named_expressions_.begin(), named_expressions_.end(), + [](const auto &a, const auto &b) { return a->name_ < b->name_; }); + } + + // If true, results need to be distinct. + bool distinct() const { return body_.distinct; } + // Named expressions which are used to produce results. + const auto &named_expressions() const { return named_expressions_; } + // Pairs of (Ordering, Expression *) for sorting results. + const auto &order_by() const { return body_.order_by; } + // Optional expression which determines how many results to skip. + auto *skip() const { return body_.skip; } + // Optional expression which determines how many results to produce. + auto *limit() const { return body_.limit; } + // Optional Where clause for filtering. + const auto *where() const { return where_; } + // Set of symbols used inside the visited expressions, including the inside of + // aggregation expression. These only includes old symbols, even though new + // ones may have been used in ORDER BY or WHERE. + const auto &used_symbols() const { return used_symbols_; } + // List of aggregation elements found in expressions. + const auto &aggregations() const { return aggregations_; } + // When there is at least one aggregation element, all the non-aggregate (sub) + // expressions are used for grouping. For example, in `WITH sum(n.a) + 2 * n.b + // AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`. + const auto &group_by() const { return group_by_; } + // Set of symbols used in group by expressions. + const auto &group_by_used_symbols() const { return group_by_used_symbols_; } + // All symbols generated by named expressions. They are collected in order of + // named_expressions. + const auto &output_symbols() const { return output_symbols_; } + + private: + const ReturnBody &body_; + SymbolTable &symbol_table_; + const std::unordered_set<Symbol> &bound_symbols_; + AstStorage &storage_; + const Where *const where_ = nullptr; + std::unordered_set<Symbol> used_symbols_; + std::vector<Symbol> output_symbols_; + std::vector<Aggregate::Element> aggregations_; + std::vector<Expression *> group_by_; + std::unordered_set<Symbol> group_by_used_symbols_; + // Flag stack indicating whether an expression contains an aggregation. A + // stack is needed so that we differentiate the case where a child + // sub-expression has an aggregation, while the other child doesn't. For + // example AST, (+ (sum x) y) + // * (sum x) -- Has an aggregation. + // * y -- Doesn't, we need to group by this. + // * (+ (sum x) y) -- The whole expression has an aggregation, so we don't + // group by it. + std::list<bool> has_aggregation_; + std::vector<NamedExpression *> named_expressions_; +}; + +std::unique_ptr<LogicalOperator> GenReturnBody(std::unique_ptr<LogicalOperator> input_op, bool advance_command, + const ReturnBodyContext &body, bool accumulate = false) { + std::vector<Symbol> used_symbols(body.used_symbols().begin(), body.used_symbols().end()); + auto last_op = std::move(input_op); + if (accumulate) { + // We only advance the command in Accumulate. This is done for WITH clause, + // when the first part updated the database. RETURN clause may only need an + // accumulation after updates, without advancing the command. + last_op = std::make_unique<Accumulate>(std::move(last_op), used_symbols, advance_command); + } + if (!body.aggregations().empty()) { + // When we have aggregation, SKIP/LIMIT should always come after it. + std::vector<Symbol> remember(body.group_by_used_symbols().begin(), body.group_by_used_symbols().end()); + last_op = std::make_unique<Aggregate>(std::move(last_op), body.aggregations(), body.group_by(), remember); + } + last_op = std::make_unique<Produce>(std::move(last_op), body.named_expressions()); + // Distinct in ReturnBody only makes Produce values unique, so plan after it. + if (body.distinct()) { + last_op = std::make_unique<Distinct>(std::move(last_op), body.output_symbols()); + } + // Like Where, OrderBy can read from symbols established by named expressions + // in Produce, so it must come after it. + if (!body.order_by().empty()) { + last_op = std::make_unique<OrderBy>(std::move(last_op), body.order_by(), body.output_symbols()); + } + // Finally, Skip and Limit must come after OrderBy. + if (body.skip()) { + last_op = std::make_unique<Skip>(std::move(last_op), body.skip()); + } + // Limit is always after Skip. + if (body.limit()) { + last_op = std::make_unique<Limit>(std::move(last_op), body.limit()); + } + // Where may see new symbols so it comes after we generate Produce and in + // general, comes after any OrderBy, Skip or Limit. + if (body.where()) { + last_op = std::make_unique<Filter>(std::move(last_op), body.where()->expression_); + } + return last_op; +} + +} // namespace + +namespace impl { + +Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols, Filters &filters, AstStorage &storage) { + Expression *filter_expr = nullptr; + for (auto filters_it = filters.begin(); filters_it != filters.end();) { + if (HasBoundFilterSymbols(bound_symbols, *filters_it)) { + filter_expr = impl::BoolJoin<AndOperator>(storage, filter_expr, filters_it->expression); + filters_it = filters.erase(filters_it); + } else { + filters_it++; + } + } + return filter_expr; +} + +std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator> last_op, + const std::unordered_set<Symbol> &bound_symbols, Filters &filters, + AstStorage &storage) { + auto *filter_expr = ExtractFilters(bound_symbols, filters, storage); + if (filter_expr) { + last_op = std::make_unique<Filter>(std::move(last_op), filter_expr); + } + return last_op; +} + +std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator> last_op, + std::unordered_set<Symbol> &bound_symbols, + std::unordered_map<Symbol, std::vector<Symbol>> &named_paths) { + auto all_are_bound = [&bound_symbols](const std::vector<Symbol> &syms) { + for (const auto &sym : syms) + if (bound_symbols.find(sym) == bound_symbols.end()) return false; + return true; + }; + for (auto named_path_it = named_paths.begin(); named_path_it != named_paths.end();) { + if (all_are_bound(named_path_it->second)) { + last_op = std::make_unique<ConstructNamedPath>(std::move(last_op), named_path_it->first, + std::move(named_path_it->second)); + bound_symbols.insert(named_path_it->first); + named_path_it = named_paths.erase(named_path_it); + } else { + ++named_path_it; + } + } + + return last_op; +} + +std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op, + SymbolTable &symbol_table, bool is_write, + const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) { + // Similar to WITH clause, but we want to accumulate when the query writes to + // the database. This way we handle the case when we want to return + // expressions with the latest updated results. For example, `MATCH (n) -- () + // SET n.prop = n.prop + 1 RETURN n.prop`. If we match same `n` multiple 'k' + // times, we want to return 'k' results where the property value is the same, + // final result of 'k' increments. + bool accumulate = is_write; + bool advance_command = false; + ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage); + return GenReturnBody(std::move(input_op), advance_command, body, accumulate); +} + +std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op, + SymbolTable &symbol_table, bool is_write, + std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) { + // WITH clause is Accumulate/Aggregate (advance_command) + Produce and + // optional Filter. In case of update and aggregation, we want to accumulate + // first, so that when aggregating, we get the latest results. Similar to + // RETURN clause. + bool accumulate = is_write; + // No need to advance the command if we only performed reads. + bool advance_command = is_write; + ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage, with.where_); + auto last_op = GenReturnBody(std::move(input_op), advance_command, body, accumulate); + // Reset bound symbols, so that only those in WITH are exposed. + bound_symbols.clear(); + for (const auto &symbol : body.output_symbols()) { + bound_symbols.insert(symbol); + } + return last_op; +} + +std::unique_ptr<LogicalOperator> GenUnion(const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op, + std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table) { + return std::make_unique<Union>(left_op, right_op, cypher_union.union_symbols_, left_op->OutputSymbols(symbol_table), + right_op->OutputSymbols(symbol_table)); +} + +} // namespace impl + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/rule_based_planner.hpp b/src/query/v2/plan/rule_based_planner.hpp new file mode 100644 index 000000000..16318f8b0 --- /dev/null +++ b/src/query/v2/plan/rule_based_planner.hpp @@ -0,0 +1,561 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +#pragma once + +#include <optional> +#include <variant> + +#include "gflags/gflags.h" + +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/ast_visitor.hpp" +#include "query/v2/plan/operator.hpp" +#include "query/v2/plan/preprocess.hpp" +#include "utils/logging.hpp" +#include "utils/typeinfo.hpp" + +namespace memgraph::query::v2::plan { + +/// @brief Context which contains variables commonly used during planning. +template <class TDbAccessor> +struct PlanningContext { + /// @brief SymbolTable is used to determine inputs and outputs of planned + /// operators. + /// + /// Newly created AST nodes may be added to reference existing symbols. + SymbolTable *symbol_table{nullptr}; + /// @brief The storage is used to create new AST nodes for use in operators. + AstStorage *ast_storage{nullptr}; + /// @brief Cypher query to be planned + CypherQuery *query{nullptr}; + /// @brief TDbAccessor, which may be used to get some information from the + /// database to generate better plans. The accessor is required only to live + /// long enough for the plan generation to finish. + TDbAccessor *db{nullptr}; + /// @brief Symbol set is used to differentiate cycles in pattern matching. + /// During planning, symbols will be added as each operator produces values + /// for them. This way, the operator can be correctly initialized whether to + /// read a symbol or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and + /// write) the first `n`, but the latter `n` would only read the already + /// written information. + std::unordered_set<Symbol> bound_symbols{}; +}; + +template <class TDbAccessor> +auto MakePlanningContext(AstStorage *ast_storage, SymbolTable *symbol_table, CypherQuery *query, TDbAccessor *db) { + return PlanningContext<TDbAccessor>{symbol_table, ast_storage, query, db}; +} + +// Contextual information used for generating match operators. +struct MatchContext { + const Matching &matching; + const SymbolTable &symbol_table; + // Already bound symbols, which are used to determine whether the operator + // should reference them or establish new. This is both read from and written + // to during generation. + std::unordered_set<Symbol> &bound_symbols; + // Determines whether the match should see the new graph state or not. + storage::v3::View view = storage::v3::View::OLD; + // All the newly established symbols in match. + std::vector<Symbol> new_symbols{}; +}; + +namespace impl { + +// These functions are an internal implementation of RuleBasedPlanner. To avoid +// writing the whole code inline in this header file, they are declared here and +// defined in the cpp file. + +// Iterates over `Filters` joining them in one expression via +// `AndOperator` if symbols they use are bound.. All the joined filters are +// removed from `Filters`. +Expression *ExtractFilters(const std::unordered_set<Symbol> &, Filters &, AstStorage &); + +std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator>, const std::unordered_set<Symbol> &, + Filters &, AstStorage &); + +/// Utility function for iterating pattern atoms and accumulating a result. +/// +/// Each pattern is of the form `NodeAtom (, EdgeAtom, NodeAtom)*`. Therefore, +/// the `base` function is called on the first `NodeAtom`, while the `collect` +/// is called for the whole triplet. Result of the function is passed to the +/// next call. Final result is returned. +/// +/// Example usage of counting edge atoms in the pattern. +/// +/// auto base = [](NodeAtom *first_node) { return 0; }; +/// auto collect = [](int accum, NodeAtom *prev_node, EdgeAtom *edge, +/// NodeAtom *node) { +/// return accum + 1; +/// }; +/// int edge_count = ReducePattern<int>(pattern, base, collect); +/// +// TODO: It might be a good idea to move this somewhere else, for easier usage +// in other files. +template <typename T> +auto ReducePattern(Pattern &pattern, std::function<T(NodeAtom *)> base, + std::function<T(T, NodeAtom *, EdgeAtom *, NodeAtom *)> collect) { + MG_ASSERT(!pattern.atoms_.empty(), "Missing atoms in pattern"); + auto atoms_it = pattern.atoms_.begin(); + auto current_node = utils::Downcast<NodeAtom>(*atoms_it++); + MG_ASSERT(current_node, "First pattern atom is not a node"); + auto last_res = base(current_node); + // Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)* + while (atoms_it != pattern.atoms_.end()) { + auto edge = utils::Downcast<EdgeAtom>(*atoms_it++); + MG_ASSERT(edge, "Expected an edge atom in pattern."); + MG_ASSERT(atoms_it != pattern.atoms_.end(), "Edge atom should not end the pattern."); + auto prev_node = current_node; + current_node = utils::Downcast<NodeAtom>(*atoms_it++); + MG_ASSERT(current_node, "Expected a node atom in pattern."); + last_res = collect(std::move(last_res), prev_node, edge, current_node); + } + return last_res; +} + +// For all given `named_paths` checks if all its symbols have been bound. +// If so, it creates a logical operator for named path generation, binds its +// symbol, removes that path from the collection of unhandled ones and returns +// the new op. Otherwise, returns `last_op`. +std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator> last_op, + std::unordered_set<Symbol> &bound_symbols, + std::unordered_map<Symbol, std::vector<Symbol>> &named_paths); + +std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op, + SymbolTable &symbol_table, bool is_write, + const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage); + +std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op, + SymbolTable &symbol_table, bool is_write, + std::unordered_set<Symbol> &bound_symbols, AstStorage &storage); + +std::unique_ptr<LogicalOperator> GenUnion(const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op, + std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table); + +template <class TBoolOperator> +Expression *BoolJoin(AstStorage &storage, Expression *expr1, Expression *expr2) { + if (expr1 && expr2) { + return storage.Create<TBoolOperator>(expr1, expr2); + } + return expr1 ? expr1 : expr2; +} + +} // namespace impl + +/// @brief Planner which uses hardcoded rules to produce operators. +/// +/// @sa MakeLogicalPlan +template <class TPlanningContext> +class RuleBasedPlanner { + public: + explicit RuleBasedPlanner(TPlanningContext *context) : context_(context) {} + + /// @brief The result of plan generation is the root of the generated operator + /// tree. + using PlanResult = std::unique_ptr<LogicalOperator>; + /// @brief Generates the operator tree based on explicitly set rules. + PlanResult Plan(const std::vector<SingleQueryPart> &query_parts) { + auto &context = *context_; + std::unique_ptr<LogicalOperator> input_op; + // Set to true if a query command writes to the database. + bool is_write = false; + for (const auto &query_part : query_parts) { + MatchContext match_ctx{query_part.matching, *context.symbol_table, context.bound_symbols}; + input_op = PlanMatching(match_ctx, std::move(input_op)); + for (const auto &matching : query_part.optional_matching) { + MatchContext opt_ctx{matching, *context.symbol_table, context.bound_symbols}; + auto match_op = PlanMatching(opt_ctx, nullptr); + if (match_op) { + input_op = std::make_unique<Optional>(std::move(input_op), std::move(match_op), opt_ctx.new_symbols); + } + } + uint64_t merge_id = 0; + for (auto *clause : query_part.remaining_clauses) { + MG_ASSERT(!utils::IsSubtype(*clause, Match::kType), "Unexpected Match in remaining clauses"); + if (auto *ret = utils::Downcast<Return>(clause)) { + input_op = impl::GenReturn(*ret, std::move(input_op), *context.symbol_table, is_write, context.bound_symbols, + *context.ast_storage); + } else if (auto *merge = utils::Downcast<query::v2::Merge>(clause)) { + input_op = GenMerge(*merge, std::move(input_op), query_part.merge_matching[merge_id++]); + // Treat MERGE clause as write, because we do not know if it will + // create anything. + is_write = true; + } else if (auto *with = utils::Downcast<query::v2::With>(clause)) { + input_op = impl::GenWith(*with, std::move(input_op), *context.symbol_table, is_write, context.bound_symbols, + *context.ast_storage); + // WITH clause advances the command, so reset the flag. + is_write = false; + } else if (auto op = HandleWriteClause(clause, input_op, *context.symbol_table, context.bound_symbols)) { + is_write = true; + input_op = std::move(op); + } else if (auto *unwind = utils::Downcast<query::v2::Unwind>(clause)) { + const auto &symbol = context.symbol_table->at(*unwind->named_expression_); + context.bound_symbols.insert(symbol); + input_op = + std::make_unique<plan::Unwind>(std::move(input_op), unwind->named_expression_->expression_, symbol); + } else if (auto *call_proc = utils::Downcast<query::v2::CallProcedure>(clause)) { + std::vector<Symbol> result_symbols; + result_symbols.reserve(call_proc->result_identifiers_.size()); + for (const auto *ident : call_proc->result_identifiers_) { + const auto &sym = context.symbol_table->at(*ident); + context.bound_symbols.insert(sym); + result_symbols.push_back(sym); + } + // TODO: When we add support for write and eager procedures, we will + // need to plan this operator with Accumulate and pass in + // storage::v3::View::NEW. + input_op = std::make_unique<plan::CallProcedure>( + std::move(input_op), call_proc->procedure_name_, call_proc->arguments_, call_proc->result_fields_, + result_symbols, call_proc->memory_limit_, call_proc->memory_scale_, call_proc->is_write_); + } else if (auto *load_csv = utils::Downcast<query::v2::LoadCsv>(clause)) { + const auto &row_sym = context.symbol_table->at(*load_csv->row_var_); + context.bound_symbols.insert(row_sym); + + input_op = + std::make_unique<plan::LoadCsv>(std::move(input_op), load_csv->file_, load_csv->with_header_, + load_csv->ignore_bad_, load_csv->delimiter_, load_csv->quote_, row_sym); + } else if (auto *foreach = utils::Downcast<query::v2::Foreach>(clause)) { + is_write = true; + input_op = HandleForeachClause(foreach, std::move(input_op), *context.symbol_table, context.bound_symbols, + query_part, merge_id); + } else { + throw utils::NotYetImplemented("clause '{}' conversion to operator(s)", clause->GetTypeInfo().name); + } + } + } + return input_op; + } + + private: + TPlanningContext *context_; + + storage::v3::LabelId GetLabel(LabelIx label) { return context_->db->NameToLabel(label.name); } + + storage::v3::PropertyId GetProperty(PropertyIx prop) { return context_->db->NameToProperty(prop.name); } + + storage::v3::EdgeTypeId GetEdgeType(EdgeTypeIx edge_type) { return context_->db->NameToEdgeType(edge_type.name); } + + std::unique_ptr<LogicalOperator> GenCreate(Create &create, std::unique_ptr<LogicalOperator> input_op, + const SymbolTable &symbol_table, + std::unordered_set<Symbol> &bound_symbols) { + auto last_op = std::move(input_op); + for (auto pattern : create.patterns_) { + last_op = GenCreateForPattern(*pattern, std::move(last_op), symbol_table, bound_symbols); + } + return last_op; + } + + std::unique_ptr<LogicalOperator> GenCreateForPattern(Pattern &pattern, std::unique_ptr<LogicalOperator> input_op, + const SymbolTable &symbol_table, + std::unordered_set<Symbol> &bound_symbols) { + auto node_to_creation_info = [&](const NodeAtom &node) { + const auto &node_symbol = symbol_table.at(*node.identifier_); + std::vector<storage::v3::LabelId> labels; + labels.reserve(node.labels_.size()); + for (const auto &label : node.labels_) { + labels.push_back(GetLabel(label)); + } + + auto properties = std::invoke([&]() -> std::variant<PropertiesMapList, ParameterLookup *> { + if (const auto *node_properties = + std::get_if<std::unordered_map<PropertyIx, Expression *>>(&node.properties_)) { + PropertiesMapList vector_props; + vector_props.reserve(node_properties->size()); + for (const auto &kv : *node_properties) { + vector_props.push_back({GetProperty(kv.first), kv.second}); + } + return std::move(vector_props); + } + return std::get<ParameterLookup *>(node.properties_); + }); + return NodeCreationInfo{node_symbol, labels, properties}; + }; + + auto base = [&](NodeAtom *node) -> std::unique_ptr<LogicalOperator> { + const auto &node_symbol = symbol_table.at(*node->identifier_); + if (bound_symbols.insert(node_symbol).second) { + auto node_info = node_to_creation_info(*node); + return std::make_unique<CreateNode>(std::move(input_op), node_info); + } + return std::move(input_op); + }; + + auto collect = [&](std::unique_ptr<LogicalOperator> last_op, NodeAtom *prev_node, EdgeAtom *edge, NodeAtom *node) { + // Store the symbol from the first node as the input to CreateExpand. + const auto &input_symbol = symbol_table.at(*prev_node->identifier_); + // If the expand node was already bound, then we need to indicate this, + // so that CreateExpand only creates an edge. + bool node_existing = false; + if (!bound_symbols.insert(symbol_table.at(*node->identifier_)).second) { + node_existing = true; + } + const auto &edge_symbol = symbol_table.at(*edge->identifier_); + if (!bound_symbols.insert(edge_symbol).second) { + LOG_FATAL("Symbols used for created edges cannot be redeclared."); + } + auto node_info = node_to_creation_info(*node); + auto properties = std::invoke([&]() -> std::variant<PropertiesMapList, ParameterLookup *> { + if (const auto *edge_properties = + std::get_if<std::unordered_map<PropertyIx, Expression *>>(&edge->properties_)) { + PropertiesMapList vector_props; + vector_props.reserve(edge_properties->size()); + for (const auto &kv : *edge_properties) { + vector_props.push_back({GetProperty(kv.first), kv.second}); + } + return std::move(vector_props); + } + return std::get<ParameterLookup *>(edge->properties_); + }); + + MG_ASSERT(edge->edge_types_.size() == 1, "Creating an edge with a single type should be required by syntax"); + EdgeCreationInfo edge_info{edge_symbol, properties, GetEdgeType(edge->edge_types_[0]), edge->direction_}; + return std::make_unique<CreateExpand>(node_info, edge_info, std::move(last_op), input_symbol, node_existing); + }; + + auto last_op = impl::ReducePattern<std::unique_ptr<LogicalOperator>>(pattern, base, collect); + + // If the pattern is named, append the path constructing logical operator. + if (pattern.identifier_->user_declared_) { + std::vector<Symbol> path_elements; + for (const PatternAtom *atom : pattern.atoms_) path_elements.emplace_back(symbol_table.at(*atom->identifier_)); + last_op = std::make_unique<ConstructNamedPath>(std::move(last_op), symbol_table.at(*pattern.identifier_), + path_elements); + } + + return last_op; + } + + // Generate an operator for a clause which writes to the database. Ownership + // of input_op is transferred to the newly created operator. If the clause + // isn't handled, returns nullptr and input_op is left as is. + std::unique_ptr<LogicalOperator> HandleWriteClause(Clause *clause, std::unique_ptr<LogicalOperator> &input_op, + const SymbolTable &symbol_table, + std::unordered_set<Symbol> &bound_symbols) { + if (auto *create = utils::Downcast<Create>(clause)) { + return GenCreate(*create, std::move(input_op), symbol_table, bound_symbols); + } else if (auto *del = utils::Downcast<query::v2::Delete>(clause)) { + return std::make_unique<plan::Delete>(std::move(input_op), del->expressions_, del->detach_); + } else if (auto *set = utils::Downcast<query::v2::SetProperty>(clause)) { + return std::make_unique<plan::SetProperty>(std::move(input_op), GetProperty(set->property_lookup_->property_), + set->property_lookup_, set->expression_); + } else if (auto *set = utils::Downcast<query::v2::SetProperties>(clause)) { + auto op = set->update_ ? plan::SetProperties::Op::UPDATE : plan::SetProperties::Op::REPLACE; + const auto &input_symbol = symbol_table.at(*set->identifier_); + return std::make_unique<plan::SetProperties>(std::move(input_op), input_symbol, set->expression_, op); + } else if (auto *set = utils::Downcast<query::v2::SetLabels>(clause)) { + const auto &input_symbol = symbol_table.at(*set->identifier_); + std::vector<storage::v3::LabelId> labels; + labels.reserve(set->labels_.size()); + for (const auto &label : set->labels_) { + labels.push_back(GetLabel(label)); + } + return std::make_unique<plan::SetLabels>(std::move(input_op), input_symbol, labels); + } else if (auto *rem = utils::Downcast<query::v2::RemoveProperty>(clause)) { + return std::make_unique<plan::RemoveProperty>(std::move(input_op), GetProperty(rem->property_lookup_->property_), + rem->property_lookup_); + } else if (auto *rem = utils::Downcast<query::v2::RemoveLabels>(clause)) { + const auto &input_symbol = symbol_table.at(*rem->identifier_); + std::vector<storage::v3::LabelId> labels; + labels.reserve(rem->labels_.size()); + for (const auto &label : rem->labels_) { + labels.push_back(GetLabel(label)); + } + return std::make_unique<plan::RemoveLabels>(std::move(input_op), input_symbol, labels); + } + return nullptr; + } + + std::unique_ptr<LogicalOperator> PlanMatching(MatchContext &match_context, + std::unique_ptr<LogicalOperator> input_op) { + auto &bound_symbols = match_context.bound_symbols; + auto &storage = *context_->ast_storage; + const auto &symbol_table = match_context.symbol_table; + const auto &matching = match_context.matching; + // Copy filters, because we will modify them as we generate Filters. + auto filters = matching.filters; + // Copy the named_paths for the same reason. + auto named_paths = matching.named_paths; + // Try to generate any filters even before the 1st match operator. This + // optimizes the optional match which filters only on symbols bound in + // regular match. + auto last_op = impl::GenFilters(std::move(input_op), bound_symbols, filters, storage); + for (const auto &expansion : matching.expansions) { + const auto &node1_symbol = symbol_table.at(*expansion.node1->identifier_); + if (bound_symbols.insert(node1_symbol).second) { + // We have just bound this symbol, so generate ScanAll which fills it. + last_op = std::make_unique<ScanAll>(std::move(last_op), node1_symbol, match_context.view); + match_context.new_symbols.emplace_back(node1_symbol); + last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage); + last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths); + last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage); + } + // We have an edge, so generate Expand. + if (expansion.edge) { + auto *edge = expansion.edge; + // If the expand symbols were already bound, then we need to indicate + // that they exist. The Expand will then check whether the pattern holds + // instead of writing the expansion to symbols. + const auto &node_symbol = symbol_table.at(*expansion.node2->identifier_); + auto existing_node = utils::Contains(bound_symbols, node_symbol); + const auto &edge_symbol = symbol_table.at(*edge->identifier_); + MG_ASSERT(!utils::Contains(bound_symbols, edge_symbol), "Existing edges are not supported"); + std::vector<storage::v3::EdgeTypeId> edge_types; + edge_types.reserve(edge->edge_types_.size()); + for (const auto &type : edge->edge_types_) { + edge_types.push_back(GetEdgeType(type)); + } + if (edge->IsVariable()) { + std::optional<ExpansionLambda> weight_lambda; + std::optional<Symbol> total_weight; + + if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) { + weight_lambda.emplace(ExpansionLambda{symbol_table.at(*edge->weight_lambda_.inner_edge), + symbol_table.at(*edge->weight_lambda_.inner_node), + edge->weight_lambda_.expression}); + + total_weight.emplace(symbol_table.at(*edge->total_weight_)); + } + + ExpansionLambda filter_lambda; + filter_lambda.inner_edge_symbol = symbol_table.at(*edge->filter_lambda_.inner_edge); + filter_lambda.inner_node_symbol = symbol_table.at(*edge->filter_lambda_.inner_node); + { + // Bind the inner edge and node symbols so they're available for + // inline filtering in ExpandVariable. + bool inner_edge_bound = bound_symbols.insert(filter_lambda.inner_edge_symbol).second; + bool inner_node_bound = bound_symbols.insert(filter_lambda.inner_node_symbol).second; + MG_ASSERT(inner_edge_bound && inner_node_bound, "An inner edge and node can't be bound from before"); + } + // Join regular filters with lambda filter expression, so that they + // are done inline together. Semantic analysis should guarantee that + // lambda filtering uses bound symbols. + filter_lambda.expression = impl::BoolJoin<AndOperator>( + storage, impl::ExtractFilters(bound_symbols, filters, storage), edge->filter_lambda_.expression); + // At this point it's possible we have leftover filters for inline + // filtering (they use the inner symbols. If they were not collected, + // we have to remove them manually because no other filter-extraction + // will ever bind them again. + filters.erase(std::remove_if( + filters.begin(), filters.end(), + [e = filter_lambda.inner_edge_symbol, n = filter_lambda.inner_node_symbol](FilterInfo &fi) { + return utils::Contains(fi.used_symbols, e) || utils::Contains(fi.used_symbols, n); + }), + filters.end()); + // Unbind the temporarily bound inner symbols for filtering. + bound_symbols.erase(filter_lambda.inner_edge_symbol); + bound_symbols.erase(filter_lambda.inner_node_symbol); + + if (total_weight) { + bound_symbols.insert(*total_weight); + } + + // TODO: Pass weight lambda. + MG_ASSERT(match_context.view == storage::v3::View::OLD, + "ExpandVariable should only be planned with storage::v3::View::OLD"); + last_op = std::make_unique<ExpandVariable>(std::move(last_op), node1_symbol, node_symbol, edge_symbol, + edge->type_, expansion.direction, edge_types, expansion.is_flipped, + edge->lower_bound_, edge->upper_bound_, existing_node, + filter_lambda, weight_lambda, total_weight); + } else { + last_op = std::make_unique<Expand>(std::move(last_op), node1_symbol, node_symbol, edge_symbol, + expansion.direction, edge_types, existing_node, match_context.view); + } + + // Bind the expanded edge and node. + bound_symbols.insert(edge_symbol); + match_context.new_symbols.emplace_back(edge_symbol); + if (bound_symbols.insert(node_symbol).second) { + match_context.new_symbols.emplace_back(node_symbol); + } + + // Ensure Cyphermorphism (different edge symbols always map to + // different edges). + for (const auto &edge_symbols : matching.edge_symbols) { + if (edge_symbols.find(edge_symbol) == edge_symbols.end()) { + continue; + } + std::vector<Symbol> other_symbols; + for (const auto &symbol : edge_symbols) { + if (symbol == edge_symbol || bound_symbols.find(symbol) == bound_symbols.end()) { + continue; + } + other_symbols.push_back(symbol); + } + if (!other_symbols.empty()) { + last_op = std::make_unique<EdgeUniquenessFilter>(std::move(last_op), edge_symbol, other_symbols); + } + } + last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage); + last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths); + last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage); + } + } + MG_ASSERT(named_paths.empty(), "Expected to generate all named paths"); + // We bound all named path symbols, so just add them to new_symbols. + for (const auto &named_path : matching.named_paths) { + MG_ASSERT(utils::Contains(bound_symbols, named_path.first), "Expected generated named path to have bound symbol"); + match_context.new_symbols.emplace_back(named_path.first); + } + MG_ASSERT(filters.empty(), "Expected to generate all filters"); + return last_op; + } + + auto GenMerge(query::v2::Merge &merge, std::unique_ptr<LogicalOperator> input_op, const Matching &matching) { + // Copy the bound symbol set, because we don't want to use the updated + // version when generating the create part. + std::unordered_set<Symbol> bound_symbols_copy(context_->bound_symbols); + MatchContext match_ctx{matching, *context_->symbol_table, bound_symbols_copy, storage::v3::View::NEW}; + + std::vector<Symbol> bound_symbols(context_->bound_symbols.begin(), context_->bound_symbols.end()); + + auto once_with_symbols = std::make_unique<Once>(bound_symbols); + auto on_match = PlanMatching(match_ctx, std::move(once_with_symbols)); + + once_with_symbols = std::make_unique<Once>(std::move(bound_symbols)); + // Use the original bound_symbols, so we fill it with new symbols. + auto on_create = GenCreateForPattern(*merge.pattern_, std::move(once_with_symbols), *context_->symbol_table, + context_->bound_symbols); + for (auto &set : merge.on_create_) { + on_create = HandleWriteClause(set, on_create, *context_->symbol_table, context_->bound_symbols); + MG_ASSERT(on_create, "Expected SET in MERGE ... ON CREATE"); + } + for (auto &set : merge.on_match_) { + on_match = HandleWriteClause(set, on_match, *context_->symbol_table, context_->bound_symbols); + MG_ASSERT(on_match, "Expected SET in MERGE ... ON MATCH"); + } + return std::make_unique<plan::Merge>(std::move(input_op), std::move(on_match), std::move(on_create)); + } + + std::unique_ptr<LogicalOperator> HandleForeachClause(query::v2::Foreach *foreach, + std::unique_ptr<LogicalOperator> input_op, + const SymbolTable &symbol_table, + std::unordered_set<Symbol> &bound_symbols, + const SingleQueryPart &query_part, uint64_t &merge_id) { + const auto &symbol = symbol_table.at(*foreach->named_expression_); + bound_symbols.insert(symbol); + std::unique_ptr<LogicalOperator> op = std::make_unique<plan::Once>(); + for (auto *clause : foreach->clauses_) { + if (auto *nested_for_each = utils::Downcast<query::v2::Foreach>(clause)) { + op = HandleForeachClause(nested_for_each, std::move(op), symbol_table, bound_symbols, query_part, merge_id); + } else if (auto *merge = utils::Downcast<query::v2::Merge>(clause)) { + op = GenMerge(*merge, std::move(op), query_part.merge_matching[merge_id++]); + } else { + op = HandleWriteClause(clause, op, symbol_table, bound_symbols); + } + } + return std::make_unique<plan::Foreach>(std::move(input_op), std::move(op), foreach->named_expression_->expression_, + symbol); + } +}; + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/scoped_profile.hpp b/src/query/v2/plan/scoped_profile.hpp new file mode 100644 index 000000000..e71cbb047 --- /dev/null +++ b/src/query/v2/plan/scoped_profile.hpp @@ -0,0 +1,79 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <cstdint> + +#include "query/v2/context.hpp" +#include "query/v2/plan/profile.hpp" +#include "utils/likely.hpp" +#include "utils/tsc.hpp" + +namespace memgraph::query::v2::plan { + +/** + * A RAII class used for profiling logical operators. Instances of this class + * update the profiling data stored within the `ExecutionContext` object and build + * up a tree of `ProfilingStats` instances. The structure of the `ProfilingStats` + * tree depends on the `LogicalOperator`s that were executed. + */ +class ScopedProfile { + public: + ScopedProfile(uint64_t key, const char *name, query::v2::ExecutionContext *context) noexcept : context_(context) { + if (UNLIKELY(context_->is_profile_query)) { + root_ = context_->stats_root; + + // Are we the root logical operator? + if (!root_) { + stats_ = &context_->stats; + stats_->key = key; + stats_->name = name; + } else { + stats_ = nullptr; + + // Was this logical operator already hit on one of the previous pulls? + auto it = std::find_if(root_->children.begin(), root_->children.end(), + [key](auto &stats) { return stats.key == key; }); + + if (it == root_->children.end()) { + root_->children.emplace_back(); + stats_ = &root_->children.back(); + stats_->key = key; + stats_->name = name; + } else { + stats_ = &(*it); + } + } + + context_->stats_root = stats_; + stats_->actual_hits++; + start_time_ = utils::ReadTSC(); + } + } + + ~ScopedProfile() noexcept { + if (UNLIKELY(context_->is_profile_query)) { + stats_->num_cycles += utils::ReadTSC() - start_time_; + + // Restore the old root ("pop") + context_->stats_root = root_; + } + } + + private: + query::v2::ExecutionContext *context_; + ProfilingStats *root_{nullptr}; + ProfilingStats *stats_{nullptr}; + unsigned long long start_time_{0}; +}; + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/variable_start_planner.cpp b/src/query/v2/plan/variable_start_planner.cpp new file mode 100644 index 000000000..b6d15f73d --- /dev/null +++ b/src/query/v2/plan/variable_start_planner.cpp @@ -0,0 +1,296 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/plan/variable_start_planner.hpp" + +#include <limits> +#include <queue> + +#include "utils/flag_validation.hpp" +#include "utils/logging.hpp" + +DEFINE_VALIDATED_HIDDEN_uint64(query_max_plans, 1000U, "Maximum number of generated plans for a query.", + FLAG_IN_RANGE(1, std::numeric_limits<std::uint64_t>::max())); + +namespace memgraph::query::v2::plan::impl { + +namespace { + +// Add applicable expansions for `node_symbol` to `next_expansions`. These +// expansions are removed from `node_symbol_to_expansions`, while +// `seen_expansions` and `expanded_symbols` are populated with new data. +void AddNextExpansions(const Symbol &node_symbol, const Matching &matching, const SymbolTable &symbol_table, + std::unordered_set<Symbol> &expanded_symbols, + std::unordered_map<Symbol, std::set<size_t>> &node_symbol_to_expansions, + std::unordered_set<size_t> &seen_expansions, std::queue<Expansion> &next_expansions) { + auto node_to_expansions_it = node_symbol_to_expansions.find(node_symbol); + if (node_to_expansions_it == node_symbol_to_expansions.end()) { + return; + } + // Returns true if the expansion is a regular expand or if it is a variable + // path expand, but with bound symbols used inside the range expression. + auto can_expand = [&](auto &expansion) { + for (const auto &range_symbol : expansion.symbols_in_range) { + // If the symbols used in range need to be bound during this whole + // expansion, we must check whether they have already been expanded and + // therefore bound. If the symbols are not found in the whole expansion, + // then the semantic analysis should guarantee that the symbols have been + // bound long before we expand. + if (matching.expansion_symbols.find(range_symbol) != matching.expansion_symbols.end() && + expanded_symbols.find(range_symbol) == expanded_symbols.end()) { + return false; + } + } + return true; + }; + auto &node_expansions = node_to_expansions_it->second; + auto node_expansions_it = node_expansions.begin(); + while (node_expansions_it != node_to_expansions_it->second.end()) { + auto expansion_id = *node_expansions_it; + if (seen_expansions.find(expansion_id) != seen_expansions.end()) { + // Skip and erase seen (already expanded) expansions. + node_expansions_it = node_expansions.erase(node_expansions_it); + continue; + } + auto expansion = matching.expansions[expansion_id]; + if (!can_expand(expansion)) { + // Skip but save expansions which need other symbols for later. + ++node_expansions_it; + continue; + } + if (symbol_table.at(*expansion.node1->identifier_) != node_symbol) { + // We are not expanding from node1, so flip the expansion. + DMG_ASSERT(expansion.node2 && symbol_table.at(*expansion.node2->identifier_) == node_symbol, + "Expected node_symbol to be bound in node2"); + if (expansion.edge->type_ != EdgeAtom::Type::BREADTH_FIRST) { + // BFS must *not* be flipped. Doing that changes the BFS results. + std::swap(expansion.node1, expansion.node2); + expansion.is_flipped = true; + if (expansion.direction != EdgeAtom::Direction::BOTH) { + expansion.direction = + expansion.direction == EdgeAtom::Direction::IN ? EdgeAtom::Direction::OUT : EdgeAtom::Direction::IN; + } + } + } + seen_expansions.insert(expansion_id); + expanded_symbols.insert(symbol_table.at(*expansion.node1->identifier_)); + if (expansion.edge) { + expanded_symbols.insert(symbol_table.at(*expansion.edge->identifier_)); + expanded_symbols.insert(symbol_table.at(*expansion.node2->identifier_)); + } + next_expansions.emplace(std::move(expansion)); + node_expansions_it = node_expansions.erase(node_expansions_it); + } + if (node_expansions.empty()) { + node_symbol_to_expansions.erase(node_to_expansions_it); + } +} + +// Generates expansions emanating from the start_node by forming a chain. When +// the chain can no longer be continued, a different starting node is picked +// among remaining expansions and the process continues. This is done until all +// matching.expansions are used. +std::vector<Expansion> ExpansionsFrom(const NodeAtom *start_node, const Matching &matching, + const SymbolTable &symbol_table) { + // Make a copy of node_symbol_to_expansions, because we will modify it as + // expansions are chained. + auto node_symbol_to_expansions = matching.node_symbol_to_expansions; + std::unordered_set<size_t> seen_expansions; + std::queue<Expansion> next_expansions; + std::unordered_set<Symbol> expanded_symbols({symbol_table.at(*start_node->identifier_)}); + auto add_next_expansions = [&](const auto *node) { + AddNextExpansions(symbol_table.at(*node->identifier_), matching, symbol_table, expanded_symbols, + node_symbol_to_expansions, seen_expansions, next_expansions); + }; + add_next_expansions(start_node); + // Potential optimization: expansions and next_expansions could be merge into + // a single vector and an index could be used to determine from which should + // additional expansions be added. + std::vector<Expansion> expansions; + while (!next_expansions.empty()) { + auto expansion = next_expansions.front(); + next_expansions.pop(); + expansions.emplace_back(expansion); + add_next_expansions(expansion.node1); + if (expansion.node2) { + add_next_expansions(expansion.node2); + } + } + if (!node_symbol_to_expansions.empty()) { + // We could pick a new starting expansion, but to avoid runtime + // complexity, simply append the remaining expansions. They should have the + // correct order, since the original expansions were verified during + // semantic analysis. + for (size_t i = 0; i < matching.expansions.size(); ++i) { + if (seen_expansions.find(i) != seen_expansions.end()) { + continue; + } + expansions.emplace_back(matching.expansions[i]); + } + } + return expansions; +} + +// Collect all unique nodes from expansions. Uniqueness is determined by +// symbol uniqueness. +auto ExpansionNodes(const std::vector<Expansion> &expansions, const SymbolTable &symbol_table) { + std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual> nodes(expansions.size(), NodeSymbolHash(symbol_table), + NodeSymbolEqual(symbol_table)); + for (const auto &expansion : expansions) { + // TODO: Handle labels and properties from different node atoms. + nodes.insert(expansion.node1); + if (expansion.node2) { + nodes.insert(expansion.node2); + } + } + return nodes; +} + +} // namespace + +VaryMatchingStart::VaryMatchingStart(Matching matching, const SymbolTable &symbol_table) + : matching_(matching), symbol_table_(symbol_table), nodes_(ExpansionNodes(matching.expansions, symbol_table)) {} + +VaryMatchingStart::iterator::iterator(VaryMatchingStart *self, bool is_done) + : self_(self), + // Use the original matching as the first matching. We are only + // interested in changing the expansions part, so the remaining fields + // should stay the same. This also produces a matching for the case + // when there are no nodes. + current_matching_(self->matching_) { + if (!self_->nodes_.empty()) { + // Overwrite the original matching expansions with the new ones by + // generating it from the first start node. + start_nodes_it_ = self_->nodes_.begin(); + current_matching_.expansions = ExpansionsFrom(**start_nodes_it_, self_->matching_, self_->symbol_table_); + } + DMG_ASSERT(start_nodes_it_ || self_->nodes_.empty(), + "start_nodes_it_ should only be nullopt when self_->nodes_ is empty"); + if (is_done) { + start_nodes_it_ = self_->nodes_.end(); + } +} + +VaryMatchingStart::iterator &VaryMatchingStart::iterator::operator++() { + if (!start_nodes_it_) { + DMG_ASSERT(self_->nodes_.empty(), "start_nodes_it_ should only be nullopt when self_->nodes_ is empty"); + start_nodes_it_ = self_->nodes_.end(); + } + if (*start_nodes_it_ == self_->nodes_.end()) { + return *this; + } + ++*start_nodes_it_; + // start_nodes_it_ can become equal to `end` and we shouldn't dereference + // iterator in that case. + if (*start_nodes_it_ == self_->nodes_.end()) { + return *this; + } + const auto &start_node = **start_nodes_it_; + current_matching_.expansions = ExpansionsFrom(start_node, self_->matching_, self_->symbol_table_); + return *this; +} + +CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(const std::vector<Matching> &matchings, + const SymbolTable &symbol_table) { + std::vector<VaryMatchingStart> variants; + variants.reserve(matchings.size()); + for (const auto &matching : matchings) { + variants.emplace_back(VaryMatchingStart(matching, symbol_table)); + } + return MakeCartesianProduct(std::move(variants)); +} + +VaryQueryPartMatching::VaryQueryPartMatching(SingleQueryPart query_part, const SymbolTable &symbol_table) + : query_part_(std::move(query_part)), + matchings_(VaryMatchingStart(query_part_.matching, symbol_table)), + optional_matchings_(VaryMultiMatchingStarts(query_part_.optional_matching, symbol_table)), + merge_matchings_(VaryMultiMatchingStarts(query_part_.merge_matching, symbol_table)) {} + +VaryQueryPartMatching::iterator::iterator(const SingleQueryPart &query_part, + VaryMatchingStart::iterator matchings_begin, + VaryMatchingStart::iterator matchings_end, + CartesianProduct<VaryMatchingStart>::iterator optional_begin, + CartesianProduct<VaryMatchingStart>::iterator optional_end, + CartesianProduct<VaryMatchingStart>::iterator merge_begin, + CartesianProduct<VaryMatchingStart>::iterator merge_end) + : current_query_part_(query_part), + matchings_it_(matchings_begin), + matchings_end_(matchings_end), + optional_it_(optional_begin), + optional_begin_(optional_begin), + optional_end_(optional_end), + merge_it_(merge_begin), + merge_begin_(merge_begin), + merge_end_(merge_end) { + if (matchings_it_ != matchings_end_) { + // Fill the query part with the first variation of matchings + SetCurrentQueryPart(); + } +} + +VaryQueryPartMatching::iterator &VaryQueryPartMatching::iterator::operator++() { + // Produce parts for each possible combination. E.g. if we have: + // * matchings (m1) and (m2) + // * optional matchings (o1) and (o2) + // * merge matching (g1) + // We want to produce parts for: + // * (m1), (o1), (g1) + // * (m1), (o2), (g1) + // * (m2), (o1), (g1) + // * (m2), (o2), (g1) + // Create variations by changing the merge part first. + if (merge_it_ != merge_end_) ++merge_it_; + // If all merge variations are done, start them from beginning and move to the + // next optional matching variation. + if (merge_it_ == merge_end_) { + merge_it_ = merge_begin_; + if (optional_it_ != optional_end_) ++optional_it_; + } + // If all optional matching variations are done (after exhausting merge + // variations), start them from beginning and move to the next regular + // matching variation. + if (optional_it_ == optional_end_ && merge_it_ == merge_begin_) { + optional_it_ = optional_begin_; + if (matchings_it_ != matchings_end_) ++matchings_it_; + } + // We have reached the end, so return; + if (matchings_it_ == matchings_end_) return *this; + // Fill the query part with the new variation of matchings. + SetCurrentQueryPart(); + return *this; +} + +void VaryQueryPartMatching::iterator::SetCurrentQueryPart() { + current_query_part_.matching = *matchings_it_; + DMG_ASSERT(optional_it_ != optional_end_ || optional_begin_ == optional_end_, + "Either there are no optional matchings or we can always " + "generate a variation"); + if (optional_it_ != optional_end_) { + current_query_part_.optional_matching = *optional_it_; + } + DMG_ASSERT(merge_it_ != merge_end_ || merge_begin_ == merge_end_, + "Either there are no merge matchings or we can always generate " + "a variation"); + if (merge_it_ != merge_end_) { + current_query_part_.merge_matching = *merge_it_; + } +} + +bool VaryQueryPartMatching::iterator::operator==(const iterator &other) const { + if (matchings_it_ == other.matchings_it_ && matchings_it_ == matchings_end_) { + // matchings_it_ is the primary iterator. If both are at the end, then other + // iterators can be at any position. + return true; + } + return matchings_it_ == other.matchings_it_ && optional_it_ == other.optional_it_ && merge_it_ == other.merge_it_; +} + +} // namespace memgraph::query::v2::plan::impl diff --git a/src/query/v2/plan/variable_start_planner.hpp b/src/query/v2/plan/variable_start_planner.hpp new file mode 100644 index 000000000..27722b6b2 --- /dev/null +++ b/src/query/v2/plan/variable_start_planner.hpp @@ -0,0 +1,336 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +#pragma once + +#include "cppitertools/imap.hpp" +#include "cppitertools/slice.hpp" +#include "gflags/gflags.h" + +#include "query/v2/plan/rule_based_planner.hpp" + +DECLARE_uint64(query_max_plans); + +namespace memgraph::query::v2::plan { + +/// Produces a Cartesian product among vectors between begin and end iterator. +/// For example: +/// +/// std::vector<int> first_set{1,2,3}; +/// std::vector<int> second_set{4,5}; +/// std::vector<std::vector<int>> all_sets{first_set, second_set}; +/// // prod should be {{1, 4}, {1, 5}, {2, 4}, {2, 5}, {3, 4}, {3, 5}} +/// auto product = MakeCartesianProduct(all_sets); +/// for (const auto &set : product) { +/// ... +/// } +/// +/// The product is created lazily by iterating over the constructed +/// CartesianProduct instance. +template <typename TSet> +class CartesianProduct { + private: + // The original sets whose Cartesian product we are calculating. + std::vector<TSet> original_sets_; + // Iterators to the beginning and end of original_sets_. + decltype(original_sets_.begin()) begin_; + decltype(original_sets_.end()) end_; + + // Type of the set element. + using TElement = typename decltype(begin_->begin())::value_type; + + public: + CartesianProduct(std::vector<TSet> sets) + : original_sets_(std::move(sets)), begin_(original_sets_.begin()), end_(original_sets_.end()) {} + + class iterator { + public: + typedef std::input_iterator_tag iterator_category; + typedef std::vector<TElement> value_type; + typedef long difference_type; + typedef const std::vector<TElement> &reference; + typedef const std::vector<TElement> *pointer; + + explicit iterator(CartesianProduct *self, bool is_done) : self_(self), is_done_(is_done) { + if (is_done || self->begin_ == self->end_) { + is_done_ = true; + return; + } + auto begin = self->begin_; + while (begin != self->end_) { + auto set_it = begin->begin(); + if (set_it == begin->end()) { + // One of the sets is empty, so there is no product. + is_done_ = true; + return; + } + // Collect the first product, by taking the first element of each set. + current_product_.emplace_back(*set_it); + // Store starting iterators to all sets. + sets_.emplace_back(begin, set_it); + begin++; + } + } + + iterator &operator++() { + if (is_done_) return *this; + // Increment the leftmost set iterator. + auto sets_it = sets_.begin(); + ++sets_it->second; + // If the leftmost is at the end, reset it and increment the next + // leftmost. + while (sets_it->second == sets_it->first->end()) { + sets_it->second = sets_it->first->begin(); + sets_it++; + if (sets_it == sets_.end()) { + // The leftmost set is the last set and it was exhausted, so we are + // done. + is_done_ = true; + return *this; + } + ++sets_it->second; + } + // We can now collect another product from the modified set iterators. + DMG_ASSERT(current_product_.size() == sets_.size(), + "Expected size of current_product_ to match the size of sets_"); + size_t i = 0; + // Change only the prefix of the product, remaining elements (after + // sets_it) should be the same. + auto last_unmodified = sets_it + 1; + for (auto kv_it = sets_.begin(); kv_it != last_unmodified; ++kv_it) { + current_product_[i++] = *kv_it->second; + } + return *this; + } + + bool operator==(const iterator &other) const { + if (self_->begin_ != other.self_->begin_ || self_->end_ != other.self_->end_) return false; + return (is_done_ && other.is_done_) || (sets_ == other.sets_); + } + + bool operator!=(const iterator &other) const { return !(*this == other); } + + // Iterator interface says that dereferencing a past-the-end iterator is + // undefined, so don't bother checking if we are done. + reference operator*() const { return current_product_; } + pointer operator->() const { return ¤t_product_; } + + private: + // Pointer instead of reference to auto generate copy constructor and + // assignment. + CartesianProduct *self_; + // Vector of (original_sets_iterator, set_iterator) pairs. The + // original_sets_iterator points to the set among all the sets, while the + // set_iterator points to an element inside the pointed to set. + std::vector<std::pair<decltype(self_->begin_), decltype(self_->begin_->begin())>> sets_; + // Currently built product from pointed to elements in all sets. + std::vector<TElement> current_product_; + // Set to true when we have generated all products. + bool is_done_ = false; + }; + + auto begin() { return iterator(this, false); } + auto end() { return iterator(this, true); } + + private: + friend class iterator; +}; + +/// Convenience function for creating CartesianProduct by deducing template +/// arguments from function arguments. +template <typename TSet> +auto MakeCartesianProduct(std::vector<TSet> sets) { + return CartesianProduct<TSet>(std::move(sets)); +} + +namespace impl { + +class NodeSymbolHash { + public: + explicit NodeSymbolHash(const SymbolTable &symbol_table) : symbol_table_(symbol_table) {} + + size_t operator()(const NodeAtom *node_atom) const { + return std::hash<Symbol>{}(symbol_table_.at(*node_atom->identifier_)); + } + + private: + const SymbolTable &symbol_table_; +}; + +class NodeSymbolEqual { + public: + explicit NodeSymbolEqual(const SymbolTable &symbol_table) : symbol_table_(symbol_table) {} + + bool operator()(const NodeAtom *node_atom1, const NodeAtom *node_atom2) const { + return symbol_table_.at(*node_atom1->identifier_) == symbol_table_.at(*node_atom2->identifier_); + } + + private: + const SymbolTable &symbol_table_; +}; + +// Generates n matchings, where n is the number of nodes to match. Each Matching +// will have a different node as a starting node for expansion. +class VaryMatchingStart { + public: + VaryMatchingStart(Matching, const SymbolTable &); + + class iterator { + public: + typedef std::input_iterator_tag iterator_category; + typedef Matching value_type; + typedef long difference_type; + typedef const Matching &reference; + typedef const Matching *pointer; + + iterator(VaryMatchingStart *, bool); + + iterator &operator++(); + reference operator*() const { return current_matching_; } + pointer operator->() const { return ¤t_matching_; } + bool operator==(const iterator &other) const { + return self_ == other.self_ && start_nodes_it_ == other.start_nodes_it_; + } + bool operator!=(const iterator &other) const { return !(*this == other); } + + private: + // Pointer instead of reference to auto generate copy constructor and + // assignment. + VaryMatchingStart *self_; + Matching current_matching_; + // Iterator over start nodes. Optional is used for differentiating the case + // when there are no start nodes vs. VaryMatchingStart::iterator itself + // being at the end. When there are no nodes, this iterator needs to produce + // a single result, which is the original matching passed in. Setting + // start_nodes_it_ to end signifies the end of our iteration. + std::optional<std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual>::iterator> start_nodes_it_; + }; + + auto begin() { return iterator(this, false); } + auto end() { return iterator(this, true); } + + private: + friend class iterator; + Matching matching_; + const SymbolTable &symbol_table_; + std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual> nodes_; +}; + +// Similar to VaryMatchingStart, but varies the starting nodes for all given +// matchings. After all matchings produce multiple alternative starts, the +// Cartesian product of all of them is returned. +CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(const std::vector<Matching> &, const SymbolTable &); + +// Produces alternative query parts out of a single part by varying how each +// graph matching is done. +class VaryQueryPartMatching { + public: + VaryQueryPartMatching(SingleQueryPart, const SymbolTable &); + + class iterator { + public: + typedef std::input_iterator_tag iterator_category; + typedef SingleQueryPart value_type; + typedef long difference_type; + typedef const SingleQueryPart &reference; + typedef const SingleQueryPart *pointer; + + iterator(const SingleQueryPart &, VaryMatchingStart::iterator, VaryMatchingStart::iterator, + CartesianProduct<VaryMatchingStart>::iterator, CartesianProduct<VaryMatchingStart>::iterator, + CartesianProduct<VaryMatchingStart>::iterator, CartesianProduct<VaryMatchingStart>::iterator); + + iterator &operator++(); + reference operator*() const { return current_query_part_; } + pointer operator->() const { return ¤t_query_part_; } + bool operator==(const iterator &) const; + bool operator!=(const iterator &other) const { return !(*this == other); } + + private: + void SetCurrentQueryPart(); + + SingleQueryPart current_query_part_; + VaryMatchingStart::iterator matchings_it_; + VaryMatchingStart::iterator matchings_end_; + CartesianProduct<VaryMatchingStart>::iterator optional_it_; + CartesianProduct<VaryMatchingStart>::iterator optional_begin_; + CartesianProduct<VaryMatchingStart>::iterator optional_end_; + CartesianProduct<VaryMatchingStart>::iterator merge_it_; + CartesianProduct<VaryMatchingStart>::iterator merge_begin_; + CartesianProduct<VaryMatchingStart>::iterator merge_end_; + }; + + auto begin() { + return iterator(query_part_, matchings_.begin(), matchings_.end(), optional_matchings_.begin(), + optional_matchings_.end(), merge_matchings_.begin(), merge_matchings_.end()); + } + auto end() { + return iterator(query_part_, matchings_.end(), matchings_.end(), optional_matchings_.end(), + optional_matchings_.end(), merge_matchings_.end(), merge_matchings_.end()); + } + + private: + SingleQueryPart query_part_; + // Multiple regular matchings, each starting from different node. + VaryMatchingStart matchings_; + // Multiple optional matchings, where each combination has different starting + // nodes. + CartesianProduct<VaryMatchingStart> optional_matchings_; + // Like optional matching, but for merge matchings. + CartesianProduct<VaryMatchingStart> merge_matchings_; +}; + +} // namespace impl + +/// @brief Planner which generates multiple plans by changing the order of graph +/// traversal. +/// +/// This planner picks different starting nodes from which to start graph +/// traversal. Generating a single plan is backed by @c RuleBasedPlanner. +/// +/// @sa MakeLogicalPlan +template <class TPlanningContext> +class VariableStartPlanner { + private: + TPlanningContext *context_; + + // Generates different, equivalent query parts by taking different graph + // matching routes for each query part. + auto VaryQueryMatching(const std::vector<SingleQueryPart> &query_parts, const SymbolTable &symbol_table) { + std::vector<impl::VaryQueryPartMatching> alternative_query_parts; + alternative_query_parts.reserve(query_parts.size()); + for (const auto &query_part : query_parts) { + alternative_query_parts.emplace_back(impl::VaryQueryPartMatching(query_part, symbol_table)); + } + return iter::slice(MakeCartesianProduct(std::move(alternative_query_parts)), 0UL, FLAGS_query_max_plans); + } + + public: + explicit VariableStartPlanner(TPlanningContext *context) : context_(context) {} + + /// @brief Generate multiple plans by varying the order of graph traversal. + auto Plan(const std::vector<SingleQueryPart> &query_parts) { + return iter::imap( + [context = context_](const auto &alternative_query_parts) { + RuleBasedPlanner<TPlanningContext> rule_planner(context); + context->bound_symbols.clear(); + return rule_planner.Plan(alternative_query_parts); + }, + VaryQueryMatching(query_parts, *context_->symbol_table)); + } + + /// @brief The result of plan generation is an iterable of roots to multiple + /// generated operator trees. + using PlanResult = typename std::result_of<decltype (&VariableStartPlanner<TPlanningContext>::Plan)( + VariableStartPlanner<TPlanningContext>, std::vector<SingleQueryPart> &)>::type; +}; + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/plan/vertex_count_cache.hpp b/src/query/v2/plan/vertex_count_cache.hpp new file mode 100644 index 000000000..fe8c68327 --- /dev/null +++ b/src/query/v2/plan/vertex_count_cache.hpp @@ -0,0 +1,141 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +#pragma once + +#include <optional> + +#include "query/v2/typed_value.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "utils/bound.hpp" +#include "utils/fnv.hpp" + +namespace memgraph::query::v2::plan { + +/// A stand in class for `TDbAccessor` which provides memoized calls to +/// `VerticesCount`. +template <class TDbAccessor> +class VertexCountCache { + public: + VertexCountCache(TDbAccessor *db) : db_(db) {} + + auto NameToLabel(const std::string &name) { return db_->NameToLabel(name); } + auto NameToProperty(const std::string &name) { return db_->NameToProperty(name); } + auto NameToEdgeType(const std::string &name) { return db_->NameToEdgeType(name); } + + int64_t VerticesCount() { + if (!vertices_count_) vertices_count_ = db_->VerticesCount(); + return *vertices_count_; + } + + int64_t VerticesCount(storage::v3::LabelId label) { + if (label_vertex_count_.find(label) == label_vertex_count_.end()) + label_vertex_count_[label] = db_->VerticesCount(label); + return label_vertex_count_.at(label); + } + + int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property) { + auto key = std::make_pair(label, property); + if (label_property_vertex_count_.find(key) == label_property_vertex_count_.end()) + label_property_vertex_count_[key] = db_->VerticesCount(label, property); + return label_property_vertex_count_.at(key); + } + + int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property, + const storage::v3::PropertyValue &value) { + auto label_prop = std::make_pair(label, property); + auto &value_vertex_count = property_value_vertex_count_[label_prop]; + // TODO: Why do we even need TypedValue in this whole file? + TypedValue tv_value(value); + if (value_vertex_count.find(tv_value) == value_vertex_count.end()) + value_vertex_count[tv_value] = db_->VerticesCount(label, property, value); + return value_vertex_count.at(tv_value); + } + + int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property, + const std::optional<utils::Bound<storage::v3::PropertyValue>> &lower, + const std::optional<utils::Bound<storage::v3::PropertyValue>> &upper) { + auto label_prop = std::make_pair(label, property); + auto &bounds_vertex_count = property_bounds_vertex_count_[label_prop]; + BoundsKey bounds = std::make_pair(lower, upper); + if (bounds_vertex_count.find(bounds) == bounds_vertex_count.end()) + bounds_vertex_count[bounds] = db_->VerticesCount(label, property, lower, upper); + return bounds_vertex_count.at(bounds); + } + + bool LabelIndexExists(storage::v3::LabelId label) { return db_->LabelIndexExists(label); } + + bool LabelPropertyIndexExists(storage::v3::LabelId label, storage::v3::PropertyId property) { + return db_->LabelPropertyIndexExists(label, property); + } + + private: + typedef std::pair<storage::v3::LabelId, storage::v3::PropertyId> LabelPropertyKey; + + struct LabelPropertyHash { + size_t operator()(const LabelPropertyKey &key) const { + return utils::HashCombine<storage::v3::LabelId, storage::v3::PropertyId>{}(key.first, key.second); + } + }; + + typedef std::pair<std::optional<utils::Bound<storage::v3::PropertyValue>>, + std::optional<utils::Bound<storage::v3::PropertyValue>>> + BoundsKey; + + struct BoundsHash { + size_t operator()(const BoundsKey &key) const { + const auto &maybe_lower = key.first; + const auto &maybe_upper = key.second; + query::v2::TypedValue lower; + query::v2::TypedValue upper; + if (maybe_lower) lower = TypedValue(maybe_lower->value()); + if (maybe_upper) upper = TypedValue(maybe_upper->value()); + query::v2::TypedValue::Hash hash; + return utils::HashCombine<size_t, size_t>{}(hash(lower), hash(upper)); + } + }; + + struct BoundsEqual { + bool operator()(const BoundsKey &a, const BoundsKey &b) const { + auto bound_equal = [](const auto &maybe_bound_a, const auto &maybe_bound_b) { + if (maybe_bound_a && maybe_bound_b && maybe_bound_a->type() != maybe_bound_b->type()) return false; + query::v2::TypedValue bound_a; + query::v2::TypedValue bound_b; + if (maybe_bound_a) bound_a = TypedValue(maybe_bound_a->value()); + if (maybe_bound_b) bound_b = TypedValue(maybe_bound_b->value()); + return query::v2::TypedValue::BoolEqual{}(bound_a, bound_b); + }; + return bound_equal(a.first, b.first) && bound_equal(a.second, b.second); + } + }; + + TDbAccessor *db_; + std::optional<int64_t> vertices_count_; + std::unordered_map<storage::v3::LabelId, int64_t> label_vertex_count_; + std::unordered_map<LabelPropertyKey, int64_t, LabelPropertyHash> label_property_vertex_count_; + std::unordered_map< + LabelPropertyKey, + std::unordered_map<query::v2::TypedValue, int64_t, query::v2::TypedValue::Hash, query::v2::TypedValue::BoolEqual>, + LabelPropertyHash> + property_value_vertex_count_; + std::unordered_map<LabelPropertyKey, std::unordered_map<BoundsKey, int64_t, BoundsHash, BoundsEqual>, + LabelPropertyHash> + property_bounds_vertex_count_; +}; + +template <class TDbAccessor> +auto MakeVertexCountCache(TDbAccessor *db) { + return VertexCountCache<TDbAccessor>(db); +} + +} // namespace memgraph::query::v2::plan diff --git a/src/query/v2/procedure/cypher_type_ptr.hpp b/src/query/v2/procedure/cypher_type_ptr.hpp new file mode 100644 index 000000000..be8d5ff6e --- /dev/null +++ b/src/query/v2/procedure/cypher_type_ptr.hpp @@ -0,0 +1,20 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <functional> +#include <memory> + +namespace memgraph::query::v2::procedure { +class CypherType; +using CypherTypePtr = std::unique_ptr<CypherType, std::function<void(CypherType *)>>; +} // namespace memgraph::query::v2::procedure diff --git a/src/query/v2/procedure/cypher_types.hpp b/src/query/v2/procedure/cypher_types.hpp new file mode 100644 index 000000000..dc8f0f25b --- /dev/null +++ b/src/query/v2/procedure/cypher_types.hpp @@ -0,0 +1,293 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +#pragma once + +#include "mg_procedure.h" + +#include <functional> +#include <memory> +#include <string_view> + +#include "query/v2/procedure/cypher_type_ptr.hpp" +#include "query/v2/procedure/mg_procedure_impl.hpp" +#include "query/v2/typed_value.hpp" +#include "utils/memory.hpp" +#include "utils/pmr/string.hpp" + +namespace memgraph::query::v2::procedure { + +class ListType; +class NullableType; + +/// Interface for all supported types in openCypher type system. +class CypherType { + public: + CypherType() = default; + virtual ~CypherType() = default; + + CypherType(const CypherType &) = delete; + CypherType(CypherType &&) = delete; + CypherType &operator=(const CypherType &) = delete; + CypherType &operator=(CypherType &&) = delete; + + /// Get name of the type as it should be presented to the user. + virtual std::string_view GetPresentableName() const = 0; + + /// Return true if given mgp_value is of the type as described by `this`. + virtual bool SatisfiesType(const mgp_value &) const = 0; + + /// Return true if given TypedValue is of the type as described by `this`. + virtual bool SatisfiesType(const query::v2::TypedValue &) const = 0; + + // The following methods are a simple replacement for RTTI because we have + // some special cases we need to handle. + virtual const ListType *AsListType() const { return nullptr; } + virtual const NullableType *AsNullableType() const { return nullptr; } +}; + +// Simple Types + +class AnyType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "ANY"; } + + bool SatisfiesType(const mgp_value &value) const override { return value.type != MGP_VALUE_TYPE_NULL; } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return !value.IsNull(); } +}; + +class BoolType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "BOOLEAN"; } + + bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_BOOL; } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsBool(); } +}; + +class StringType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "STRING"; } + + bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_STRING; } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsString(); } +}; + +class IntType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "INTEGER"; } + + bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_INT; } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsInt(); } +}; + +class FloatType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "FLOAT"; } + + bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_DOUBLE; } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsDouble(); } +}; + +class NumberType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "NUMBER"; } + + bool SatisfiesType(const mgp_value &value) const override { + return value.type == MGP_VALUE_TYPE_INT || value.type == MGP_VALUE_TYPE_DOUBLE; + } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsInt() || value.IsDouble(); } +}; + +class NodeType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "NODE"; } + + bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_VERTEX; } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsVertex(); } +}; + +class RelationshipType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "RELATIONSHIP"; } + + bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_EDGE; } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsEdge(); } +}; + +class PathType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "PATH"; } + + bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_PATH; } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsPath(); } +}; + +// You'd think that MapType would be a composite type like ListType, but nope. +// Why? No-one really knows. It's defined like that in "CIP2015-09-16 Public +// Type System and Type Annotations" +// Additionally, MapType also covers NodeType and RelationshipType because +// values of that type have property *maps*. +class MapType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "MAP"; } + + bool SatisfiesType(const mgp_value &value) const override { + return value.type == MGP_VALUE_TYPE_MAP || value.type == MGP_VALUE_TYPE_VERTEX || value.type == MGP_VALUE_TYPE_EDGE; + } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { + return value.IsMap() || value.IsVertex() || value.IsEdge(); + } +}; + +// Temporal Types + +class DateType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "DATE"; } + + bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_DATE; } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsDate(); } +}; + +class LocalTimeType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "LOCAL_TIME"; } + + bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_LOCAL_TIME; } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsLocalTime(); } +}; + +class LocalDateTimeType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "LOCAL_DATE_TIME"; } + + bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_LOCAL_DATE_TIME; } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsLocalDateTime(); } +}; + +class DurationType : public CypherType { + public: + std::string_view GetPresentableName() const override { return "DURATION"; } + + bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_DURATION; } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsDuration(); } +}; + +// Composite Types + +class ListType : public CypherType { + public: + CypherTypePtr element_type_; + utils::pmr::string presentable_name_; + + /// @throw std::bad_alloc + /// @throw std::length_error + explicit ListType(CypherTypePtr element_type, utils::MemoryResource *memory) + : element_type_(std::move(element_type)), presentable_name_("LIST OF ", memory) { + presentable_name_.append(element_type_->GetPresentableName()); + } + + std::string_view GetPresentableName() const override { return presentable_name_; } + + bool SatisfiesType(const mgp_value &value) const override { + if (value.type != MGP_VALUE_TYPE_LIST) { + return false; + } + auto *list = value.list_v; + const auto list_size = list->elems.size(); + for (size_t i = 0; i < list_size; ++i) { + if (!element_type_->SatisfiesType(list->elems[i])) { + return false; + }; + } + return true; + } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { + if (!value.IsList()) return false; + for (const auto &elem : value.ValueList()) { + if (!element_type_->SatisfiesType(elem)) return false; + } + return true; + } + + const ListType *AsListType() const override { return this; } +}; + +class NullableType : public CypherType { + CypherTypePtr type_; + utils::pmr::string presentable_name_; + + // Constructor is private, because we use a factory method Create to prevent + // nesting NullableType on top of each other. + // @throw std::bad_alloc + // @throw std::length_error + explicit NullableType(CypherTypePtr type, utils::MemoryResource *memory) + : type_(std::move(type)), presentable_name_(memory) { + const auto *list_type = type_->AsListType(); + // ListType is specially formatted + if (list_type) { + presentable_name_.assign("LIST? OF ").append(list_type->element_type_->GetPresentableName()); + } else { + presentable_name_.assign(type_->GetPresentableName()).append("?"); + } + } + + public: + /// Create a NullableType of some CypherType. + /// If passed in `type` is already a NullableType, it is returned intact. + /// Otherwise, `type` is wrapped in a new instance of NullableType. + /// @throw std::bad_alloc + /// @throw std::length_error + static CypherTypePtr Create(CypherTypePtr type, utils::MemoryResource *memory) { + if (type->AsNullableType()) return type; + utils::Allocator<NullableType> alloc(memory); + auto *nullable = alloc.allocate(1); + try { + new (nullable) NullableType(std::move(type), memory); + } catch (...) { + alloc.deallocate(nullable, 1); + throw; + } + return CypherTypePtr(nullable, [alloc](CypherType *base_ptr) mutable { + alloc.delete_object(static_cast<NullableType *>(base_ptr)); + }); + } + + std::string_view GetPresentableName() const override { return presentable_name_; } + + bool SatisfiesType(const mgp_value &value) const override { + return value.type == MGP_VALUE_TYPE_NULL || type_->SatisfiesType(value); + } + + bool SatisfiesType(const query::v2::TypedValue &value) const override { + return value.IsNull() || type_->SatisfiesType(value); + } + + const NullableType *AsNullableType() const override { return this; } +}; + +} // namespace memgraph::query::v2::procedure diff --git a/src/query/v2/procedure/mg_procedure_helpers.cpp b/src/query/v2/procedure/mg_procedure_helpers.cpp new file mode 100644 index 000000000..f3a624736 --- /dev/null +++ b/src/query/v2/procedure/mg_procedure_helpers.cpp @@ -0,0 +1,36 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/procedure/mg_procedure_helpers.hpp" + +namespace memgraph::query::v2::procedure { +MgpUniquePtr<mgp_value> GetStringValueOrSetError(const char *string, mgp_memory *memory, mgp_result *result) { + procedure::MgpUniquePtr<mgp_value> value{nullptr, mgp_value_destroy}; + const auto success = + TryOrSetError([&] { return procedure::CreateMgpObject(value, mgp_value_make_string, string, memory); }, result); + if (!success) { + value.reset(); + } + + return value; +} + +bool InsertResultOrSetError(mgp_result *result, mgp_result_record *record, const char *result_name, mgp_value *value) { + if (const auto err = mgp_result_record_insert(record, result_name, value); err != mgp_error::MGP_ERROR_NO_ERROR) { + const auto error_msg = fmt::format("Unable to set the result for {}, error = {}", result_name, err); + static_cast<void>(mgp_result_set_error_msg(result, error_msg.c_str())); + return false; + } + + return true; +} + +} // namespace memgraph::query::v2::procedure diff --git a/src/query/v2/procedure/mg_procedure_helpers.hpp b/src/query/v2/procedure/mg_procedure_helpers.hpp new file mode 100644 index 000000000..35248b794 --- /dev/null +++ b/src/query/v2/procedure/mg_procedure_helpers.hpp @@ -0,0 +1,69 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <memory> +#include <type_traits> +#include <utility> + +#include <fmt/format.h> + +#include "mg_procedure.h" + +namespace memgraph::query::v2::procedure { +template <typename TResult, typename TFunc, typename... TArgs> +TResult Call(TFunc func, TArgs... args) { + static_assert(std::is_trivially_copyable_v<TFunc>); + static_assert((std::is_trivially_copyable_v<std::remove_reference_t<TArgs>> && ...)); + TResult result{}; + MG_ASSERT(func(args..., &result) == mgp_error::MGP_ERROR_NO_ERROR); + return result; +} + +template <typename TFunc, typename... TArgs> +bool CallBool(TFunc func, TArgs... args) { + return Call<int>(func, args...) != 0; +} + +template <typename TObj> +using MgpRawObjectDeleter = void (*)(TObj *); + +template <typename TObj> +using MgpUniquePtr = std::unique_ptr<TObj, MgpRawObjectDeleter<TObj>>; + +template <typename TObj, typename TFunc, typename... TArgs> +mgp_error CreateMgpObject(MgpUniquePtr<TObj> &obj, TFunc func, TArgs &&...args) { + TObj *raw_obj{nullptr}; + const auto err = func(std::forward<TArgs>(args)..., &raw_obj); + obj.reset(raw_obj); + return err; +} + +template <typename Fun> +[[nodiscard]] bool TryOrSetError(Fun &&func, mgp_result *result) { + if (const auto err = func(); err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + static_cast<void>(mgp_result_set_error_msg(result, "Not enough memory!")); + return false; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + const auto error_msg = fmt::format("Unexpected error ({})!", err); + static_cast<void>(mgp_result_set_error_msg(result, error_msg.c_str())); + return false; + } + return true; +} + +[[nodiscard]] MgpUniquePtr<mgp_value> GetStringValueOrSetError(const char *string, mgp_memory *memory, + mgp_result *result); + +[[nodiscard]] bool InsertResultOrSetError(mgp_result *result, mgp_result_record *record, const char *result_name, + mgp_value *value); +} // namespace memgraph::query::v2::procedure diff --git a/src/query/v2/procedure/mg_procedure_impl.cpp b/src/query/v2/procedure/mg_procedure_impl.cpp new file mode 100644 index 000000000..b2a467ef0 --- /dev/null +++ b/src/query/v2/procedure/mg_procedure_impl.cpp @@ -0,0 +1,2798 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/procedure/mg_procedure_impl.hpp" + +#include <algorithm> +#include <cstddef> +#include <cstring> +#include <exception> +#include <memory> +#include <optional> +#include <regex> +#include <stdexcept> +#include <type_traits> +#include <utility> + +#include "mg_procedure.h" +#include "module.hpp" +#include "query/v2/procedure/cypher_types.hpp" +#include "query/v2/procedure/mg_procedure_helpers.hpp" +#include "query/v2/stream/common.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/view.hpp" +#include "utils/algorithm.hpp" +#include "utils/concepts.hpp" +#include "utils/logging.hpp" +#include "utils/math.hpp" +#include "utils/memory.hpp" +#include "utils/string.hpp" +#include "utils/temporal.hpp" +#include "utils/variant_helpers.hpp" + +// This file contains implementation of top level C API functions, but this is +// all actually part of memgraph::query::v2::procedure. So use that namespace for simplicity. +// NOLINTNEXTLINE(google-build-using-namespace) +using namespace memgraph::query::v2::procedure; + +namespace { + +void *MgpAlignedAllocImpl(memgraph::utils::MemoryResource &memory, const size_t size_in_bytes, const size_t alignment) { + if (size_in_bytes == 0U || !memgraph::utils::IsPow2(alignment)) return nullptr; + // Simplify alignment by always using values greater or equal to max_align. + const size_t alloc_align = std::max(alignment, alignof(std::max_align_t)); + // Allocate space for header containing size & alignment info. + const size_t header_size = sizeof(size_in_bytes) + sizeof(alloc_align); + // We need to return the `data` pointer aligned to the requested alignment. + // Since we request the initial memory to be aligned to `alloc_align`, we can + // just allocate an additional multiple of `alloc_align` of bytes such that + // the header fits. `data` will then be aligned after this multiple of bytes. + static_assert(std::is_same_v<size_t, uint64_t>); + const auto maybe_bytes_for_header = memgraph::utils::RoundUint64ToMultiple(header_size, alloc_align); + if (!maybe_bytes_for_header) return nullptr; + const size_t bytes_for_header = *maybe_bytes_for_header; + const size_t alloc_size = bytes_for_header + size_in_bytes; + if (alloc_size < size_in_bytes) return nullptr; + + void *ptr = memory.Allocate(alloc_size, alloc_align); + char *data = reinterpret_cast<char *>(ptr) + bytes_for_header; + std::memcpy(data - sizeof(size_in_bytes), &size_in_bytes, sizeof(size_in_bytes)); + std::memcpy(data - sizeof(size_in_bytes) - sizeof(alloc_align), &alloc_align, sizeof(alloc_align)); + return data; +} + +void MgpFreeImpl(memgraph::utils::MemoryResource &memory, void *const p) noexcept { + try { + if (!p) return; + char *const data = reinterpret_cast<char *>(p); + // Read the header containing size & alignment info. + size_t size_in_bytes{}; + std::memcpy(&size_in_bytes, data - sizeof(size_in_bytes), sizeof(size_in_bytes)); + size_t alloc_align{}; + std::memcpy(&alloc_align, data - sizeof(size_in_bytes) - sizeof(alloc_align), sizeof(alloc_align)); + // Reconstruct how many bytes we allocated on top of the original request. + // We need not check allocation request overflow, since we did so already in + // mgp_aligned_alloc. + const size_t header_size = sizeof(size_in_bytes) + sizeof(alloc_align); + const size_t bytes_for_header = *memgraph::utils::RoundUint64ToMultiple(header_size, alloc_align); + const size_t alloc_size = bytes_for_header + size_in_bytes; + // Get the original ptr we allocated. + void *const original_ptr = data - bytes_for_header; + memory.Deallocate(original_ptr, alloc_size, alloc_align); + } catch (const memgraph::utils::BasicException &be) { + spdlog::error("BasicException during the release of memory for query modules: {}", be.what()); + } catch (const std::exception &e) { + spdlog::error("std::exception during the release of memory for query modules: {}", e.what()); + } catch (...) { + spdlog::error("Unexpected throw during the release of memory for query modules"); + } +} +struct DeletedObjectException : public memgraph::utils::BasicException { + using memgraph::utils::BasicException::BasicException; +}; + +struct KeyAlreadyExistsException : public memgraph::utils::BasicException { + using memgraph::utils::BasicException::BasicException; +}; + +struct InsufficientBufferException : public memgraph::utils::BasicException { + using memgraph::utils::BasicException::BasicException; +}; + +struct ImmutableObjectException : public memgraph::utils::BasicException { + using memgraph::utils::BasicException::BasicException; +}; + +struct ValueConversionException : public memgraph::utils::BasicException { + using memgraph::utils::BasicException::BasicException; +}; + +struct SerializationException : public memgraph::utils::BasicException { + using memgraph::utils::BasicException::BasicException; +}; + +template <typename TFunc, typename TReturn> +concept ReturnsType = std::same_as<std::invoke_result_t<TFunc>, TReturn>; + +template <typename TFunc> +concept ReturnsVoid = ReturnsType<TFunc, void>; + +template <ReturnsVoid TFunc> +void WrapExceptionsHelper(TFunc &&func) { + std::forward<TFunc>(func)(); +} + +template <typename TFunc, typename TReturn = std::invoke_result_t<TFunc>> +void WrapExceptionsHelper(TFunc &&func, TReturn *result) { + *result = {}; + *result = std::forward<TFunc>(func)(); +} + +template <typename TFunc, typename... Args> +[[nodiscard]] mgp_error WrapExceptions(TFunc &&func, Args &&...args) noexcept { + static_assert(sizeof...(args) <= 1, "WrapExceptions should have only one or zero parameter!"); + try { + WrapExceptionsHelper(std::forward<TFunc>(func), std::forward<Args>(args)...); + } catch (const DeletedObjectException &neoe) { + spdlog::error("Deleted object error during mg API call: {}", neoe.what()); + return mgp_error::MGP_ERROR_DELETED_OBJECT; + } catch (const KeyAlreadyExistsException &kaee) { + spdlog::error("Key already exists error during mg API call: {}", kaee.what()); + return mgp_error::MGP_ERROR_KEY_ALREADY_EXISTS; + } catch (const InsufficientBufferException &ibe) { + spdlog::error("Insufficient buffer error during mg API call: {}", ibe.what()); + return mgp_error::MGP_ERROR_INSUFFICIENT_BUFFER; + } catch (const ImmutableObjectException &ioe) { + spdlog::error("Immutable object error during mg API call: {}", ioe.what()); + return mgp_error::MGP_ERROR_IMMUTABLE_OBJECT; + } catch (const ValueConversionException &vce) { + spdlog::error("Value converion error during mg API call: {}", vce.what()); + return mgp_error::MGP_ERROR_VALUE_CONVERSION; + } catch (const SerializationException &se) { + spdlog::error("Serialization error during mg API call: {}", se.what()); + return mgp_error::MGP_ERROR_SERIALIZATION_ERROR; + } catch (const std::bad_alloc &bae) { + spdlog::error("Memory allocation error during mg API call: {}", bae.what()); + return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE; + } catch (const memgraph::utils::OutOfMemoryException &oome) { + spdlog::error("Memory limit exceeded during mg API call: {}", oome.what()); + return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE; + } catch (const std::out_of_range &oore) { + spdlog::error("Out of range error during mg API call: {}", oore.what()); + return mgp_error::MGP_ERROR_OUT_OF_RANGE; + } catch (const std::invalid_argument &iae) { + spdlog::error("Invalid argument error during mg API call: {}", iae.what()); + return mgp_error::MGP_ERROR_INVALID_ARGUMENT; + } catch (const std::logic_error &lee) { + spdlog::error("Logic error during mg API call: {}", lee.what()); + return mgp_error::MGP_ERROR_LOGIC_ERROR; + } catch (const std::exception &e) { + spdlog::error("Unexpected error during mg API call: {}", e.what()); + return mgp_error::MGP_ERROR_UNKNOWN_ERROR; + } catch (const memgraph::utils::temporal::InvalidArgumentException &e) { + spdlog::error("Invalid argument was sent to an mg API call for temporal types: {}", e.what()); + return mgp_error::MGP_ERROR_INVALID_ARGUMENT; + } catch (...) { + spdlog::error("Unexpected error during mg API call"); + return mgp_error::MGP_ERROR_UNKNOWN_ERROR; + } + return mgp_error::MGP_ERROR_NO_ERROR; +} + +// Graph mutations +bool MgpGraphIsMutable(const mgp_graph &graph) noexcept { + return graph.view == memgraph::storage::v3::View::NEW && graph.ctx != nullptr; +} + +bool MgpVertexIsMutable(const mgp_vertex &vertex) { return MgpGraphIsMutable(*vertex.graph); } + +bool MgpEdgeIsMutable(const mgp_edge &edge) { return MgpVertexIsMutable(edge.from); } +} // namespace + +mgp_error mgp_alloc(mgp_memory *memory, size_t size_in_bytes, void **result) { + return mgp_aligned_alloc(memory, size_in_bytes, alignof(std::max_align_t), result); +} + +mgp_error mgp_aligned_alloc(mgp_memory *memory, const size_t size_in_bytes, const size_t alignment, void **result) { + return WrapExceptions( + [memory, size_in_bytes, alignment] { return MgpAlignedAllocImpl(*memory->impl, size_in_bytes, alignment); }, + result); +} + +void mgp_free(mgp_memory *memory, void *const p) { + static_assert(noexcept(MgpFreeImpl(*memory->impl, p))); + MgpFreeImpl(*memory->impl, p); +} + +mgp_error mgp_global_alloc(size_t size_in_bytes, void **result) { + return mgp_global_aligned_alloc(size_in_bytes, alignof(std::max_align_t), result); +} + +mgp_error mgp_global_aligned_alloc(size_t size_in_bytes, size_t alignment, void **result) { + return WrapExceptions( + [size_in_bytes, alignment] { + return MgpAlignedAllocImpl(gModuleRegistry.GetSharedMemoryResource(), size_in_bytes, alignment); + }, + result); +} + +void mgp_global_free(void *const p) { + static_assert(noexcept(MgpFreeImpl(gModuleRegistry.GetSharedMemoryResource(), p))); + MgpFreeImpl(gModuleRegistry.GetSharedMemoryResource(), p); +} + +namespace { + +template <class U, class... TArgs> +U *NewRawMgpObject(memgraph::utils::MemoryResource *memory, TArgs &&...args) { + memgraph::utils::Allocator<U> allocator(memory); + return allocator.template new_object<U>(std::forward<TArgs>(args)...); +} + +template <class U, class... TArgs> +U *NewRawMgpObject(mgp_memory *memory, TArgs &&...args) { + return NewRawMgpObject<U, TArgs...>(memory->impl, std::forward<TArgs>(args)...); +} + +// Assume that deallocation and object destruction never throws. If it does, +// we are in big trouble. +template <class T> +void DeleteRawMgpObject(T *ptr) noexcept { + try { + if (!ptr) return; + memgraph::utils::Allocator<T> allocator(ptr->GetMemoryResource()); + allocator.delete_object(ptr); + } catch (...) { + LOG_FATAL("Cannot deallocate mgp object"); + } +} + +template <class U, class... TArgs> +MgpUniquePtr<U> NewMgpObject(mgp_memory *memory, TArgs &&...args) { + return MgpUniquePtr<U>(NewRawMgpObject<U>(memory->impl, std::forward<TArgs>(args)...), &DeleteRawMgpObject<U>); +} + +mgp_value_type FromTypedValueType(memgraph::query::v2::TypedValue::Type type) { + switch (type) { + case memgraph::query::v2::TypedValue::Type::Null: + return MGP_VALUE_TYPE_NULL; + case memgraph::query::v2::TypedValue::Type::Bool: + return MGP_VALUE_TYPE_BOOL; + case memgraph::query::v2::TypedValue::Type::Int: + return MGP_VALUE_TYPE_INT; + case memgraph::query::v2::TypedValue::Type::Double: + return MGP_VALUE_TYPE_DOUBLE; + case memgraph::query::v2::TypedValue::Type::String: + return MGP_VALUE_TYPE_STRING; + case memgraph::query::v2::TypedValue::Type::List: + return MGP_VALUE_TYPE_LIST; + case memgraph::query::v2::TypedValue::Type::Map: + return MGP_VALUE_TYPE_MAP; + case memgraph::query::v2::TypedValue::Type::Vertex: + return MGP_VALUE_TYPE_VERTEX; + case memgraph::query::v2::TypedValue::Type::Edge: + return MGP_VALUE_TYPE_EDGE; + case memgraph::query::v2::TypedValue::Type::Path: + return MGP_VALUE_TYPE_PATH; + case memgraph::query::v2::TypedValue::Type::Date: + return MGP_VALUE_TYPE_DATE; + case memgraph::query::v2::TypedValue::Type::LocalTime: + return MGP_VALUE_TYPE_LOCAL_TIME; + case memgraph::query::v2::TypedValue::Type::LocalDateTime: + return MGP_VALUE_TYPE_LOCAL_DATE_TIME; + case memgraph::query::v2::TypedValue::Type::Duration: + return MGP_VALUE_TYPE_DURATION; + } +} +} // namespace + +memgraph::query::v2::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory) { + switch (val.type) { + case MGP_VALUE_TYPE_NULL: + return memgraph::query::v2::TypedValue(memory); + case MGP_VALUE_TYPE_BOOL: + return memgraph::query::v2::TypedValue(val.bool_v, memory); + case MGP_VALUE_TYPE_INT: + return memgraph::query::v2::TypedValue(val.int_v, memory); + case MGP_VALUE_TYPE_DOUBLE: + return memgraph::query::v2::TypedValue(val.double_v, memory); + case MGP_VALUE_TYPE_STRING: + return {val.string_v, memory}; + case MGP_VALUE_TYPE_LIST: { + const auto *list = val.list_v; + memgraph::query::v2::TypedValue::TVector tv_list(memory); + tv_list.reserve(list->elems.size()); + for (const auto &elem : list->elems) { + tv_list.emplace_back(ToTypedValue(elem, memory)); + } + return memgraph::query::v2::TypedValue(std::move(tv_list)); + } + case MGP_VALUE_TYPE_MAP: { + const auto *map = val.map_v; + memgraph::query::v2::TypedValue::TMap tv_map(memory); + for (const auto &item : map->items) { + tv_map.emplace(item.first, ToTypedValue(item.second, memory)); + } + return memgraph::query::v2::TypedValue(std::move(tv_map)); + } + case MGP_VALUE_TYPE_VERTEX: + return memgraph::query::v2::TypedValue(val.vertex_v->impl, memory); + case MGP_VALUE_TYPE_EDGE: + return memgraph::query::v2::TypedValue(val.edge_v->impl, memory); + case MGP_VALUE_TYPE_PATH: { + const auto *path = val.path_v; + MG_ASSERT(!path->vertices.empty()); + MG_ASSERT(path->vertices.size() == path->edges.size() + 1); + memgraph::query::v2::Path tv_path(path->vertices[0].impl, memory); + for (size_t i = 0; i < path->edges.size(); ++i) { + tv_path.Expand(path->edges[i].impl); + tv_path.Expand(path->vertices[i + 1].impl); + } + return memgraph::query::v2::TypedValue(std::move(tv_path)); + } + case MGP_VALUE_TYPE_DATE: + return memgraph::query::v2::TypedValue(val.date_v->date, memory); + case MGP_VALUE_TYPE_LOCAL_TIME: + return memgraph::query::v2::TypedValue(val.local_time_v->local_time, memory); + case MGP_VALUE_TYPE_LOCAL_DATE_TIME: + return memgraph::query::v2::TypedValue(val.local_date_time_v->local_date_time, memory); + case MGP_VALUE_TYPE_DURATION: + return memgraph::query::v2::TypedValue(val.duration_v->duration, memory); + } +} + +mgp_value::mgp_value(memgraph::utils::MemoryResource *m) noexcept : type(MGP_VALUE_TYPE_NULL), memory(m) {} + +mgp_value::mgp_value(bool val, memgraph::utils::MemoryResource *m) noexcept + : type(MGP_VALUE_TYPE_BOOL), memory(m), bool_v(val) {} + +mgp_value::mgp_value(int64_t val, memgraph::utils::MemoryResource *m) noexcept + : type(MGP_VALUE_TYPE_INT), memory(m), int_v(val) {} + +mgp_value::mgp_value(double val, memgraph::utils::MemoryResource *m) noexcept + : type(MGP_VALUE_TYPE_DOUBLE), memory(m), double_v(val) {} + +mgp_value::mgp_value(const char *val, memgraph::utils::MemoryResource *m) + : type(MGP_VALUE_TYPE_STRING), memory(m), string_v(val, m) {} + +mgp_value::mgp_value(mgp_list *val, memgraph::utils::MemoryResource *m) noexcept + : type(MGP_VALUE_TYPE_LIST), memory(m), list_v(val) { + MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator."); +} + +mgp_value::mgp_value(mgp_map *val, memgraph::utils::MemoryResource *m) noexcept + : type(MGP_VALUE_TYPE_MAP), memory(m), map_v(val) { + MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator."); +} + +mgp_value::mgp_value(mgp_vertex *val, memgraph::utils::MemoryResource *m) noexcept + : type(MGP_VALUE_TYPE_VERTEX), memory(m), vertex_v(val) { + MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator."); +} + +mgp_value::mgp_value(mgp_edge *val, memgraph::utils::MemoryResource *m) noexcept + : type(MGP_VALUE_TYPE_EDGE), memory(m), edge_v(val) { + MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator."); +} + +mgp_value::mgp_value(mgp_path *val, memgraph::utils::MemoryResource *m) noexcept + : type(MGP_VALUE_TYPE_PATH), memory(m), path_v(val) { + MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator."); +} + +mgp_value::mgp_value(mgp_date *val, memgraph::utils::MemoryResource *m) noexcept + : type(MGP_VALUE_TYPE_DATE), memory(m), date_v(val) { + MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator."); +} + +mgp_value::mgp_value(mgp_local_time *val, memgraph::utils::MemoryResource *m) noexcept + : type(MGP_VALUE_TYPE_LOCAL_TIME), memory(m), local_time_v(val) { + MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator."); +} + +mgp_value::mgp_value(mgp_local_date_time *val, memgraph::utils::MemoryResource *m) noexcept + : type(MGP_VALUE_TYPE_LOCAL_DATE_TIME), memory(m), local_date_time_v(val) { + MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator."); +} + +mgp_value::mgp_value(mgp_duration *val, memgraph::utils::MemoryResource *m) noexcept + : type(MGP_VALUE_TYPE_DURATION), memory(m), duration_v(val) { + MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator."); +} + +mgp_value::mgp_value(const memgraph::query::v2::TypedValue &tv, mgp_graph *graph, memgraph::utils::MemoryResource *m) + : type(FromTypedValueType(tv.type())), memory(m) { + switch (type) { + case MGP_VALUE_TYPE_NULL: + break; + case MGP_VALUE_TYPE_BOOL: + bool_v = tv.ValueBool(); + break; + case MGP_VALUE_TYPE_INT: + int_v = tv.ValueInt(); + break; + case MGP_VALUE_TYPE_DOUBLE: + double_v = tv.ValueDouble(); + break; + case MGP_VALUE_TYPE_STRING: + new (&string_v) memgraph::utils::pmr::string(tv.ValueString(), m); + break; + case MGP_VALUE_TYPE_LIST: { + // Fill the stack allocated container and then construct the actual member + // value. This handles the case when filling the container throws + // something and our destructor doesn't get called so member value isn't + // released. + memgraph::utils::pmr::vector<mgp_value> elems(m); + elems.reserve(tv.ValueList().size()); + for (const auto &elem : tv.ValueList()) { + elems.emplace_back(elem, graph); + } + memgraph::utils::Allocator<mgp_list> allocator(m); + list_v = allocator.new_object<mgp_list>(std::move(elems)); + break; + } + case MGP_VALUE_TYPE_MAP: { + // Fill the stack allocated container and then construct the actual member + // value. This handles the case when filling the container throws + // something and our destructor doesn't get called so member value isn't + // released. + memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_value> items(m); + for (const auto &item : tv.ValueMap()) { + items.emplace(item.first, mgp_value(item.second, graph, m)); + } + memgraph::utils::Allocator<mgp_map> allocator(m); + map_v = allocator.new_object<mgp_map>(std::move(items)); + break; + } + case MGP_VALUE_TYPE_VERTEX: { + memgraph::utils::Allocator<mgp_vertex> allocator(m); + vertex_v = allocator.new_object<mgp_vertex>(tv.ValueVertex(), graph); + break; + } + case MGP_VALUE_TYPE_EDGE: { + memgraph::utils::Allocator<mgp_edge> allocator(m); + edge_v = allocator.new_object<mgp_edge>(tv.ValueEdge(), graph); + break; + } + case MGP_VALUE_TYPE_PATH: { + // Fill the stack allocated container and then construct the actual member + // value. This handles the case when filling the container throws + // something and our destructor doesn't get called so member value isn't + // released. + mgp_path tmp_path(m); + tmp_path.vertices.reserve(tv.ValuePath().vertices().size()); + for (const auto &v : tv.ValuePath().vertices()) { + tmp_path.vertices.emplace_back(v, graph); + } + tmp_path.edges.reserve(tv.ValuePath().edges().size()); + for (const auto &e : tv.ValuePath().edges()) { + tmp_path.edges.emplace_back(e, graph); + } + memgraph::utils::Allocator<mgp_path> allocator(m); + path_v = allocator.new_object<mgp_path>(std::move(tmp_path)); + break; + } + case MGP_VALUE_TYPE_DATE: { + memgraph::utils::Allocator<mgp_date> allocator(m); + date_v = allocator.new_object<mgp_date>(tv.ValueDate()); + break; + } + case MGP_VALUE_TYPE_LOCAL_TIME: { + memgraph::utils::Allocator<mgp_local_time> allocator(m); + local_time_v = allocator.new_object<mgp_local_time>(tv.ValueLocalTime()); + break; + } + case MGP_VALUE_TYPE_LOCAL_DATE_TIME: { + memgraph::utils::Allocator<mgp_local_date_time> allocator(m); + local_date_time_v = allocator.new_object<mgp_local_date_time>(tv.ValueLocalDateTime()); + break; + } + case MGP_VALUE_TYPE_DURATION: { + memgraph::utils::Allocator<mgp_duration> allocator(m); + duration_v = allocator.new_object<mgp_duration>(tv.ValueDuration()); + break; + } + } +} + +mgp_value::mgp_value(const memgraph::storage::v3::PropertyValue &pv, memgraph::utils::MemoryResource *m) : memory(m) { + switch (pv.type()) { + case memgraph::storage::v3::PropertyValue::Type::Null: + type = MGP_VALUE_TYPE_NULL; + break; + case memgraph::storage::v3::PropertyValue::Type::Bool: + type = MGP_VALUE_TYPE_BOOL; + bool_v = pv.ValueBool(); + break; + case memgraph::storage::v3::PropertyValue::Type::Int: + type = MGP_VALUE_TYPE_INT; + int_v = pv.ValueInt(); + break; + case memgraph::storage::v3::PropertyValue::Type::Double: + type = MGP_VALUE_TYPE_DOUBLE; + double_v = pv.ValueDouble(); + break; + case memgraph::storage::v3::PropertyValue::Type::String: + type = MGP_VALUE_TYPE_STRING; + new (&string_v) memgraph::utils::pmr::string(pv.ValueString(), m); + break; + case memgraph::storage::v3::PropertyValue::Type::List: { + // Fill the stack allocated container and then construct the actual member + // value. This handles the case when filling the container throws + // something and our destructor doesn't get called so member value isn't + // released. + type = MGP_VALUE_TYPE_LIST; + memgraph::utils::pmr::vector<mgp_value> elems(m); + elems.reserve(pv.ValueList().size()); + for (const auto &elem : pv.ValueList()) { + elems.emplace_back(elem); + } + memgraph::utils::Allocator<mgp_list> allocator(m); + list_v = allocator.new_object<mgp_list>(std::move(elems)); + break; + } + case memgraph::storage::v3::PropertyValue::Type::Map: { + // Fill the stack allocated container and then construct the actual member + // value. This handles the case when filling the container throws + // something and our destructor doesn't get called so member value isn't + // released. + type = MGP_VALUE_TYPE_MAP; + memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_value> items(m); + for (const auto &item : pv.ValueMap()) { + items.emplace(item.first, item.second); + } + memgraph::utils::Allocator<mgp_map> allocator(m); + map_v = allocator.new_object<mgp_map>(std::move(items)); + break; + } + case memgraph::storage::v3::PropertyValue::Type::TemporalData: { + const auto &temporal_data = pv.ValueTemporalData(); + switch (temporal_data.type) { + case memgraph::storage::v3::TemporalType::Date: { + type = MGP_VALUE_TYPE_DATE; + date_v = NewRawMgpObject<mgp_date>(m, temporal_data.microseconds); + break; + } + case memgraph::storage::v3::TemporalType::LocalTime: { + type = MGP_VALUE_TYPE_LOCAL_TIME; + local_time_v = NewRawMgpObject<mgp_local_time>(m, temporal_data.microseconds); + break; + } + case memgraph::storage::v3::TemporalType::LocalDateTime: { + type = MGP_VALUE_TYPE_LOCAL_DATE_TIME; + local_date_time_v = NewRawMgpObject<mgp_local_date_time>(m, temporal_data.microseconds); + break; + } + case memgraph::storage::v3::TemporalType::Duration: { + type = MGP_VALUE_TYPE_DURATION; + duration_v = NewRawMgpObject<mgp_duration>(m, temporal_data.microseconds); + break; + } + } + } + } +} + +mgp_value::mgp_value(const mgp_value &other, memgraph::utils::MemoryResource *m) : type(other.type), memory(m) { + switch (other.type) { + case MGP_VALUE_TYPE_NULL: + break; + case MGP_VALUE_TYPE_BOOL: + bool_v = other.bool_v; + break; + case MGP_VALUE_TYPE_INT: + int_v = other.int_v; + break; + case MGP_VALUE_TYPE_DOUBLE: + double_v = other.double_v; + break; + case MGP_VALUE_TYPE_STRING: + new (&string_v) memgraph::utils::pmr::string(other.string_v, m); + break; + case MGP_VALUE_TYPE_LIST: { + memgraph::utils::Allocator<mgp_list> allocator(m); + list_v = allocator.new_object<mgp_list>(*other.list_v); + break; + } + case MGP_VALUE_TYPE_MAP: { + memgraph::utils::Allocator<mgp_map> allocator(m); + map_v = allocator.new_object<mgp_map>(*other.map_v); + break; + } + case MGP_VALUE_TYPE_VERTEX: { + memgraph::utils::Allocator<mgp_vertex> allocator(m); + vertex_v = allocator.new_object<mgp_vertex>(*other.vertex_v); + break; + } + case MGP_VALUE_TYPE_EDGE: { + memgraph::utils::Allocator<mgp_edge> allocator(m); + edge_v = allocator.new_object<mgp_edge>(*other.edge_v); + break; + } + case MGP_VALUE_TYPE_PATH: { + memgraph::utils::Allocator<mgp_path> allocator(m); + path_v = allocator.new_object<mgp_path>(*other.path_v); + break; + } + case MGP_VALUE_TYPE_DATE: { + date_v = NewRawMgpObject<mgp_date>(m, *other.date_v); + break; + } + case MGP_VALUE_TYPE_LOCAL_TIME: { + local_time_v = NewRawMgpObject<mgp_local_time>(m, *other.local_time_v); + break; + } + case MGP_VALUE_TYPE_LOCAL_DATE_TIME: { + local_date_time_v = NewRawMgpObject<mgp_local_date_time>(m, *other.local_date_time_v); + break; + } + case MGP_VALUE_TYPE_DURATION: { + duration_v = NewRawMgpObject<mgp_duration>(m, *other.duration_v); + break; + } + } +} + +namespace { + +void DeleteValueMember(mgp_value *value) noexcept { + MG_ASSERT(value); + memgraph::utils::Allocator<mgp_value> allocator(value->GetMemoryResource()); + switch (Call<mgp_value_type>(mgp_value_get_type, value)) { + case MGP_VALUE_TYPE_NULL: + case MGP_VALUE_TYPE_BOOL: + case MGP_VALUE_TYPE_INT: + case MGP_VALUE_TYPE_DOUBLE: + return; + case MGP_VALUE_TYPE_STRING: + using TString = memgraph::utils::pmr::string; + value->string_v.~TString(); + return; + case MGP_VALUE_TYPE_LIST: + allocator.delete_object(value->list_v); + return; + case MGP_VALUE_TYPE_MAP: + allocator.delete_object(value->map_v); + return; + case MGP_VALUE_TYPE_VERTEX: + allocator.delete_object(value->vertex_v); + return; + case MGP_VALUE_TYPE_EDGE: + allocator.delete_object(value->edge_v); + return; + case MGP_VALUE_TYPE_PATH: + allocator.delete_object(value->path_v); + return; + case MGP_VALUE_TYPE_DATE: + allocator.delete_object(value->date_v); + return; + case MGP_VALUE_TYPE_LOCAL_TIME: + allocator.delete_object(value->local_time_v); + return; + case MGP_VALUE_TYPE_LOCAL_DATE_TIME: + allocator.delete_object(value->local_date_time_v); + return; + case MGP_VALUE_TYPE_DURATION: + allocator.delete_object(value->duration_v); + return; + } +} + +} // namespace + +mgp_value::mgp_value(mgp_value &&other, memgraph::utils::MemoryResource *m) : type(other.type), memory(m) { + switch (other.type) { + case MGP_VALUE_TYPE_NULL: + break; + case MGP_VALUE_TYPE_BOOL: + bool_v = other.bool_v; + break; + case MGP_VALUE_TYPE_INT: + int_v = other.int_v; + break; + case MGP_VALUE_TYPE_DOUBLE: + double_v = other.double_v; + break; + case MGP_VALUE_TYPE_STRING: + new (&string_v) memgraph::utils::pmr::string(std::move(other.string_v), m); + break; + case MGP_VALUE_TYPE_LIST: + static_assert(std::is_pointer_v<decltype(list_v)>, "Expected to move list_v by copying pointers."); + if (*other.GetMemoryResource() == *m) { + list_v = other.list_v; + other.type = MGP_VALUE_TYPE_NULL; + } else { + memgraph::utils::Allocator<mgp_list> allocator(m); + list_v = allocator.new_object<mgp_list>(std::move(*other.list_v)); + } + break; + case MGP_VALUE_TYPE_MAP: + static_assert(std::is_pointer_v<decltype(map_v)>, "Expected to move map_v by copying pointers."); + if (*other.GetMemoryResource() == *m) { + map_v = other.map_v; + other.type = MGP_VALUE_TYPE_NULL; + } else { + memgraph::utils::Allocator<mgp_map> allocator(m); + map_v = allocator.new_object<mgp_map>(std::move(*other.map_v)); + } + break; + case MGP_VALUE_TYPE_VERTEX: + static_assert(std::is_pointer_v<decltype(vertex_v)>, "Expected to move vertex_v by copying pointers."); + if (*other.GetMemoryResource() == *m) { + vertex_v = other.vertex_v; + other.type = MGP_VALUE_TYPE_NULL; + } else { + memgraph::utils::Allocator<mgp_vertex> allocator(m); + vertex_v = allocator.new_object<mgp_vertex>(std::move(*other.vertex_v)); + } + break; + case MGP_VALUE_TYPE_EDGE: + static_assert(std::is_pointer_v<decltype(edge_v)>, "Expected to move edge_v by copying pointers."); + if (*other.GetMemoryResource() == *m) { + edge_v = other.edge_v; + other.type = MGP_VALUE_TYPE_NULL; + } else { + memgraph::utils::Allocator<mgp_edge> allocator(m); + edge_v = allocator.new_object<mgp_edge>(std::move(*other.edge_v)); + } + break; + case MGP_VALUE_TYPE_PATH: + static_assert(std::is_pointer_v<decltype(path_v)>, "Expected to move path_v by copying pointers."); + if (*other.GetMemoryResource() == *m) { + path_v = other.path_v; + other.type = MGP_VALUE_TYPE_NULL; + } else { + memgraph::utils::Allocator<mgp_path> allocator(m); + path_v = allocator.new_object<mgp_path>(std::move(*other.path_v)); + } + break; + case MGP_VALUE_TYPE_DATE: + static_assert(std::is_pointer_v<decltype(date_v)>, "Expected to move date_v by copying pointers."); + if (*other.GetMemoryResource() == *m) { + date_v = other.date_v; + other.type = MGP_VALUE_TYPE_NULL; + } else { + date_v = NewRawMgpObject<mgp_date>(m, *other.date_v); + } + break; + case MGP_VALUE_TYPE_LOCAL_TIME: + static_assert(std::is_pointer_v<decltype(local_time_v)>, "Expected to move local_time_v by copying pointers."); + if (*other.GetMemoryResource() == *m) { + local_time_v = other.local_time_v; + other.type = MGP_VALUE_TYPE_NULL; + } else { + local_time_v = NewRawMgpObject<mgp_local_time>(m, *other.local_time_v); + } + break; + case MGP_VALUE_TYPE_LOCAL_DATE_TIME: + static_assert(std::is_pointer_v<decltype(local_date_time_v)>, + "Expected to move local_date_time_v by copying pointers."); + if (*other.GetMemoryResource() == *m) { + local_date_time_v = other.local_date_time_v; + other.type = MGP_VALUE_TYPE_NULL; + } else { + local_date_time_v = NewRawMgpObject<mgp_local_date_time>(m, *other.local_date_time_v); + } + break; + case MGP_VALUE_TYPE_DURATION: + static_assert(std::is_pointer_v<decltype(duration_v)>, "Expected to move duration_v by copying pointers."); + if (*other.GetMemoryResource() == *m) { + duration_v = other.duration_v; + other.type = MGP_VALUE_TYPE_NULL; + } else { + duration_v = NewRawMgpObject<mgp_duration>(m, *other.duration_v); + } + break; + } + DeleteValueMember(&other); + other.type = MGP_VALUE_TYPE_NULL; +} + +mgp_value::~mgp_value() noexcept { DeleteValueMember(this); } + +mgp_edge *mgp_edge::Copy(const mgp_edge &edge, mgp_memory &memory) { + return NewRawMgpObject<mgp_edge>(&memory, edge.impl, edge.from.graph); +} + +void mgp_value_destroy(mgp_value *val) { DeleteRawMgpObject(val); } + +mgp_error mgp_value_make_null(mgp_memory *memory, mgp_value **result) { + return WrapExceptions([memory] { return NewRawMgpObject<mgp_value>(memory); }, result); +} + +mgp_error mgp_value_make_bool(int val, mgp_memory *memory, mgp_value **result) { + return WrapExceptions([val, memory] { return NewRawMgpObject<mgp_value>(memory, val != 0); }, result); +} + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define DEFINE_MGP_VALUE_MAKE_WITH_MEMORY(type, param) \ + mgp_error mgp_value_make_##type(param val, mgp_memory *memory, mgp_value **result) { \ + return WrapExceptions([val, memory] { return NewRawMgpObject<mgp_value>(memory, val); }, result); \ + } + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +DEFINE_MGP_VALUE_MAKE_WITH_MEMORY(int, int64_t); +DEFINE_MGP_VALUE_MAKE_WITH_MEMORY(double, double); +DEFINE_MGP_VALUE_MAKE_WITH_MEMORY(string, const char *); + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define DEFINE_MGP_VALUE_MAKE(type) \ + mgp_error mgp_value_make_##type(mgp_##type *val, mgp_value **result) { \ + return WrapExceptions([val] { return NewRawMgpObject<mgp_value>(val->GetMemoryResource(), val); }, result); \ + } + +DEFINE_MGP_VALUE_MAKE(list) +DEFINE_MGP_VALUE_MAKE(map) +DEFINE_MGP_VALUE_MAKE(vertex) +DEFINE_MGP_VALUE_MAKE(edge) +DEFINE_MGP_VALUE_MAKE(path) +DEFINE_MGP_VALUE_MAKE(date) +DEFINE_MGP_VALUE_MAKE(local_time) +DEFINE_MGP_VALUE_MAKE(local_date_time) +DEFINE_MGP_VALUE_MAKE(duration) + +namespace { +mgp_value_type MgpValueGetType(const mgp_value &val) noexcept { return val.type; } +} // namespace + +mgp_error mgp_value_get_type(mgp_value *val, mgp_value_type *result) { + static_assert(noexcept(MgpValueGetType(*val))); + *result = MgpValueGetType(*val); + return mgp_error::MGP_ERROR_NO_ERROR; +} + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define DEFINE_MGP_VALUE_IS(type_lowercase, type_uppercase) \ + mgp_error mgp_value_is_##type_lowercase(mgp_value *val, int *result) { \ + static_assert(noexcept(MgpValueGetType(*val))); \ + *result = MgpValueGetType(*val) == MGP_VALUE_TYPE_##type_uppercase; \ + return mgp_error::MGP_ERROR_NO_ERROR; \ + } + +DEFINE_MGP_VALUE_IS(null, NULL) +DEFINE_MGP_VALUE_IS(bool, BOOL) +DEFINE_MGP_VALUE_IS(int, INT) +DEFINE_MGP_VALUE_IS(double, DOUBLE) +DEFINE_MGP_VALUE_IS(string, STRING) +DEFINE_MGP_VALUE_IS(list, LIST) +DEFINE_MGP_VALUE_IS(map, MAP) +DEFINE_MGP_VALUE_IS(vertex, VERTEX) +DEFINE_MGP_VALUE_IS(edge, EDGE) +DEFINE_MGP_VALUE_IS(path, PATH) +DEFINE_MGP_VALUE_IS(date, DATE) +DEFINE_MGP_VALUE_IS(local_time, LOCAL_TIME) +DEFINE_MGP_VALUE_IS(local_date_time, LOCAL_DATE_TIME) +DEFINE_MGP_VALUE_IS(duration, DURATION) + +mgp_error mgp_value_get_bool(mgp_value *val, int *result) { + *result = val->bool_v ? 1 : 0; + return mgp_error::MGP_ERROR_NO_ERROR; +} +mgp_error mgp_value_get_int(mgp_value *val, int64_t *result) { + *result = val->int_v; + return mgp_error::MGP_ERROR_NO_ERROR; +} +mgp_error mgp_value_get_double(mgp_value *val, double *result) { + *result = val->double_v; + return mgp_error::MGP_ERROR_NO_ERROR; +} +mgp_error mgp_value_get_string(mgp_value *val, const char **result) { + static_assert(noexcept(val->string_v.c_str())); + *result = val->string_v.c_str(); + return mgp_error::MGP_ERROR_NO_ERROR; +} + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define DEFINE_MGP_VALUE_GET(type) \ + mgp_error mgp_value_get_##type(mgp_value *val, mgp_##type **result) { \ + *result = val->type##_v; \ + return mgp_error::MGP_ERROR_NO_ERROR; \ + } + +DEFINE_MGP_VALUE_GET(list) +DEFINE_MGP_VALUE_GET(map) +DEFINE_MGP_VALUE_GET(vertex) +DEFINE_MGP_VALUE_GET(edge) +DEFINE_MGP_VALUE_GET(path) +DEFINE_MGP_VALUE_GET(date) +DEFINE_MGP_VALUE_GET(local_time) +DEFINE_MGP_VALUE_GET(local_date_time) +DEFINE_MGP_VALUE_GET(duration) + +mgp_error mgp_list_make_empty(size_t capacity, mgp_memory *memory, mgp_list **result) { + return WrapExceptions( + [capacity, memory] { + auto list = NewMgpObject<mgp_list>(memory); + list->elems.reserve(capacity); + return list.release(); + }, + result); +} + +void mgp_list_destroy(mgp_list *list) { DeleteRawMgpObject(list); } + +namespace { +void MgpListAppendExtend(mgp_list &list, const mgp_value &value) { list.elems.push_back(value); } +} // namespace + +mgp_error mgp_list_append(mgp_list *list, mgp_value *val) { + return WrapExceptions([list, val] { + if (Call<size_t>(mgp_list_size, list) >= Call<size_t>(mgp_list_capacity, list)) { + throw InsufficientBufferException{ + "Cannot append a new value to the mgp_list without extending it, because its size reached its capacity!"}; + } + MgpListAppendExtend(*list, *val); + }); +} + +mgp_error mgp_list_append_extend(mgp_list *list, mgp_value *val) { + return WrapExceptions([list, val] { MgpListAppendExtend(*list, *val); }); +} + +mgp_error mgp_list_size(mgp_list *list, size_t *result) { + static_assert(noexcept(list->elems.size())); + *result = list->elems.size(); + return mgp_error::MGP_ERROR_NO_ERROR; +} + +mgp_error mgp_list_capacity(mgp_list *list, size_t *result) { + static_assert(noexcept(list->elems.capacity())); + *result = list->elems.capacity(); + return mgp_error::MGP_ERROR_NO_ERROR; +} + +mgp_error mgp_list_at(mgp_list *list, size_t i, mgp_value **result) { + return WrapExceptions( + [list, i] { + if (i >= Call<size_t>(mgp_list_size, list)) { + throw std::out_of_range("Element cannot be retrieved, because index exceeds list's size!"); + } + return &list->elems[i]; + }, + result); +} + +mgp_error mgp_map_make_empty(mgp_memory *memory, mgp_map **result) { + return WrapExceptions([&memory] { return NewRawMgpObject<mgp_map>(memory); }, result); +} + +void mgp_map_destroy(mgp_map *map) { DeleteRawMgpObject(map); } + +mgp_error mgp_map_insert(mgp_map *map, const char *key, mgp_value *value) { + return WrapExceptions([&] { + auto emplace_result = map->items.emplace(key, *value); + if (!emplace_result.second) { + throw KeyAlreadyExistsException{"Map already contains mapping for {}", key}; + } + }); +} + +mgp_error mgp_map_size(mgp_map *map, size_t *result) { + static_assert(noexcept(map->items.size())); + *result = map->items.size(); + return mgp_error::MGP_ERROR_NO_ERROR; +} + +mgp_error mgp_map_at(mgp_map *map, const char *key, mgp_value **result) { + return WrapExceptions( + [&map, &key]() -> mgp_value * { + auto found_it = map->items.find(key); + if (found_it == map->items.end()) { + return nullptr; + }; + return &found_it->second; + }, + result); +} + +mgp_error mgp_map_item_key(mgp_map_item *item, const char **result) { + return WrapExceptions([&item] { return item->key; }, result); +} + +mgp_error mgp_map_item_value(mgp_map_item *item, mgp_value **result) { + return WrapExceptions([item] { return item->value; }, result); +} + +mgp_error mgp_map_iter_items(mgp_map *map, mgp_memory *memory, mgp_map_items_iterator **result) { + return WrapExceptions([map, memory] { return NewRawMgpObject<mgp_map_items_iterator>(memory, map); }, result); +} + +void mgp_map_items_iterator_destroy(mgp_map_items_iterator *it) { DeleteRawMgpObject(it); } + +mgp_error mgp_map_items_iterator_get(mgp_map_items_iterator *it, mgp_map_item **result) { + return WrapExceptions( + [it]() -> mgp_map_item * { + if (it->current_it == it->map->items.end()) { + return nullptr; + }; + return &it->current; + }, + result); +} + +mgp_error mgp_map_items_iterator_next(mgp_map_items_iterator *it, mgp_map_item **result) { + return WrapExceptions( + [it]() -> mgp_map_item * { + if (it->current_it == it->map->items.end()) { + return nullptr; + } + if (++it->current_it == it->map->items.end()) { + return nullptr; + } + it->current.key = it->current_it->first.c_str(); + it->current.value = &it->current_it->second; + return &it->current; + }, + result); +} + +mgp_error mgp_path_make_with_start(mgp_vertex *vertex, mgp_memory *memory, mgp_path **result) { + return WrapExceptions( + [vertex, memory]() -> mgp_path * { + auto path = NewMgpObject<mgp_path>(memory); + if (path == nullptr) { + return nullptr; + } + path->vertices.push_back(*vertex); + return path.release(); + }, + result); +} + +mgp_error mgp_path_copy(mgp_path *path, mgp_memory *memory, mgp_path **result) { + return WrapExceptions( + [path, memory] { + MG_ASSERT(Call<size_t>(mgp_path_size, path) == path->vertices.size() - 1, "Invalid mgp_path"); + return NewRawMgpObject<mgp_path>(memory, *path); + }, + result); +} + +void mgp_path_destroy(mgp_path *path) { DeleteRawMgpObject(path); } + +mgp_error mgp_path_expand(mgp_path *path, mgp_edge *edge) { + return WrapExceptions([path, edge] { + MG_ASSERT(Call<size_t>(mgp_path_size, path) == path->vertices.size() - 1, "Invalid mgp_path"); + // Check that the both the last vertex on path and dst_vertex are endpoints of + // the given edge. + auto *src_vertex = &path->vertices.back(); + mgp_vertex *dst_vertex{nullptr}; + if (edge->to == *src_vertex) { + dst_vertex = &edge->from; + } else if (edge->from == *src_vertex) { + dst_vertex = &edge->to; + } else { + // edge is not a continuation on src_vertex + throw std::logic_error{"The current last vertex in the path is not part of the given edge."}; + } + // Try appending edge and dst_vertex to path, preserving the original mgp_path + // instance if anything fails. + memgraph::utils::OnScopeExit scope_guard( + [path] { MG_ASSERT(Call<size_t>(mgp_path_size, path) == path->vertices.size() - 1); }); + + path->edges.push_back(*edge); + path->vertices.push_back(*dst_vertex); + }); +} + +namespace { +size_t MgpPathSize(const mgp_path &path) noexcept { return path.edges.size(); } +} // namespace + +mgp_error mgp_path_size(mgp_path *path, size_t *result) { + *result = MgpPathSize(*path); + return mgp_error::MGP_ERROR_NO_ERROR; +} + +mgp_error mgp_path_vertex_at(mgp_path *path, size_t i, mgp_vertex **result) { + return WrapExceptions( + [path, i] { + const auto path_size = Call<size_t>(mgp_path_size, path); + MG_ASSERT(path_size == path->vertices.size() - 1); + if (i > path_size) { + throw std::out_of_range("Vertex cannot be retrieved, because index exceeds path's size!"); + } + return &path->vertices[i]; + }, + result); +} + +mgp_error mgp_path_edge_at(mgp_path *path, size_t i, mgp_edge **result) { + return WrapExceptions( + [path, i] { + const auto path_size = Call<size_t>(mgp_path_size, path); + MG_ASSERT(path_size == path->vertices.size() - 1); + if (i > path_size) { + throw std::out_of_range("Edge cannot be retrieved, because index exceeds path's size!"); + } + return &path->edges[i]; + }, + result); +} + +mgp_error mgp_path_equal(mgp_path *p1, mgp_path *p2, int *result) { + return WrapExceptions( + [p1, p2] { + const auto p1_size = MgpPathSize(*p1); + const auto p2_size = MgpPathSize(*p2); + MG_ASSERT(p1_size == p1->vertices.size() - 1); + MG_ASSERT(p2_size == p2->vertices.size() - 1); + if (p1_size != p2_size) { + return 0; + } + const auto *start1 = Call<mgp_vertex *>(mgp_path_vertex_at, p1, 0); + const auto *start2 = Call<mgp_vertex *>(mgp_path_vertex_at, p2, 0); + static_assert(noexcept(start1->impl == start2->impl)); + if (*start1 != *start2) { + return 0; + } + for (size_t i = 0; i < p1_size; ++i) { + const auto *e1 = Call<mgp_edge *>(mgp_path_edge_at, p1, i); + const auto *e2 = Call<mgp_edge *>(mgp_path_edge_at, p2, i); + if (*e1 != *e2) { + return 0; + } + } + return 1; + }, + result); +} + +mgp_error mgp_date_from_string(const char *string, mgp_memory *memory, mgp_date **date) { + return WrapExceptions([string, memory] { return NewRawMgpObject<mgp_date>(memory, string); }, date); +} + +mgp_error mgp_date_from_parameters(mgp_date_parameters *parameters, mgp_memory *memory, mgp_date **date) { + return WrapExceptions([parameters, memory] { return NewRawMgpObject<mgp_date>(memory, parameters); }, date); +} + +mgp_error mgp_date_copy(mgp_date *date, mgp_memory *memory, mgp_date **result) { + return WrapExceptions([date, memory] { return NewRawMgpObject<mgp_date>(memory, *date); }, result); +} + +void mgp_date_destroy(mgp_date *date) { DeleteRawMgpObject(date); } + +mgp_error mgp_date_equal(mgp_date *first, mgp_date *second, int *result) { + return WrapExceptions([first, second] { return first->date == second->date; }, result); +} + +mgp_error mgp_date_get_year(mgp_date *date, int *year) { + return WrapExceptions([date] { return date->date.year; }, year); +} + +mgp_error mgp_date_get_month(mgp_date *date, int *month) { + return WrapExceptions([date] { return date->date.month; }, month); +} + +mgp_error mgp_date_get_day(mgp_date *date, int *day) { + return WrapExceptions([date] { return date->date.day; }, day); +} + +mgp_error mgp_date_timestamp(mgp_date *date, int64_t *timestamp) { + return WrapExceptions([date] { return date->date.MicrosecondsSinceEpoch(); }, timestamp); +} + +mgp_error mgp_date_now(mgp_memory *memory, mgp_date **date) { + return WrapExceptions([memory] { return NewRawMgpObject<mgp_date>(memory, memgraph::utils::CurrentDate()); }, date); +} + +mgp_error mgp_date_add_duration(mgp_date *date, mgp_duration *dur, mgp_memory *memory, mgp_date **result) { + return WrapExceptions([date, dur, memory] { return NewRawMgpObject<mgp_date>(memory, date->date + dur->duration); }, + result); +} + +mgp_error mgp_date_sub_duration(mgp_date *date, mgp_duration *dur, mgp_memory *memory, mgp_date **result) { + return WrapExceptions([date, dur, memory] { return NewRawMgpObject<mgp_date>(memory, date->date - dur->duration); }, + result); +} + +mgp_error mgp_date_diff(mgp_date *first, mgp_date *second, mgp_memory *memory, mgp_duration **result) { + return WrapExceptions( + [first, second, memory] { return NewRawMgpObject<mgp_duration>(memory, first->date - second->date); }, result); +} + +mgp_error mgp_local_time_from_string(const char *string, mgp_memory *memory, mgp_local_time **local_time) { + return WrapExceptions([string, memory] { return NewRawMgpObject<mgp_local_time>(memory, string); }, local_time); +} + +mgp_error mgp_local_time_from_parameters(mgp_local_time_parameters *parameters, mgp_memory *memory, + mgp_local_time **local_time) { + return WrapExceptions([parameters, memory] { return NewRawMgpObject<mgp_local_time>(memory, parameters); }, + local_time); +} + +mgp_error mgp_local_time_copy(mgp_local_time *local_time, mgp_memory *memory, mgp_local_time **result) { + return WrapExceptions([local_time, memory] { return NewRawMgpObject<mgp_local_time>(memory, *local_time); }, result); +} + +void mgp_local_time_destroy(mgp_local_time *local_time) { DeleteRawMgpObject(local_time); } + +mgp_error mgp_local_time_equal(mgp_local_time *first, mgp_local_time *second, int *result) { + return WrapExceptions([first, second] { return first->local_time == second->local_time; }, result); +} + +mgp_error mgp_local_time_get_hour(mgp_local_time *local_time, int *hour) { + return WrapExceptions([local_time] { return local_time->local_time.hour; }, hour); +} + +mgp_error mgp_local_time_get_minute(mgp_local_time *local_time, int *minute) { + return WrapExceptions([local_time] { return local_time->local_time.minute; }, minute); +} + +mgp_error mgp_local_time_get_second(mgp_local_time *local_time, int *second) { + return WrapExceptions([local_time] { return local_time->local_time.second; }, second); +} + +mgp_error mgp_local_time_get_millisecond(mgp_local_time *local_time, int *millisecond) { + return WrapExceptions([local_time] { return local_time->local_time.millisecond; }, millisecond); +} + +mgp_error mgp_local_time_get_microsecond(mgp_local_time *local_time, int *microsecond) { + return WrapExceptions([local_time] { return local_time->local_time.microsecond; }, microsecond); +} + +mgp_error mgp_local_time_timestamp(mgp_local_time *local_time, int64_t *timestamp) { + return WrapExceptions([local_time] { return local_time->local_time.MicrosecondsSinceEpoch(); }, timestamp); +} + +mgp_error mgp_local_time_now(mgp_memory *memory, mgp_local_time **local_time) { + return WrapExceptions( + [memory] { return NewRawMgpObject<mgp_local_time>(memory, memgraph::utils::CurrentLocalTime()); }, local_time); +} + +mgp_error mgp_local_time_add_duration(mgp_local_time *local_time, mgp_duration *dur, mgp_memory *memory, + mgp_local_time **result) { + return WrapExceptions( + [local_time, dur, memory] { + return NewRawMgpObject<mgp_local_time>(memory, local_time->local_time + dur->duration); + }, + result); +} + +mgp_error mgp_local_time_sub_duration(mgp_local_time *local_time, mgp_duration *dur, mgp_memory *memory, + mgp_local_time **result) { + return WrapExceptions( + [local_time, dur, memory] { + return NewRawMgpObject<mgp_local_time>(memory, local_time->local_time - dur->duration); + }, + result); +} + +mgp_error mgp_local_time_diff(mgp_local_time *first, mgp_local_time *second, mgp_memory *memory, + mgp_duration **result) { + return WrapExceptions( + [first, second, memory] { return NewRawMgpObject<mgp_duration>(memory, first->local_time - second->local_time); }, + result); +} + +mgp_error mgp_local_date_time_from_string(const char *string, mgp_memory *memory, + mgp_local_date_time **local_date_time) { + return WrapExceptions([string, memory] { return NewRawMgpObject<mgp_local_date_time>(memory, string); }, + local_date_time); +} + +mgp_error mgp_local_date_time_from_parameters(mgp_local_date_time_parameters *parameters, mgp_memory *memory, + mgp_local_date_time **local_date_time) { + return WrapExceptions([parameters, memory] { return NewRawMgpObject<mgp_local_date_time>(memory, parameters); }, + local_date_time); +} + +mgp_error mgp_local_date_time_copy(mgp_local_date_time *local_date_time, mgp_memory *memory, + mgp_local_date_time **result) { + return WrapExceptions( + [local_date_time, memory] { return NewRawMgpObject<mgp_local_date_time>(memory, *local_date_time); }, result); +} + +void mgp_local_date_time_destroy(mgp_local_date_time *local_date_time) { DeleteRawMgpObject(local_date_time); } + +mgp_error mgp_local_date_time_equal(mgp_local_date_time *first, mgp_local_date_time *second, int *result) { + return WrapExceptions([first, second] { return first->local_date_time == second->local_date_time; }, result); +} + +mgp_error mgp_local_date_time_get_year(mgp_local_date_time *local_date_time, int *year) { + return WrapExceptions([local_date_time] { return local_date_time->local_date_time.date.year; }, year); +} + +mgp_error mgp_local_date_time_get_month(mgp_local_date_time *local_date_time, int *month) { + return WrapExceptions([local_date_time] { return local_date_time->local_date_time.date.month; }, month); +} + +mgp_error mgp_local_date_time_get_day(mgp_local_date_time *local_date_time, int *day) { + return WrapExceptions([local_date_time] { return local_date_time->local_date_time.date.day; }, day); +} + +mgp_error mgp_local_date_time_get_hour(mgp_local_date_time *local_date_time, int *hour) { + return WrapExceptions([local_date_time] { return local_date_time->local_date_time.local_time.hour; }, hour); +} + +mgp_error mgp_local_date_time_get_minute(mgp_local_date_time *local_date_time, int *minute) { + return WrapExceptions([local_date_time] { return local_date_time->local_date_time.local_time.minute; }, minute); +} + +mgp_error mgp_local_date_time_get_second(mgp_local_date_time *local_date_time, int *second) { + return WrapExceptions([local_date_time] { return local_date_time->local_date_time.local_time.second; }, second); +} + +mgp_error mgp_local_date_time_get_millisecond(mgp_local_date_time *local_date_time, int *millisecond) { + return WrapExceptions([local_date_time] { return local_date_time->local_date_time.local_time.millisecond; }, + millisecond); +} + +mgp_error mgp_local_date_time_get_microsecond(mgp_local_date_time *local_date_time, int *microsecond) { + return WrapExceptions([local_date_time] { return local_date_time->local_date_time.local_time.microsecond; }, + microsecond); +} + +mgp_error mgp_local_date_time_timestamp(mgp_local_date_time *local_date_time, int64_t *timestamp) { + return WrapExceptions([local_date_time] { return local_date_time->local_date_time.MicrosecondsSinceEpoch(); }, + timestamp); +} + +mgp_error mgp_local_date_time_now(mgp_memory *memory, mgp_local_date_time **local_date_time) { + return WrapExceptions( + [memory] { return NewRawMgpObject<mgp_local_date_time>(memory, memgraph::utils::CurrentLocalDateTime()); }, + local_date_time); +} + +mgp_error mgp_local_date_time_add_duration(mgp_local_date_time *local_date_time, mgp_duration *dur, mgp_memory *memory, + mgp_local_date_time **result) { + return WrapExceptions( + [local_date_time, dur, memory] { + return NewRawMgpObject<mgp_local_date_time>(memory, local_date_time->local_date_time + dur->duration); + }, + result); +} + +mgp_error mgp_local_date_time_sub_duration(mgp_local_date_time *local_date_time, mgp_duration *dur, mgp_memory *memory, + mgp_local_date_time **result) { + return WrapExceptions( + [local_date_time, dur, memory] { + return NewRawMgpObject<mgp_local_date_time>(memory, local_date_time->local_date_time - dur->duration); + }, + result); +} + +mgp_error mgp_local_date_time_diff(mgp_local_date_time *first, mgp_local_date_time *second, mgp_memory *memory, + mgp_duration **result) { + return WrapExceptions( + [first, second, memory] { + return NewRawMgpObject<mgp_duration>(memory, first->local_date_time - second->local_date_time); + }, + result); +} + +mgp_error mgp_duration_from_string(const char *string, mgp_memory *memory, mgp_duration **duration) { + return WrapExceptions([memory, string] { return NewRawMgpObject<mgp_duration>(memory, string); }, duration); +} + +mgp_error mgp_duration_from_parameters(mgp_duration_parameters *parameters, mgp_memory *memory, + mgp_duration **duration) { + return WrapExceptions([memory, parameters] { return NewRawMgpObject<mgp_duration>(memory, parameters); }, duration); +} + +mgp_error mgp_duration_from_microseconds(int64_t microseconds, mgp_memory *memory, mgp_duration **duration) { + return WrapExceptions([microseconds, memory] { return NewRawMgpObject<mgp_duration>(memory, microseconds); }, + duration); +} + +mgp_error mgp_duration_copy(mgp_duration *duration, mgp_memory *memory, mgp_duration **result) { + return WrapExceptions([duration, memory] { return NewRawMgpObject<mgp_duration>(memory, *duration); }, result); +} + +void mgp_duration_destroy(mgp_duration *duration) { DeleteRawMgpObject(duration); } + +mgp_error mgp_duration_get_microseconds(mgp_duration *duration, int64_t *microseconds) { + return WrapExceptions([duration] { return duration->duration.microseconds; }, microseconds); +} + +mgp_error mgp_duration_equal(mgp_duration *first, mgp_duration *second, int *result) { + return WrapExceptions([first, second] { return first->duration == second->duration; }, result); +} + +mgp_error mgp_duration_neg(mgp_duration *dur, mgp_memory *memory, mgp_duration **result) { + return WrapExceptions([memory, dur] { return NewRawMgpObject<mgp_duration>(memory, -dur->duration); }, result); +} + +mgp_error mgp_duration_add(mgp_duration *first, mgp_duration *second, mgp_memory *memory, mgp_duration **result) { + return WrapExceptions( + [memory, first, second] { return NewRawMgpObject<mgp_duration>(memory, first->duration + second->duration); }, + result); +} + +mgp_error mgp_duration_sub(mgp_duration *first, mgp_duration *second, mgp_memory *memory, mgp_duration **result) { + return WrapExceptions( + [memory, first, second] { return NewRawMgpObject<mgp_duration>(memory, first->duration - second->duration); }, + result); +} + +/// Plugin Result + +mgp_error mgp_result_set_error_msg(mgp_result *res, const char *msg) { + return WrapExceptions([=] { + auto *memory = res->rows.get_allocator().GetMemoryResource(); + res->error_msg.emplace(msg, memory); + }); +} + +mgp_error mgp_result_new_record(mgp_result *res, mgp_result_record **result) { + return WrapExceptions( + [res] { + auto *memory = res->rows.get_allocator().GetMemoryResource(); + MG_ASSERT(res->signature, "Expected to have a valid signature"); + res->rows.push_back(mgp_result_record{ + res->signature, + memgraph::utils::pmr::map<memgraph::utils::pmr::string, memgraph::query::v2::TypedValue>(memory)}); + return &res->rows.back(); + }, + result); +} + +mgp_error mgp_result_record_insert(mgp_result_record *record, const char *field_name, mgp_value *val) { + return WrapExceptions([=] { + auto *memory = record->values.get_allocator().GetMemoryResource(); + // Validate field_name & val satisfy the procedure's result signature. + MG_ASSERT(record->signature, "Expected to have a valid signature"); + auto find_it = record->signature->find(field_name); + if (find_it == record->signature->end()) { + throw std::out_of_range{fmt::format("The result doesn't have any field named '{}'.", field_name)}; + } + const auto *type = find_it->second.first; + if (!type->SatisfiesType(*val)) { + throw std::logic_error{ + fmt::format("The type of value doesn't satisfies the type '{}'!", type->GetPresentableName())}; + } + record->values.emplace(field_name, ToTypedValue(*val, memory)); + }); +} + +mgp_error mgp_func_result_set_error_msg(mgp_func_result *res, const char *msg, mgp_memory *memory) { + return WrapExceptions([=] { res->error_msg.emplace(msg, memory->impl); }); +} + +mgp_error mgp_func_result_set_value(mgp_func_result *res, mgp_value *value, mgp_memory *memory) { + return WrapExceptions([=] { res->value = ToTypedValue(*value, memory->impl); }); +} + +/// Graph Constructs + +void mgp_properties_iterator_destroy(mgp_properties_iterator *it) { DeleteRawMgpObject(it); } + +mgp_error mgp_properties_iterator_get(mgp_properties_iterator *it, mgp_property **result) { + return WrapExceptions( + [it]() -> mgp_property * { + if (it->current) { + return &it->property; + }; + return nullptr; + }, + result); +} + +mgp_error mgp_properties_iterator_next(mgp_properties_iterator *it, mgp_property **result) { + // Incrementing the iterator either for on-disk or in-memory + // storage, so perhaps the underlying thing can throw. + // Both copying TypedValue and/or string from PropertyName may fail to + // allocate. Also, dereferencing `it->current_it` could also throw, so + // either way return nullptr and leave `it` in undefined state. + // Hopefully iterator comparison doesn't throw, but wrap the whole thing in + // try ... catch just to be sure. + return WrapExceptions( + [it]() -> mgp_property * { + if (it->current_it == it->pvs.end()) { + MG_ASSERT(!it->current, + "Iteration is already done, so it->current should " + "have been set to std::nullopt"); + return nullptr; + } + if (++it->current_it == it->pvs.end()) { + it->current = std::nullopt; + return nullptr; + } + memgraph::utils::OnScopeExit clean_up([it] { it->current = std::nullopt; }); + it->current.emplace(memgraph::utils::pmr::string(it->graph->impl->PropertyToName(it->current_it->first), + it->GetMemoryResource()), + mgp_value(it->current_it->second, it->GetMemoryResource())); + it->property.name = it->current->first.c_str(); + it->property.value = &it->current->second; + clean_up.Disable(); + return &it->property; + }, + result); +} + +mgp_error mgp_vertex_get_id(mgp_vertex *v, mgp_vertex_id *result) { + return WrapExceptions([v] { return mgp_vertex_id{.as_int = v->impl.Gid().AsInt()}; }, result); +} + +mgp_error mgp_vertex_underlying_graph_is_mutable(mgp_vertex *v, int *result) { + return mgp_graph_is_mutable(v->graph, result); +} + +namespace { +memgraph::storage::v3::PropertyValue ToPropertyValue(const mgp_value &value); + +memgraph::storage::v3::PropertyValue ToPropertyValue(const mgp_list &list) { + memgraph::storage::v3::PropertyValue result{std::vector<memgraph::storage::v3::PropertyValue>{}}; + auto &result_list = result.ValueList(); + for (const auto &value : list.elems) { + result_list.push_back(ToPropertyValue(value)); + } + return result; +} + +memgraph::storage::v3::PropertyValue ToPropertyValue(const mgp_map &map) { + memgraph::storage::v3::PropertyValue result{std::map<std::string, memgraph::storage::v3::PropertyValue>{}}; + auto &result_map = result.ValueMap(); + for (const auto &[key, value] : map.items) { + result_map.insert_or_assign(std::string{key}, ToPropertyValue(value)); + } + return result; +} + +memgraph::storage::v3::PropertyValue ToPropertyValue(const mgp_value &value) { + switch (value.type) { + case MGP_VALUE_TYPE_NULL: + return memgraph::storage::v3::PropertyValue{}; + case MGP_VALUE_TYPE_BOOL: + return memgraph::storage::v3::PropertyValue{value.bool_v}; + case MGP_VALUE_TYPE_INT: + return memgraph::storage::v3::PropertyValue{value.int_v}; + case MGP_VALUE_TYPE_DOUBLE: + return memgraph::storage::v3::PropertyValue{value.double_v}; + case MGP_VALUE_TYPE_STRING: + return memgraph::storage::v3::PropertyValue{std::string{value.string_v}}; + case MGP_VALUE_TYPE_LIST: + return ToPropertyValue(*value.list_v); + case MGP_VALUE_TYPE_MAP: + return ToPropertyValue(*value.map_v); + case MGP_VALUE_TYPE_DATE: + return memgraph::storage::v3::PropertyValue{memgraph::storage::v3::TemporalData{ + memgraph::storage::v3::TemporalType::Date, value.date_v->date.MicrosecondsSinceEpoch()}}; + case MGP_VALUE_TYPE_LOCAL_TIME: + return memgraph::storage::v3::PropertyValue{memgraph::storage::v3::TemporalData{ + memgraph::storage::v3::TemporalType::LocalTime, value.local_time_v->local_time.MicrosecondsSinceEpoch()}}; + case MGP_VALUE_TYPE_LOCAL_DATE_TIME: + return memgraph::storage::v3::PropertyValue{ + memgraph::storage::v3::TemporalData{memgraph::storage::v3::TemporalType::LocalDateTime, + value.local_date_time_v->local_date_time.MicrosecondsSinceEpoch()}}; + case MGP_VALUE_TYPE_DURATION: + return memgraph::storage::v3::PropertyValue{memgraph::storage::v3::TemporalData{ + memgraph::storage::v3::TemporalType::Duration, value.duration_v->duration.microseconds}}; + case MGP_VALUE_TYPE_VERTEX: + throw ValueConversionException{"A vertex is not a valid property value! "}; + case MGP_VALUE_TYPE_EDGE: + throw ValueConversionException{"An edge is not a valid property value!"}; + case MGP_VALUE_TYPE_PATH: + throw ValueConversionException{"A path is not a valid property value!"}; + } +} +} // namespace + +mgp_error mgp_vertex_set_property(struct mgp_vertex *v, const char *property_name, mgp_value *property_value) { + return WrapExceptions([=] { + if (!MgpVertexIsMutable(*v)) { + throw ImmutableObjectException{"Cannot set a property on an immutable vertex!"}; + } + const auto prop_key = v->graph->impl->NameToProperty(property_name); + const auto result = v->impl.SetProperty(prop_key, ToPropertyValue(*property_value)); + if (result.HasError()) { + switch (result.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot set the properties of a deleted vertex!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL("Query modules shouldn't have access to nonexistent objects when setting a property of a vertex!"); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + LOG_FATAL("Unexpected error when setting a property of a vertex."); + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + throw SerializationException{"Cannot serialize setting a property of a vertex."}; + } + } + + auto &ctx = v->graph->ctx; + + ctx->execution_stats[memgraph::query::v2::ExecutionStats::Key::UPDATED_PROPERTIES] += 1; + + auto *trigger_ctx_collector = ctx->trigger_context_collector; + if (!trigger_ctx_collector || + !trigger_ctx_collector->ShouldRegisterObjectPropertyChange<memgraph::query::v2::VertexAccessor>()) { + return; + } + const auto old_value = memgraph::query::v2::TypedValue(*result); + if (property_value->type == mgp_value_type::MGP_VALUE_TYPE_NULL) { + trigger_ctx_collector->RegisterRemovedObjectProperty(v->impl, prop_key, old_value); + return; + } + const auto new_value = ToTypedValue(*property_value, property_value->memory); + trigger_ctx_collector->RegisterSetObjectProperty(v->impl, prop_key, old_value, new_value); + }); +} + +mgp_error mgp_vertex_add_label(struct mgp_vertex *v, mgp_label label) { + return WrapExceptions([=] { + if (!MgpVertexIsMutable(*v)) { + throw ImmutableObjectException{"Cannot add a label to an immutable vertex!"}; + } + const auto label_id = v->graph->impl->NameToLabel(label.name); + const auto result = v->impl.AddLabel(label_id); + + if (result.HasError()) { + switch (result.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot add a label to a deleted vertex!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL("Query modules shouldn't have access to nonexistent objects when adding a label to a vertex!"); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + LOG_FATAL("Unexpected error when adding a label to a vertex."); + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + throw SerializationException{"Cannot serialize adding a label to a vertex."}; + } + } + + auto &ctx = v->graph->ctx; + + ctx->execution_stats[memgraph::query::v2::ExecutionStats::Key::CREATED_LABELS] += 1; + + if (ctx->trigger_context_collector) { + ctx->trigger_context_collector->RegisterSetVertexLabel(v->impl, label_id); + } + }); +} + +mgp_error mgp_vertex_remove_label(struct mgp_vertex *v, mgp_label label) { + return WrapExceptions([=] { + if (!MgpVertexIsMutable(*v)) { + throw ImmutableObjectException{"Cannot remove a label from an immutable vertex!"}; + } + const auto label_id = v->graph->impl->NameToLabel(label.name); + const auto result = v->impl.RemoveLabel(label_id); + + if (result.HasError()) { + switch (result.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot remove a label from a deleted vertex!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL("Query modules shouldn't have access to nonexistent objects when removing a label from a vertex!"); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + LOG_FATAL("Unexpected error when removing a label from a vertex."); + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + throw SerializationException{"Cannot serialize removing a label from a vertex."}; + } + } + + auto &ctx = v->graph->ctx; + + ctx->execution_stats[memgraph::query::v2::ExecutionStats::Key::DELETED_LABELS] += 1; + + if (ctx->trigger_context_collector) { + ctx->trigger_context_collector->RegisterRemovedVertexLabel(v->impl, label_id); + } + }); +} + +mgp_error mgp_vertex_copy(mgp_vertex *v, mgp_memory *memory, mgp_vertex **result) { + return WrapExceptions([v, memory] { return NewRawMgpObject<mgp_vertex>(memory, *v); }, result); +} + +void mgp_vertex_destroy(mgp_vertex *v) { DeleteRawMgpObject(v); } + +mgp_error mgp_vertex_equal(mgp_vertex *v1, mgp_vertex *v2, int *result) { + // NOLINTNEXTLINE(clang-diagnostic-unevaluated-expression) + static_assert(noexcept(*result = *v1 == *v2 ? 1 : 0)); + *result = *v1 == *v2 ? 1 : 0; + return mgp_error::MGP_ERROR_NO_ERROR; +} + +mgp_error mgp_vertex_labels_count(mgp_vertex *v, size_t *result) { + return WrapExceptions( + [v]() -> size_t { + auto maybe_labels = v->impl.Labels(v->graph->view); + if (maybe_labels.HasError()) { + switch (maybe_labels.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot get the labels of a deleted vertex!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL("Query modules shouldn't have access to nonexistent objects when getting vertex labels!"); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + LOG_FATAL("Unexpected error when getting vertex labels."); + } + } + return maybe_labels->size(); + }, + result); +} + +mgp_error mgp_vertex_label_at(mgp_vertex *v, size_t i, mgp_label *result) { + return WrapExceptions( + [v, i]() -> const char * { + // TODO: Maybe it's worth caching this in mgp_vertex. + auto maybe_labels = v->impl.Labels(v->graph->view); + if (maybe_labels.HasError()) { + switch (maybe_labels.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot get a label of a deleted vertex!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL("Query modules shouldn't have access to nonexistent objects when getting a label of a vertex!"); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + LOG_FATAL("Unexpected error when getting a label of a vertex."); + } + } + if (i >= maybe_labels->size()) { + throw std::out_of_range("Label cannot be retrieved, because index exceeds the number of labels!"); + } + const auto &label = (*maybe_labels)[i]; + static_assert(std::is_lvalue_reference_v<decltype(v->graph->impl->LabelToName(label))>, + "Expected LabelToName to return a pointer or reference, so we " + "don't have to take a copy and manage memory."); + const auto &name = v->graph->impl->LabelToName(label); + return name.c_str(); + }, + &result->name); +} + +mgp_error mgp_vertex_has_label_named(mgp_vertex *v, const char *name, int *result) { + return WrapExceptions( + [v, name] { + memgraph::storage::v3::LabelId label; + label = v->graph->impl->NameToLabel(name); + + auto maybe_has_label = v->impl.HasLabel(v->graph->view, label); + if (maybe_has_label.HasError()) { + switch (maybe_has_label.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot check the existence of a label on a deleted vertex!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL( + "Query modules shouldn't have access to nonexistent objects when checking the existence of a label " + "on " + "a vertex!"); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + LOG_FATAL("Unexpected error when checking the existence of a label on a vertex."); + } + } + return *maybe_has_label ? 1 : 0; + }, + result); +} + +mgp_error mgp_vertex_has_label(mgp_vertex *v, mgp_label label, int *result) { + return mgp_vertex_has_label_named(v, label.name, result); +} + +mgp_error mgp_vertex_get_property(mgp_vertex *v, const char *name, mgp_memory *memory, mgp_value **result) { + return WrapExceptions( + [v, name, memory]() -> mgp_value * { + const auto &key = v->graph->impl->NameToProperty(name); + auto maybe_prop = v->impl.GetProperty(v->graph->view, key); + if (maybe_prop.HasError()) { + switch (maybe_prop.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot get a property of a deleted vertex!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL( + "Query modules shouldn't have access to nonexistent objects when getting a property of a vertex."); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + LOG_FATAL("Unexpected error when getting a property of a vertex."); + } + } + return NewRawMgpObject<mgp_value>(memory, std::move(*maybe_prop)); + }, + result); +} + +mgp_error mgp_vertex_iter_properties(mgp_vertex *v, mgp_memory *memory, mgp_properties_iterator **result) { + // NOTE: This copies the whole properties into the iterator. + // TODO: Think of a good way to avoid the copy which doesn't just rely on some + // assumption that storage may return a pointer to the property store. This + // will probably require a different API in storage. + return WrapExceptions( + [v, memory] { + auto maybe_props = v->impl.Properties(v->graph->view); + if (maybe_props.HasError()) { + switch (maybe_props.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot get the properties of a deleted vertex!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL( + "Query modules shouldn't have access to nonexistent objects when getting the properties of a " + "vertex."); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + LOG_FATAL("Unexpected error when getting the properties of a vertex."); + } + } + return NewRawMgpObject<mgp_properties_iterator>(memory, v->graph, std::move(*maybe_props)); + }, + result); +} + +void mgp_edges_iterator_destroy(mgp_edges_iterator *it) { DeleteRawMgpObject(it); } + +mgp_error mgp_vertex_iter_in_edges(mgp_vertex *v, mgp_memory *memory, mgp_edges_iterator **result) { + return WrapExceptions( + [v, memory] { + auto it = NewMgpObject<mgp_edges_iterator>(memory, *v); + MG_ASSERT(it != nullptr); + + auto maybe_edges = v->impl.InEdges(v->graph->view); + if (maybe_edges.HasError()) { + switch (maybe_edges.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot get the inbound edges of a deleted vertex!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL( + "Query modules shouldn't have access to nonexistent objects when getting the inbound edges of a " + "vertex."); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + LOG_FATAL("Unexpected error when getting the inbound edges of a vertex."); + } + } + it->in.emplace(std::move(*maybe_edges)); + it->in_it.emplace(it->in->begin()); + if (*it->in_it != it->in->end()) { + it->current_e.emplace(**it->in_it, v->graph, it->GetMemoryResource()); + } + + return it.release(); + }, + result); +} + +mgp_error mgp_vertex_iter_out_edges(mgp_vertex *v, mgp_memory *memory, mgp_edges_iterator **result) { + return WrapExceptions( + [v, memory] { + auto it = NewMgpObject<mgp_edges_iterator>(memory, *v); + MG_ASSERT(it != nullptr); + + auto maybe_edges = v->impl.OutEdges(v->graph->view); + if (maybe_edges.HasError()) { + switch (maybe_edges.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot get the outbound edges of a deleted vertex!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL( + "Query modules shouldn't have access to nonexistent objects when getting the outbound edges of a " + "vertex."); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + LOG_FATAL("Unexpected error when getting the outbound edges of a vertex."); + } + } + it->out.emplace(std::move(*maybe_edges)); + it->out_it.emplace(it->out->begin()); + if (*it->out_it != it->out->end()) { + it->current_e.emplace(**it->out_it, v->graph, it->GetMemoryResource()); + } + + return it.release(); + }, + result); +} + +mgp_error mgp_edges_iterator_underlying_graph_is_mutable(mgp_edges_iterator *it, int *result) { + return mgp_vertex_underlying_graph_is_mutable(&it->source_vertex, result); +} + +mgp_error mgp_edges_iterator_get(mgp_edges_iterator *it, mgp_edge **result) { + return WrapExceptions( + [it]() -> mgp_edge * { + if (it->current_e.has_value()) { + return &*it->current_e; + } + return nullptr; + }, + result); +} + +mgp_error mgp_edges_iterator_next(mgp_edges_iterator *it, mgp_edge **result) { + return WrapExceptions( + [it] { + MG_ASSERT(it->in || it->out); + auto next = [&](auto *impl_it, const auto &end) -> mgp_edge * { + if (*impl_it == end) { + MG_ASSERT(!it->current_e, + "Iteration is already done, so it->current_e " + "should have been set to std::nullopt"); + return nullptr; + } + if (++(*impl_it) == end) { + it->current_e = std::nullopt; + return nullptr; + } + it->current_e.emplace(**impl_it, it->source_vertex.graph, it->GetMemoryResource()); + return &*it->current_e; + }; + if (it->in_it) { + return next(&*it->in_it, it->in->end()); + } + return next(&*it->out_it, it->out->end()); + }, + result); +} + +mgp_error mgp_edge_get_id(mgp_edge *e, mgp_edge_id *result) { + return WrapExceptions([e] { return mgp_edge_id{.as_int = e->impl.Gid().AsInt()}; }, result); +} + +mgp_error mgp_edge_underlying_graph_is_mutable(mgp_edge *e, int *result) { + return mgp_vertex_underlying_graph_is_mutable(&e->from, result); +} + +mgp_error mgp_edge_copy(mgp_edge *e, mgp_memory *memory, mgp_edge **result) { + return WrapExceptions([e, memory] { return mgp_edge::Copy(*e, *memory); }, result); +} + +void mgp_edge_destroy(mgp_edge *e) { DeleteRawMgpObject(e); } + +mgp_error mgp_edge_equal(mgp_edge *e1, mgp_edge *e2, int *result) { + // NOLINTNEXTLINE(clang-diagnostic-unevaluated-expression) + static_assert(noexcept(*result = *e1 == *e2 ? 1 : 0)); + *result = *e1 == *e2 ? 1 : 0; + return mgp_error::MGP_ERROR_NO_ERROR; +} + +mgp_error mgp_edge_get_type(mgp_edge *e, mgp_edge_type *result) { + return WrapExceptions( + [e] { + const auto &name = e->from.graph->impl->EdgeTypeToName(e->impl.EdgeType()); + static_assert(std::is_lvalue_reference_v<decltype(e->from.graph->impl->EdgeTypeToName(e->impl.EdgeType()))>, + "Expected EdgeTypeToName to return a pointer or reference, so we " + "don't have to take a copy and manage memory."); + return name.c_str(); + }, + &result->name); +} + +mgp_error mgp_edge_get_from(mgp_edge *e, mgp_vertex **result) { + *result = &e->from; + return mgp_error::MGP_ERROR_NO_ERROR; +} + +mgp_error mgp_edge_get_to(mgp_edge *e, mgp_vertex **result) { + *result = &e->to; + return mgp_error::MGP_ERROR_NO_ERROR; +} + +mgp_error mgp_edge_get_property(mgp_edge *e, const char *name, mgp_memory *memory, mgp_value **result) { + return WrapExceptions( + [e, name, memory] { + const auto &key = e->from.graph->impl->NameToProperty(name); + auto view = e->from.graph->view; + auto maybe_prop = e->impl.GetProperty(view, key); + if (maybe_prop.HasError()) { + switch (maybe_prop.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot get a property of a deleted edge!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL( + "Query modules shouldn't have access to nonexistent objects when getting a property of an edge."); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + LOG_FATAL("Unexpected error when getting a property of an edge."); + } + } + return NewRawMgpObject<mgp_value>(memory, std::move(*maybe_prop)); + }, + result); +} + +mgp_error mgp_edge_set_property(struct mgp_edge *e, const char *property_name, mgp_value *property_value) { + return WrapExceptions([=] { + if (!MgpEdgeIsMutable(*e)) { + throw ImmutableObjectException{"Cannot set a property on an immutable edge!"}; + } + const auto prop_key = e->from.graph->impl->NameToProperty(property_name); + const auto result = e->impl.SetProperty(prop_key, ToPropertyValue(*property_value)); + + if (result.HasError()) { + switch (result.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot set the properties of a deleted edge!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL("Query modules shouldn't have access to nonexistent objects when setting a property of an edge!"); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + throw std::logic_error{"Cannot set the properties of edges, because properties on edges are disabled!"}; + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + LOG_FATAL("Unexpected error when setting a property of an edge."); + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + throw SerializationException{"Cannot serialize setting a property of an edge."}; + } + } + + auto &ctx = e->from.graph->ctx; + + ctx->execution_stats[memgraph::query::v2::ExecutionStats::Key::UPDATED_PROPERTIES] += 1; + + auto *trigger_ctx_collector = e->from.graph->ctx->trigger_context_collector; + if (!trigger_ctx_collector || + !trigger_ctx_collector->ShouldRegisterObjectPropertyChange<memgraph::query::v2::EdgeAccessor>()) { + return; + } + const auto old_value = memgraph::query::v2::TypedValue(*result); + if (property_value->type == mgp_value_type::MGP_VALUE_TYPE_NULL) { + e->from.graph->ctx->trigger_context_collector->RegisterRemovedObjectProperty(e->impl, prop_key, old_value); + return; + } + const auto new_value = ToTypedValue(*property_value, property_value->memory); + e->from.graph->ctx->trigger_context_collector->RegisterSetObjectProperty(e->impl, prop_key, old_value, new_value); + }); +} + +mgp_error mgp_edge_iter_properties(mgp_edge *e, mgp_memory *memory, mgp_properties_iterator **result) { + // NOTE: This copies the whole properties into iterator. + // TODO: Think of a good way to avoid the copy which doesn't just rely on some + // assumption that storage may return a pointer to the property store. This + // will probably require a different API in storage. + return WrapExceptions( + [e, memory] { + auto view = e->from.graph->view; + auto maybe_props = e->impl.Properties(view); + if (maybe_props.HasError()) { + switch (maybe_props.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot get the properties of a deleted edge!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL( + "Query modules shouldn't have access to nonexistent objects when getting the properties of an edge."); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + LOG_FATAL("Unexpected error when getting the properties of an edge."); + } + } + return NewRawMgpObject<mgp_properties_iterator>(memory, e->from.graph, std::move(*maybe_props)); + }, + result); +} + +mgp_error mgp_graph_get_vertex_by_id(mgp_graph *graph, mgp_vertex_id id, mgp_memory *memory, mgp_vertex **result) { + return WrapExceptions( + [graph, id, memory]() -> mgp_vertex * { + auto maybe_vertex = graph->impl->FindVertex(memgraph::storage::v3::Gid::FromInt(id.as_int), graph->view); + if (maybe_vertex) { + return NewRawMgpObject<mgp_vertex>(memory, *maybe_vertex, graph); + } + return nullptr; + }, + result); +} + +mgp_error mgp_graph_is_mutable(mgp_graph *graph, int *result) { + *result = MgpGraphIsMutable(*graph) ? 1 : 0; + return mgp_error::MGP_ERROR_NO_ERROR; +}; + +mgp_error mgp_graph_create_vertex(struct mgp_graph *graph, mgp_memory *memory, mgp_vertex **result) { + return WrapExceptions( + [=] { + if (!MgpGraphIsMutable(*graph)) { + throw ImmutableObjectException{"Cannot create a vertex in an immutable graph!"}; + } + auto vertex = graph->impl->InsertVertex(); + + auto &ctx = graph->ctx; + ctx->execution_stats[memgraph::query::v2::ExecutionStats::Key::CREATED_NODES] += 1; + + if (ctx->trigger_context_collector) { + ctx->trigger_context_collector->RegisterCreatedObject(vertex); + } + return NewRawMgpObject<mgp_vertex>(memory, vertex, graph); + }, + result); +} + +mgp_error mgp_graph_delete_vertex(struct mgp_graph *graph, mgp_vertex *vertex) { + return WrapExceptions([=] { + if (!MgpGraphIsMutable(*graph)) { + throw ImmutableObjectException{"Cannot remove a vertex from an immutable graph!"}; + } + const auto result = graph->impl->RemoveVertex(&vertex->impl); + + if (result.HasError()) { + switch (result.GetError()) { + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL("Query modules shouldn't have access to nonexistent objects when removing a vertex!"); + case memgraph::storage::v3::Error::DELETED_OBJECT: + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + LOG_FATAL("Unexpected error when removing a vertex."); + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + throw std::logic_error{"Cannot remove a vertex that has edges!"}; + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + throw SerializationException{"Cannot serialize removing a vertex."}; + } + } + + if (!*result) { + return; + } + + auto &ctx = graph->ctx; + + ctx->execution_stats[memgraph::query::v2::ExecutionStats::Key::DELETED_NODES] += 1; + + if (ctx->trigger_context_collector) { + ctx->trigger_context_collector->RegisterDeletedObject(**result); + } + }); +} + +mgp_error mgp_graph_detach_delete_vertex(struct mgp_graph *graph, mgp_vertex *vertex) { + return WrapExceptions([=] { + if (!MgpGraphIsMutable(*graph)) { + throw ImmutableObjectException{"Cannot remove a vertex from an immutable graph!"}; + } + const auto result = graph->impl->DetachRemoveVertex(&vertex->impl); + + if (result.HasError()) { + switch (result.GetError()) { + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL("Query modules shouldn't have access to nonexistent objects when removing a vertex!"); + case memgraph::storage::v3::Error::DELETED_OBJECT: + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + LOG_FATAL("Unexpected error when removing a vertex."); + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + throw SerializationException{"Cannot serialize removing a vertex."}; + } + } + + if (!*result) { + return; + } + + auto &ctx = graph->ctx; + + ctx->execution_stats[memgraph::query::v2::ExecutionStats::Key::DELETED_NODES] += 1; + ctx->execution_stats[memgraph::query::v2::ExecutionStats::Key::DELETED_EDGES] += + static_cast<int64_t>((*result)->second.size()); + + auto *trigger_ctx_collector = ctx->trigger_context_collector; + if (!trigger_ctx_collector) { + return; + } + + trigger_ctx_collector->RegisterDeletedObject((*result)->first); + if (!trigger_ctx_collector->ShouldRegisterDeletedObject<memgraph::query::v2::EdgeAccessor>()) { + return; + } + for (const auto &edge : (*result)->second) { + trigger_ctx_collector->RegisterDeletedObject(edge); + } + }); +} + +mgp_error mgp_graph_create_edge(mgp_graph *graph, mgp_vertex *from, mgp_vertex *to, mgp_edge_type type, + mgp_memory *memory, mgp_edge **result) { + return WrapExceptions( + [=] { + if (!MgpGraphIsMutable(*graph)) { + throw ImmutableObjectException{"Cannot create an edge in an immutable graph!"}; + } + + auto edge = graph->impl->InsertEdge(&from->impl, &to->impl, from->graph->impl->NameToEdgeType(type.name)); + if (edge.HasError()) { + switch (edge.GetError()) { + case memgraph::storage::v3::Error::DELETED_OBJECT: + throw DeletedObjectException{"Cannot add an edge to a deleted vertex!"}; + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL("Query modules shouldn't have access to nonexistent objects when creating an edge!"); + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + LOG_FATAL("Unexpected error when creating an edge."); + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + throw SerializationException{"Cannot serialize creating an edge."}; + } + } + auto &ctx = graph->ctx; + + ctx->execution_stats[memgraph::query::v2::ExecutionStats::Key::CREATED_EDGES] += 1; + + if (ctx->trigger_context_collector) { + ctx->trigger_context_collector->RegisterCreatedObject(*edge); + } + return NewRawMgpObject<mgp_edge>(memory, edge.GetValue(), from->graph); + }, + result); +} + +mgp_error mgp_graph_delete_edge(struct mgp_graph *graph, mgp_edge *edge) { + return WrapExceptions([=] { + if (!MgpGraphIsMutable(*graph)) { + throw ImmutableObjectException{"Cannot remove an edge from an immutable graph!"}; + } + const auto result = graph->impl->RemoveEdge(&edge->impl); + + if (result.HasError()) { + switch (result.GetError()) { + case memgraph::storage::v3::Error::NONEXISTENT_OBJECT: + LOG_FATAL("Query modules shouldn't have access to nonexistent objects when removing an edge!"); + case memgraph::storage::v3::Error::DELETED_OBJECT: + case memgraph::storage::v3::Error::PROPERTIES_DISABLED: + case memgraph::storage::v3::Error::VERTEX_HAS_EDGES: + LOG_FATAL("Unexpected error when removing an edge."); + case memgraph::storage::v3::Error::SERIALIZATION_ERROR: + throw SerializationException{"Cannot serialize removing an edge."}; + } + } + + if (!*result) { + return; + } + auto &ctx = graph->ctx; + + ctx->execution_stats[memgraph::query::v2::ExecutionStats::Key::DELETED_EDGES] += 1; + if (ctx->trigger_context_collector) { + ctx->trigger_context_collector->RegisterDeletedObject(**result); + } + }); +} + +void mgp_vertices_iterator_destroy(mgp_vertices_iterator *it) { DeleteRawMgpObject(it); } + +mgp_error mgp_graph_iter_vertices(mgp_graph *graph, mgp_memory *memory, mgp_vertices_iterator **result) { + return WrapExceptions([graph, memory] { return NewRawMgpObject<mgp_vertices_iterator>(memory, graph); }, result); +} + +mgp_error mgp_vertices_iterator_underlying_graph_is_mutable(mgp_vertices_iterator *it, int *result) { + return mgp_graph_is_mutable(it->graph, result); +} + +mgp_error mgp_vertices_iterator_get(mgp_vertices_iterator *it, mgp_vertex **result) { + return WrapExceptions( + [it]() -> mgp_vertex * { + if (it->current_v.has_value()) { + return &*it->current_v; + } + return nullptr; + }, + result); +} + +mgp_error mgp_vertices_iterator_next(mgp_vertices_iterator *it, mgp_vertex **result) { + return WrapExceptions( + [it]() -> mgp_vertex * { + if (it->current_it == it->vertices.end()) { + MG_ASSERT(!it->current_v, + "Iteration is already done, so it->current_v " + "should have been set to std::nullopt"); + return nullptr; + } + if (++it->current_it == it->vertices.end()) { + it->current_v = std::nullopt; + return nullptr; + } + memgraph::utils::OnScopeExit clean_up([it] { it->current_v = std::nullopt; }); + it->current_v.emplace(*it->current_it, it->graph, it->GetMemoryResource()); + clean_up.Disable(); + return &*it->current_v; + }, + result); +} + +/// Type System +/// +/// All types are allocated globally, so that we simplify the API and minimize +/// allocations done for types. + +namespace { +void NoOpCypherTypeDeleter(CypherType * /*type*/) {} +} // namespace + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define DEFINE_MGP_TYPE_GETTER(cypher_type_name, mgp_type_name) \ + mgp_error mgp_type_##mgp_type_name(mgp_type **result) { \ + return WrapExceptions( \ + [] { \ + static cypher_type_name##Type impl; \ + static mgp_type mgp_type_name_type{CypherTypePtr(&impl, NoOpCypherTypeDeleter)}; \ + return &mgp_type_name_type; \ + }, \ + result); \ + } + +DEFINE_MGP_TYPE_GETTER(Any, any); +DEFINE_MGP_TYPE_GETTER(Bool, bool); +DEFINE_MGP_TYPE_GETTER(String, string); +DEFINE_MGP_TYPE_GETTER(Int, int); +DEFINE_MGP_TYPE_GETTER(Float, float); +DEFINE_MGP_TYPE_GETTER(Number, number); +DEFINE_MGP_TYPE_GETTER(Map, map); +DEFINE_MGP_TYPE_GETTER(Node, node); +DEFINE_MGP_TYPE_GETTER(Relationship, relationship); +DEFINE_MGP_TYPE_GETTER(Path, path); +DEFINE_MGP_TYPE_GETTER(Date, date); +DEFINE_MGP_TYPE_GETTER(LocalTime, local_time); +DEFINE_MGP_TYPE_GETTER(LocalDateTime, local_date_time); +DEFINE_MGP_TYPE_GETTER(Duration, duration); + +mgp_error mgp_type_list(mgp_type *type, mgp_type **result) { + return WrapExceptions( + [type] { + // Maps `type` to corresponding instance of ListType. + static memgraph::utils::pmr::map<mgp_type *, mgp_type> gListTypes(memgraph::utils::NewDeleteResource()); + static memgraph::utils::SpinLock lock; + std::lock_guard<memgraph::utils::SpinLock> guard(lock); + auto found_it = gListTypes.find(type); + if (found_it != gListTypes.end()) { + return &found_it->second; + } + auto alloc = gListTypes.get_allocator(); + CypherTypePtr impl( + alloc.new_object<ListType>( + // Just obtain the pointer to original impl, don't own it. + CypherTypePtr(type->impl.get(), NoOpCypherTypeDeleter), alloc.GetMemoryResource()), + [alloc](CypherType *base_ptr) mutable { alloc.delete_object(static_cast<ListType *>(base_ptr)); }); + return &gListTypes.emplace(type, mgp_type{std::move(impl)}).first->second; + }, + result); +} + +mgp_error mgp_type_nullable(mgp_type *type, mgp_type **result) { + return WrapExceptions( + [type] { + // Maps `type` to corresponding instance of NullableType. + static memgraph::utils::pmr::map<mgp_type *, mgp_type> gNullableTypes(memgraph::utils::NewDeleteResource()); + static memgraph::utils::SpinLock lock; + std::lock_guard<memgraph::utils::SpinLock> guard(lock); + auto found_it = gNullableTypes.find(type); + if (found_it != gNullableTypes.end()) return &found_it->second; + + auto alloc = gNullableTypes.get_allocator(); + auto impl = + NullableType::Create(CypherTypePtr(type->impl.get(), NoOpCypherTypeDeleter), alloc.GetMemoryResource()); + return &gNullableTypes.emplace(type, mgp_type{std::move(impl)}).first->second; + }, + result); +} + +namespace { +mgp_proc *mgp_module_add_procedure(mgp_module *module, const char *name, mgp_proc_cb cb, + const ProcedureInfo &procedure_info) { + if (!IsValidIdentifierName(name)) { + throw std::invalid_argument{fmt::format("Invalid procedure name: {}", name)}; + } + if (module->procedures.find(name) != module->procedures.end()) { + throw std::logic_error{fmt::format("Procedure already exists with name '{}'", name)}; + }; + + auto *memory = module->procedures.get_allocator().GetMemoryResource(); + // May throw std::bad_alloc, std::length_error + return &module->procedures.emplace(name, mgp_proc(name, cb, memory, procedure_info)).first->second; +} +} // namespace + +mgp_error mgp_module_add_read_procedure(mgp_module *module, const char *name, mgp_proc_cb cb, mgp_proc **result) { + return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, {.is_write = false}); }, result); +} + +mgp_error mgp_module_add_write_procedure(mgp_module *module, const char *name, mgp_proc_cb cb, mgp_proc **result) { + return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, {.is_write = true}); }, result); +} + +namespace { +template <typename T> +concept IsCallable = memgraph::utils::SameAsAnyOf<T, mgp_proc, mgp_func>; + +template <IsCallable TCall> +mgp_error MgpAddArg(TCall &callable, const std::string &name, mgp_type &type) { + return WrapExceptions([&]() mutable { + static constexpr std::string_view type_name = std::invoke([]() constexpr { + if constexpr (std::is_same_v<TCall, mgp_proc>) { + return "procedure"; + } else if constexpr (std::is_same_v<TCall, mgp_func>) { + return "function"; + } + }); + + if (!IsValidIdentifierName(name.c_str())) { + throw std::invalid_argument{fmt::format("Invalid argument name for {} '{}': {}", type_name, callable.name, name)}; + } + if (!callable.opt_args.empty()) { + throw std::logic_error{fmt::format("Cannot add required argument '{}' to {} '{}' after adding any optional one", + name, type_name, callable.name)}; + } + callable.args.emplace_back(name, type.impl.get()); + }); +} + +template <IsCallable TCall> +mgp_error MgpAddOptArg(TCall &callable, const std::string name, mgp_type &type, mgp_value &default_value) { + return WrapExceptions([&]() mutable { + static constexpr std::string_view type_name = std::invoke([]() constexpr { + if constexpr (std::is_same_v<TCall, mgp_proc>) { + return "procedure"; + } else if constexpr (std::is_same_v<TCall, mgp_func>) { + return "function"; + } + }); + + if (!IsValidIdentifierName(name.c_str())) { + throw std::invalid_argument{fmt::format("Invalid argument name for {} '{}': {}", type_name, callable.name, name)}; + } + switch (MgpValueGetType(default_value)) { + case MGP_VALUE_TYPE_VERTEX: + case MGP_VALUE_TYPE_EDGE: + case MGP_VALUE_TYPE_PATH: + // default_value must not be a graph element. + throw ValueConversionException{"Default value of argument '{}' of {} '{}' name must not be a graph element!", + name, type_name, callable.name}; + case MGP_VALUE_TYPE_NULL: + case MGP_VALUE_TYPE_BOOL: + case MGP_VALUE_TYPE_INT: + case MGP_VALUE_TYPE_DOUBLE: + case MGP_VALUE_TYPE_STRING: + case MGP_VALUE_TYPE_LIST: + case MGP_VALUE_TYPE_MAP: + case MGP_VALUE_TYPE_DATE: + case MGP_VALUE_TYPE_LOCAL_TIME: + case MGP_VALUE_TYPE_LOCAL_DATE_TIME: + case MGP_VALUE_TYPE_DURATION: + break; + } + // Default value must be of required `type`. + if (!type.impl->SatisfiesType(default_value)) { + throw std::logic_error{fmt::format("The default value of argument '{}' for {} '{}' doesn't satisfy type '{}'", + name, type_name, callable.name, type.impl->GetPresentableName())}; + } + auto *memory = callable.opt_args.get_allocator().GetMemoryResource(); + callable.opt_args.emplace_back(memgraph::utils::pmr::string(name, memory), type.impl.get(), + ToTypedValue(default_value, memory)); + }); +} +} // namespace + +mgp_error mgp_proc_add_arg(mgp_proc *proc, const char *name, mgp_type *type) { + return MgpAddArg(*proc, std::string(name), *type); +} + +mgp_error mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, mgp_type *type, mgp_value *default_value) { + return MgpAddOptArg(*proc, std::string(name), *type, *default_value); +} + +mgp_error mgp_func_add_arg(mgp_func *func, const char *name, mgp_type *type) { + return MgpAddArg(*func, std::string(name), *type); +} + +mgp_error mgp_func_add_opt_arg(mgp_func *func, const char *name, mgp_type *type, mgp_value *default_value) { + return MgpAddOptArg(*func, std::string(name), *type, *default_value); +} + +namespace { + +template <typename T> +concept ModuleProperties = memgraph::utils::SameAsAnyOf<T, mgp_proc, mgp_trans>; + +template <ModuleProperties T> +mgp_error AddResultToProp(T *prop, const char *name, mgp_type *type, bool is_deprecated) noexcept { + return WrapExceptions([=] { + if (!IsValidIdentifierName(name)) { + throw std::invalid_argument{fmt::format("Invalid result name for procedure '{}': {}", prop->name, name)}; + } + if (prop->results.find(name) != prop->results.end()) { + throw std::logic_error{fmt::format("Result already exists with name '{}' for procedure '{}'", name, prop->name)}; + }; + auto *memory = prop->results.get_allocator().GetMemoryResource(); + prop->results.emplace(memgraph::utils::pmr::string(name, memory), std::make_pair(type->impl.get(), is_deprecated)); + }); +} + +} // namespace + +mgp_error mgp_proc_add_result(mgp_proc *proc, const char *name, mgp_type *type) { + return AddResultToProp(proc, name, type, false); +} + +mgp_error MgpTransAddFixedResult(mgp_trans *trans) noexcept { + if (const auto err = AddResultToProp(trans, "query", Call<mgp_type *>(mgp_type_string), false); + err != mgp_error::MGP_ERROR_NO_ERROR) { + return err; + } + return AddResultToProp(trans, "parameters", Call<mgp_type *>(mgp_type_nullable, Call<mgp_type *>(mgp_type_map)), + false); +} + +mgp_error mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, mgp_type *type) { + return AddResultToProp(proc, name, type, true); +} + +int mgp_must_abort(mgp_graph *graph) { + MG_ASSERT(graph->ctx); + static_assert(noexcept(memgraph::query::v2::MustAbort(*graph->ctx))); + return memgraph::query::v2::MustAbort(*graph->ctx) ? 1 : 0; +} + +namespace memgraph::query::v2::procedure { + +namespace { + +// Print the value in user presentable fashion. +// @throw std::bad_alloc +// @throw std::length_error +std::ostream &PrintValue(const TypedValue &value, std::ostream *stream) { + switch (value.type()) { + case TypedValue::Type::Null: + return (*stream) << "Null"; + case TypedValue::Type::Bool: + return (*stream) << (value.ValueBool() ? "true" : "false"); + case TypedValue::Type::Int: + return (*stream) << value.ValueInt(); + case TypedValue::Type::Double: + return (*stream) << value.ValueDouble(); + case TypedValue::Type::String: + // String value should be escaped, this allocates a new string. + return (*stream) << memgraph::utils::Escape(value.ValueString()); + case TypedValue::Type::List: + (*stream) << "["; + memgraph::utils::PrintIterable(*stream, value.ValueList(), ", ", + [](auto &stream, const auto &elem) { PrintValue(elem, &stream); }); + return (*stream) << "]"; + case TypedValue::Type::Map: + (*stream) << "{"; + memgraph::utils::PrintIterable(*stream, value.ValueMap(), ", ", [](auto &stream, const auto &item) { + // Map keys are not escaped strings. + stream << item.first << ": "; + PrintValue(item.second, &stream); + }); + return (*stream) << "}"; + case TypedValue::Type::Date: + return (*stream) << value.ValueDate(); + case TypedValue::Type::LocalTime: + return (*stream) << value.ValueLocalTime(); + case TypedValue::Type::LocalDateTime: + return (*stream) << value.ValueLocalDateTime(); + case TypedValue::Type::Duration: + return (*stream) << value.ValueDuration(); + case TypedValue::Type::Vertex: + case TypedValue::Type::Edge: + case TypedValue::Type::Path: + LOG_FATAL("value must not be a graph element"); + } +} + +} // namespace + +void PrintProcSignature(const mgp_proc &proc, std::ostream *stream) { + (*stream) << proc.name << "("; + memgraph::utils::PrintIterable(*stream, proc.args, ", ", [](auto &stream, const auto &arg) { + stream << arg.first << " :: " << arg.second->GetPresentableName(); + }); + if (!proc.args.empty() && !proc.opt_args.empty()) (*stream) << ", "; + memgraph::utils::PrintIterable(*stream, proc.opt_args, ", ", [](auto &stream, const auto &arg) { + stream << std::get<0>(arg) << " = "; + PrintValue(std::get<2>(arg), &stream) << " :: " << std::get<1>(arg)->GetPresentableName(); + }); + (*stream) << ") :: ("; + memgraph::utils::PrintIterable(*stream, proc.results, ", ", [](auto &stream, const auto &name_result) { + const auto &[type, is_deprecated] = name_result.second; + if (is_deprecated) stream << "DEPRECATED "; + stream << name_result.first << " :: " << type->GetPresentableName(); + }); + (*stream) << ")"; +} + +void PrintFuncSignature(const mgp_func &func, std::ostream &stream) { + stream << func.name << "("; + utils::PrintIterable(stream, func.args, ", ", [](auto &stream, const auto &arg) { + stream << arg.first << " :: " << arg.second->GetPresentableName(); + }); + if (!func.args.empty() && !func.opt_args.empty()) { + stream << ", "; + } + utils::PrintIterable(stream, func.opt_args, ", ", [](auto &stream, const auto &arg) { + const auto &[name, type, default_val] = arg; + stream << name << " = "; + PrintValue(default_val, &stream) << " :: " << type->GetPresentableName(); + }); + stream << ")"; +} + +bool IsValidIdentifierName(const char *name) { + if (!name) return false; + std::regex regex("[_[:alpha:]][_[:alnum:]]*"); + return std::regex_match(name, regex); +} + +} // namespace memgraph::query::v2::procedure + +namespace { +using StreamSourceType = memgraph::query::v2::stream::StreamSourceType; + +class InvalidMessageFunction : public std::invalid_argument { + public: + InvalidMessageFunction(const StreamSourceType type, const std::string_view function_name) + : std::invalid_argument{fmt::format("'{}' is not defined for a message from a stream of type '{}'", function_name, + StreamSourceTypeToString(type))} {} +}; + +StreamSourceType MessageToStreamSourceType(const mgp_message::KafkaMessage & /*msg*/) { + return StreamSourceType::KAFKA; +} + +StreamSourceType MessageToStreamSourceType(const mgp_message::PulsarMessage & /*msg*/) { + return StreamSourceType::PULSAR; +} + +mgp_source_type StreamSourceTypeToMgpSourceType(const StreamSourceType type) { + switch (type) { + case StreamSourceType::KAFKA: + return mgp_source_type::KAFKA; + case StreamSourceType::PULSAR: + return mgp_source_type::PULSAR; + } +} + +} // namespace + +mgp_error mgp_message_source_type(mgp_message *message, mgp_source_type *result) { + return WrapExceptions( + [message] { + return std::visit(memgraph::utils::Overloaded{[](const auto &message) { + return StreamSourceTypeToMgpSourceType(MessageToStreamSourceType(message)); + }}, + message->msg); + }, + result); +} + +mgp_error mgp_message_payload(mgp_message *message, const char **result) { + return WrapExceptions( + [message] { + return std::visit( + memgraph::utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Payload().data(); }, + [](const mgp_message::PulsarMessage &msg) { return msg.Payload().data(); }, + [](const auto &msg) -> const char * { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), "payload"); + }}, + message->msg); + }, + result); +} + +mgp_error mgp_message_payload_size(mgp_message *message, size_t *result) { + return WrapExceptions( + [message] { + return std::visit( + memgraph::utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Payload().size(); }, + [](const mgp_message::PulsarMessage &msg) { return msg.Payload().size(); }, + [](const auto &msg) -> size_t { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), "payload_size"); + }}, + message->msg); + }, + result); +} + +mgp_error mgp_message_topic_name(mgp_message *message, const char **result) { + return WrapExceptions( + [message] { + return std::visit( + memgraph::utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->TopicName().data(); }, + [](const mgp_message::PulsarMessage &msg) { return msg.TopicName().data(); }, + [](const auto &msg) -> const char * { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), "topic_name"); + }}, + message->msg); + }, + result); +} + +mgp_error mgp_message_key(mgp_message *message, const char **result) { + return WrapExceptions( + [message] { + return std::visit( + memgraph::utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Key().data(); }, + [](const auto &msg) -> const char * { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), "key"); + }}, + message->msg); + }, + result); +} + +mgp_error mgp_message_key_size(mgp_message *message, size_t *result) { + return WrapExceptions( + [message] { + return std::visit( + memgraph::utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Key().size(); }, + [](const auto &msg) -> size_t { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), "key_size"); + }}, + message->msg); + }, + result); +} + +mgp_error mgp_message_timestamp(mgp_message *message, int64_t *result) { + return WrapExceptions( + [message] { + return std::visit( + memgraph::utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Timestamp(); }, + [](const auto &msg) -> int64_t { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), "timestamp"); + }}, + message->msg); + }, + result); +} + +mgp_error mgp_message_offset(struct mgp_message *message, int64_t *result) { + return WrapExceptions( + [message] { + return std::visit( + memgraph::utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Offset(); }, + [](const auto &msg) -> int64_t { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), "offset"); + }}, + message->msg); + }, + result); +} + +mgp_error mgp_messages_size(mgp_messages *messages, size_t *result) { + static_assert(noexcept(messages->messages.size())); + *result = messages->messages.size(); + return mgp_error::MGP_ERROR_NO_ERROR; +} + +mgp_error mgp_messages_at(mgp_messages *messages, size_t index, mgp_message **result) { + return WrapExceptions( + [messages, index] { + if (index >= Call<size_t>(mgp_messages_size, messages)) { + throw std::out_of_range("Message cannot be retrieved, because index exceeds messages' size!"); + } + return &messages->messages[index]; + }, + result); +} + +mgp_error mgp_module_add_transformation(mgp_module *module, const char *name, mgp_trans_cb cb) { + return WrapExceptions([=] { + if (!IsValidIdentifierName(name)) { + throw std::invalid_argument{fmt::format("Invalid transformation name: {}", name)}; + } + if (module->transformations.find(name) != module->transformations.end()) { + throw std::logic_error{fmt::format("Transformation already exists with name '{}'", name)}; + }; + auto *memory = module->transformations.get_allocator().GetMemoryResource(); + module->transformations.emplace(name, mgp_trans(name, cb, memory)); + }); +} + +mgp_error mgp_module_add_function(mgp_module *module, const char *name, mgp_func_cb cb, mgp_func **result) { + return WrapExceptions( + [=] { + if (!IsValidIdentifierName(name)) { + throw std::invalid_argument{fmt::format("Invalid function name: {}", name)}; + } + if (module->functions.find(name) != module->functions.end()) { + throw std::logic_error{fmt::format("Function with similar name already exists '{}'", name)}; + }; + auto *memory = module->functions.get_allocator().GetMemoryResource(); + + return &module->functions.emplace(name, mgp_func(name, cb, memory)).first->second; + }, + result); +} diff --git a/src/query/v2/procedure/mg_procedure_impl.hpp b/src/query/v2/procedure/mg_procedure_impl.hpp new file mode 100644 index 000000000..8bdd3c9b4 --- /dev/null +++ b/src/query/v2/procedure/mg_procedure_impl.hpp @@ -0,0 +1,926 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +/// Contains private (implementation) declarations and definitions for +/// mg_procedure.h +#pragma once + +#include "mg_procedure.h" + +#include <optional> +#include <ostream> + +#include "integrations/kafka/consumer.hpp" +#include "integrations/pulsar/consumer.hpp" +#include "query/v2/context.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/procedure/cypher_type_ptr.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/view.hpp" +#include "utils/memory.hpp" +#include "utils/pmr/map.hpp" +#include "utils/pmr/string.hpp" +#include "utils/pmr/vector.hpp" +#include "utils/temporal.hpp" +/// Wraps memory resource used in custom procedures. +/// +/// This should have been `using mgp_memory = memgraph::utils::MemoryResource`, but that's +/// not valid C++ because we have a forward declare `struct mgp_memory` in +/// mg_procedure.h +/// TODO: Make this extendable in C API, so that custom procedure writer can add +/// their own memory management wrappers. +struct mgp_memory { + memgraph::utils::MemoryResource *impl; +}; + +/// Immutable container of various values that appear in openCypher. +struct mgp_value { + /// Allocator type so that STL containers are aware that we need one. + using allocator_type = memgraph::utils::Allocator<mgp_value>; + + // Construct MGP_VALUE_TYPE_NULL. + explicit mgp_value(memgraph::utils::MemoryResource *) noexcept; + + mgp_value(bool, memgraph::utils::MemoryResource *) noexcept; + mgp_value(int64_t, memgraph::utils::MemoryResource *) noexcept; + mgp_value(double, memgraph::utils::MemoryResource *) noexcept; + /// @throw std::bad_alloc + mgp_value(const char *, memgraph::utils::MemoryResource *); + /// Take ownership of the mgp_list, MemoryResource must match. + mgp_value(mgp_list *, memgraph::utils::MemoryResource *) noexcept; + /// Take ownership of the mgp_map, MemoryResource must match. + mgp_value(mgp_map *, memgraph::utils::MemoryResource *) noexcept; + /// Take ownership of the mgp_vertex, MemoryResource must match. + mgp_value(mgp_vertex *, memgraph::utils::MemoryResource *) noexcept; + /// Take ownership of the mgp_edge, MemoryResource must match. + mgp_value(mgp_edge *, memgraph::utils::MemoryResource *) noexcept; + /// Take ownership of the mgp_path, MemoryResource must match. + mgp_value(mgp_path *, memgraph::utils::MemoryResource *) noexcept; + + mgp_value(mgp_date *, memgraph::utils::MemoryResource *) noexcept; + mgp_value(mgp_local_time *, memgraph::utils::MemoryResource *) noexcept; + mgp_value(mgp_local_date_time *, memgraph::utils::MemoryResource *) noexcept; + mgp_value(mgp_duration *, memgraph::utils::MemoryResource *) noexcept; + + /// Construct by copying memgraph::query::v2::TypedValue using memgraph::utils::MemoryResource. + /// mgp_graph is needed to construct mgp_vertex and mgp_edge. + /// @throw std::bad_alloc + mgp_value(const memgraph::query::v2::TypedValue &, mgp_graph *, memgraph::utils::MemoryResource *); + + /// Construct by copying memgraph::storage::v3::PropertyValue using memgraph::utils::MemoryResource. + /// @throw std::bad_alloc + mgp_value(const memgraph::storage::v3::PropertyValue &, memgraph::utils::MemoryResource *); + + /// Copy construction without memgraph::utils::MemoryResource is not allowed. + mgp_value(const mgp_value &) = delete; + + /// Copy construct using given memgraph::utils::MemoryResource. + /// @throw std::bad_alloc + mgp_value(const mgp_value &, memgraph::utils::MemoryResource *); + + /// Move construct using given memgraph::utils::MemoryResource. + /// @throw std::bad_alloc if MemoryResource is different, so we cannot move. + mgp_value(mgp_value &&, memgraph::utils::MemoryResource *); + + /// Move construct, memgraph::utils::MemoryResource is inherited. + mgp_value(mgp_value &&other) noexcept : mgp_value(other, other.memory) {} + + /// Copy-assignment is not allowed to preserve immutability. + mgp_value &operator=(const mgp_value &) = delete; + + /// Move-assignment is not allowed to preserve immutability. + mgp_value &operator=(mgp_value &&) = delete; + + ~mgp_value() noexcept; + + memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; } + + mgp_value_type type; + memgraph::utils::MemoryResource *memory; + + union { + bool bool_v; + int64_t int_v; + double double_v; + memgraph::utils::pmr::string string_v; + // We use pointers so that taking ownership via C API is easier. Besides, + // mgp_map cannot use incomplete mgp_value type, because that would be + // undefined behaviour. + mgp_list *list_v; + mgp_map *map_v; + mgp_vertex *vertex_v; + mgp_edge *edge_v; + mgp_path *path_v; + mgp_date *date_v; + mgp_local_time *local_time_v; + mgp_local_date_time *local_date_time_v; + mgp_duration *duration_v; + }; +}; + +inline memgraph::utils::DateParameters MapDateParameters(const mgp_date_parameters *parameters) { + return {.year = parameters->year, .month = parameters->month, .day = parameters->day}; +} + +struct mgp_date { + /// Allocator type so that STL containers are aware that we need one. + /// We don't actually need this, but it simplifies the C API, because we store + /// the allocator which was used to allocate `this`. + using allocator_type = memgraph::utils::Allocator<mgp_date>; + + // Hopefully memgraph::utils::Date copy constructor remains noexcept, so that we can + // have everything noexcept here. + static_assert(std::is_nothrow_copy_constructible_v<memgraph::utils::Date>); + + mgp_date(const memgraph::utils::Date &date, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), date(date) {} + + mgp_date(const std::string_view string, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), date(memgraph::utils::ParseDateParameters(string).first) {} + + mgp_date(const mgp_date_parameters *parameters, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), date(MapDateParameters(parameters)) {} + + mgp_date(const int64_t microseconds, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), date(microseconds) {} + + mgp_date(const mgp_date &other, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), date(other.date) {} + + mgp_date(mgp_date &&other, memgraph::utils::MemoryResource *memory) noexcept : memory(memory), date(other.date) {} + + mgp_date(mgp_date &&other) noexcept : memory(other.memory), date(other.date) {} + + /// Copy construction without memgraph::utils::MemoryResource is not allowed. + mgp_date(const mgp_date &) = delete; + + mgp_date &operator=(const mgp_date &) = delete; + mgp_date &operator=(mgp_date &&) = delete; + + ~mgp_date() = default; + + memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; } + + memgraph::utils::MemoryResource *memory; + memgraph::utils::Date date; +}; + +inline memgraph::utils::LocalTimeParameters MapLocalTimeParameters(const mgp_local_time_parameters *parameters) { + return {.hour = parameters->hour, + .minute = parameters->minute, + .second = parameters->second, + .millisecond = parameters->millisecond, + .microsecond = parameters->microsecond}; +} + +struct mgp_local_time { + /// Allocator type so that STL containers are aware that we need one. + /// We don't actually need this, but it simplifies the C API, because we store + /// the allocator which was used to allocate `this`. + using allocator_type = memgraph::utils::Allocator<mgp_local_time>; + + // Hopefully memgraph::utils::LocalTime copy constructor remains noexcept, so that we can + // have everything noexcept here. + static_assert(std::is_nothrow_copy_constructible_v<memgraph::utils::LocalTime>); + + mgp_local_time(const std::string_view string, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), local_time(memgraph::utils::ParseLocalTimeParameters(string).first) {} + + mgp_local_time(const mgp_local_time_parameters *parameters, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), local_time(MapLocalTimeParameters(parameters)) {} + + mgp_local_time(const memgraph::utils::LocalTime &local_time, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), local_time(local_time) {} + + mgp_local_time(const int64_t microseconds, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), local_time(microseconds) {} + + mgp_local_time(const mgp_local_time &other, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), local_time(other.local_time) {} + + mgp_local_time(mgp_local_time &&other, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), local_time(other.local_time) {} + + mgp_local_time(mgp_local_time &&other) noexcept : memory(other.memory), local_time(other.local_time) {} + + /// Copy construction without memgraph::utils::MemoryResource is not allowed. + mgp_local_time(const mgp_local_time &) = delete; + + mgp_local_time &operator=(const mgp_local_time &) = delete; + mgp_local_time &operator=(mgp_local_time &&) = delete; + + ~mgp_local_time() = default; + + memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; } + + memgraph::utils::MemoryResource *memory; + memgraph::utils::LocalTime local_time; +}; + +inline memgraph::utils::LocalDateTime CreateLocalDateTimeFromString(const std::string_view string) { + const auto &[date_parameters, local_time_parameters] = memgraph::utils::ParseLocalDateTimeParameters(string); + return memgraph::utils::LocalDateTime{date_parameters, local_time_parameters}; +} + +struct mgp_local_date_time { + /// Allocator type so that STL containers are aware that we need one. + /// We don't actually need this, but it simplifies the C API, because we store + /// the allocator which was used to allocate `this`. + using allocator_type = memgraph::utils::Allocator<mgp_local_date_time>; + + // Hopefully memgraph::utils::LocalDateTime copy constructor remains noexcept, so that we can + // have everything noexcept here. + static_assert(std::is_nothrow_copy_constructible_v<memgraph::utils::LocalDateTime>); + + mgp_local_date_time(const memgraph::utils::LocalDateTime &local_date_time, + memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), local_date_time(local_date_time) {} + + mgp_local_date_time(const std::string_view string, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), local_date_time(CreateLocalDateTimeFromString(string)) {} + + mgp_local_date_time(const mgp_local_date_time_parameters *parameters, + memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), + local_date_time(MapDateParameters(parameters->date_parameters), + MapLocalTimeParameters(parameters->local_time_parameters)) {} + + mgp_local_date_time(const int64_t microseconds, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), local_date_time(microseconds) {} + + mgp_local_date_time(const mgp_local_date_time &other, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), local_date_time(other.local_date_time) {} + + mgp_local_date_time(mgp_local_date_time &&other, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), local_date_time(other.local_date_time) {} + + mgp_local_date_time(mgp_local_date_time &&other) noexcept + : memory(other.memory), local_date_time(other.local_date_time) {} + + /// Copy construction without memgraph::utils::MemoryResource is not allowed. + mgp_local_date_time(const mgp_local_date_time &) = delete; + + mgp_local_date_time &operator=(const mgp_local_date_time &) = delete; + mgp_local_date_time &operator=(mgp_local_date_time &&) = delete; + + ~mgp_local_date_time() = default; + + memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; } + + memgraph::utils::MemoryResource *memory; + memgraph::utils::LocalDateTime local_date_time; +}; + +inline memgraph::utils::DurationParameters MapDurationParameters(const mgp_duration_parameters *parameters) { + return {.day = parameters->day, + .hour = parameters->hour, + .minute = parameters->minute, + .second = parameters->second, + .millisecond = parameters->millisecond, + .microsecond = parameters->microsecond}; +} + +struct mgp_duration { + /// Allocator type so that STL containers are aware that we need one. + /// We don't actually need this, but it simplifies the C API, because we store + /// the allocator which was used to allocate `this`. + using allocator_type = memgraph::utils::Allocator<mgp_duration>; + + // Hopefully memgraph::utils::Duration copy constructor remains noexcept, so that we can + // have everything noexcept here. + static_assert(std::is_nothrow_copy_constructible_v<memgraph::utils::Duration>); + + mgp_duration(const std::string_view string, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), duration(memgraph::utils::ParseDurationParameters(string)) {} + + mgp_duration(const mgp_duration_parameters *parameters, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), duration(MapDurationParameters(parameters)) {} + + mgp_duration(const memgraph::utils::Duration &duration, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), duration(duration) {} + + mgp_duration(const int64_t microseconds, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), duration(microseconds) {} + + mgp_duration(const mgp_duration &other, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), duration(other.duration) {} + + mgp_duration(mgp_duration &&other, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), duration(other.duration) {} + + mgp_duration(mgp_duration &&other) noexcept : memory(other.memory), duration(other.duration) {} + + /// Copy construction without memgraph::utils::MemoryResource is not allowed. + mgp_duration(const mgp_duration &) = delete; + + mgp_duration &operator=(const mgp_duration &) = delete; + mgp_duration &operator=(mgp_duration &&) = delete; + + ~mgp_duration() = default; + + memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; } + + memgraph::utils::MemoryResource *memory; + memgraph::utils::Duration duration; +}; + +struct mgp_list { + /// Allocator type so that STL containers are aware that we need one. + using allocator_type = memgraph::utils::Allocator<mgp_list>; + + explicit mgp_list(memgraph::utils::MemoryResource *memory) : elems(memory) {} + + mgp_list(memgraph::utils::pmr::vector<mgp_value> &&elems, memgraph::utils::MemoryResource *memory) + : elems(std::move(elems), memory) {} + + mgp_list(const mgp_list &other, memgraph::utils::MemoryResource *memory) : elems(other.elems, memory) {} + + mgp_list(mgp_list &&other, memgraph::utils::MemoryResource *memory) : elems(std::move(other.elems), memory) {} + + mgp_list(mgp_list &&other) noexcept : elems(std::move(other.elems)) {} + + /// Copy construction without memgraph::utils::MemoryResource is not allowed. + mgp_list(const mgp_list &) = delete; + + mgp_list &operator=(const mgp_list &) = delete; + mgp_list &operator=(mgp_list &&) = delete; + + ~mgp_list() = default; + + memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { + return elems.get_allocator().GetMemoryResource(); + } + + // C++17 vector can work with incomplete type. + memgraph::utils::pmr::vector<mgp_value> elems; +}; + +struct mgp_map { + /// Allocator type so that STL containers are aware that we need one. + using allocator_type = memgraph::utils::Allocator<mgp_map>; + + explicit mgp_map(memgraph::utils::MemoryResource *memory) : items(memory) {} + + mgp_map(memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_value> &&items, + memgraph::utils::MemoryResource *memory) + : items(std::move(items), memory) {} + + mgp_map(const mgp_map &other, memgraph::utils::MemoryResource *memory) : items(other.items, memory) {} + + mgp_map(mgp_map &&other, memgraph::utils::MemoryResource *memory) : items(std::move(other.items), memory) {} + + mgp_map(mgp_map &&other) noexcept : items(std::move(other.items)) {} + + /// Copy construction without memgraph::utils::MemoryResource is not allowed. + mgp_map(const mgp_map &) = delete; + + mgp_map &operator=(const mgp_map &) = delete; + mgp_map &operator=(mgp_map &&) = delete; + + ~mgp_map() = default; + + memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { + return items.get_allocator().GetMemoryResource(); + } + + // Unfortunately using incomplete type with map is undefined, so mgp_map + // needs to be defined after mgp_value. + memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_value> items; +}; + +struct mgp_map_item { + const char *key; + mgp_value *value; +}; + +struct mgp_map_items_iterator { + using allocator_type = memgraph::utils::Allocator<mgp_map_items_iterator>; + + mgp_map_items_iterator(mgp_map *map, memgraph::utils::MemoryResource *memory) + : memory(memory), map(map), current_it(map->items.begin()) { + if (current_it != map->items.end()) { + current.key = current_it->first.c_str(); + current.value = ¤t_it->second; + } + } + + mgp_map_items_iterator(const mgp_map_items_iterator &) = delete; + mgp_map_items_iterator(mgp_map_items_iterator &&) = delete; + mgp_map_items_iterator &operator=(const mgp_map_items_iterator &) = delete; + mgp_map_items_iterator &operator=(mgp_map_items_iterator &&) = delete; + + ~mgp_map_items_iterator() = default; + + memgraph::utils::MemoryResource *GetMemoryResource() const { return memory; } + + memgraph::utils::MemoryResource *memory; + mgp_map *map; + decltype(map->items.begin()) current_it; + mgp_map_item current; +}; + +struct mgp_vertex { + /// Allocator type so that STL containers are aware that we need one. + /// We don't actually need this, but it simplifies the C API, because we store + /// the allocator which was used to allocate `this`. + using allocator_type = memgraph::utils::Allocator<mgp_vertex>; + + // Hopefully VertexAccessor copy constructor remains noexcept, so that we can + // have everything noexcept here. + static_assert(std::is_nothrow_copy_constructible_v<memgraph::query::v2::VertexAccessor>); + + mgp_vertex(memgraph::query::v2::VertexAccessor v, mgp_graph *graph, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), impl(v), graph(graph) {} + + mgp_vertex(const mgp_vertex &other, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), impl(other.impl), graph(other.graph) {} + + mgp_vertex(mgp_vertex &&other, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), impl(other.impl), graph(other.graph) {} + + mgp_vertex(mgp_vertex &&other) noexcept : memory(other.memory), impl(other.impl), graph(other.graph) {} + + /// Copy construction without memgraph::utils::MemoryResource is not allowed. + mgp_vertex(const mgp_vertex &) = delete; + + mgp_vertex &operator=(const mgp_vertex &) = delete; + mgp_vertex &operator=(mgp_vertex &&) = delete; + + bool operator==(const mgp_vertex &other) const noexcept { return this->impl == other.impl; } + bool operator!=(const mgp_vertex &other) const noexcept { return !(*this == other); }; + + ~mgp_vertex() = default; + + memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; } + + memgraph::utils::MemoryResource *memory; + memgraph::query::v2::VertexAccessor impl; + mgp_graph *graph; +}; + +struct mgp_edge { + /// Allocator type so that STL containers are aware that we need one. + /// We don't actually need this, but it simplifies the C API, because we store + /// the allocator which was used to allocate `this`. + using allocator_type = memgraph::utils::Allocator<mgp_edge>; + + // Hopefully EdgeAccessor copy constructor remains noexcept, so that we can + // have everything noexcept here. + static_assert(std::is_nothrow_copy_constructible_v<memgraph::query::v2::EdgeAccessor>); + + static mgp_edge *Copy(const mgp_edge &edge, mgp_memory &memory); + + mgp_edge(const memgraph::query::v2::EdgeAccessor &impl, mgp_graph *graph, + memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), impl(impl), from(impl.From(), graph, memory), to(impl.To(), graph, memory) {} + + mgp_edge(const mgp_edge &other, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), impl(other.impl), from(other.from, memory), to(other.to, memory) {} + + mgp_edge(mgp_edge &&other, memgraph::utils::MemoryResource *memory) noexcept + : memory(other.memory), impl(other.impl), from(std::move(other.from), memory), to(std::move(other.to), memory) {} + + mgp_edge(mgp_edge &&other) noexcept + : memory(other.memory), impl(other.impl), from(std::move(other.from)), to(std::move(other.to)) {} + + /// Copy construction without memgraph::utils::MemoryResource is not allowed. + mgp_edge(const mgp_edge &) = delete; + + mgp_edge &operator=(const mgp_edge &) = delete; + mgp_edge &operator=(mgp_edge &&) = delete; + ~mgp_edge() = default; + + bool operator==(const mgp_edge &other) const noexcept { return this->impl == other.impl; } + bool operator!=(const mgp_edge &other) const noexcept { return !(*this == other); }; + + memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; } + + memgraph::utils::MemoryResource *memory; + memgraph::query::v2::EdgeAccessor impl; + mgp_vertex from; + mgp_vertex to; +}; + +struct mgp_path { + /// Allocator type so that STL containers are aware that we need one. + using allocator_type = memgraph::utils::Allocator<mgp_path>; + + explicit mgp_path(memgraph::utils::MemoryResource *memory) : vertices(memory), edges(memory) {} + + mgp_path(const mgp_path &other, memgraph::utils::MemoryResource *memory) + : vertices(other.vertices, memory), edges(other.edges, memory) {} + + mgp_path(mgp_path &&other, memgraph::utils::MemoryResource *memory) + : vertices(std::move(other.vertices), memory), edges(std::move(other.edges), memory) {} + + mgp_path(mgp_path &&other) noexcept : vertices(std::move(other.vertices)), edges(std::move(other.edges)) {} + + /// Copy construction without memgraph::utils::MemoryResource is not allowed. + mgp_path(const mgp_path &) = delete; + + mgp_path &operator=(const mgp_path &) = delete; + mgp_path &operator=(mgp_path &&) = delete; + + ~mgp_path() = default; + + memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { + return vertices.get_allocator().GetMemoryResource(); + } + + memgraph::utils::pmr::vector<mgp_vertex> vertices; + memgraph::utils::pmr::vector<mgp_edge> edges; +}; + +struct mgp_result_record { + /// Result record signature as defined for mgp_proc. + const memgraph::utils::pmr::map<memgraph::utils::pmr::string, + std::pair<const memgraph::query::v2::procedure::CypherType *, bool>> *signature; + memgraph::utils::pmr::map<memgraph::utils::pmr::string, memgraph::query::v2::TypedValue> values; +}; + +struct mgp_result { + explicit mgp_result( + const memgraph::utils::pmr::map<memgraph::utils::pmr::string, + std::pair<const memgraph::query::v2::procedure::CypherType *, bool>> *signature, + memgraph::utils::MemoryResource *mem) + : signature(signature), rows(mem) {} + + /// Result record signature as defined for mgp_proc. + const memgraph::utils::pmr::map<memgraph::utils::pmr::string, + std::pair<const memgraph::query::v2::procedure::CypherType *, bool>> *signature; + memgraph::utils::pmr::vector<mgp_result_record> rows; + std::optional<memgraph::utils::pmr::string> error_msg; +}; + +struct mgp_func_result { + mgp_func_result() {} + /// Return Magic function result. If user forgets it, the error is raised + std::optional<memgraph::query::v2::TypedValue> value; + /// Return Magic function result with potential error + std::optional<memgraph::utils::pmr::string> error_msg; +}; + +struct mgp_graph { + memgraph::query::v2::DbAccessor *impl; + memgraph::storage::v3::View view; + // TODO: Merge `mgp_graph` and `mgp_memory` into a single `mgp_context`. The + // `ctx` field is out of place here. + memgraph::query::v2::ExecutionContext *ctx; + + static mgp_graph WritableGraph(memgraph::query::v2::DbAccessor &acc, memgraph::storage::v3::View view, + memgraph::query::v2::ExecutionContext &ctx) { + return mgp_graph{&acc, view, &ctx}; + } + + static mgp_graph NonWritableGraph(memgraph::query::v2::DbAccessor &acc, memgraph::storage::v3::View view) { + return mgp_graph{&acc, view, nullptr}; + } +}; + +// Prevents user to use ExecutionContext in writable callables +struct mgp_func_context { + memgraph::query::v2::DbAccessor *impl; + memgraph::storage::v3::View view; +}; +struct mgp_properties_iterator { + using allocator_type = memgraph::utils::Allocator<mgp_properties_iterator>; + + // Define members at the start because we use decltype a lot here, so members + // need to be visible in method definitions. + + memgraph::utils::MemoryResource *memory; + mgp_graph *graph; + std::remove_reference_t<decltype(*std::declval<memgraph::query::v2::VertexAccessor>().Properties(graph->view))> pvs; + decltype(pvs.begin()) current_it; + std::optional<std::pair<memgraph::utils::pmr::string, mgp_value>> current; + mgp_property property{nullptr, nullptr}; + + // Construct with no properties. + explicit mgp_properties_iterator(mgp_graph *graph, memgraph::utils::MemoryResource *memory) + : memory(memory), graph(graph), current_it(pvs.begin()) {} + + // May throw who the #$@! knows what because PropertyValueStore doesn't + // document what it throws, and it may surely throw some piece of !@#$ + // exception because it's built on top of STL and other libraries. + mgp_properties_iterator(mgp_graph *graph, decltype(pvs) pvs, memgraph::utils::MemoryResource *memory) + : memory(memory), graph(graph), pvs(std::move(pvs)), current_it(this->pvs.begin()) { + if (current_it != this->pvs.end()) { + current.emplace(memgraph::utils::pmr::string(graph->impl->PropertyToName(current_it->first), memory), + mgp_value(current_it->second, memory)); + property.name = current->first.c_str(); + property.value = ¤t->second; + } + } + + mgp_properties_iterator(const mgp_properties_iterator &) = delete; + mgp_properties_iterator(mgp_properties_iterator &&) = delete; + + mgp_properties_iterator &operator=(const mgp_properties_iterator &) = delete; + mgp_properties_iterator &operator=(mgp_properties_iterator &&) = delete; + + ~mgp_properties_iterator() = default; + + memgraph::utils::MemoryResource *GetMemoryResource() const { return memory; } +}; + +struct mgp_edges_iterator { + using allocator_type = memgraph::utils::Allocator<mgp_edges_iterator>; + + // Hopefully mgp_vertex copy constructor remains noexcept, so that we can + // have everything noexcept here. + static_assert(std::is_nothrow_constructible_v<mgp_vertex, const mgp_vertex &, memgraph::utils::MemoryResource *>); + + mgp_edges_iterator(const mgp_vertex &v, memgraph::utils::MemoryResource *memory) noexcept + : memory(memory), source_vertex(v, memory) {} + + mgp_edges_iterator(mgp_edges_iterator &&other) noexcept + : memory(other.memory), + source_vertex(std::move(other.source_vertex)), + in(std::move(other.in)), + in_it(std::move(other.in_it)), + out(std::move(other.out)), + out_it(std::move(other.out_it)), + current_e(std::move(other.current_e)) {} + + mgp_edges_iterator(const mgp_edges_iterator &) = delete; + mgp_edges_iterator &operator=(const mgp_edges_iterator &) = delete; + mgp_edges_iterator &operator=(mgp_edges_iterator &&) = delete; + + ~mgp_edges_iterator() = default; + + memgraph::utils::MemoryResource *GetMemoryResource() const { return memory; } + + memgraph::utils::MemoryResource *memory; + mgp_vertex source_vertex; + std::optional<std::remove_reference_t<decltype(*source_vertex.impl.InEdges(source_vertex.graph->view))>> in; + std::optional<decltype(in->begin())> in_it; + std::optional<std::remove_reference_t<decltype(*source_vertex.impl.OutEdges(source_vertex.graph->view))>> out; + std::optional<decltype(out->begin())> out_it; + std::optional<mgp_edge> current_e; +}; + +struct mgp_vertices_iterator { + using allocator_type = memgraph::utils::Allocator<mgp_vertices_iterator>; + + /// @throw anything VerticesIterable may throw + mgp_vertices_iterator(mgp_graph *graph, memgraph::utils::MemoryResource *memory) + : memory(memory), graph(graph), vertices(graph->impl->Vertices(graph->view)), current_it(vertices.begin()) { + if (current_it != vertices.end()) { + current_v.emplace(*current_it, graph, memory); + } + } + + memgraph::utils::MemoryResource *GetMemoryResource() const { return memory; } + + memgraph::utils::MemoryResource *memory; + mgp_graph *graph; + decltype(graph->impl->Vertices(graph->view)) vertices; + decltype(vertices.begin()) current_it; + std::optional<mgp_vertex> current_v; +}; + +struct mgp_type { + memgraph::query::v2::procedure::CypherTypePtr impl; +}; + +struct ProcedureInfo { + bool is_write = false; + std::optional<memgraph::query::v2::AuthQuery::Privilege> required_privilege = std::nullopt; +}; +struct mgp_proc { + using allocator_type = memgraph::utils::Allocator<mgp_proc>; + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_proc(const char *name, mgp_proc_cb cb, memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {}) + : name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {} + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_proc(const char *name, std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb, + memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {}) + : name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {} + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_proc(const std::string_view name, std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb, + memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {}) + : name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {} + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_proc(const mgp_proc &other, memgraph::utils::MemoryResource *memory) + : name(other.name, memory), + cb(other.cb), + args(other.args, memory), + opt_args(other.opt_args, memory), + results(other.results, memory), + info(other.info) {} + + mgp_proc(mgp_proc &&other, memgraph::utils::MemoryResource *memory) + : name(std::move(other.name), memory), + cb(std::move(other.cb)), + args(std::move(other.args), memory), + opt_args(std::move(other.opt_args), memory), + results(std::move(other.results), memory), + info(other.info) {} + + mgp_proc(const mgp_proc &other) = default; + mgp_proc(mgp_proc &&other) = default; + + mgp_proc &operator=(const mgp_proc &) = delete; + mgp_proc &operator=(mgp_proc &&) = delete; + + ~mgp_proc() = default; + + /// Name of the procedure. + memgraph::utils::pmr::string name; + /// Entry-point for the procedure. + std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb; + /// Required, positional arguments as a (name, type) pair. + memgraph::utils::pmr::vector< + std::pair<memgraph::utils::pmr::string, const memgraph::query::v2::procedure::CypherType *>> + args; + /// Optional positional arguments as a (name, type, default_value) tuple. + memgraph::utils::pmr::vector< + std::tuple<memgraph::utils::pmr::string, const memgraph::query::v2::procedure::CypherType *, + memgraph::query::v2::TypedValue>> + opt_args; + /// Fields this procedure returns, as a (name -> (type, is_deprecated)) map. + memgraph::utils::pmr::map<memgraph::utils::pmr::string, + std::pair<const memgraph::query::v2::procedure::CypherType *, bool>> + results; + ProcedureInfo info; +}; + +struct mgp_trans { + using allocator_type = memgraph::utils::Allocator<mgp_trans>; + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_trans(const char *name, mgp_trans_cb cb, memgraph::utils::MemoryResource *memory) + : name(name, memory), cb(cb), results(memory) {} + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_trans(const char *name, std::function<void(mgp_messages *, mgp_graph *, mgp_result *, mgp_memory *)> cb, + memgraph::utils::MemoryResource *memory) + : name(name, memory), cb(cb), results(memory) {} + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_trans(const mgp_trans &other, memgraph::utils::MemoryResource *memory) + : name(other.name, memory), cb(other.cb), results(other.results) {} + + mgp_trans(mgp_trans &&other, memgraph::utils::MemoryResource *memory) + : name(std::move(other.name), memory), cb(std::move(other.cb)), results(std::move(other.results)) {} + + mgp_trans(const mgp_trans &other) = default; + mgp_trans(mgp_trans &&other) = default; + + mgp_trans &operator=(const mgp_trans &) = delete; + mgp_trans &operator=(mgp_trans &&) = delete; + + ~mgp_trans() = default; + + /// Name of the transformation. + memgraph::utils::pmr::string name; + /// Entry-point for the transformation. + std::function<void(mgp_messages *, mgp_graph *, mgp_result *, mgp_memory *)> cb; + /// Fields this transformation returns. + memgraph::utils::pmr::map<memgraph::utils::pmr::string, + std::pair<const memgraph::query::v2::procedure::CypherType *, bool>> + results; +}; + +struct mgp_func { + using allocator_type = memgraph::utils::Allocator<mgp_func>; + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_func(const char *name, mgp_func_cb cb, memgraph::utils::MemoryResource *memory) + : name(name, memory), cb(cb), args(memory), opt_args(memory) {} + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_func(const char *name, std::function<void(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *)> cb, + memgraph::utils::MemoryResource *memory) + : name(name, memory), cb(cb), args(memory), opt_args(memory) {} + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_func(const mgp_func &other, memgraph::utils::MemoryResource *memory) + : name(other.name, memory), cb(other.cb), args(other.args, memory), opt_args(other.opt_args, memory) {} + + mgp_func(mgp_func &&other, memgraph::utils::MemoryResource *memory) + : name(std::move(other.name), memory), + cb(std::move(other.cb)), + args(std::move(other.args), memory), + opt_args(std::move(other.opt_args), memory) {} + + mgp_func(const mgp_func &other) = default; + mgp_func(mgp_func &&other) = default; + + mgp_func &operator=(const mgp_func &) = delete; + mgp_func &operator=(mgp_func &&) = delete; + + ~mgp_func() = default; + + /// Name of the function. + memgraph::utils::pmr::string name; + /// Entry-point for the function. + std::function<void(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *)> cb; + /// Required, positional arguments as a (name, type) pair. + memgraph::utils::pmr::vector< + std::pair<memgraph::utils::pmr::string, const memgraph::query::v2::procedure::CypherType *>> + args; + /// Optional positional arguments as a (name, type, default_value) tuple. + memgraph::utils::pmr::vector< + std::tuple<memgraph::utils::pmr::string, const memgraph::query::v2::procedure::CypherType *, + memgraph::query::v2::TypedValue>> + opt_args; +}; + +mgp_error MgpTransAddFixedResult(mgp_trans *trans) noexcept; + +struct mgp_module { + using allocator_type = memgraph::utils::Allocator<mgp_module>; + + explicit mgp_module(memgraph::utils::MemoryResource *memory) + : procedures(memory), transformations(memory), functions(memory) {} + + mgp_module(const mgp_module &other, memgraph::utils::MemoryResource *memory) + : procedures(other.procedures, memory), + transformations(other.transformations, memory), + functions(other.functions, memory) {} + + mgp_module(mgp_module &&other, memgraph::utils::MemoryResource *memory) + : procedures(std::move(other.procedures), memory), + transformations(std::move(other.transformations), memory), + functions(std::move(other.functions), memory) {} + + mgp_module(const mgp_module &) = default; + mgp_module(mgp_module &&) = default; + + mgp_module &operator=(const mgp_module &) = delete; + mgp_module &operator=(mgp_module &&) = delete; + + ~mgp_module() = default; + + memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_proc> procedures; + memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_trans> transformations; + memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_func> functions; +}; + +namespace memgraph::query::v2::procedure { + +/// @throw std::bad_alloc +/// @throw std::length_error +/// @throw anything std::ostream::operator<< may throw. +void PrintProcSignature(const mgp_proc &, std::ostream *); + +/// @throw std::bad_alloc +/// @throw std::length_error +/// @throw anything std::ostream::operator<< may throw. +void PrintFuncSignature(const mgp_func &, std::ostream &); + +bool IsValidIdentifierName(const char *name); + +} // namespace memgraph::query::v2::procedure + +struct mgp_message { + explicit mgp_message(const memgraph::integrations::kafka::Message &message) : msg{&message} {} + explicit mgp_message(const memgraph::integrations::pulsar::Message &message) : msg{message} {} + + using KafkaMessage = const memgraph::integrations::kafka::Message *; + using PulsarMessage = memgraph::integrations::pulsar::Message; + std::variant<KafkaMessage, PulsarMessage> msg; +}; + +struct mgp_messages { + using allocator_type = memgraph::utils::Allocator<mgp_messages>; + using storage_type = memgraph::utils::pmr::vector<mgp_message>; + explicit mgp_messages(storage_type &&storage) : messages(std::move(storage)) {} + + mgp_messages(const mgp_messages &) = delete; + mgp_messages &operator=(const mgp_messages &) = delete; + + mgp_messages(mgp_messages &&) = delete; + mgp_messages &operator=(mgp_messages &&) = delete; + + ~mgp_messages() = default; + + storage_type messages; +}; + +memgraph::query::v2::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory); diff --git a/src/query/v2/procedure/module.cpp b/src/query/v2/procedure/module.cpp new file mode 100644 index 000000000..038df8fa8 --- /dev/null +++ b/src/query/v2/procedure/module.cpp @@ -0,0 +1,1258 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/procedure/module.hpp" + +#include <filesystem> +#include <optional> + +extern "C" { +#include <dlfcn.h> +} + +#include <fmt/format.h> +#include <unistd.h> + +#include "py/py.hpp" +#include "query/v2/procedure/mg_procedure_helpers.hpp" +#include "query/v2/procedure/py_module.hpp" +#include "utils/file.hpp" +#include "utils/logging.hpp" +#include "utils/memory.hpp" +#include "utils/message.hpp" +#include "utils/pmr/vector.hpp" +#include "utils/string.hpp" + +namespace memgraph::query::v2::procedure { + +ModuleRegistry gModuleRegistry; + +Module::~Module() {} + +class BuiltinModule final : public Module { + public: + BuiltinModule(); + ~BuiltinModule() override; + BuiltinModule(const BuiltinModule &) = delete; + BuiltinModule(BuiltinModule &&) = delete; + BuiltinModule &operator=(const BuiltinModule &) = delete; + BuiltinModule &operator=(BuiltinModule &&) = delete; + + bool Close() override; + + const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override; + + const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override; + + const std::map<std::string, mgp_func, std::less<>> *Functions() const override; + + void AddProcedure(std::string_view name, mgp_proc proc); + + void AddTransformation(std::string_view name, mgp_trans trans); + + std::optional<std::filesystem::path> Path() const override { return std::nullopt; } + + private: + /// Registered procedures + std::map<std::string, mgp_proc, std::less<>> procedures_; + std::map<std::string, mgp_trans, std::less<>> transformations_; + std::map<std::string, mgp_func, std::less<>> functions_; +}; + +BuiltinModule::BuiltinModule() {} + +BuiltinModule::~BuiltinModule() {} + +bool BuiltinModule::Close() { return true; } + +const std::map<std::string, mgp_proc, std::less<>> *BuiltinModule::Procedures() const { return &procedures_; } + +const std::map<std::string, mgp_trans, std::less<>> *BuiltinModule::Transformations() const { + return &transformations_; +} +const std::map<std::string, mgp_func, std::less<>> *BuiltinModule::Functions() const { return &functions_; } + +void BuiltinModule::AddProcedure(std::string_view name, mgp_proc proc) { procedures_.emplace(name, std::move(proc)); } + +void BuiltinModule::AddTransformation(std::string_view name, mgp_trans trans) { + transformations_.emplace(name, std::move(trans)); +} + +namespace { + +auto WithUpgradedLock(auto *lock, const auto &function) { + lock->unlock_shared(); + utils::OnScopeExit shared_lock{[&] { lock->lock_shared(); }}; + function(); +}; + +void RegisterMgLoad(ModuleRegistry *module_registry, utils::RWLock *lock, BuiltinModule *module) { + // Loading relies on the fact that regular procedure invocation through + // CallProcedureCursor::Pull takes ModuleRegistry::lock_ with READ access. To + // load modules we have to upgrade our READ access to WRITE access, + // therefore we release the READ lock and invoke the load function which + // takes the WRITE lock. Obviously, some other thread may take a READ or WRITE + // lock during our transition when we hold no such lock. In this case it is + // fine, because our builtin module cannot be unloaded and we are ok with + // using the new state of module_registry when we manage to acquire the lock + // we desire. Note, deadlock between threads should not be possible, because a + // single thread may only take either a READ or a WRITE lock, it's not + // possible for a thread to hold both. If a thread tries to do that, it will + // deadlock immediately (no other thread needs to do anything). + auto load_all_cb = [module_registry, lock](mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result * /*result*/, + mgp_memory * /*memory*/) { + WithUpgradedLock(lock, [&]() { module_registry->UnloadAndLoadModulesFromDirectories(); }); + }; + mgp_proc load_all("load_all", load_all_cb, utils::NewDeleteResource()); + module->AddProcedure("load_all", std::move(load_all)); + auto load_cb = [module_registry, lock](mgp_list *args, mgp_graph * /*graph*/, mgp_result *result, + mgp_memory * /*memory*/) { + MG_ASSERT(Call<size_t>(mgp_list_size, args) == 1U, "Should have been type checked already"); + auto *arg = Call<mgp_value *>(mgp_list_at, args, 0); + MG_ASSERT(CallBool(mgp_value_is_string, arg), "Should have been type checked already"); + bool succ = false; + WithUpgradedLock(lock, [&]() { + const char *arg_as_string{nullptr}; + if (const auto err = mgp_value_get_string(arg, &arg_as_string); err != mgp_error::MGP_ERROR_NO_ERROR) { + succ = false; + } else { + succ = module_registry->LoadOrReloadModuleFromName(arg_as_string); + } + }); + if (!succ) { + MG_ASSERT(mgp_result_set_error_msg(result, "Failed to (re)load the module.") == mgp_error::MGP_ERROR_NO_ERROR); + } + }; + mgp_proc load("load", load_cb, utils::NewDeleteResource()); + MG_ASSERT(mgp_proc_add_arg(&load, "module_name", Call<mgp_type *>(mgp_type_string)) == mgp_error::MGP_ERROR_NO_ERROR); + module->AddProcedure("load", std::move(load)); +} + +namespace { +[[nodiscard]] bool IsFileEditable(const std::optional<std::filesystem::path> &path) { + return path && access(path->c_str(), W_OK) == 0; +} + +std::string GetPathString(const std::optional<std::filesystem::path> &path) { + if (!path) { + return "builtin"; + } + + return std::filesystem::canonical(*path).generic_string(); +} +} // namespace + +void RegisterMgProcedures( + // We expect modules to be sorted by name. + const std::map<std::string, std::unique_ptr<Module>, std::less<>> *all_modules, BuiltinModule *module) { + auto procedures_cb = [all_modules](mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result *result, + mgp_memory *memory) { + // Iterating over all_modules assumes that the standard mechanism of custom + // procedure invocations takes the ModuleRegistry::lock_ with READ access. + // For details on how the invocation is done, take a look at the + // CallProcedureCursor::Pull implementation. + for (const auto &[module_name, module] : *all_modules) { + // Return the results in sorted order by module and by procedure. + static_assert( + std::is_same_v<decltype(module->Procedures()), const std::map<std::string, mgp_proc, std::less<>> *>, + "Expected module procedures to be sorted by name"); + + const auto path = module->Path(); + const auto path_string = GetPathString(path); + const auto is_editable = IsFileEditable(path); + + for (const auto &[proc_name, proc] : *module->Procedures()) { + mgp_result_record *record{nullptr}; + if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; + } + + const auto path_value = GetStringValueOrSetError(path_string.c_str(), memory, result); + if (!path_value) { + return; + } + + MgpUniquePtr<mgp_value> is_editable_value{nullptr, mgp_value_destroy}; + if (!TryOrSetError([&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); }, + result)) { + return; + } + + utils::pmr::string full_name(module_name, memory->impl); + full_name.append(1, '.'); + full_name.append(proc_name); + const auto name_value = GetStringValueOrSetError(full_name.c_str(), memory, result); + if (!name_value) { + return; + } + + std::stringstream ss; + ss << module_name << "."; + PrintProcSignature(proc, &ss); + const auto signature = ss.str(); + const auto signature_value = GetStringValueOrSetError(signature.c_str(), memory, result); + if (!signature_value) { + return; + } + + MgpUniquePtr<mgp_value> is_write_value{nullptr, mgp_value_destroy}; + if (!TryOrSetError( + [&, &proc = proc] { + return CreateMgpObject(is_write_value, mgp_value_make_bool, proc.info.is_write ? 1 : 0, memory); + }, + result)) { + return; + } + + if (!InsertResultOrSetError(result, record, "name", name_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "signature", signature_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "is_write", is_write_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "path", path_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "is_editable", is_editable_value.get())) { + return; + } + } + } + }; + mgp_proc procedures("procedures", procedures_cb, utils::NewDeleteResource()); + MG_ASSERT(mgp_proc_add_result(&procedures, "name", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&procedures, "signature", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&procedures, "is_write", Call<mgp_type *>(mgp_type_bool)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&procedures, "path", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&procedures, "is_editable", Call<mgp_type *>(mgp_type_bool)) == + mgp_error::MGP_ERROR_NO_ERROR); + module->AddProcedure("procedures", std::move(procedures)); +} + +void RegisterMgTransformations(const std::map<std::string, std::unique_ptr<Module>, std::less<>> *all_modules, + BuiltinModule *module) { + auto transformations_cb = [all_modules](mgp_list * /*unused*/, mgp_graph * /*unused*/, mgp_result *result, + mgp_memory *memory) { + for (const auto &[module_name, module] : *all_modules) { + // Return the results in sorted order by module and by transformation. + static_assert( + std::is_same_v<decltype(module->Transformations()), const std::map<std::string, mgp_trans, std::less<>> *>, + "Expected module transformations to be sorted by name"); + + const auto path = module->Path(); + const auto path_string = GetPathString(path); + const auto is_editable = IsFileEditable(path); + + for (const auto &[trans_name, proc] : *module->Transformations()) { + mgp_result_record *record{nullptr}; + if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; + } + + const auto path_value = GetStringValueOrSetError(path_string.c_str(), memory, result); + if (!path_value) { + return; + } + + MgpUniquePtr<mgp_value> is_editable_value{nullptr, mgp_value_destroy}; + if (!TryOrSetError([&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); }, + result)) { + return; + } + + utils::pmr::string full_name(module_name, memory->impl); + full_name.append(1, '.'); + full_name.append(trans_name); + + const auto name_value = GetStringValueOrSetError(full_name.c_str(), memory, result); + if (!name_value) { + return; + } + + if (!InsertResultOrSetError(result, record, "name", name_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "path", path_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "is_editable", is_editable_value.get())) { + return; + } + } + } + }; + mgp_proc procedures("transformations", transformations_cb, utils::NewDeleteResource()); + MG_ASSERT(mgp_proc_add_result(&procedures, "name", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&procedures, "path", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&procedures, "is_editable", Call<mgp_type *>(mgp_type_bool)) == + mgp_error::MGP_ERROR_NO_ERROR); + module->AddProcedure("transformations", std::move(procedures)); +} + +void RegisterMgFunctions( + // We expect modules to be sorted by name. + const std::map<std::string, std::unique_ptr<Module>, std::less<>> *all_modules, BuiltinModule *module) { + auto functions_cb = [all_modules](mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result *result, + mgp_memory *memory) { + // Iterating over all_modules assumes that the standard mechanism of magic + // functions invocations takes the ModuleRegistry::lock_ with READ access. + for (const auto &[module_name, module] : *all_modules) { + // Return the results in sorted order by module and by function_name. + static_assert(std::is_same_v<decltype(module->Functions()), const std::map<std::string, mgp_func, std::less<>> *>, + "Expected module magic functions to be sorted by name"); + + const auto path = module->Path(); + const auto path_string = GetPathString(path); + const auto is_editable = IsFileEditable(path); + + for (const auto &[func_name, func] : *module->Functions()) { + mgp_result_record *record{nullptr}; + + if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; + } + + const auto path_value = GetStringValueOrSetError(path_string.c_str(), memory, result); + if (!path_value) { + return; + } + + MgpUniquePtr<mgp_value> is_editable_value{nullptr, mgp_value_destroy}; + if (!TryOrSetError([&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); }, + result)) { + return; + } + + utils::pmr::string full_name(module_name, memory->impl); + full_name.append(1, '.'); + full_name.append(func_name); + const auto name_value = GetStringValueOrSetError(full_name.c_str(), memory, result); + if (!name_value) { + return; + } + + std::stringstream ss; + ss << module_name << "."; + PrintFuncSignature(func, ss); + const auto signature = ss.str(); + const auto signature_value = GetStringValueOrSetError(signature.c_str(), memory, result); + if (!signature_value) { + return; + } + + if (!InsertResultOrSetError(result, record, "name", name_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "signature", signature_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "path", path_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "is_editable", is_editable_value.get())) { + return; + } + } + } + }; + mgp_proc functions("functions", functions_cb, utils::NewDeleteResource()); + MG_ASSERT(mgp_proc_add_result(&functions, "name", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&functions, "signature", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&functions, "path", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&functions, "is_editable", Call<mgp_type *>(mgp_type_bool)) == + mgp_error::MGP_ERROR_NO_ERROR); + module->AddProcedure("functions", std::move(functions)); +} +namespace { +bool IsAllowedExtension(const auto &extension) { + static constexpr std::array<std::string_view, 1> allowed_extensions{".py"}; + return std::any_of(allowed_extensions.begin(), allowed_extensions.end(), + [&](const auto allowed_extension) { return allowed_extension == extension; }); +} + +bool IsSubPath(const auto &base, const auto &destination) { + const auto relative = std::filesystem::relative(destination, base); + return !relative.empty() && *relative.begin() != ".."; +} + +std::optional<std::string> ReadFile(const auto &path) { + std::ifstream file(path); + if (!file.is_open()) { + return std::nullopt; + } + + const auto size = std::filesystem::file_size(path); + std::string content(size, '\0'); + file.read(content.data(), static_cast<std::streamsize>(size)); + return std::move(content); +} + +// Return the module directory that contains the `path` +utils::BasicResult<const char *, std::filesystem::path> ParentModuleDirectory(const ModuleRegistry &module_registry, + const std::filesystem::path &path) { + const auto &module_directories = module_registry.GetModulesDirectory(); + + auto longest_parent_directory = module_directories.end(); + auto max_length = std::numeric_limits<uint64_t>::min(); + for (auto it = module_directories.begin(); it != module_directories.end(); ++it) { + if (IsSubPath(*it, path)) { + const auto length = std::filesystem::canonical(*it).string().size(); + if (length > max_length) { + longest_parent_directory = it; + max_length = length; + } + } + } + + if (longest_parent_directory == module_directories.end()) { + return "The specified file isn't contained in any of the module directories."; + } + + return *longest_parent_directory; +} +} // namespace + +void RegisterMgGetModuleFiles(ModuleRegistry *module_registry, BuiltinModule *module) { + auto get_module_files_cb = [module_registry](mgp_list * /*args*/, mgp_graph * /*unused*/, mgp_result *result, + mgp_memory *memory) { + for (const auto &module_directory : module_registry->GetModulesDirectory()) { + for (const auto &dir_entry : std::filesystem::recursive_directory_iterator(module_directory)) { + if (dir_entry.is_regular_file() && IsAllowedExtension(dir_entry.path().extension())) { + mgp_result_record *record{nullptr}; + if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; + } + + const auto path_string = GetPathString(dir_entry); + const auto is_editable = IsFileEditable(dir_entry); + + const auto path_value = GetStringValueOrSetError(path_string.c_str(), memory, result); + if (!path_value) { + return; + } + + MgpUniquePtr<mgp_value> is_editable_value{nullptr, mgp_value_destroy}; + if (!TryOrSetError( + [&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); }, + result)) { + return; + } + + if (!InsertResultOrSetError(result, record, "path", path_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "is_editable", is_editable_value.get())) { + return; + } + } + } + } + }; + + mgp_proc get_module_files("get_module_files", get_module_files_cb, utils::NewDeleteResource(), + {.required_privilege = AuthQuery::Privilege::MODULE_READ}); + MG_ASSERT(mgp_proc_add_result(&get_module_files, "path", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&get_module_files, "is_editable", Call<mgp_type *>(mgp_type_bool)) == + mgp_error::MGP_ERROR_NO_ERROR); + module->AddProcedure("get_module_files", std::move(get_module_files)); +} + +void RegisterMgGetModuleFile(ModuleRegistry *module_registry, BuiltinModule *module) { + auto get_module_file_cb = [module_registry](mgp_list *args, mgp_graph * /*unused*/, mgp_result *result, + mgp_memory *memory) { + MG_ASSERT(Call<size_t>(mgp_list_size, args) == 1U, "Should have been type checked already"); + auto *arg = Call<mgp_value *>(mgp_list_at, args, 0); + MG_ASSERT(CallBool(mgp_value_is_string, arg), "Should have been type checked already"); + const char *path_str{nullptr}; + if (!TryOrSetError([&] { return mgp_value_get_string(arg, &path_str); }, result)) { + return; + } + + const std::filesystem::path path{path_str}; + + if (!path.is_absolute()) { + static_cast<void>(mgp_result_set_error_msg(result, "The path should be an absolute path.")); + return; + } + + if (!IsAllowedExtension(path.extension())) { + static_cast<void>(mgp_result_set_error_msg(result, "The specified file isn't in the supported format.")); + return; + } + + if (!std::filesystem::exists(path)) { + static_cast<void>(mgp_result_set_error_msg(result, "The specified file doesn't exist.")); + return; + } + + if (auto maybe_error_msg = ParentModuleDirectory(*module_registry, path); maybe_error_msg.HasError()) { + static_cast<void>(mgp_result_set_error_msg(result, maybe_error_msg.GetError())); + return; + } + + const auto maybe_content = ReadFile(path); + if (!maybe_content) { + static_cast<void>(mgp_result_set_error_msg(result, "Couldn't read the content of the file.")); + return; + } + + mgp_result_record *record{nullptr}; + if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; + } + + const auto content_value = GetStringValueOrSetError(maybe_content->c_str(), memory, result); + if (!content_value) { + return; + } + + if (!InsertResultOrSetError(result, record, "content", content_value.get())) { + return; + } + }; + mgp_proc get_module_file("get_module_file", std::move(get_module_file_cb), utils::NewDeleteResource(), + {.required_privilege = AuthQuery::Privilege::MODULE_READ}); + MG_ASSERT(mgp_proc_add_arg(&get_module_file, "path", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&get_module_file, "content", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + module->AddProcedure("get_module_file", std::move(get_module_file)); +} + +namespace { +utils::BasicResult<std::string> WriteToFile(const std::filesystem::path &file, const std::string_view content) { + std::ofstream output_file{file}; + if (!output_file.is_open()) { + return fmt::format("Failed to open the file at location {}", file); + } + output_file.write(content.data(), static_cast<std::streamsize>(content.size())); + output_file.flush(); + return {}; +} +} // namespace + +void RegisterMgCreateModuleFile(ModuleRegistry *module_registry, utils::RWLock *lock, BuiltinModule *module) { + auto create_module_file_cb = [module_registry, lock](mgp_list *args, mgp_graph * /*unused*/, mgp_result *result, + mgp_memory *memory) { + MG_ASSERT(Call<size_t>(mgp_list_size, args) == 2U, "Should have been type checked already"); + auto *filename_arg = Call<mgp_value *>(mgp_list_at, args, 0); + MG_ASSERT(CallBool(mgp_value_is_string, filename_arg), "Should have been type checked already"); + const char *filename_str{nullptr}; + if (!TryOrSetError([&] { return mgp_value_get_string(filename_arg, &filename_str); }, result)) { + return; + } + + const auto file_path = module_registry->InternalModuleDir() / filename_str; + + if (!IsSubPath(module_registry->InternalModuleDir(), file_path)) { + static_cast<void>(mgp_result_set_error_msg( + result, + "Invalid relative path defined. The module file cannot be define outside the internal modules directory.")); + return; + } + + if (!IsAllowedExtension(file_path.extension())) { + static_cast<void>(mgp_result_set_error_msg(result, "The specified file isn't in the supported format.")); + return; + } + + if (std::filesystem::exists(file_path)) { + static_cast<void>(mgp_result_set_error_msg(result, "File with the same name already exists!")); + return; + } + + utils::EnsureDir(file_path.parent_path()); + + auto *content_arg = Call<mgp_value *>(mgp_list_at, args, 1); + MG_ASSERT(CallBool(mgp_value_is_string, content_arg), "Should have been type checked already"); + const char *content_str{nullptr}; + if (!TryOrSetError([&] { return mgp_value_get_string(content_arg, &content_str); }, result)) { + return; + } + + if (auto maybe_error = WriteToFile(file_path, {content_str, std::strlen(content_str)}); maybe_error.HasError()) { + static_cast<void>(mgp_result_set_error_msg(result, maybe_error.GetError().c_str())); + return; + } + + mgp_result_record *record{nullptr}; + if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; + } + + const auto path_value = GetStringValueOrSetError(std::filesystem::canonical(file_path).c_str(), memory, result); + if (!path_value) { + return; + } + + if (!InsertResultOrSetError(result, record, "path", path_value.get())) { + return; + } + + WithUpgradedLock(lock, [&]() { module_registry->UnloadAndLoadModulesFromDirectories(); }); + }; + mgp_proc create_module_file("create_module_file", std::move(create_module_file_cb), utils::NewDeleteResource(), + {.required_privilege = AuthQuery::Privilege::MODULE_WRITE}); + MG_ASSERT(mgp_proc_add_arg(&create_module_file, "filename", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_arg(&create_module_file, "content", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&create_module_file, "path", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + module->AddProcedure("create_module_file", std::move(create_module_file)); +} + +void RegisterMgUpdateModuleFile(ModuleRegistry *module_registry, utils::RWLock *lock, BuiltinModule *module) { + auto update_module_file_cb = [module_registry, lock](mgp_list *args, mgp_graph * /*unused*/, mgp_result *result, + mgp_memory * /*memory*/) { + MG_ASSERT(Call<size_t>(mgp_list_size, args) == 2U, "Should have been type checked already"); + auto *path_arg = Call<mgp_value *>(mgp_list_at, args, 0); + MG_ASSERT(CallBool(mgp_value_is_string, path_arg), "Should have been type checked already"); + const char *path_str{nullptr}; + if (!TryOrSetError([&] { return mgp_value_get_string(path_arg, &path_str); }, result)) { + return; + } + + const std::filesystem::path path{path_str}; + + if (!path.is_absolute()) { + static_cast<void>(mgp_result_set_error_msg(result, "The path should be an absolute path.")); + return; + } + + if (!IsAllowedExtension(path.extension())) { + static_cast<void>(mgp_result_set_error_msg(result, "The specified file isn't in the supported format.")); + return; + } + + if (!std::filesystem::exists(path)) { + static_cast<void>(mgp_result_set_error_msg(result, "The specified file doesn't exist.")); + return; + } + + if (auto maybe_error_msg = ParentModuleDirectory(*module_registry, path); maybe_error_msg.HasError()) { + static_cast<void>(mgp_result_set_error_msg(result, maybe_error_msg.GetError())); + return; + } + + auto *content_arg = Call<mgp_value *>(mgp_list_at, args, 1); + MG_ASSERT(CallBool(mgp_value_is_string, content_arg), "Should have been type checked already"); + const char *content_str{nullptr}; + if (!TryOrSetError([&] { return mgp_value_get_string(content_arg, &content_str); }, result)) { + return; + } + + if (auto maybe_error = WriteToFile(path, {content_str, std::strlen(content_str)}); maybe_error.HasError()) { + static_cast<void>(mgp_result_set_error_msg(result, maybe_error.GetError().c_str())); + return; + } + + WithUpgradedLock(lock, [&]() { module_registry->UnloadAndLoadModulesFromDirectories(); }); + }; + mgp_proc update_module_file("update_module_file", std::move(update_module_file_cb), utils::NewDeleteResource(), + {.required_privilege = AuthQuery::Privilege::MODULE_WRITE}); + MG_ASSERT(mgp_proc_add_arg(&update_module_file, "path", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_arg(&update_module_file, "content", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + module->AddProcedure("update_module_file", std::move(update_module_file)); +} + +void RegisterMgDeleteModuleFile(ModuleRegistry *module_registry, utils::RWLock *lock, BuiltinModule *module) { + auto delete_module_file_cb = [module_registry, lock](mgp_list *args, mgp_graph * /*unused*/, mgp_result *result, + mgp_memory * /*memory*/) { + MG_ASSERT(Call<size_t>(mgp_list_size, args) == 1U, "Should have been type checked already"); + auto *path_arg = Call<mgp_value *>(mgp_list_at, args, 0); + MG_ASSERT(CallBool(mgp_value_is_string, path_arg), "Should have been type checked already"); + const char *path_str{nullptr}; + if (!TryOrSetError([&] { return mgp_value_get_string(path_arg, &path_str); }, result)) { + return; + } + + const std::filesystem::path path{path_str}; + + if (!path.is_absolute()) { + static_cast<void>(mgp_result_set_error_msg(result, "The path should be an absolute path.")); + return; + } + + if (!IsAllowedExtension(path.extension())) { + static_cast<void>(mgp_result_set_error_msg(result, "The specified file isn't in the supported format.")); + return; + } + + if (!std::filesystem::exists(path)) { + static_cast<void>(mgp_result_set_error_msg(result, "The specified file doesn't exist.")); + return; + } + + const auto parent_module_directory = ParentModuleDirectory(*module_registry, path); + if (parent_module_directory.HasError()) { + static_cast<void>(mgp_result_set_error_msg(result, parent_module_directory.GetError())); + return; + } + + std::error_code ec; + if (!std::filesystem::remove(path, ec)) { + static_cast<void>( + mgp_result_set_error_msg(result, fmt::format("Failed to delete the module: {}", ec.message()).c_str())); + return; + } + + auto parent_path = path.parent_path(); + while (!std::filesystem::is_symlink(parent_path) && std::filesystem::is_empty(parent_path) && + !std::filesystem::equivalent(*parent_module_directory, parent_path)) { + std::filesystem::remove(parent_path); + parent_path = parent_path.parent_path(); + } + + WithUpgradedLock(lock, [&]() { module_registry->UnloadAndLoadModulesFromDirectories(); }); + }; + mgp_proc delete_module_file("delete_module_file", std::move(delete_module_file_cb), utils::NewDeleteResource(), + {.required_privilege = AuthQuery::Privilege::MODULE_WRITE}); + MG_ASSERT(mgp_proc_add_arg(&delete_module_file, "path", Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + module->AddProcedure("delete_module_file", std::move(delete_module_file)); +} + +// Run `fun` with `mgp_module *` and `mgp_memory *` arguments. If `fun` returned +// a `true` value, store the `mgp_module::procedures` and +// `mgp_module::transformations into `proc_map`. The return value of WithModuleRegistration +// is the same as that of `fun`. Note, the return value need only be convertible to `bool`, +// it does not have to be `bool` itself. +template <class TProcMap, class TTransMap, class TFuncMap, class TFun> +auto WithModuleRegistration(TProcMap *proc_map, TTransMap *trans_map, TFuncMap *func_map, const TFun &fun) { + // We probably don't need more than 256KB for module initialization. + static constexpr size_t stack_bytes = 256UL * 1024UL; + unsigned char stack_memory[stack_bytes]; + utils::MonotonicBufferResource monotonic_memory(stack_memory, stack_bytes); + mgp_memory memory{&monotonic_memory}; + mgp_module module_def{memory.impl}; + auto res = fun(&module_def, &memory); + if (res) { + // Copy procedures into resulting proc_map. + for (const auto &proc : module_def.procedures) proc_map->emplace(proc); + // Copy transformations into resulting trans_map. + for (const auto &trans : module_def.transformations) trans_map->emplace(trans); + // Copy functions into resulting func_map. + for (const auto &func : module_def.functions) func_map->emplace(func); + } + return res; +} + +} // namespace + +class SharedLibraryModule final : public Module { + public: + SharedLibraryModule(); + ~SharedLibraryModule() override; + SharedLibraryModule(const SharedLibraryModule &) = delete; + SharedLibraryModule(SharedLibraryModule &&) = delete; + SharedLibraryModule &operator=(const SharedLibraryModule &) = delete; + SharedLibraryModule &operator=(SharedLibraryModule &&) = delete; + + bool Load(const std::filesystem::path &file_path); + + bool Close() override; + + const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override; + + const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override; + + const std::map<std::string, mgp_func, std::less<>> *Functions() const override; + + std::optional<std::filesystem::path> Path() const override { return file_path_; } + + private: + /// Path as requested for loading the module from a library. + std::filesystem::path file_path_; + /// System handle to shared library. + void *handle_; + /// Required initialization function called on module load. + std::function<int(mgp_module *, mgp_memory *)> init_fn_; + /// Optional shutdown function called on module unload. + std::function<int()> shutdown_fn_; + /// Registered procedures + std::map<std::string, mgp_proc, std::less<>> procedures_; + /// Registered transformations + std::map<std::string, mgp_trans, std::less<>> transformations_; + /// Registered functions + std::map<std::string, mgp_func, std::less<>> functions_; +}; + +SharedLibraryModule::SharedLibraryModule() : handle_(nullptr) {} + +SharedLibraryModule::~SharedLibraryModule() { + if (handle_) Close(); +} + +bool SharedLibraryModule::Load(const std::filesystem::path &file_path) { + MG_ASSERT(!handle_, "Attempting to load an already loaded module..."); + spdlog::info("Loading module {}...", file_path); + file_path_ = file_path; + dlerror(); // Clear any existing error. + // NOLINTNEXTLINE(hicpp-signed-bitwise) + handle_ = dlopen(file_path.c_str(), RTLD_NOW | RTLD_LOCAL | RTLD_DEEPBIND); + if (!handle_) { + spdlog::error( + utils::MessageWithLink("Unable to load module {}; {}.", file_path, dlerror(), "https://memgr.ph/modules")); + return false; + } + // Get required mgp_init_module + init_fn_ = reinterpret_cast<int (*)(mgp_module *, mgp_memory *)>(dlsym(handle_, "mgp_init_module")); + char *dl_errored = dlerror(); + if (!init_fn_ || dl_errored) { + spdlog::error( + utils::MessageWithLink("Unable to load module {}; {}.", file_path, dl_errored, "https://memgr.ph/modules")); + dlclose(handle_); + handle_ = nullptr; + return false; + } + auto module_cb = [&](auto *module_def, auto *memory) { + // Run mgp_init_module which must succeed. + int init_res = init_fn_(module_def, memory); + auto with_error = [this](std::string_view error_msg) { + spdlog::error(error_msg); + dlclose(handle_); + handle_ = nullptr; + return false; + }; + + if (init_res != 0) { + const auto error = fmt::format("Unable to load module {}; mgp_init_module_returned {} ", file_path, init_res); + return with_error(error); + } + for (auto &trans : module_def->transformations) { + const bool success = mgp_error::MGP_ERROR_NO_ERROR == MgpTransAddFixedResult(&trans.second); + if (!success) { + const auto error = + fmt::format("Unable to add result to transformation in module {}; add result failed", file_path); + return with_error(error); + } + } + return true; + }; + if (!WithModuleRegistration(&procedures_, &transformations_, &functions_, module_cb)) { + return false; + } + // Get optional mgp_shutdown_module + shutdown_fn_ = reinterpret_cast<int (*)()>(dlsym(handle_, "mgp_shutdown_module")); + dl_errored = dlerror(); + if (dl_errored) spdlog::warn("When loading module {}; {}", file_path, dl_errored); + spdlog::info("Loaded module {}", file_path); + return true; +} + +bool SharedLibraryModule::Close() { + MG_ASSERT(handle_, "Attempting to close a module that has not been loaded..."); + spdlog::info("Closing module {}...", file_path_); + // non-existent shutdown function is semantically the same as a shutdown + // function that does nothing. + int shutdown_res = 0; + if (shutdown_fn_) shutdown_res = shutdown_fn_(); + if (shutdown_res != 0) { + spdlog::warn("When closing module {}; mgp_shutdown_module returned {}", file_path_, shutdown_res); + } + if (dlclose(handle_) != 0) { + spdlog::error( + utils::MessageWithLink("Failed to close module {}; {}.", file_path_, dlerror(), "https://memgr.ph/modules")); + return false; + } + spdlog::info("Closed module {}", file_path_); + handle_ = nullptr; + procedures_.clear(); + return true; +} + +const std::map<std::string, mgp_proc, std::less<>> *SharedLibraryModule::Procedures() const { + MG_ASSERT(handle_, + "Attempting to access procedures of a module that has not " + "been loaded..."); + return &procedures_; +} + +const std::map<std::string, mgp_trans, std::less<>> *SharedLibraryModule::Transformations() const { + MG_ASSERT(handle_, + "Attempting to access procedures of a module that has not " + "been loaded..."); + return &transformations_; +} + +const std::map<std::string, mgp_func, std::less<>> *SharedLibraryModule::Functions() const { + MG_ASSERT(handle_, + "Attempting to access functions of a module that has not " + "been loaded..."); + return &functions_; +} + +class PythonModule final : public Module { + public: + PythonModule(); + ~PythonModule() override; + PythonModule(const PythonModule &) = delete; + PythonModule(PythonModule &&) = delete; + PythonModule &operator=(const PythonModule &) = delete; + PythonModule &operator=(PythonModule &&) = delete; + + bool Load(const std::filesystem::path &file_path); + + bool Close() override; + + const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override; + const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override; + const std::map<std::string, mgp_func, std::less<>> *Functions() const override; + std::optional<std::filesystem::path> Path() const override { return file_path_; } + + private: + std::filesystem::path file_path_; + py::Object py_module_; + std::map<std::string, mgp_proc, std::less<>> procedures_; + std::map<std::string, mgp_trans, std::less<>> transformations_; + std::map<std::string, mgp_func, std::less<>> functions_; +}; + +PythonModule::PythonModule() {} + +PythonModule::~PythonModule() { + if (py_module_) Close(); +} + +bool PythonModule::Load(const std::filesystem::path &file_path) { + MG_ASSERT(!py_module_, "Attempting to load an already loaded module..."); + spdlog::info("Loading module {}...", file_path); + file_path_ = file_path; + auto gil = py::EnsureGIL(); + auto maybe_exc = py::AppendToSysPath(file_path.parent_path().c_str()); + if (maybe_exc) { + spdlog::error( + utils::MessageWithLink("Unable to load module {}; {}.", file_path, *maybe_exc, "https://memgr.ph/modules")); + return false; + } + bool succ = true; + auto module_cb = [&](auto *module_def, auto * /*memory*/) { + auto result = ImportPyModule(file_path.stem().c_str(), module_def); + for (auto &trans : module_def->transformations) { + succ = MgpTransAddFixedResult(&trans.second) == mgp_error::MGP_ERROR_NO_ERROR; + if (!succ) { + return result; + } + }; + return result; + }; + py_module_ = WithModuleRegistration(&procedures_, &transformations_, &functions_, module_cb); + if (py_module_) { + spdlog::info("Loaded module {}", file_path); + + if (!succ) { + spdlog::error("Unable to add result to transformation"); + return false; + } + return true; + } + auto exc_info = py::FetchError().value(); + spdlog::error( + utils::MessageWithLink("Unable to load module {}; {}.", file_path, exc_info, "https://memgr.ph/modules")); + return false; +} + +bool PythonModule::Close() { + MG_ASSERT(py_module_, "Attempting to close a module that has not been loaded..."); + spdlog::info("Closing module {}...", file_path_); + // The procedures and transformations are closures which hold references to the Python callbacks. + // Releasing these references might result in deallocations so we need to take the GIL. + auto gil = py::EnsureGIL(); + procedures_.clear(); + transformations_.clear(); + functions_.clear(); + // Delete the module from the `sys.modules` directory so that the module will + // be properly imported if imported again. + py::Object sys(PyImport_ImportModule("sys")); + if (PyDict_DelItemString(sys.GetAttr("modules").Ptr(), file_path_.stem().c_str()) != 0) { + spdlog::warn("Failed to remove the module from sys.modules"); + py_module_ = py::Object(nullptr); + return false; + } + + // Remove the cached bytecode if it's present + std::filesystem::remove_all(file_path_.parent_path() / "__pycache__"); + py_module_ = py::Object(nullptr); + spdlog::info("Closed module {}", file_path_); + return true; +} + +const std::map<std::string, mgp_proc, std::less<>> *PythonModule::Procedures() const { + MG_ASSERT(py_module_, + "Attempting to access procedures of a module that has " + "not been loaded..."); + return &procedures_; +} + +const std::map<std::string, mgp_trans, std::less<>> *PythonModule::Transformations() const { + MG_ASSERT(py_module_, + "Attempting to access procedures of a module that has " + "not been loaded..."); + return &transformations_; +} + +const std::map<std::string, mgp_func, std::less<>> *PythonModule::Functions() const { + MG_ASSERT(py_module_, + "Attempting to access functions of a module that has " + "not been loaded..."); + return &functions_; +} +namespace { + +std::unique_ptr<Module> LoadModuleFromFile(const std::filesystem::path &path) { + const auto &ext = path.extension(); + if (ext != ".so" && ext != ".py") { + spdlog::warn(utils::MessageWithLink("Unknown query module file {}.", path, "https://memgr.ph/modules")); + return nullptr; + } + std::unique_ptr<Module> module; + if (path.extension() == ".so") { + auto lib_module = std::make_unique<SharedLibraryModule>(); + if (!lib_module->Load(path)) return nullptr; + module = std::move(lib_module); + } else if (path.extension() == ".py") { + auto py_module = std::make_unique<PythonModule>(); + if (!py_module->Load(path)) return nullptr; + module = std::move(py_module); + } + return module; +} + +} // namespace + +bool ModuleRegistry::RegisterModule(const std::string_view name, std::unique_ptr<Module> module) { + MG_ASSERT(!name.empty(), "Module name cannot be empty"); + MG_ASSERT(module, "Tried to register an invalid module"); + if (modules_.find(name) != modules_.end()) { + spdlog::error( + utils::MessageWithLink("Unable to overwrite an already loaded module {}.", name, "https://memgr.ph/modules")); + return false; + } + modules_.emplace(name, std::move(module)); + return true; +} + +void ModuleRegistry::DoUnloadAllModules() { + MG_ASSERT(modules_.find("mg") != modules_.end(), "Expected the builtin \"mg\" module to be present."); + // This is correct because the destructor will close each module. However, + // we don't want to unload the builtin "mg" module. + auto module = std::move(modules_["mg"]); + modules_.clear(); + modules_.emplace("mg", std::move(module)); +} + +ModuleRegistry::ModuleRegistry() { + auto module = std::make_unique<BuiltinModule>(); + RegisterMgProcedures(&modules_, module.get()); + RegisterMgTransformations(&modules_, module.get()); + RegisterMgFunctions(&modules_, module.get()); + RegisterMgLoad(this, &lock_, module.get()); + RegisterMgGetModuleFiles(this, module.get()); + RegisterMgGetModuleFile(this, module.get()); + RegisterMgCreateModuleFile(this, &lock_, module.get()); + RegisterMgUpdateModuleFile(this, &lock_, module.get()); + RegisterMgDeleteModuleFile(this, &lock_, module.get()); + modules_.emplace("mg", std::move(module)); +} + +void ModuleRegistry::SetModulesDirectory(std::vector<std::filesystem::path> modules_dirs, + const std::filesystem::path &data_directory) { + internal_module_dir_ = data_directory / "internal_modules"; + utils::EnsureDirOrDie(internal_module_dir_); + modules_dirs_ = std::move(modules_dirs); + modules_dirs_.push_back(internal_module_dir_); +} + +const std::vector<std::filesystem::path> &ModuleRegistry::GetModulesDirectory() const { return modules_dirs_; } + +bool ModuleRegistry::LoadModuleIfFound(const std::filesystem::path &modules_dir, const std::string_view name) { + if (!utils::DirExists(modules_dir)) { + spdlog::error( + utils::MessageWithLink("Module directory {} doesn't exist.", modules_dir, "https://memgr.ph/modules")); + return false; + } + for (const auto &entry : std::filesystem::directory_iterator(modules_dir)) { + const auto &path = entry.path(); + if (entry.is_regular_file() && path.stem() == name) { + auto module = LoadModuleFromFile(path); + if (!module) return false; + return RegisterModule(name, std::move(module)); + } + } + return false; +} + +bool ModuleRegistry::LoadOrReloadModuleFromName(const std::string_view name) { + if (modules_dirs_.empty()) return false; + if (name.empty()) return false; + std::unique_lock<utils::RWLock> guard(lock_); + auto found_it = modules_.find(name); + if (found_it != modules_.end()) { + if (!found_it->second->Close()) { + spdlog::warn("Failed to close module {}", found_it->first); + } + modules_.erase(found_it); + } + + for (const auto &module_dir : modules_dirs_) { + if (LoadModuleIfFound(module_dir, name)) { + return true; + } + } + return false; +} + +void ModuleRegistry::LoadModulesFromDirectory(const std::filesystem::path &modules_dir) { + if (modules_dir.empty()) return; + if (!utils::DirExists(modules_dir)) { + spdlog::error( + utils::MessageWithLink("Module directory {} doesn't exist.", modules_dir, "https://memgr.ph/modules")); + return; + } + for (const auto &entry : std::filesystem::directory_iterator(modules_dir)) { + const auto &path = entry.path(); + if (entry.is_regular_file()) { + std::string name = path.stem(); + if (name.empty()) continue; + auto module = LoadModuleFromFile(path); + if (!module) continue; + RegisterModule(name, std::move(module)); + } + } +} + +void ModuleRegistry::UnloadAndLoadModulesFromDirectories() { + std::unique_lock<utils::RWLock> guard(lock_); + DoUnloadAllModules(); + for (const auto &module_dir : modules_dirs_) { + LoadModulesFromDirectory(module_dir); + } +} + +ModulePtr ModuleRegistry::GetModuleNamed(const std::string_view name) const { + std::shared_lock<utils::RWLock> guard(lock_); + auto found_it = modules_.find(name); + if (found_it == modules_.end()) return nullptr; + return ModulePtr(found_it->second.get(), std::move(guard)); +} + +void ModuleRegistry::UnloadAllModules() { + std::unique_lock<utils::RWLock> guard(lock_); + DoUnloadAllModules(); +} + +utils::MemoryResource &ModuleRegistry::GetSharedMemoryResource() noexcept { return *shared_; } + +bool ModuleRegistry::RegisterMgProcedure(const std::string_view name, mgp_proc proc) { + std::unique_lock<utils::RWLock> guard(lock_); + if (auto module = modules_.find("mg"); module != modules_.end()) { + auto *builtin_module = dynamic_cast<BuiltinModule *>(module->second.get()); + builtin_module->AddProcedure(name, std::move(proc)); + return true; + } + return false; +} + +const std::filesystem::path &ModuleRegistry::InternalModuleDir() const noexcept { return internal_module_dir_; } + +namespace { + +/// This function returns a pair of either +// ModuleName | Prop +/// 1. <ModuleName, ProcedureName> +/// 2. <ModuleName, TransformationName> +std::optional<std::pair<std::string_view, std::string_view>> FindModuleNameAndProp( + const ModuleRegistry &module_registry, std::string_view fully_qualified_name, utils::MemoryResource *memory) { + utils::pmr::vector<std::string_view> name_parts(memory); + utils::Split(&name_parts, fully_qualified_name, "."); + if (name_parts.size() == 1U) return std::nullopt; + auto last_dot_pos = fully_qualified_name.find_last_of('.'); + MG_ASSERT(last_dot_pos != std::string_view::npos); + + const auto &module_name = fully_qualified_name.substr(0, last_dot_pos); + const auto &name = name_parts.back(); + return std::make_pair(module_name, name); +} + +template <typename T> +concept ModuleProperties = utils::SameAsAnyOf<T, mgp_proc, mgp_trans, mgp_func>; + +template <ModuleProperties T> +std::optional<std::pair<ModulePtr, const T *>> MakePairIfPropFound(const ModuleRegistry &module_registry, + std::string_view fully_qualified_name, + utils::MemoryResource *memory) { + auto prop_fun = [](auto &module) { + if constexpr (std::is_same_v<T, mgp_proc>) { + return module->Procedures(); + } else if constexpr (std::is_same_v<T, mgp_trans>) { + return module->Transformations(); + } else if constexpr (std::is_same_v<T, mgp_func>) { + return module->Functions(); + } + }; + auto result = FindModuleNameAndProp(module_registry, fully_qualified_name, memory); + if (!result) return std::nullopt; + auto [module_name, prop_name] = *result; + auto module = module_registry.GetModuleNamed(module_name); + if (!module) return std::nullopt; + auto *prop = prop_fun(module); + const auto &prop_it = prop->find(prop_name); + if (prop_it == prop->end()) return std::nullopt; + return std::make_pair(std::move(module), &prop_it->second); +} + +} // namespace + +std::optional<std::pair<ModulePtr, const mgp_proc *>> FindProcedure(const ModuleRegistry &module_registry, + std::string_view fully_qualified_procedure_name, + utils::MemoryResource *memory) { + return MakePairIfPropFound<mgp_proc>(module_registry, fully_qualified_procedure_name, memory); +} + +std::optional<std::pair<ModulePtr, const mgp_trans *>> FindTransformation( + const ModuleRegistry &module_registry, std::string_view fully_qualified_transformation_name, + utils::MemoryResource *memory) { + return MakePairIfPropFound<mgp_trans>(module_registry, fully_qualified_transformation_name, memory); +} + +std::optional<std::pair<ModulePtr, const mgp_func *>> FindFunction(const ModuleRegistry &module_registry, + std::string_view fully_qualified_function_name, + utils::MemoryResource *memory) { + return MakePairIfPropFound<mgp_func>(module_registry, fully_qualified_function_name, memory); +} + +} // namespace memgraph::query::v2::procedure diff --git a/src/query/v2/procedure/module.hpp b/src/query/v2/procedure/module.hpp new file mode 100644 index 000000000..628f94c34 --- /dev/null +++ b/src/query/v2/procedure/module.hpp @@ -0,0 +1,246 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +/// API for loading and registering modules providing custom oC procedures +#pragma once + +#include <dlfcn.h> +#include <filesystem> +#include <functional> +#include <optional> +#include <shared_mutex> +#include <string> +#include <string_view> +#include <unordered_map> + +#include "query/v2/procedure/cypher_types.hpp" +#include "query/v2/procedure/mg_procedure_impl.hpp" +#include "utils/memory.hpp" +#include "utils/rw_lock.hpp" + +class CypherMainVisitorTest; + +namespace memgraph::query::v2::procedure { + +class Module { + public: + Module() {} + virtual ~Module(); + Module(const Module &) = delete; + Module(Module &&) = delete; + Module &operator=(const Module &) = delete; + Module &operator=(Module &&) = delete; + + /// Invokes the (optional) shutdown function and closes the module. + virtual bool Close() = 0; + + /// Returns registered procedures of this module + virtual const std::map<std::string, mgp_proc, std::less<>> *Procedures() const = 0; + /// Returns registered transformations of this module + virtual const std::map<std::string, mgp_trans, std::less<>> *Transformations() const = 0; + // /// Returns registered functions of this module + virtual const std::map<std::string, mgp_func, std::less<>> *Functions() const = 0; + + virtual std::optional<std::filesystem::path> Path() const = 0; +}; + +/// Proxy for a registered Module, acquires a read lock from ModuleRegistry. +class ModulePtr final { + const Module *module_{nullptr}; + std::shared_lock<utils::RWLock> lock_; + + public: + ModulePtr() = default; + ModulePtr(std::nullptr_t) {} + ModulePtr(const Module *module, std::shared_lock<utils::RWLock> lock) : module_(module), lock_(std::move(lock)) {} + + explicit operator bool() const { return static_cast<bool>(module_); } + + const Module &operator*() const { return *module_; } + const Module *operator->() const { return module_; } +}; + +/// Thread-safe registration of modules from libraries, uses utils::RWLock. +class ModuleRegistry final { + friend CypherMainVisitorTest; + + std::map<std::string, std::unique_ptr<Module>, std::less<>> modules_; + mutable utils::RWLock lock_{utils::RWLock::Priority::WRITE}; + std::unique_ptr<utils::MemoryResource> shared_{std::make_unique<utils::ResourceWithOutOfMemoryException>()}; + + bool RegisterModule(std::string_view name, std::unique_ptr<Module> module); + + void DoUnloadAllModules(); + + /// Loads the module if it's in the modules_dir directory + /// @return Whether the module was loaded + bool LoadModuleIfFound(const std::filesystem::path &modules_dir, std::string_view name); + + void LoadModulesFromDirectory(const std::filesystem::path &modules_dir); + + public: + ModuleRegistry(); + + /// Set the modules directories that will be used when (re)loading modules. + void SetModulesDirectory(std::vector<std::filesystem::path> modules_dir, const std::filesystem::path &data_directory); + const std::vector<std::filesystem::path> &GetModulesDirectory() const; + + /// Atomically load or reload a module with a particular name from the given + /// directory. + /// + /// Takes a write lock. If the module exists it is reloaded. Otherwise, the + /// module is loaded from the file whose filename, without the extension, + /// matches the module's name. If multiple such files exist, only one is + /// chosen, in an unspecified manner. If loading of the chosen file fails, no + /// other files are tried. + /// + /// Return true if the module was loaded or reloaded successfully, false + /// otherwise. + bool LoadOrReloadModuleFromName(std::string_view name); + + /// Atomically unload all modules and then load all possible modules from the + /// set directories. + /// + /// Takes a write lock. + void UnloadAndLoadModulesFromDirectories(); + + /// Find a module with given name or return nullptr. + /// Takes a read lock. + ModulePtr GetModuleNamed(std::string_view name) const; + + /// Remove all loaded (non-builtin) modules. + /// Takes a write lock. + void UnloadAllModules(); + + /// Returns the shared memory allocator used by modules + utils::MemoryResource &GetSharedMemoryResource() noexcept; + + bool RegisterMgProcedure(std::string_view name, mgp_proc proc); + + const std::filesystem::path &InternalModuleDir() const noexcept; + + private: + class SharedLibraryHandle { + public: + SharedLibraryHandle(const std::string &shared_library, int mode) : handle_{dlopen(shared_library.c_str(), mode)} {} + SharedLibraryHandle(const SharedLibraryHandle &) = delete; + SharedLibraryHandle(SharedLibraryHandle &&) = delete; + SharedLibraryHandle operator=(const SharedLibraryHandle &) = delete; + SharedLibraryHandle operator=(SharedLibraryHandle &&) = delete; + + ~SharedLibraryHandle() { + if (handle_) { + dlclose(handle_); + } + } + + private: + void *handle_; + }; + +#if __has_feature(address_sanitizer) + // This is why we need RTLD_NODELETE and we must not use RTLD_DEEPBIND with + // ASAN: https://github.com/google/sanitizers/issues/89 + SharedLibraryHandle libstd_handle{"libstdc++.so.6", RTLD_NOW | RTLD_LOCAL | RTLD_NODELETE}; +#else + // The reason behind opening share library during runtime is to avoid issues + // with loading symbols from stdlib. We have encounter issues with locale + // that cause std::cout not being printed and issues when python libraries + // would call stdlib (e.g. pytorch). + // The way that those issues were solved was + // by using RTLD_DEEPBIND. RTLD_DEEPBIND ensures that the lookup for the + // mentioned library will be first performed in the already existing binded + // libraries and then the global namespace. + // RTLD_DEEPBIND => https://linux.die.net/man/3/dlopen + SharedLibraryHandle libstd_handle{"libstdc++.so.6", RTLD_NOW | RTLD_LOCAL | RTLD_DEEPBIND}; +#endif + std::vector<std::filesystem::path> modules_dirs_; + std::filesystem::path internal_module_dir_; +}; + +/// Single, global module registry. +extern ModuleRegistry gModuleRegistry; + +/// Return the ModulePtr and `mgp_proc *` of the found procedure after resolving +/// `fully_qualified_procedure_name`. `memory` is used for temporary allocations +/// inside this function. ModulePtr must be kept alive to make sure it won't be +/// unloaded. +std::optional<std::pair<procedure::ModulePtr, const mgp_proc *>> FindProcedure( + const ModuleRegistry &module_registry, std::string_view fully_qualified_procedure_name, + utils::MemoryResource *memory); + +/// Return the ModulePtr and `mgp_trans *` of the found transformation after resolving +/// `fully_qualified_transformation_name`. `memory` is used for temporary allocations +/// inside this function. ModulePtr must be kept alive to make sure it won't be +/// unloaded. +std::optional<std::pair<procedure::ModulePtr, const mgp_trans *>> FindTransformation( + const ModuleRegistry &module_registry, std::string_view fully_qualified_transformation_name, + utils::MemoryResource *memory); + +/// Return the ModulePtr and `mgp_func *` of the found function after resolving +/// `fully_qualified_function_name` if found. If there is no such function +/// std::nullopt is returned. `memory` is used for temporary allocations +/// inside this function. ModulePtr must be kept alive to make sure it won't be unloaded. +std::optional<std::pair<procedure::ModulePtr, const mgp_func *>> FindFunction( + const ModuleRegistry &module_registry, std::string_view fully_qualified_function_name, + utils::MemoryResource *memory); + +template <typename T> +concept IsCallable = utils::SameAsAnyOf<T, mgp_proc, mgp_func>; + +template <IsCallable TCall> +void ConstructArguments(const std::vector<TypedValue> &args, const TCall &callable, + const std::string_view fully_qualified_name, mgp_list &args_list, mgp_graph &graph) { + const auto n_args = args.size(); + const auto c_args_sz = callable.args.size(); + const auto c_opt_args_sz = callable.opt_args.size(); + + if (n_args < c_args_sz || (n_args - c_args_sz > c_opt_args_sz)) { + if (callable.args.empty() && callable.opt_args.empty()) { + throw QueryRuntimeException("'{}' requires no arguments.", fully_qualified_name); + } + + if (callable.opt_args.empty()) { + throw QueryRuntimeException("'{}' requires exactly {} {}.", fully_qualified_name, c_args_sz, + c_args_sz == 1U ? "argument" : "arguments"); + } + + throw QueryRuntimeException("'{}' requires between {} and {} arguments.", fully_qualified_name, c_args_sz, + c_args_sz + c_opt_args_sz); + } + args_list.elems.reserve(n_args); + + auto is_not_optional_arg = [c_args_sz](int i) { return c_args_sz > i; }; + for (size_t i = 0; i < n_args; ++i) { + auto arg = args[i]; + std::string_view name; + const query::v2::procedure::CypherType *type; + if (is_not_optional_arg(i)) { + name = callable.args[i].first; + type = callable.args[i].second; + } else { + name = std::get<0>(callable.opt_args[i - c_args_sz]); + type = std::get<1>(callable.opt_args[i - c_args_sz]); + } + if (!type->SatisfiesType(arg)) { + throw QueryRuntimeException("'{}' argument named '{}' at position {} must be of type {}.", fully_qualified_name, + name, i, type->GetPresentableName()); + } + args_list.elems.emplace_back(std::move(arg), &graph); + } + // Fill missing optional arguments with their default values. + const size_t passed_in_opt_args = n_args - c_args_sz; + for (size_t i = passed_in_opt_args; i < c_opt_args_sz; ++i) { + args_list.elems.emplace_back(std::get<2>(callable.opt_args[i]), &graph); + } +} +} // namespace memgraph::query::v2::procedure diff --git a/src/query/v2/procedure/py_module.cpp b/src/query/v2/procedure/py_module.cpp new file mode 100644 index 000000000..9a26532bb --- /dev/null +++ b/src/query/v2/procedure/py_module.cpp @@ -0,0 +1,2649 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/procedure/py_module.hpp" + +#include <datetime.h> +#include <pyerrors.h> +#include <array> +#include <sstream> +#include <stdexcept> +#include <string> +#include <string_view> + +#include "mg_procedure.h" +#include "query/v2/procedure/mg_procedure_helpers.hpp" +#include "query/v2/procedure/mg_procedure_impl.hpp" +#include "utils/memory.hpp" +#include "utils/on_scope_exit.hpp" +#include "utils/pmr/vector.hpp" + +namespace memgraph::query::v2::procedure { + +namespace { +// Set this as a __reduce__ special method on our types to prevent `pickle` and +// `copy` module operations on our types. +PyObject *DisallowPickleAndCopy(PyObject *self, PyObject *Py_UNUSED(ignored)) { + auto *type = Py_TYPE(self); + std::stringstream ss; + ss << "cannot pickle nor copy '" << type->tp_name << "' object"; + const auto &msg = ss.str(); + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return nullptr; +} + +PyObject *gMgpUnknownError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +PyObject *gMgpUnableToAllocateError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +PyObject *gMgpInsufficientBufferError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +PyObject *gMgpOutOfRangeError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +PyObject *gMgpLogicErrorError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +PyObject *gMgpDeletedObjectError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +PyObject *gMgpInvalidArgumentError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +PyObject *gMgpKeyAlreadyExistsError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +PyObject *gMgpImmutableObjectError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +PyObject *gMgpValueConversionError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +PyObject *gMgpSerializationError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) + +// Returns true if an exception is raised +bool RaiseExceptionFromErrorCode(const mgp_error error) { + switch (error) { + case mgp_error::MGP_ERROR_NO_ERROR: + return false; + case mgp_error::MGP_ERROR_UNKNOWN_ERROR: { + PyErr_SetString(gMgpUnknownError, "Unknown error happened."); + return true; + } + case mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE: { + PyErr_SetString(gMgpUnableToAllocateError, "Unable to allocate memory."); + return true; + } + case mgp_error::MGP_ERROR_INSUFFICIENT_BUFFER: { + PyErr_SetString(gMgpInsufficientBufferError, "Insufficient buffer."); + return true; + } + case mgp_error::MGP_ERROR_OUT_OF_RANGE: { + PyErr_SetString(gMgpOutOfRangeError, "Out of range."); + return true; + } + case mgp_error::MGP_ERROR_LOGIC_ERROR: { + PyErr_SetString(gMgpLogicErrorError, "Logic error."); + return true; + } + case mgp_error::MGP_ERROR_DELETED_OBJECT: { + PyErr_SetString(gMgpDeletedObjectError, "Accessing deleted object."); + return true; + } + case mgp_error::MGP_ERROR_INVALID_ARGUMENT: { + PyErr_SetString(gMgpInvalidArgumentError, "Invalid argument."); + return true; + } + case mgp_error::MGP_ERROR_KEY_ALREADY_EXISTS: { + PyErr_SetString(gMgpKeyAlreadyExistsError, "Key already exists."); + return true; + } + case mgp_error::MGP_ERROR_IMMUTABLE_OBJECT: { + PyErr_SetString(gMgpImmutableObjectError, "Cannot modify immutable object."); + return true; + } + case mgp_error::MGP_ERROR_VALUE_CONVERSION: { + PyErr_SetString(gMgpValueConversionError, "Value conversion failed."); + return true; + } + case mgp_error::MGP_ERROR_SERIALIZATION_ERROR: { + PyErr_SetString(gMgpSerializationError, "Operation cannot be serialized."); + return true; + } + } +} + +mgp_value *PyObjectToMgpValueWithPythonExceptions(PyObject *py_value, mgp_memory *memory) noexcept { + try { + return PyObjectToMgpValue(py_value, memory); + } catch (const std::bad_alloc &e) { + PyErr_SetString(PyExc_MemoryError, e.what()); + return nullptr; + } catch (const std::overflow_error &e) { + PyErr_SetString(PyExc_OverflowError, e.what()); + return nullptr; + } catch (const std::invalid_argument &e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (const std::exception &e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + } catch (...) { + PyErr_SetString(PyExc_RuntimeError, "Unknown error happened"); + return nullptr; + } +} +} // namespace + +// Definitions of types wrapping C API types +// +// These should all be in the private `_mgp` Python module, which will be used +// by the `mgp` to implement the user friendly Python API. + +// Wraps mgp_graph in a PyObject. +// +// Executing a `CALL python_module.procedure(...)` in openCypher should +// instantiate exactly 1 mgp_graph instance. We will rely on this assumption in +// order to test for validity of usage. The idea is to clear the `graph` to +// `nullptr` after the execution completes. If a user stored a reference to +// `_mgp.Graph` in their global Python state, then we are no longer working with +// a valid graph so `nullptr` will catch this. `_mgp.Graph` provides `is_valid` +// method for checking this by our higher level API in `mgp` module. Python only +// does shallow copies by default, and we do not provide deep copy of +// `_mgp.Graph`, so this validity concept should work fine. +// +// clang-format off +struct PyGraph { + PyObject_HEAD + mgp_graph *graph; + mgp_memory *memory; +}; +// clang-format on + +// clang-format off +struct PyVerticesIterator { + PyObject_HEAD + mgp_vertices_iterator *it; + PyGraph *py_graph; +}; +// clang-format on + +PyObject *MakePyVertex(mgp_vertex &vertex, PyGraph *py_graph); + +void PyVerticesIteratorDealloc(PyVerticesIterator *self) { + MG_ASSERT(self->it); + MG_ASSERT(self->py_graph); + // Avoid invoking `mgp_vertices_iterator_destroy` if we are not in valid + // execution context. The query execution should free all memory used during + // execution, so we may cause a double free issue. + if (self->py_graph->graph) mgp_vertices_iterator_destroy(self->it); + Py_DECREF(self->py_graph); + Py_TYPE(self)->tp_free(self); +} + +PyObject *PyVerticesIteratorGet(PyVerticesIterator *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->it); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_vertex *vertex{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_vertices_iterator_get(self->it, &vertex))) { + return nullptr; + } + if (vertex == nullptr) { + Py_INCREF(Py_None); + return Py_None; + } + return MakePyVertex(*vertex, self->py_graph); +} + +PyObject *PyVerticesIteratorNext(PyVerticesIterator *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->it); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_vertex *vertex{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_vertices_iterator_next(self->it, &vertex))) { + return nullptr; + } + if (vertex == nullptr) { + Py_INCREF(Py_None); + return Py_None; + } + return MakePyVertex(*vertex, self->py_graph); +} + +static PyMethodDef PyVerticesIteratorMethods[] = { + {"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"get", reinterpret_cast<PyCFunction>(PyVerticesIteratorGet), METH_NOARGS, + "Get the current vertex pointed to by the iterator or return None."}, + {"next", reinterpret_cast<PyCFunction>(PyVerticesIteratorNext), METH_NOARGS, + "Advance the iterator to the next vertex and return it."}, + {nullptr}, +}; + +// clang-format off +static PyTypeObject PyVerticesIteratorType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.VerticesIterator", + .tp_basicsize = sizeof(PyVerticesIterator), + .tp_dealloc = reinterpret_cast<destructor>(PyVerticesIteratorDealloc), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_vertices_iterator.", + .tp_methods = PyVerticesIteratorMethods, +}; +// clang-format on + +// clang-format off +struct PyEdgesIterator { + PyObject_HEAD + mgp_edges_iterator *it; + PyGraph *py_graph; +}; +// clang-format on + +PyObject *MakePyEdge(mgp_edge &edge, PyGraph *py_graph); + +void PyEdgesIteratorDealloc(PyEdgesIterator *self) { + MG_ASSERT(self->it); + MG_ASSERT(self->py_graph); + // Avoid invoking `mgp_edges_iterator_destroy` if we are not in valid + // execution context. The query execution should free all memory used during + // execution, so we may cause a double free issue. + if (self->py_graph->graph) mgp_edges_iterator_destroy(self->it); + Py_DECREF(self->py_graph); + Py_TYPE(self)->tp_free(self); +} + +PyObject *PyEdgesIteratorGet(PyEdgesIterator *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->it); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_edge *edge{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_edges_iterator_get(self->it, &edge))) { + return nullptr; + } + if (edge == nullptr) { + Py_INCREF(Py_None); + return Py_None; + } + return MakePyEdge(*edge, self->py_graph); +} + +PyObject *PyEdgesIteratorNext(PyEdgesIterator *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->it); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_edge *edge{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_edges_iterator_next(self->it, &edge))) { + return nullptr; + } + if (edge == nullptr) { + Py_INCREF(Py_None); + return Py_None; + } + return MakePyEdge(*edge, self->py_graph); +} + +static PyMethodDef PyEdgesIteratorMethods[] = { + {"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"get", reinterpret_cast<PyCFunction>(PyEdgesIteratorGet), METH_NOARGS, + "Get the current edge pointed to by the iterator or return None."}, + {"next", reinterpret_cast<PyCFunction>(PyEdgesIteratorNext), METH_NOARGS, + "Advance the iterator to the next edge and return it."}, + {nullptr}, +}; + +// clang-format off +static PyTypeObject PyEdgesIteratorType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.EdgesIterator", + .tp_basicsize = sizeof(PyEdgesIterator), + .tp_dealloc = reinterpret_cast<destructor>(PyEdgesIteratorDealloc), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_edges_iterator.", + .tp_methods = PyEdgesIteratorMethods, +}; +// clang-format on + +PyObject *PyGraphInvalidate(PyGraph *self, PyObject *Py_UNUSED(ignored)) { + self->graph = nullptr; + self->memory = nullptr; + Py_RETURN_NONE; +} + +bool PyGraphIsValidImpl(PyGraph &self) { return self.graph != nullptr; } + +PyObject *PyGraphIsValid(PyGraph *self, PyObject *Py_UNUSED(ignored)) { + return PyBool_FromLong(PyGraphIsValidImpl(*self)); +} + +PyObject *PyGraphIsMutable(PyGraph *self, PyObject *Py_UNUSED(ignored)) { + return PyBool_FromLong(CallBool(mgp_graph_is_mutable, self->graph)); +} + +PyObject *MakePyVertexWithoutCopy(mgp_vertex &vertex, PyGraph *py_graph); + +PyObject *PyGraphGetVertexById(PyGraph *self, PyObject *args) { + MG_ASSERT(PyGraphIsValidImpl(*self)); + MG_ASSERT(self->memory); + static_assert(std::is_same_v<int64_t, long>); + int64_t id = 0; + if (!PyArg_ParseTuple(args, "l", &id)) return nullptr; + mgp_vertex *vertex{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_graph_get_vertex_by_id(self->graph, mgp_vertex_id{id}, self->memory, &vertex))) { + return nullptr; + } + if (!vertex) { + PyErr_SetString(PyExc_IndexError, "Unable to find the vertex with given ID."); + return nullptr; + } + auto *py_vertex = MakePyVertexWithoutCopy(*vertex, self); + if (!py_vertex) mgp_vertex_destroy(vertex); + return py_vertex; +} + +PyObject *PyGraphCreateVertex(PyGraph *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(PyGraphIsValidImpl(*self)); + MG_ASSERT(self->memory); + MgpUniquePtr<mgp_vertex> new_vertex{nullptr, mgp_vertex_destroy}; + if (RaiseExceptionFromErrorCode(CreateMgpObject(new_vertex, mgp_graph_create_vertex, self->graph, self->memory))) { + return nullptr; + } + auto *py_vertex = MakePyVertexWithoutCopy(*new_vertex, self); + if (py_vertex != nullptr) { + static_cast<void>(new_vertex.release()); + } + return py_vertex; +} + +PyObject *PyGraphCreateEdge(PyGraph *self, PyObject *args); + +PyObject *PyGraphDeleteVertex(PyGraph *self, PyObject *args); + +PyObject *PyGraphDetachDeleteVertex(PyGraph *self, PyObject *args); + +PyObject *PyGraphDeleteEdge(PyGraph *self, PyObject *args); + +PyObject *PyGraphIterVertices(PyGraph *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(PyGraphIsValidImpl(*self)); + MG_ASSERT(self->memory); + mgp_vertices_iterator *vertices_it{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_graph_iter_vertices(self->graph, self->memory, &vertices_it))) { + return nullptr; + } + auto *py_vertices_it = PyObject_New(PyVerticesIterator, &PyVerticesIteratorType); + if (!py_vertices_it) { + mgp_vertices_iterator_destroy(vertices_it); + return nullptr; + } + py_vertices_it->it = vertices_it; + Py_INCREF(self); + py_vertices_it->py_graph = self; + return reinterpret_cast<PyObject *>(py_vertices_it); +} + +PyObject *PyGraphMustAbort(PyGraph *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(PyGraphIsValidImpl(*self)); + return PyBool_FromLong(mgp_must_abort(self->graph)); +} + +static PyMethodDef PyGraphMethods[] = { + {"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"invalidate", reinterpret_cast<PyCFunction>(PyGraphInvalidate), METH_NOARGS, + "Invalidate the Graph context thus preventing the Graph from being used."}, + {"is_valid", reinterpret_cast<PyCFunction>(PyGraphIsValid), METH_NOARGS, + "Return True if Graph is in valid context and may be used."}, + {"is_mutable", reinterpret_cast<PyCFunction>(PyGraphIsMutable), METH_NOARGS, + "Return True if Graph is mutable and can be used to modify vertices and edges."}, + {"get_vertex_by_id", reinterpret_cast<PyCFunction>(PyGraphGetVertexById), METH_VARARGS, + "Get the vertex or raise IndexError."}, + {"create_vertex", reinterpret_cast<PyCFunction>(PyGraphCreateVertex), METH_NOARGS, "Create a vertex."}, + {"create_edge", reinterpret_cast<PyCFunction>(PyGraphCreateEdge), METH_VARARGS, "Create an edge."}, + {"delete_vertex", reinterpret_cast<PyCFunction>(PyGraphDeleteVertex), METH_VARARGS, "Delete a vertex."}, + {"detach_delete_vertex", reinterpret_cast<PyCFunction>(PyGraphDetachDeleteVertex), METH_VARARGS, + "Delete a vertex and all of its edges."}, + {"delete_edge", reinterpret_cast<PyCFunction>(PyGraphDeleteEdge), METH_VARARGS, "Delete an edge."}, + {"iter_vertices", reinterpret_cast<PyCFunction>(PyGraphIterVertices), METH_NOARGS, "Return _mgp.VerticesIterator."}, + {"must_abort", reinterpret_cast<PyCFunction>(PyGraphMustAbort), METH_NOARGS, + "Check whether the running procedure should abort"}, + {nullptr}, +}; + +// clang-format off +static PyTypeObject PyGraphType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.Graph", + .tp_basicsize = sizeof(PyGraph), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_graph.", + .tp_methods = PyGraphMethods, +}; +// clang-format on + +PyObject *MakePyGraph(mgp_graph *graph, mgp_memory *memory) { + MG_ASSERT(!graph || (graph && memory)); + auto *py_graph = PyObject_New(PyGraph, &PyGraphType); + if (!py_graph) return nullptr; + py_graph->graph = graph; + py_graph->memory = memory; + return reinterpret_cast<PyObject *>(py_graph); +} + +// clang-format off +struct PyCypherType { + PyObject_HEAD + mgp_type *type; +}; +// clang-format on + +// clang-format off +static PyTypeObject PyCypherTypeType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.Type", + .tp_basicsize = sizeof(PyCypherType), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_type.", +}; +// clang-format on + +PyObject *MakePyCypherType(mgp_type *type) { + MG_ASSERT(type); + auto *py_type = PyObject_New(PyCypherType, &PyCypherTypeType); + if (!py_type) return nullptr; + py_type->type = type; + return reinterpret_cast<PyObject *>(py_type); +} + +// clang-format off +struct PyQueryProc { + PyObject_HEAD + mgp_proc *callable; +}; +// clang-format on + +// clang-format off +struct PyMagicFunc{ + PyObject_HEAD + mgp_func *callable; +}; +// clang-format on + +template <typename T> +concept IsCallable = utils::SameAsAnyOf<T, PyQueryProc, PyMagicFunc>; + +template <IsCallable TCall> +PyObject *PyCallableAddArg(TCall *self, PyObject *args) { + MG_ASSERT(self->callable); + const char *name = nullptr; + PyCypherType *py_type = nullptr; + if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr; + auto *type = py_type->type; + + if constexpr (std::is_same_v<TCall, PyQueryProc>) { + if (RaiseExceptionFromErrorCode(mgp_proc_add_arg(self->callable, name, type))) { + return nullptr; + } + } else if constexpr (std::is_same_v<TCall, PyMagicFunc>) { + if (RaiseExceptionFromErrorCode(mgp_func_add_arg(self->callable, name, type))) { + return nullptr; + } + } + + Py_RETURN_NONE; +} + +template <IsCallable TCall> +PyObject *PyCallableAddOptArg(TCall *self, PyObject *args) { + MG_ASSERT(self->callable); + const char *name = nullptr; + PyCypherType *py_type = nullptr; + PyObject *py_value = nullptr; + if (!PyArg_ParseTuple(args, "sO!O", &name, &PyCypherTypeType, &py_type, &py_value)) return nullptr; + auto *type = py_type->type; + mgp_memory memory{self->callable->opt_args.get_allocator().GetMemoryResource()}; + mgp_value *value = PyObjectToMgpValueWithPythonExceptions(py_value, &memory); + if (value == nullptr) { + return nullptr; + } + if constexpr (std::is_same_v<TCall, PyQueryProc>) { + if (RaiseExceptionFromErrorCode(mgp_proc_add_opt_arg(self->callable, name, type, value))) { + mgp_value_destroy(value); + return nullptr; + } + } else if constexpr (std::is_same_v<TCall, PyMagicFunc>) { + if (RaiseExceptionFromErrorCode(mgp_func_add_opt_arg(self->callable, name, type, value))) { + mgp_value_destroy(value); + return nullptr; + } + } + + mgp_value_destroy(value); + Py_RETURN_NONE; +} + +PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) { return PyCallableAddArg(self, args); } + +PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) { return PyCallableAddOptArg(self, args); } + +PyObject *PyQueryProcAddResult(PyQueryProc *self, PyObject *args) { + MG_ASSERT(self->callable); + const char *name = nullptr; + PyCypherType *py_type = nullptr; + if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr; + + auto *type = reinterpret_cast<PyCypherType *>(py_type)->type; + if (RaiseExceptionFromErrorCode(mgp_proc_add_result(self->callable, name, type))) { + return nullptr; + } + Py_RETURN_NONE; +} + +PyObject *PyQueryProcAddDeprecatedResult(PyQueryProc *self, PyObject *args) { + MG_ASSERT(self->callable); + const char *name = nullptr; + PyCypherType *py_type = nullptr; + if (!PyArg_ParseTuple(args, "sO!", &name, &PyCypherTypeType, &py_type)) return nullptr; + auto *type = reinterpret_cast<PyCypherType *>(py_type)->type; + if (RaiseExceptionFromErrorCode(mgp_proc_add_deprecated_result(self->callable, name, type))) { + return nullptr; + } + Py_RETURN_NONE; +} + +static PyMethodDef PyQueryProcMethods[] = { + {"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"add_arg", reinterpret_cast<PyCFunction>(PyQueryProcAddArg), METH_VARARGS, + "Add a required argument to a procedure."}, + {"add_opt_arg", reinterpret_cast<PyCFunction>(PyQueryProcAddOptArg), METH_VARARGS, + "Add an optional argument with a default value to a procedure."}, + {"add_result", reinterpret_cast<PyCFunction>(PyQueryProcAddResult), METH_VARARGS, + "Add a result field to a procedure."}, + {"add_deprecated_result", reinterpret_cast<PyCFunction>(PyQueryProcAddDeprecatedResult), METH_VARARGS, + "Add a result field to a procedure and mark it as deprecated."}, + {nullptr}, +}; + +// clang-format off +static PyTypeObject PyQueryProcType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.Proc", + .tp_basicsize = sizeof(PyQueryProc), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_proc.", + .tp_methods = PyQueryProcMethods, +}; +// clang-format on + +PyObject *PyMagicFuncAddArg(PyMagicFunc *self, PyObject *args) { return PyCallableAddArg(self, args); } + +PyObject *PyMagicFuncAddOptArg(PyMagicFunc *self, PyObject *args) { return PyCallableAddOptArg(self, args); } + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static PyMethodDef PyMagicFuncMethods[] = { + {"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"add_arg", reinterpret_cast<PyCFunction>(PyMagicFuncAddArg), METH_VARARGS, + "Add a required argument to a function."}, + {"add_opt_arg", reinterpret_cast<PyCFunction>(PyMagicFuncAddOptArg), METH_VARARGS, + "Add an optional argument with a default value to a function."}, + {nullptr}, +}; + +// clang-format off +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static PyTypeObject PyMagicFuncType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.Func", + .tp_basicsize = sizeof(PyMagicFunc), + // NOLINTNEXTLINE(hicpp-signed-bitwise) + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_func.", + .tp_methods = PyMagicFuncMethods, +}; +// clang-format on + +// clang-format off +struct PyQueryModule { + PyObject_HEAD + mgp_module *module; +}; +// clang-format on + +struct PyMessages { + PyObject_HEAD; + mgp_messages *messages; + mgp_memory *memory; +}; + +struct PyMessage { + PyObject_HEAD; + mgp_message *message; + const PyMessages *messages; + mgp_memory *memory; +}; + +PyObject *PyMessagesIsValid(const PyMessages *self, PyObject *Py_UNUSED(ignored)) { + return PyBool_FromLong(!!self->messages); +} + +PyObject *PyMessageIsValid(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + return PyMessagesIsValid(self->messages, nullptr); +} + +PyObject *PyMessageGetSourceType(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->message); + MG_ASSERT(self->memory); + mgp_source_type source_type{mgp_source_type::KAFKA}; + if (RaiseExceptionFromErrorCode(mgp_message_source_type(self->message, &source_type))) { + return nullptr; + } + auto *py_source_type = PyLong_FromLong(static_cast<int64_t>(source_type)); + if (!py_source_type) { + PyErr_SetString(PyExc_RuntimeError, "Unable to get long from source type"); + return nullptr; + } + return py_source_type; +} + +PyObject *PyMessageGetPayload(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->message); + size_t payload_size{0}; + if (RaiseExceptionFromErrorCode(mgp_message_payload_size(self->message, &payload_size))) { + return nullptr; + } + const char *payload{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_message_payload(self->message, &payload))) { + return nullptr; + } + auto *raw_bytes = PyByteArray_FromStringAndSize(payload, payload_size); + if (!raw_bytes) { + PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload"); + return nullptr; + } + return raw_bytes; +} + +PyObject *PyMessageGetTopicName(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->message); + MG_ASSERT(self->memory); + const char *topic_name{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_message_topic_name(self->message, &topic_name))) { + return nullptr; + } + auto *py_topic_name = PyUnicode_FromString(topic_name); + if (!py_topic_name) { + PyErr_SetString(PyExc_RuntimeError, "Unable to get string from topic_name"); + return nullptr; + } + return py_topic_name; +} + +PyObject *PyMessageGetKey(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->message); + MG_ASSERT(self->memory); + size_t key_size{0}; + if (RaiseExceptionFromErrorCode(mgp_message_key_size(self->message, &key_size))) { + return nullptr; + } + const char *key{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_message_key(self->message, &key))) { + return nullptr; + } + auto *raw_bytes = PyByteArray_FromStringAndSize(key, key_size); + if (!raw_bytes) { + PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload"); + return nullptr; + } + return raw_bytes; +} + +PyObject *PyMessageGetTimestamp(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->message); + MG_ASSERT(self->memory); + int64_t timestamp{0}; + if (RaiseExceptionFromErrorCode(mgp_message_timestamp(self->message, ×tamp))) { + return nullptr; + } + auto *py_int = PyLong_FromUnsignedLong(timestamp); + if (!py_int) { + PyErr_SetString(PyExc_IndexError, "Unable to get timestamp."); + return nullptr; + } + return py_int; +} + +PyObject *PyMessageGetOffset(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->message); + MG_ASSERT(self->memory); + int64_t offset{0}; + if (RaiseExceptionFromErrorCode(mgp_message_offset(self->message, &offset))) { + return nullptr; + } + auto *py_int = PyLong_FromLongLong(offset); + if (!py_int) { + PyErr_SetString(PyExc_IndexError, "Unable to get offset"); + return nullptr; + } + return py_int; +} + +// NOLINTNEXTLINE +static PyMethodDef PyMessageMethods[] = { + {"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"is_valid", reinterpret_cast<PyCFunction>(PyMessageIsValid), METH_NOARGS, + "Return True if messages is in valid context and may be used."}, + {"source_type", reinterpret_cast<PyCFunction>(PyMessageGetSourceType), METH_NOARGS, "Get stream source type."}, + {"payload", reinterpret_cast<PyCFunction>(PyMessageGetPayload), METH_NOARGS, "Get payload"}, + {"topic_name", reinterpret_cast<PyCFunction>(PyMessageGetTopicName), METH_NOARGS, "Get topic name."}, + {"key", reinterpret_cast<PyCFunction>(PyMessageGetKey), METH_NOARGS, "Get message key."}, + {"timestamp", reinterpret_cast<PyCFunction>(PyMessageGetTimestamp), METH_NOARGS, "Get message timestamp."}, + {"offset", reinterpret_cast<PyCFunction>(PyMessageGetOffset), METH_NOARGS, "Get message offset."}, + {nullptr}, +}; + +void PyMessageDealloc(PyMessage *self) { + MG_ASSERT(self->memory); + MG_ASSERT(self->message); + MG_ASSERT(self->messages); + // NOLINTNEXTLINE + Py_DECREF(self->messages); + // NOLINTNEXTLINE + Py_TYPE(self)->tp_free(self); +} + +// NOLINTNEXTLINE +static PyTypeObject PyMessageType = { + PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Message", + .tp_basicsize = sizeof(PyMessage), + .tp_dealloc = reinterpret_cast<destructor>(PyMessageDealloc), + // NOLINTNEXTLINE + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_message.", + // NOLINTNEXTLINE + .tp_methods = PyMessageMethods, +}; + +PyObject *PyMessagesInvalidate(PyMessages *self, PyObject *Py_UNUSED(ignored)) { + self->messages = nullptr; + self->memory = nullptr; + Py_RETURN_NONE; +} + +PyObject *PyMessagesGetTotalMessages(PyMessages *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->messages); + MG_ASSERT(self->memory); + auto size = self->messages->messages.size(); + auto *py_int = PyLong_FromSize_t(size); + if (!py_int) { + PyErr_SetString(PyExc_IndexError, "Unable to get total messages count."); + return nullptr; + } + return py_int; +} + +PyObject *PyMessagesGetMessageAt(PyMessages *self, PyObject *args) { + MG_ASSERT(self->messages); + MG_ASSERT(self->memory); + int64_t id = 0; + if (!PyArg_ParseTuple(args, "l", &id)) return nullptr; + if (id < 0 || id >= self->messages->messages.size()) return nullptr; + auto *message = &self->messages->messages[id]; + // NOLINTNEXTLINE + auto *py_message = PyObject_New(PyMessage, &PyMessageType); + if (!py_message) { + return nullptr; + } + py_message->message = message; + // NOLINTNEXTLINE + Py_INCREF(self); + py_message->messages = self; + py_message->memory = self->memory; + if (!message) { + PyErr_SetString(PyExc_IndexError, "Unable to find the message with given index."); + return nullptr; + } + // NOLINTNEXTLINE + return reinterpret_cast<PyObject *>(py_message); +} + +// NOLINTNEXTLINE +static PyMethodDef PyMessagesMethods[] = { + {"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"invalidate", reinterpret_cast<PyCFunction>(PyMessagesInvalidate), METH_NOARGS, + "Invalidate the messages context thus preventing the messages from being used"}, + {"is_valid", reinterpret_cast<PyCFunction>(PyMessagesIsValid), METH_NOARGS, + "Return True if messages is in valid context and may be used."}, + {"total_messages", reinterpret_cast<PyCFunction>(PyMessagesGetTotalMessages), METH_VARARGS, + "Get number of messages available"}, + {"message_at", reinterpret_cast<PyCFunction>(PyMessagesGetMessageAt), METH_VARARGS, + "Get message at index idx from messages"}, + {nullptr}, +}; + +// NOLINTNEXTLINE +static PyTypeObject PyMessagesType = { + PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Messages", + .tp_basicsize = sizeof(PyMessages), + // NOLINTNEXTLINE + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_messages.", + // NOLINTNEXTLINE + .tp_methods = PyMessagesMethods, +}; + +PyObject *MakePyMessages(mgp_messages *msgs, mgp_memory *memory) { + MG_ASSERT(!msgs || (msgs && memory)); + // NOLINTNEXTLINE + auto *py_messages = PyObject_New(PyMessages, &PyMessagesType); + if (!py_messages) return nullptr; + py_messages->messages = msgs; + py_messages->memory = memory; + return reinterpret_cast<PyObject *>(py_messages); +} + +py::Object MgpListToPyTuple(mgp_list *list, PyGraph *py_graph) { + MG_ASSERT(list); + MG_ASSERT(py_graph); + const auto len = list->elems.size(); + py::Object py_tuple(PyTuple_New(len)); + if (!py_tuple) return nullptr; + for (size_t i = 0; i < len; ++i) { + auto elem = MgpValueToPyObject(list->elems[i], py_graph); + if (!elem) return nullptr; + // Explicitly convert `py_tuple`, which is `py::Object`, via static_cast. + // Then the macro will cast it to `PyTuple *`. + PyTuple_SET_ITEM(py_tuple.Ptr(), i, elem.Steal()); + } + return py_tuple; +} + +py::Object MgpListToPyTuple(mgp_list *list, PyObject *py_graph) { + if (Py_TYPE(py_graph) != &PyGraphType) { + PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Graph."); + return nullptr; + } + return MgpListToPyTuple(list, reinterpret_cast<PyGraph *>(py_graph)); +} + +namespace { +std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result, py::Object py_record) { + py::Object py_mgp(PyImport_ImportModule("mgp")); + if (!py_mgp) return py::FetchError(); + auto record_cls = py_mgp.GetAttr("Record"); + if (!record_cls) return py::FetchError(); + if (!PyObject_IsInstance(py_record.Ptr(), record_cls.Ptr())) { + std::stringstream ss; + ss << "Value '" << py_record << "' is not an instance of 'mgp.Record'"; + const auto &msg = ss.str(); + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return py::FetchError(); + } + py::Object fields(py_record.GetAttr("fields")); + if (!fields) return py::FetchError(); + if (!PyDict_Check(fields)) { + PyErr_SetString(PyExc_TypeError, "Expected 'mgp.Record.fields' to be a 'dict'"); + return py::FetchError(); + } + py::Object items(PyDict_Items(fields.Ptr())); + if (!items) return py::FetchError(); + mgp_result_record *record{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_result_new_record(result, &record))) { + return py::FetchError(); + } + Py_ssize_t len = PyList_GET_SIZE(items.Ptr()); + for (Py_ssize_t i = 0; i < len; ++i) { + auto *item = PyList_GET_ITEM(items.Ptr(), i); + if (!item) return py::FetchError(); + MG_ASSERT(PyTuple_Check(item)); + auto *key = PyTuple_GetItem(item, 0); + if (!key) return py::FetchError(); + if (!PyUnicode_Check(key)) { + std::stringstream ss; + ss << "Field name '" << py::Object::FromBorrow(key) << "' is not an instance of 'str'"; + const auto &msg = ss.str(); + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return py::FetchError(); + } + const auto *field_name = PyUnicode_AsUTF8(key); + if (!field_name) return py::FetchError(); + auto *val = PyTuple_GetItem(item, 1); + if (!val) return py::FetchError(); + mgp_memory memory{result->rows.get_allocator().GetMemoryResource()}; + mgp_value *field_val = PyObjectToMgpValueWithPythonExceptions(val, &memory); + if (field_val == nullptr) { + return py::FetchError(); + } + if (mgp_result_record_insert(record, field_name, field_val) != mgp_error::MGP_ERROR_NO_ERROR) { + std::stringstream ss; + ss << "Unable to insert field '" << py::Object::FromBorrow(key) << "' with value: '" + << py::Object::FromBorrow(val) << "'; did you set the correct field type?"; + const auto &msg = ss.str(); + PyErr_SetString(PyExc_ValueError, msg.c_str()); + mgp_value_destroy(field_val); + return py::FetchError(); + } + mgp_value_destroy(field_val); + } + return std::nullopt; +} + +std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(mgp_result *result, py::Object py_seq) { + Py_ssize_t len = PySequence_Size(py_seq.Ptr()); + if (len == -1) return py::FetchError(); + for (Py_ssize_t i = 0; i < len; ++i) { + py::Object py_record(PySequence_GetItem(py_seq.Ptr(), i)); + if (!py_record) return py::FetchError(); + auto maybe_exc = AddRecordFromPython(result, py_record); + if (maybe_exc) return maybe_exc; + } + return std::nullopt; +} + +std::function<void()> PyObjectCleanup(py::Object &py_object) { + return [py_object]() { + // Run `gc.collect` (reference cycle-detection) explicitly, so that we are + // sure the procedure cleaned up everything it held references to. If the + // user stored a reference to one of our `_mgp` instances then the + // internally used `mgp_*` structs will stay unfreed and a memory leak + // will be reported at the end of the query execution. + py::Object gc(PyImport_ImportModule("gc")); + if (!gc) { + LOG_FATAL(py::FetchError().value()); + } + + if (!gc.CallMethod("collect")) { + LOG_FATAL(py::FetchError().value()); + } + + // After making sure all references from our side have been cleared, + // invalidate the `_mgp.Graph` object. If the user kept a reference to one + // of our `_mgp` instances then this will prevent them from using those + // objects (whose internal `mgp_*` pointers are now invalid and would cause + // a crash). + if (!py_object.CallMethod("invalidate")) { + LOG_FATAL(py::FetchError().value()); + } + }; +} + +void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_result *result, + mgp_memory *memory) { + auto gil = py::EnsureGIL(); + + auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> { + if (!exc_info) return std::nullopt; + // Here we tell the traceback formatter to skip the first line of the + // traceback because that line will always be our wrapper function in our + // internal `mgp.py` file. With that line skipped, the user will always + // get only the relevant traceback that happened in his Python code. + return py::FormatException(*exc_info, /* skip_first_line = */ true); + }; + + auto call = [&](py::Object py_graph) -> std::optional<py::ExceptionInfo> { + py::Object py_args(MgpListToPyTuple(args, py_graph.Ptr())); + if (!py_args) return py::FetchError(); + auto py_res = py_cb.Call(py_graph, py_args); + if (!py_res) return py::FetchError(); + if (PySequence_Check(py_res.Ptr())) { + return AddMultipleRecordsFromPython(result, py_res); + } else { + return AddRecordFromPython(result, py_res); + } + }; + + // It is *VERY IMPORTANT* to note that this code takes great care not to keep + // any extra references to any `_mgp` instances (except for `_mgp.Graph`), so + // as not to introduce extra reference counts and prevent their deallocation. + // In particular, the `ExceptionInfo` object has a `traceback` field that + // contains references to the Python frames and their arguments, and therefore + // our `_mgp` instances as well. Within this code we ensure not to keep the + // `ExceptionInfo` object alive so that no extra reference counts are + // introduced. We only fetch the error message and immediately destroy the + // object. + std::optional<std::string> maybe_msg; + { + py::Object py_graph(MakePyGraph(graph, memory)); + utils::OnScopeExit clean_up(PyObjectCleanup(py_graph)); + if (py_graph) { + maybe_msg = error_to_msg(call(py_graph)); + } else { + maybe_msg = error_to_msg(py::FetchError()); + } + } + + if (maybe_msg) { + static_cast<void>(mgp_result_set_error_msg(result, maybe_msg->c_str())); + } +} + +void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_graph *graph, mgp_result *result, + mgp_memory *memory) { + auto gil = py::EnsureGIL(); + + auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> { + if (!exc_info) return std::nullopt; + // Here we tell the traceback formatter to skip the first line of the + // traceback because that line will always be our wrapper function in our + // internal `mgp.py` file. With that line skipped, the user will always + // get only the relevant traceback that happened in his Python code. + return py::FormatException(*exc_info, /* skip_first_line = */ true); + }; + + auto call = [&](py::Object py_graph, py::Object py_messages) -> std::optional<py::ExceptionInfo> { + auto py_res = py_cb.Call(py_graph, py_messages); + if (!py_res) return py::FetchError(); + if (PySequence_Check(py_res.Ptr())) { + return AddMultipleRecordsFromPython(result, py_res); + } + return AddRecordFromPython(result, py_res); + }; + + // It is *VERY IMPORTANT* to note that this code takes great care not to keep + // any extra references to any `_mgp` instances (except for `_mgp.Graph`), so + // as not to introduce extra reference counts and prevent their deallocation. + // In particular, the `ExceptionInfo` object has a `traceback` field that + // contains references to the Python frames and their arguments, and therefore + // our `_mgp` instances as well. Within this code we ensure not to keep the + // `ExceptionInfo` object alive so that no extra reference counts are + // introduced. We only fetch the error message and immediately destroy the + // object. + std::optional<std::string> maybe_msg; + { + py::Object py_graph(MakePyGraph(graph, memory)); + py::Object py_messages(MakePyMessages(msgs, memory)); + + utils::OnScopeExit clean_up_graph(PyObjectCleanup(py_graph)); + utils::OnScopeExit clean_up_messages(PyObjectCleanup(py_messages)); + + if (py_graph && py_messages) { + maybe_msg = error_to_msg(call(py_graph, py_messages)); + } else { + maybe_msg = error_to_msg(py::FetchError()); + } + } + + if (maybe_msg) { + static_cast<void>(mgp_result_set_error_msg(result, maybe_msg->c_str())); + } +} + +void CallPythonFunction(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_func_result *result, + mgp_memory *memory) { + auto gil = py::EnsureGIL(); + + auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> { + if (!exc_info) return std::nullopt; + // Here we tell the traceback formatter to skip the first line of the + // traceback because that line will always be our wrapper function in our + // internal `mgp.py` file. With that line skipped, the user will always + // get only the relevant traceback that happened in his Python code. + return py::FormatException(*exc_info, /* skip_first_line = */ true); + }; + + auto call = [&](py::Object py_graph) -> utils::BasicResult<std::optional<py::ExceptionInfo>, mgp_value *> { + py::Object py_args(MgpListToPyTuple(args, py_graph.Ptr())); + if (!py_args) return {py::FetchError()}; + auto py_res = py_cb.Call(py_graph, py_args); + if (!py_res) return {py::FetchError()}; + mgp_value *ret_val = PyObjectToMgpValueWithPythonExceptions(py_res.Ptr(), memory); + if (ret_val == nullptr) { + return {py::FetchError()}; + } + return ret_val; + }; + + // It is *VERY IMPORTANT* to note that this code takes great care not to keep + // any extra references to any `_mgp` instances (except for `_mgp.Graph`), so + // as not to introduce extra reference counts and prevent their deallocation. + // In particular, the `ExceptionInfo` object has a `traceback` field that + // contains references to the Python frames and their arguments, and therefore + // our `_mgp` instances as well. Within this code we ensure not to keep the + // `ExceptionInfo` object alive so that no extra reference counts are + // introduced. We only fetch the error message and immediately destroy the + // object. + std::optional<std::string> maybe_msg; + { + py::Object py_graph(MakePyGraph(graph, memory)); + utils::OnScopeExit clean_up(PyObjectCleanup(py_graph)); + if (py_graph) { + auto maybe_result = call(py_graph); + if (!maybe_result.HasError()) { + static_cast<void>(mgp_func_result_set_value(result, maybe_result.GetValue(), memory)); + return; + } + maybe_msg = error_to_msg(maybe_result.GetError()); + } else { + maybe_msg = error_to_msg(py::FetchError()); + } + } + + if (maybe_msg) { + static_cast<void>( + mgp_func_result_set_error_msg(result, maybe_msg->c_str(), memory)); // No error fetching if this fails + } +} + +PyObject *PyQueryModuleAddProcedure(PyQueryModule *self, PyObject *cb, bool is_write_procedure) { + MG_ASSERT(self->module); + if (!PyCallable_Check(cb)) { + PyErr_SetString(PyExc_TypeError, "Expected a callable object."); + return nullptr; + } + auto py_cb = py::Object::FromBorrow(cb); + py::Object py_name(py_cb.GetAttr("__name__")); + const auto *name = PyUnicode_AsUTF8(py_name.Ptr()); + if (!name) return nullptr; + if (!IsValidIdentifierName(name)) { + PyErr_SetString(PyExc_ValueError, "Procedure name is not a valid identifier"); + return nullptr; + } + auto *memory = self->module->procedures.get_allocator().GetMemoryResource(); + mgp_proc proc(name, + [py_cb](mgp_list *args, mgp_graph *graph, mgp_result *result, mgp_memory *memory) { + CallPythonProcedure(py_cb, args, graph, result, memory); + }, + memory, {.is_write = is_write_procedure}); + const auto &[proc_it, did_insert] = self->module->procedures.emplace(name, std::move(proc)); + if (!did_insert) { + PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name."); + return nullptr; + } + auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) + if (!py_proc) return nullptr; + py_proc->callable = &proc_it->second; + return reinterpret_cast<PyObject *>(py_proc); +} +} // namespace + +PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) { + return PyQueryModuleAddProcedure(self, cb, false); +} + +PyObject *PyQueryModuleAddWriteProcedure(PyQueryModule *self, PyObject *cb) { + return PyQueryModuleAddProcedure(self, cb, true); +} + +PyObject *PyQueryModuleAddTransformation(PyQueryModule *self, PyObject *cb) { + MG_ASSERT(self->module); + if (!PyCallable_Check(cb)) { + PyErr_SetString(PyExc_TypeError, "Expected a callable object."); + return nullptr; + } + auto py_cb = py::Object::FromBorrow(cb); + py::Object py_name(py_cb.GetAttr("__name__")); + const auto *name = PyUnicode_AsUTF8(py_name.Ptr()); + if (!name) return nullptr; + if (!IsValidIdentifierName(name)) { + PyErr_SetString(PyExc_ValueError, "Transformation name is not a valid identifier"); + return nullptr; + } + auto *memory = self->module->transformations.get_allocator().GetMemoryResource(); + mgp_trans trans( + name, + [py_cb](mgp_messages *msgs, mgp_graph *graph, mgp_result *result, mgp_memory *memory) { + CallPythonTransformation(py_cb, msgs, graph, result, memory); + }, + memory); + const auto [trans_it, did_insert] = self->module->transformations.emplace(name, std::move(trans)); + if (!did_insert) { + PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name."); + return nullptr; + } + Py_RETURN_NONE; +} + +PyObject *PyQueryModuleAddFunction(PyQueryModule *self, PyObject *cb) { + MG_ASSERT(self->module); + if (!PyCallable_Check(cb)) { + PyErr_SetString(PyExc_TypeError, "Expected a callable object."); + return nullptr; + } + auto py_cb = py::Object::FromBorrow(cb); + py::Object py_name(py_cb.GetAttr("__name__")); + const auto *name = PyUnicode_AsUTF8(py_name.Ptr()); + if (!name) return nullptr; + if (!IsValidIdentifierName(name)) { + PyErr_SetString(PyExc_ValueError, "Function name is not a valid identifier"); + return nullptr; + } + auto *memory = self->module->functions.get_allocator().GetMemoryResource(); + mgp_func func( + name, + [py_cb](mgp_list *args, mgp_func_context *func_ctx, mgp_func_result *result, mgp_memory *memory) { + auto graph = mgp_graph::NonWritableGraph(*(func_ctx->impl), func_ctx->view); + return CallPythonFunction(py_cb, args, &graph, result, memory); + }, + memory); + const auto [func_it, did_insert] = self->module->functions.emplace(name, std::move(func)); + if (!did_insert) { + PyErr_SetString(PyExc_ValueError, "Already registered a function with the same name."); + return nullptr; + } + auto *py_func = PyObject_New(PyMagicFunc, &PyMagicFuncType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) + if (!py_func) return nullptr; + py_func->callable = &func_it->second; + return reinterpret_cast<PyObject *>(py_func); +} + +static PyMethodDef PyQueryModuleMethods[] = { + {"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"add_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O, + "Register a read-only procedure with this module."}, + {"add_write_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddWriteProcedure), METH_O, + "Register a writeable procedure with this module."}, + {"add_transformation", reinterpret_cast<PyCFunction>(PyQueryModuleAddTransformation), METH_O, + "Register a transformation with this module."}, + {"add_function", reinterpret_cast<PyCFunction>(PyQueryModuleAddFunction), METH_O, + "Register a function with this module."}, + {nullptr}, +}; + +// clang-format off +static PyTypeObject PyQueryModuleType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.Module", + .tp_basicsize = sizeof(PyQueryModule), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_module.", + .tp_methods = PyQueryModuleMethods, +}; +// clang-format on + +PyObject *MakePyQueryModule(mgp_module *module) { + MG_ASSERT(module); + auto *py_query_module = PyObject_New(PyQueryModule, &PyQueryModuleType); + if (!py_query_module) return nullptr; + py_query_module->module = module; + return reinterpret_cast<PyObject *>(py_query_module); +} + +PyObject *PyMgpModuleTypeNullable(PyObject *mod, PyObject *obj) { + if (Py_TYPE(obj) != &PyCypherTypeType) { + PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type."); + return nullptr; + } + auto *py_type = reinterpret_cast<PyCypherType *>(obj); + mgp_type *type{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_type_nullable(py_type->type, &type))) { + return nullptr; + } + return MakePyCypherType(type); +} + +PyObject *PyMgpModuleTypeList(PyObject *mod, PyObject *obj) { + if (Py_TYPE(obj) != &PyCypherTypeType) { + PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type."); + return nullptr; + } + auto *py_type = reinterpret_cast<PyCypherType *>(obj); + mgp_type *type{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_type_list(py_type->type, &type))) { + return nullptr; + } + return MakePyCypherType(type); +} + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define DEFINE_PY_MGP_MODULE_TYPE(capital_type, small_type) \ + PyObject *PyMgpModuleType##capital_type(PyObject * /*mod*/, PyObject *Py_UNUSED(ignored)) { \ + mgp_type *type{nullptr}; \ + if (RaiseExceptionFromErrorCode(mgp_type_##small_type(&type))) { \ + return nullptr; \ + } \ + return MakePyCypherType(type); \ + } + +DEFINE_PY_MGP_MODULE_TYPE(Any, any); +DEFINE_PY_MGP_MODULE_TYPE(Bool, bool); +DEFINE_PY_MGP_MODULE_TYPE(String, string); +DEFINE_PY_MGP_MODULE_TYPE(Int, int); +DEFINE_PY_MGP_MODULE_TYPE(Float, float); +DEFINE_PY_MGP_MODULE_TYPE(Number, number); +DEFINE_PY_MGP_MODULE_TYPE(Map, map); +DEFINE_PY_MGP_MODULE_TYPE(Node, node); +DEFINE_PY_MGP_MODULE_TYPE(Relationship, relationship); +DEFINE_PY_MGP_MODULE_TYPE(Path, path); +DEFINE_PY_MGP_MODULE_TYPE(Date, date); +DEFINE_PY_MGP_MODULE_TYPE(LocalTime, local_time); +DEFINE_PY_MGP_MODULE_TYPE(LocalDateTime, local_date_time); +DEFINE_PY_MGP_MODULE_TYPE(Duration, duration); + +static PyMethodDef PyMgpModuleMethods[] = { + {"type_nullable", PyMgpModuleTypeNullable, METH_O, + "Build a type representing either a `null` value or a value of given " + "type."}, + {"type_list", PyMgpModuleTypeList, METH_O, "Build a type representing a list of values of given type."}, + {"type_any", PyMgpModuleTypeAny, METH_NOARGS, "Get the type representing any value that isn't `null`."}, + {"type_bool", PyMgpModuleTypeBool, METH_NOARGS, "Get the type representing boolean values."}, + {"type_string", PyMgpModuleTypeString, METH_NOARGS, "Get the type representing string values."}, + {"type_int", PyMgpModuleTypeInt, METH_NOARGS, "Get the type representing integer values."}, + {"type_float", PyMgpModuleTypeFloat, METH_NOARGS, "Get the type representing floating-point values."}, + {"type_number", PyMgpModuleTypeNumber, METH_NOARGS, "Get the type representing any number value."}, + {"type_map", PyMgpModuleTypeMap, METH_NOARGS, "Get the type representing map values."}, + {"type_node", PyMgpModuleTypeNode, METH_NOARGS, "Get the type representing graph node values."}, + {"type_relationship", PyMgpModuleTypeRelationship, METH_NOARGS, + "Get the type representing graph relationship values."}, + {"type_path", PyMgpModuleTypePath, METH_NOARGS, + "Get the type representing a graph path (walk) from one node to another."}, + {"type_date", PyMgpModuleTypeDate, METH_NOARGS, "Get the type representing a Date."}, + {"type_local_time", PyMgpModuleTypeLocalTime, METH_NOARGS, "Get the type representing a LocalTime."}, + {"type_local_date_time", PyMgpModuleTypeLocalDateTime, METH_NOARGS, "Get the type representing a LocalDateTime."}, + {"type_duration", PyMgpModuleTypeDuration, METH_NOARGS, "Get the type representing a Duration."}, + {nullptr}, +}; + +// clang-format off +static PyModuleDef PyMgpModule = { + PyModuleDef_HEAD_INIT, + .m_name = "_mgp", + .m_doc = "Contains raw bindings to mg_procedure.h C API.", + .m_size = -1, + .m_methods = PyMgpModuleMethods, +}; +// clang-format on + +// clang-format off +struct PyPropertiesIterator { + PyObject_HEAD + mgp_properties_iterator *it; + PyGraph *py_graph; +}; +// clang-format on + +void PyPropertiesIteratorDealloc(PyPropertiesIterator *self) { + MG_ASSERT(self->it); + MG_ASSERT(self->py_graph); + // Avoid invoking `mgp_properties_iterator_destroy` if we are not in valid + // execution context. The query execution should free all memory used during + // execution, so we may cause a double free issue. + if (self->py_graph->graph) mgp_properties_iterator_destroy(self->it); + Py_DECREF(self->py_graph); + Py_TYPE(self)->tp_free(self); +} + +PyObject *PyPropertiesIteratorGet(PyPropertiesIterator *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->it); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_property *property{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_properties_iterator_get(self->it, &property))) { + return nullptr; + } + if (property == nullptr) { + Py_INCREF(Py_None); + return Py_None; + } + py::Object py_name(PyUnicode_FromString(property->name)); + if (!py_name) return nullptr; + auto py_value = MgpValueToPyObject(*property->value, self->py_graph); + if (!py_value) return nullptr; + return PyTuple_Pack(2, py_name.Ptr(), py_value.Ptr()); +} + +PyObject *PyPropertiesIteratorNext(PyPropertiesIterator *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->it); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_property *property{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_properties_iterator_next(self->it, &property))) { + return nullptr; + } + if (property == nullptr) { + Py_INCREF(Py_None); + return Py_None; + } + py::Object py_name(PyUnicode_FromString(property->name)); + if (!py_name) return nullptr; + auto py_value = MgpValueToPyObject(*property->value, self->py_graph); + if (!py_value) return nullptr; + return PyTuple_Pack(2, py_name.Ptr(), py_value.Ptr()); +} + +static PyMethodDef PyPropertiesIteratorMethods[] = { + {"get", reinterpret_cast<PyCFunction>(PyPropertiesIteratorGet), METH_NOARGS, + "Get the current proprety pointed to by the iterator or return None."}, + {"next", reinterpret_cast<PyCFunction>(PyPropertiesIteratorNext), METH_NOARGS, + "Advance the iterator to the next property and return it."}, + {nullptr}, +}; + +// clang-format off +static PyTypeObject PyPropertiesIteratorType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.PropertiesIterator", + .tp_basicsize = sizeof(PyPropertiesIterator), + .tp_dealloc = reinterpret_cast<destructor>(PyPropertiesIteratorDealloc), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_properties_iterator.", + .tp_methods = PyPropertiesIteratorMethods, +}; +// clang-format on + +// clang-format off +struct PyEdge { + PyObject_HEAD + mgp_edge *edge; + PyGraph *py_graph; +}; +// clang-format on + +PyObject *PyEdgeGetTypeName(PyEdge *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self); + MG_ASSERT(self->edge); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_edge_type edge_type{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_edge_get_type(self->edge, &edge_type))) { + return nullptr; + } + return PyUnicode_FromString(edge_type.name); +} + +PyObject *PyEdgeFromVertex(PyEdge *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self); + MG_ASSERT(self->edge); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + return MakePyVertex(self->edge->from, self->py_graph); +} + +PyObject *PyEdgeToVertex(PyEdge *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self); + MG_ASSERT(self->edge); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + return MakePyVertex(self->edge->to, self->py_graph); +} + +void PyEdgeDealloc(PyEdge *self) { + MG_ASSERT(self->edge); + MG_ASSERT(self->py_graph); + // Avoid invoking `mgp_edge_destroy` if we are not in valid execution context. + // The query execution should free all memory used during execution, so we may + // cause a double free issue. + if (self->py_graph->graph) mgp_edge_destroy(self->edge); + Py_DECREF(self->py_graph); + Py_TYPE(self)->tp_free(self); +} + +PyObject *PyEdgeIsValid(PyEdge *self, PyObject *Py_UNUSED(ignored)) { + return PyBool_FromLong(self->py_graph && self->py_graph->graph); +} + +PyObject *PyEdgeUnderlyingGraphIsMutable(PyEdge *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self); + MG_ASSERT(self->edge); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + return PyBool_FromLong(CallBool(mgp_graph_is_mutable, self->py_graph->graph)); +} + +PyObject *PyEdgeGetId(PyEdge *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self); + MG_ASSERT(self->edge); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_edge_id edge_id{0}; + if (RaiseExceptionFromErrorCode(mgp_edge_get_id(self->edge, &edge_id))) { + return nullptr; + } + return PyLong_FromLongLong(edge_id.as_int); +} + +PyObject *PyEdgeIterProperties(PyEdge *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self); + MG_ASSERT(self->edge); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_properties_iterator *properties_it{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_edge_iter_properties(self->edge, self->py_graph->memory, &properties_it))) { + return nullptr; + } + auto *py_properties_it = PyObject_New(PyPropertiesIterator, &PyPropertiesIteratorType); + if (!py_properties_it) { + mgp_properties_iterator_destroy(properties_it); + return nullptr; + } + py_properties_it->it = properties_it; + Py_INCREF(self->py_graph); + py_properties_it->py_graph = self->py_graph; + return reinterpret_cast<PyObject *>(py_properties_it); +} + +PyObject *PyEdgeGetProperty(PyEdge *self, PyObject *args) { + MG_ASSERT(self); + MG_ASSERT(self->edge); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + const char *prop_name = nullptr; + if (!PyArg_ParseTuple(args, "s", &prop_name)) return nullptr; + mgp_value *prop_value{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_edge_get_property(self->edge, prop_name, self->py_graph->memory, &prop_value))) { + return nullptr; + } + auto py_prop_value = MgpValueToPyObject(*prop_value, self->py_graph); + mgp_value_destroy(prop_value); + return py_prop_value.Steal(); +} + +PyObject *PyEdgeSetProperty(PyEdge *self, PyObject *args) { + MG_ASSERT(self); + MG_ASSERT(self->edge); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + const char *prop_name = nullptr; + PyObject *py_value{nullptr}; + if (!PyArg_ParseTuple(args, "sO", &prop_name, &py_value)) { + return nullptr; + } + MgpUniquePtr<mgp_value> prop_value{PyObjectToMgpValueWithPythonExceptions(py_value, self->py_graph->memory), + mgp_value_destroy}; + + if (prop_value == nullptr || + RaiseExceptionFromErrorCode(mgp_edge_set_property(self->edge, prop_name, prop_value.get()))) { + return nullptr; + } + Py_RETURN_NONE; +} + +static PyMethodDef PyEdgeMethods[] = { + {"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported."}, + {"is_valid", reinterpret_cast<PyCFunction>(PyEdgeIsValid), METH_NOARGS, + "Return True if the edge is in valid context and may be used."}, + {"underlying_graph_is_mutable", reinterpret_cast<PyCFunction>(PyEdgeUnderlyingGraphIsMutable), METH_NOARGS, + "Return True if the edge is mutable and can be modified."}, + {"get_id", reinterpret_cast<PyCFunction>(PyEdgeGetId), METH_NOARGS, "Return edge id."}, + {"get_type_name", reinterpret_cast<PyCFunction>(PyEdgeGetTypeName), METH_NOARGS, "Return the edge's type name."}, + {"from_vertex", reinterpret_cast<PyCFunction>(PyEdgeFromVertex), METH_NOARGS, "Return the edge's source vertex."}, + {"to_vertex", reinterpret_cast<PyCFunction>(PyEdgeToVertex), METH_NOARGS, "Return the edge's destination vertex."}, + {"iter_properties", reinterpret_cast<PyCFunction>(PyEdgeIterProperties), METH_NOARGS, + "Return _mgp.PropertiesIterator for this edge."}, + {"get_property", reinterpret_cast<PyCFunction>(PyEdgeGetProperty), METH_VARARGS, + "Return edge property with given name."}, + {"set_property", reinterpret_cast<PyCFunction>(PyEdgeSetProperty), METH_VARARGS, + "Set the value of the property on the edge."}, + {nullptr}, +}; + +PyObject *PyEdgeRichCompare(PyObject *self, PyObject *other, int op); + +// clang-format off +static PyTypeObject PyEdgeType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.Edge", + .tp_basicsize = sizeof(PyEdge), + .tp_dealloc = reinterpret_cast<destructor>(PyEdgeDealloc), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_edge.", + .tp_richcompare = PyEdgeRichCompare, + .tp_methods = PyEdgeMethods, +}; +// clang-format on + +PyObject *MakePyEdgeWithoutCopy(mgp_edge &edge, PyGraph *py_graph) { + MG_ASSERT(py_graph); + MG_ASSERT(py_graph->graph && py_graph->memory); + MG_ASSERT(edge.GetMemoryResource() == py_graph->memory->impl); + auto *py_edge = PyObject_New(PyEdge, &PyEdgeType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) + if (!py_edge) return nullptr; + py_edge->edge = &edge; + py_edge->py_graph = py_graph; + Py_INCREF(py_graph); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) + return reinterpret_cast<PyObject *>(py_edge); +} + +/// Create an instance of `_mgp.Edge` class. +/// +/// The created instance references an existing `_mgp.Graph` instance, which +/// marks the execution context. +PyObject *MakePyEdge(mgp_edge &edge, PyGraph *py_graph) { + MG_ASSERT(py_graph); + MG_ASSERT(py_graph->graph && py_graph->memory); + MgpUniquePtr<mgp_edge> edge_copy{nullptr, mgp_edge_destroy}; + if (RaiseExceptionFromErrorCode(CreateMgpObject(edge_copy, mgp_edge_copy, &edge, py_graph->memory))) { + return nullptr; + } + auto *py_edge = MakePyEdgeWithoutCopy(*edge_copy, py_graph); + if (py_edge != nullptr) { + static_cast<void>(edge_copy.release()); + } + return py_edge; +} + +PyObject *PyEdgeRichCompare(PyObject *self, PyObject *other, int op) { + MG_ASSERT(self); + MG_ASSERT(other); + + if (Py_TYPE(self) != &PyEdgeType || Py_TYPE(other) != &PyEdgeType || op != Py_EQ) { + Py_RETURN_NOTIMPLEMENTED; + } + + auto *e1 = reinterpret_cast<PyEdge *>(self); + auto *e2 = reinterpret_cast<PyEdge *>(other); + MG_ASSERT(e1->edge); + MG_ASSERT(e2->edge); + return PyBool_FromLong(Call<int>(mgp_edge_equal, e1->edge, e2->edge)); +} + +// clang-format off +struct PyVertex { + PyObject_HEAD + mgp_vertex *vertex; + PyGraph *py_graph; +}; +// clang-format on + +void PyVertexDealloc(PyVertex *self) { + MG_ASSERT(self->vertex); + MG_ASSERT(self->py_graph); + // Avoid invoking `mgp_vertex_destroy` if we are not in valid execution + // context. The query execution should free all memory used during + // execution, so we may cause a double free issue. + if (self->py_graph->graph) mgp_vertex_destroy(self->vertex); + Py_DECREF(self->py_graph); + Py_TYPE(self)->tp_free(self); +} + +PyObject *PyVertexIsValid(PyVertex *self, PyObject *Py_UNUSED(ignored)) { + return PyBool_FromLong(self->py_graph && self->py_graph->graph); +} + +PyObject *PyVertexUnderlyingGraphIsMutable(PyVertex *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self); + MG_ASSERT(self->vertex); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + return PyBool_FromLong(CallBool(mgp_graph_is_mutable, self->py_graph->graph)); +} + +PyObject *PyVertexGetId(PyVertex *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self); + MG_ASSERT(self->vertex); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_vertex_id id{}; + if (RaiseExceptionFromErrorCode(mgp_vertex_get_id(self->vertex, &id))) { + return nullptr; + } + return PyLong_FromLongLong(id.as_int); +} + +PyObject *PyVertexLabelsCount(PyVertex *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self); + MG_ASSERT(self->vertex); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + size_t label_count{0}; + if (RaiseExceptionFromErrorCode(mgp_vertex_labels_count(self->vertex, &label_count))) { + return nullptr; + } + return PyLong_FromSize_t(label_count); +} + +PyObject *PyVertexLabelAt(PyVertex *self, PyObject *args) { + MG_ASSERT(self); + MG_ASSERT(self->vertex); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + static_assert(std::numeric_limits<Py_ssize_t>::max() <= std::numeric_limits<size_t>::max()); + Py_ssize_t id; + if (!PyArg_ParseTuple(args, "n", &id)) { + return nullptr; + } + mgp_label label{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_vertex_label_at(self->vertex, id, &label))) { + return nullptr; + } + return PyUnicode_FromString(label.name); +} + +PyObject *PyVertexIterInEdges(PyVertex *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self); + MG_ASSERT(self->vertex); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_edges_iterator *edges_it{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_vertex_iter_in_edges(self->vertex, self->py_graph->memory, &edges_it))) { + return nullptr; + } + auto *py_edges_it = PyObject_New(PyEdgesIterator, &PyEdgesIteratorType); + if (!py_edges_it) { + mgp_edges_iterator_destroy(edges_it); + return nullptr; + } + py_edges_it->it = edges_it; + Py_INCREF(self->py_graph); + py_edges_it->py_graph = self->py_graph; + return reinterpret_cast<PyObject *>(py_edges_it); +} + +PyObject *PyVertexIterOutEdges(PyVertex *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self); + MG_ASSERT(self->vertex); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_edges_iterator *edges_it{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_vertex_iter_out_edges(self->vertex, self->py_graph->memory, &edges_it))) { + return nullptr; + } + auto *py_edges_it = PyObject_New(PyEdgesIterator, &PyEdgesIteratorType); + if (!py_edges_it) { + mgp_edges_iterator_destroy(edges_it); + return nullptr; + } + py_edges_it->it = edges_it; + Py_INCREF(self->py_graph); + py_edges_it->py_graph = self->py_graph; + return reinterpret_cast<PyObject *>(py_edges_it); +} + +PyObject *PyVertexIterProperties(PyVertex *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self); + MG_ASSERT(self->vertex); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + mgp_properties_iterator *properties_it{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_vertex_iter_properties(self->vertex, self->py_graph->memory, &properties_it))) { + return nullptr; + } + auto *py_properties_it = PyObject_New(PyPropertiesIterator, &PyPropertiesIteratorType); + if (!py_properties_it) { + mgp_properties_iterator_destroy(properties_it); + return nullptr; + } + py_properties_it->it = properties_it; + Py_INCREF(self->py_graph); + py_properties_it->py_graph = self->py_graph; + return reinterpret_cast<PyObject *>(py_properties_it); +} + +PyObject *PyVertexGetProperty(PyVertex *self, PyObject *args) { + MG_ASSERT(self); + MG_ASSERT(self->vertex); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + const char *prop_name{nullptr}; + if (!PyArg_ParseTuple(args, "s", &prop_name)) { + return nullptr; + } + mgp_value *prop_value{nullptr}; + if (RaiseExceptionFromErrorCode( + mgp_vertex_get_property(self->vertex, prop_name, self->py_graph->memory, &prop_value))) { + return nullptr; + } + auto py_prop_value = MgpValueToPyObject(*prop_value, self->py_graph); + mgp_value_destroy(prop_value); + return py_prop_value.Steal(); +} + +PyObject *PyVertexSetProperty(PyVertex *self, PyObject *args) { + MG_ASSERT(self); + MG_ASSERT(self->vertex); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + const char *prop_name = nullptr; + PyObject *py_value{nullptr}; + if (!PyArg_ParseTuple(args, "sO", &prop_name, &py_value)) { + return nullptr; + } + MgpUniquePtr<mgp_value> prop_value{PyObjectToMgpValueWithPythonExceptions(py_value, self->py_graph->memory), + mgp_value_destroy}; + + if (prop_value == nullptr || + RaiseExceptionFromErrorCode(mgp_vertex_set_property(self->vertex, prop_name, prop_value.get()))) { + return nullptr; + } + Py_RETURN_NONE; +} + +PyObject *PyVertexAddLabel(PyVertex *self, PyObject *args) { + MG_ASSERT(self); + MG_ASSERT(self->vertex); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + const char *label_name = nullptr; + if (!PyArg_ParseTuple(args, "s", &label_name)) { + return nullptr; + } + if (RaiseExceptionFromErrorCode(mgp_vertex_add_label(self->vertex, mgp_label{label_name}))) { + return nullptr; + } + Py_RETURN_NONE; +} + +PyObject *PyVertexRemoveLabel(PyVertex *self, PyObject *args) { + MG_ASSERT(self); + MG_ASSERT(self->vertex); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + const char *label_name = nullptr; + if (!PyArg_ParseTuple(args, "s", &label_name)) { + return nullptr; + } + if (RaiseExceptionFromErrorCode(mgp_vertex_remove_label(self->vertex, mgp_label{label_name}))) { + return nullptr; + } + Py_RETURN_NONE; +} + +static PyMethodDef PyVertexMethods[] = { + {"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported."}, + {"is_valid", reinterpret_cast<PyCFunction>(PyVertexIsValid), METH_NOARGS, + "Return True if the vertex is in valid context and may be used."}, + {"underlying_graph_is_mutable", reinterpret_cast<PyCFunction>(PyVertexUnderlyingGraphIsMutable), METH_NOARGS, + "Return True if the vertex is mutable and can be modified."}, + {"get_id", reinterpret_cast<PyCFunction>(PyVertexGetId), METH_NOARGS, "Return vertex id."}, + {"labels_count", reinterpret_cast<PyCFunction>(PyVertexLabelsCount), METH_NOARGS, + "Return number of lables of a vertex."}, + {"label_at", reinterpret_cast<PyCFunction>(PyVertexLabelAt), METH_VARARGS, + "Return label of a vertex on a given index."}, + {"add_label", reinterpret_cast<PyCFunction>(PyVertexAddLabel), METH_VARARGS, "Add the label to the vertex."}, + {"remove_label", reinterpret_cast<PyCFunction>(PyVertexRemoveLabel), METH_VARARGS, + "Remove the label from the vertex."}, + {"iter_in_edges", reinterpret_cast<PyCFunction>(PyVertexIterInEdges), METH_NOARGS, + "Return _mgp.EdgesIterator for in edges."}, + {"iter_out_edges", reinterpret_cast<PyCFunction>(PyVertexIterOutEdges), METH_NOARGS, + "Return _mgp.EdgesIterator for out edges."}, + {"iter_properties", reinterpret_cast<PyCFunction>(PyVertexIterProperties), METH_NOARGS, + "Return _mgp.PropertiesIterator for this vertex."}, + {"get_property", reinterpret_cast<PyCFunction>(PyVertexGetProperty), METH_VARARGS, + "Return vertex property with given name."}, + {"set_property", reinterpret_cast<PyCFunction>(PyVertexSetProperty), METH_VARARGS, + "Set the value of the property on the vertex."}, + {nullptr}, +}; + +PyObject *PyVertexRichCompare(PyObject *self, PyObject *other, int op); + +// clang-format off +static PyTypeObject PyVertexType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.Vertex", + .tp_basicsize = sizeof(PyVertex), + .tp_dealloc = reinterpret_cast<destructor>(PyVertexDealloc), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_vertex.", + .tp_richcompare = PyVertexRichCompare, + .tp_methods = PyVertexMethods, +}; +// clang-format on + +PyObject *MakePyVertexWithoutCopy(mgp_vertex &vertex, PyGraph *py_graph) { + MG_ASSERT(py_graph); + MG_ASSERT(py_graph->graph && py_graph->memory); + MG_ASSERT(vertex.GetMemoryResource() == py_graph->memory->impl); + auto *py_vertex = PyObject_New(PyVertex, &PyVertexType); + if (!py_vertex) return nullptr; + py_vertex->vertex = &vertex; + py_vertex->py_graph = py_graph; + Py_INCREF(py_graph); + return reinterpret_cast<PyObject *>(py_vertex); +} + +PyObject *MakePyVertex(mgp_vertex &vertex, PyGraph *py_graph) { + MG_ASSERT(py_graph); + MG_ASSERT(py_graph->graph && py_graph->memory); + + MgpUniquePtr<mgp_vertex> vertex_copy{nullptr, mgp_vertex_destroy}; + if (RaiseExceptionFromErrorCode(CreateMgpObject(vertex_copy, mgp_vertex_copy, &vertex, py_graph->memory))) { + return nullptr; + } + auto *py_vertex = MakePyVertexWithoutCopy(*vertex_copy, py_graph); + if (py_vertex != nullptr) { + static_cast<void>(vertex_copy.release()); + } + return py_vertex; +} + +PyObject *PyVertexRichCompare(PyObject *self, PyObject *other, int op) { + MG_ASSERT(self); + MG_ASSERT(other); + + if (Py_TYPE(self) != &PyVertexType || Py_TYPE(other) != &PyVertexType || op != Py_EQ) { + Py_RETURN_NOTIMPLEMENTED; + } + + auto *v1 = reinterpret_cast<PyVertex *>(self); + auto *v2 = reinterpret_cast<PyVertex *>(other); + MG_ASSERT(v1->vertex); + MG_ASSERT(v2->vertex); + + return PyBool_FromLong(Call<int>(mgp_vertex_equal, v1->vertex, v2->vertex)); +} + +// clang-format off +struct PyPath { + PyObject_HEAD + mgp_path *path; + PyGraph *py_graph; +}; +// clang-format on + +void PyPathDealloc(PyPath *self) { + MG_ASSERT(self->path); + MG_ASSERT(self->py_graph); + // Avoid invoking `mgp_path_destroy` if we are not in valid execution + // context. The query execution should free all memory used during + // execution, so we may cause a double free issue. + if (self->py_graph->graph) mgp_path_destroy(self->path); + Py_DECREF(self->py_graph); + Py_TYPE(self)->tp_free(self); +} + +PyObject *PyPathIsValid(PyPath *self, PyObject *Py_UNUSED(ignored)) { + return PyBool_FromLong(self->py_graph && self->py_graph->graph); +} + +PyObject *PyPathMakeWithStart(PyTypeObject *type, PyObject *vertex); + +PyObject *PyPathExpand(PyPath *self, PyObject *edge) { + MG_ASSERT(self->path); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + if (Py_TYPE(edge) != &PyEdgeType) { + PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Edge."); + return nullptr; + } + auto *py_edge = reinterpret_cast<PyEdge *>(edge); + + if (RaiseExceptionFromErrorCode(mgp_path_expand(self->path, py_edge->edge))) { + return nullptr; + } + Py_RETURN_NONE; +} + +PyObject *PyPathSize(PyPath *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->path); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + return PyLong_FromSize_t(Call<size_t>(mgp_path_size, self->path)); +} + +PyObject *PyPathVertexAt(PyPath *self, PyObject *args) { + MG_ASSERT(self->path); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + static_assert(std::numeric_limits<Py_ssize_t>::max() <= std::numeric_limits<size_t>::max()); + Py_ssize_t i; + if (!PyArg_ParseTuple(args, "n", &i)) { + return nullptr; + } + mgp_vertex *vertex{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_path_vertex_at(self->path, i, &vertex))) { + return nullptr; + } + return MakePyVertex(*vertex, self->py_graph); +} + +PyObject *PyPathEdgeAt(PyPath *self, PyObject *args) { + MG_ASSERT(self->path); + MG_ASSERT(self->py_graph); + MG_ASSERT(self->py_graph->graph); + static_assert(std::numeric_limits<Py_ssize_t>::max() <= std::numeric_limits<size_t>::max()); + Py_ssize_t i; + if (!PyArg_ParseTuple(args, "n", &i)) { + return nullptr; + } + mgp_edge *edge{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_path_edge_at(self->path, i, &edge))) { + return nullptr; + } + return MakePyEdge(*edge, self->py_graph); +} + +static PyMethodDef PyPathMethods[] = { + {"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"is_valid", reinterpret_cast<PyCFunction>(PyPathIsValid), METH_NOARGS, + "Return True if Path is in valid context and may be used."}, + {"make_with_start", reinterpret_cast<PyCFunction>(PyPathMakeWithStart), METH_O | METH_CLASS, + "Create a path with a starting vertex."}, + {"expand", reinterpret_cast<PyCFunction>(PyPathExpand), METH_O, + "Append an edge continuing from the last vertex on the path."}, + {"size", reinterpret_cast<PyCFunction>(PyPathSize), METH_NOARGS, "Return the number of edges in a mgp_path."}, + {"vertex_at", reinterpret_cast<PyCFunction>(PyPathVertexAt), METH_VARARGS, + "Return the vertex from a path at given index."}, + {"edge_at", reinterpret_cast<PyCFunction>(PyPathEdgeAt), METH_VARARGS, + "Return the edge from a path at given index."}, + {nullptr}, +}; + +// clang-format off +static PyTypeObject PyPathType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.Path", + .tp_basicsize = sizeof(PyPath), + .tp_dealloc = reinterpret_cast<destructor>(PyPathDealloc), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_path.", + .tp_methods = PyPathMethods, +}; +// clang-format on + +PyObject *MakePyPath(mgp_path *path, PyGraph *py_graph) { + MG_ASSERT(path); + MG_ASSERT(py_graph->graph && py_graph->memory); + MG_ASSERT(path->GetMemoryResource() == py_graph->memory->impl); + auto *py_path = PyObject_New(PyPath, &PyPathType); + if (!py_path) return nullptr; + py_path->path = path; + py_path->py_graph = py_graph; + Py_INCREF(py_graph); + return reinterpret_cast<PyObject *>(py_path); +} + +PyObject *MakePyPath(mgp_path &path, PyGraph *py_graph) { + MG_ASSERT(py_graph); + MG_ASSERT(py_graph->graph && py_graph->memory); + mgp_path *path_copy{nullptr}; + + if (RaiseExceptionFromErrorCode(mgp_path_copy(&path, py_graph->memory, &path_copy))) { + return nullptr; + } + auto *py_path = MakePyPath(path_copy, py_graph); + if (!py_path) mgp_path_destroy(path_copy); + return py_path; +} + +PyObject *PyPathMakeWithStart(PyTypeObject *type, PyObject *vertex) { + if (type != &PyPathType) { + PyErr_SetString(PyExc_TypeError, "Expected '<class _mgp.Path>' as the first argument."); + return nullptr; + } + if (Py_TYPE(vertex) != &PyVertexType) { + PyErr_SetString(PyExc_TypeError, "Expected a '_mgp.Vertex' as the second argument."); + return nullptr; + } + auto *py_vertex = reinterpret_cast<PyVertex *>(vertex); + mgp_path *path{nullptr}; + if (RaiseExceptionFromErrorCode(mgp_path_make_with_start(py_vertex->vertex, py_vertex->py_graph->memory, &path))) { + return nullptr; + } + auto *py_path = MakePyPath(path, py_vertex->py_graph); + if (!py_path) mgp_path_destroy(path); + return py_path; +} + +struct PyMgpError { + const char *name; + PyObject *&exception; + PyObject *&base; + const char *docstring; +}; + +bool AddModuleConstants(PyObject &module) { + // add source type constants + if (PyModule_AddIntConstant(&module, "SOURCE_TYPE_KAFKA", static_cast<int64_t>(mgp_source_type::KAFKA))) { + return false; + } + if (PyModule_AddIntConstant(&module, "SOURCE_TYPE_PULSAR", static_cast<int64_t>(mgp_source_type::PULSAR))) { + return false; + } + + return true; +} + +PyObject *PyInitMgpModule() { + PyObject *mgp = PyModule_Create(&PyMgpModule); + if (!mgp) return nullptr; + auto register_type = [mgp](auto *type, const auto *name) -> bool { + if (PyType_Ready(type) < 0) { + Py_DECREF(mgp); + return false; + } + Py_INCREF(type); + if (PyModule_AddObject(mgp, name, reinterpret_cast<PyObject *>(type)) < 0) { + Py_DECREF(type); + Py_DECREF(mgp); + return false; + } + return true; + }; + + if (!AddModuleConstants(*mgp)) return nullptr; + + if (!register_type(&PyPropertiesIteratorType, "PropertiesIterator")) return nullptr; + if (!register_type(&PyVerticesIteratorType, "VerticesIterator")) return nullptr; + if (!register_type(&PyEdgesIteratorType, "EdgesIterator")) return nullptr; + if (!register_type(&PyGraphType, "Graph")) return nullptr; + if (!register_type(&PyEdgeType, "Edge")) return nullptr; + if (!register_type(&PyQueryProcType, "Proc")) return nullptr; + if (!register_type(&PyMagicFuncType, "Func")) return nullptr; + if (!register_type(&PyQueryModuleType, "Module")) return nullptr; + if (!register_type(&PyVertexType, "Vertex")) return nullptr; + if (!register_type(&PyPathType, "Path")) return nullptr; + if (!register_type(&PyCypherTypeType, "Type")) return nullptr; + if (!register_type(&PyMessagesType, "Messages")) return nullptr; + if (!register_type(&PyMessageType, "Message")) return nullptr; + + std::array py_mgp_errors{ + PyMgpError{"_mgp.UnknownError", gMgpUnknownError, PyExc_RuntimeError, nullptr}, + PyMgpError{"_mgp.UnableToAllocateError", gMgpUnableToAllocateError, PyExc_MemoryError, nullptr}, + PyMgpError{"_mgp.InsufficientBufferError", gMgpInsufficientBufferError, PyExc_BufferError, nullptr}, + PyMgpError{"_mgp.OutOfRangeError", gMgpOutOfRangeError, PyExc_BufferError, nullptr}, + PyMgpError{"_mgp.LogicErrorError", gMgpLogicErrorError, PyExc_RuntimeError, nullptr}, + PyMgpError{"_mgp.DeletedObjectError", gMgpDeletedObjectError, PyExc_RuntimeError, nullptr}, + PyMgpError{"_mgp.InvalidArgumentError", gMgpInvalidArgumentError, PyExc_ValueError, nullptr}, + PyMgpError{"_mgp.KeyAlreadyExistsError", gMgpKeyAlreadyExistsError, PyExc_RuntimeError, nullptr}, + PyMgpError{"_mgp.ImmutableObjectError", gMgpImmutableObjectError, PyExc_RuntimeError, nullptr}, + PyMgpError{"_mgp.ValueConversionError", gMgpValueConversionError, PyExc_RuntimeError, nullptr}, + PyMgpError{"_mgp.SerializationError", gMgpSerializationError, PyExc_RuntimeError, nullptr}, + }; + Py_INCREF(Py_None); + + utils::OnScopeExit clean_up{[mgp, &py_mgp_errors] { + for (const auto &py_mgp_error : py_mgp_errors) { + Py_XDECREF(py_mgp_error.exception); + } + Py_DECREF(Py_None); + Py_DECREF(mgp); + }}; + + if (PyModule_AddObject(mgp, "_MODULE", Py_None) < 0) { + return nullptr; + } + + auto register_custom_error = [mgp](PyMgpError &py_mgp_error) { + py_mgp_error.exception = PyErr_NewException(py_mgp_error.name, py_mgp_error.base, nullptr); + if (py_mgp_error.exception == nullptr) { + return false; + } + + const auto *name_in_module = std::string_view(py_mgp_error.name).substr(5).data(); + return PyModule_AddObject(mgp, name_in_module, py_mgp_error.exception) == 0; + }; + + for (auto &py_mgp_error : py_mgp_errors) { + if (!register_custom_error(py_mgp_error)) { + return nullptr; + } + } + clean_up.Disable(); + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) + PyDateTime_IMPORT; + + return mgp; +} + +namespace { + +template <class TFun> +auto WithMgpModule(mgp_module *module_def, const TFun &fun) { + py::Object py_mgp(PyImport_ImportModule("_mgp")); + MG_ASSERT(py_mgp, "Expected builtin '_mgp' to be available for import"); + py::Object py_mgp_module(py_mgp.GetAttr("_MODULE")); + MG_ASSERT(py_mgp_module, "Expected '_mgp' to have attribute '_MODULE'"); + // NOTE: This check is not thread safe, but this should only go through + // ModuleRegistry::LoadModuleLibrary which ought to serialize loading. + MG_ASSERT(py_mgp_module.Ptr() == Py_None, + "Expected '_mgp._MODULE' to be None as we are just starting to " + "import a new module. Is some other thread also importing Python " + "modules?"); + auto *py_query_module = MakePyQueryModule(module_def); + MG_ASSERT(py_query_module); + MG_ASSERT(py_mgp.SetAttr("_MODULE", py_query_module)); + auto ret = fun(); + auto maybe_exc = py::FetchError(); + MG_ASSERT(py_mgp.SetAttr("_MODULE", Py_None)); + if (maybe_exc) { + py::RestoreError(*maybe_exc); + } + return ret; +} + +} // namespace + +py::Object ImportPyModule(const char *name, mgp_module *module_def) { + return WithMgpModule(module_def, [name]() { return py::Object(PyImport_ImportModule(name)); }); +} + +py::Object ReloadPyModule(PyObject *py_module, mgp_module *module_def) { + return WithMgpModule(module_def, [py_module]() { return py::Object(PyImport_ReloadModule(py_module)); }); +} + +py::Object MgpValueToPyObject(const mgp_value &value, PyObject *py_graph) { + if (Py_TYPE(py_graph) != &PyGraphType) { + PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Graph."); + return nullptr; + } + return MgpValueToPyObject(value, reinterpret_cast<PyGraph *>(py_graph)); +} + +py::Object MgpValueToPyObject(const mgp_value &value, PyGraph *py_graph) { + switch (value.type) { + case MGP_VALUE_TYPE_NULL: + Py_INCREF(Py_None); + return py::Object(Py_None); + case MGP_VALUE_TYPE_BOOL: + return py::Object(PyBool_FromLong(value.bool_v)); + case MGP_VALUE_TYPE_INT: + return py::Object(PyLong_FromLongLong(value.int_v)); + case MGP_VALUE_TYPE_DOUBLE: + return py::Object(PyFloat_FromDouble(value.double_v)); + case MGP_VALUE_TYPE_STRING: + return py::Object(PyUnicode_FromString(value.string_v.c_str())); + case MGP_VALUE_TYPE_LIST: + return MgpListToPyTuple(value.list_v, py_graph); + case MGP_VALUE_TYPE_MAP: { + auto *map = value.map_v; + py::Object py_dict(PyDict_New()); + if (!py_dict) { + return nullptr; + } + for (const auto &[key, val] : map->items) { + auto py_val = MgpValueToPyObject(val, py_graph); + if (!py_val) { + return nullptr; + } + // Unlike PyList_SET_ITEM, PyDict_SetItem does not steal the value. + if (PyDict_SetItemString(py_dict.Ptr(), key.c_str(), py_val.Ptr()) != 0) return nullptr; + } + return py_dict; + } + case MGP_VALUE_TYPE_VERTEX: { + py::Object py_mgp(PyImport_ImportModule("mgp")); + if (!py_mgp) return nullptr; + auto *v = value.vertex_v; + py::Object py_vertex(reinterpret_cast<PyObject *>(MakePyVertex(*v, py_graph))); + return py_mgp.CallMethod("Vertex", py_vertex); + } + case MGP_VALUE_TYPE_EDGE: { + py::Object py_mgp(PyImport_ImportModule("mgp")); + if (!py_mgp) return nullptr; + auto *e = value.edge_v; + py::Object py_edge(reinterpret_cast<PyObject *>(MakePyEdge(*e, py_graph))); + return py_mgp.CallMethod("Edge", py_edge); + } + case MGP_VALUE_TYPE_PATH: { + py::Object py_mgp(PyImport_ImportModule("mgp")); + if (!py_mgp) return nullptr; + auto *p = value.path_v; + py::Object py_path(reinterpret_cast<PyObject *>(MakePyPath(*p, py_graph))); + return py_mgp.CallMethod("Path", py_path); + } + case MGP_VALUE_TYPE_DATE: { + const auto &date = value.date_v->date; + py::Object py_date(PyDate_FromDate(date.year, date.month, date.day)); + return py_date; + } + case MGP_VALUE_TYPE_LOCAL_TIME: { + const auto &local_time = value.local_time_v->local_time; + py::Object py_local_time(PyTime_FromTime(local_time.hour, local_time.minute, local_time.second, + local_time.millisecond * 1000 + local_time.microsecond)); + return py_local_time; + } + case MGP_VALUE_TYPE_LOCAL_DATE_TIME: { + const auto &local_time = value.local_date_time_v->local_date_time.local_time; + const auto &date = value.local_date_time_v->local_date_time.date; + py::Object py_local_date_time(PyDateTime_FromDateAndTime(date.year, date.month, date.day, local_time.hour, + local_time.minute, local_time.second, + local_time.millisecond * 1000 + local_time.microsecond)); + return py_local_date_time; + } + case MGP_VALUE_TYPE_DURATION: { + const auto &duration = value.duration_v->duration; + py::Object py_duration(PyDelta_FromDSU(0, 0, duration.microseconds)); + return py_duration; + } + } +} + +mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) { + auto py_seq_to_list = [memory](PyObject *seq, Py_ssize_t len, const auto &py_seq_get_item) { + static_assert(std::numeric_limits<Py_ssize_t>::max() <= std::numeric_limits<size_t>::max()); + MgpUniquePtr<mgp_list> list{nullptr, &mgp_list_destroy}; + if (const auto err = CreateMgpObject(list, mgp_list_make_empty, len, memory); + err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error during making mgp_list"}; + } + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject *e = py_seq_get_item(seq, i); + mgp_value *v{nullptr}; + v = PyObjectToMgpValue(e, memory); + const auto err = mgp_list_append(list.get(), v); + mgp_value_destroy(v); + if (err != mgp_error::MGP_ERROR_NO_ERROR) { + if (err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } + throw std::runtime_error{"Unexpected error during appending to mgp_list"}; + } + } + mgp_value *v{nullptr}; + if (const auto err = mgp_value_make_list(list.get(), &v); err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error during making mgp_value"}; + } + static_cast<void>(list.release()); + return v; + }; + + auto is_mgp_instance = [](PyObject *obj, const char *mgp_type_name) { + py::Object py_mgp(PyImport_ImportModule("mgp")); + if (!py_mgp) { + PyErr_Clear(); + // This way we skip conversions of types from user-facing 'mgp' module. + return false; + } + auto mgp_type = py_mgp.GetAttr(mgp_type_name); + if (!mgp_type) { + PyErr_Clear(); + std::stringstream ss; + ss << "'mgp' module is missing '" << mgp_type_name << "' type"; + throw std::invalid_argument(ss.str()); + } + int res = PyObject_IsInstance(obj, mgp_type.Ptr()); + if (res == -1) { + PyErr_Clear(); + std::stringstream ss; + ss << "Error when checking object is instance of 'mgp." << mgp_type_name << "' type"; + throw std::invalid_argument(ss.str()); + } + return static_cast<bool>(res); + }; + + mgp_value *mgp_v{nullptr}; + mgp_error last_error{mgp_error::MGP_ERROR_NO_ERROR}; + + if (o == Py_None) { + last_error = mgp_value_make_null(memory, &mgp_v); + } else if (PyBool_Check(o)) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) Py_True is defined with C-style cast + last_error = mgp_value_make_bool(static_cast<int>(o == Py_True), memory, &mgp_v); + } else if (PyLong_Check(o)) { + int64_t value = PyLong_AsLong(o); + if (PyErr_Occurred()) { + PyErr_Clear(); + throw std::overflow_error("Python integer is out of range"); + } + last_error = mgp_value_make_int(value, memory, &mgp_v); + } else if (PyFloat_Check(o)) { + last_error = mgp_value_make_double(PyFloat_AsDouble(o), memory, &mgp_v); + } else if (PyUnicode_Check(o)) { // NOLINT(hicpp-signed-bitwise) + last_error = mgp_value_make_string(PyUnicode_AsUTF8(o), memory, &mgp_v); + } else if (PyList_Check(o)) { + mgp_v = py_seq_to_list(o, PyList_Size(o), [](auto *list, const auto i) { return PyList_GET_ITEM(list, i); }); + } else if (PyTuple_Check(o)) { + mgp_v = py_seq_to_list(o, PyTuple_Size(o), [](auto *tuple, const auto i) { return PyTuple_GET_ITEM(tuple, i); }); + } else if (PyDict_Check(o)) { // NOLINT(hicpp-signed-bitwise) + MgpUniquePtr<mgp_map> map{nullptr, mgp_map_destroy}; + const auto map_err = CreateMgpObject(map, mgp_map_make_empty, memory); + + if (map_err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } + if (map_err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error during creating mgp_map"}; + } + + PyObject *key{nullptr}; + PyObject *value{nullptr}; + Py_ssize_t pos{0}; + while (PyDict_Next(o, &pos, &key, &value)) { + if (!PyUnicode_Check(key)) { + throw std::invalid_argument("Dictionary keys must be strings"); + } + + const char *k = PyUnicode_AsUTF8(key); + + if (!k) { + PyErr_Clear(); + throw std::bad_alloc{}; + } + + MgpUniquePtr<mgp_value> v{PyObjectToMgpValue(value, memory), mgp_value_destroy}; + + if (const auto err = mgp_map_insert(map.get(), k, v.get()); err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error during inserting an item to mgp_map"}; + } + } + + if (const auto err = mgp_value_make_map(map.get(), &mgp_v); err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error during creating mgp_value"}; + } + static_cast<void>(map.release()); + } else if (Py_TYPE(o) == &PyEdgeType) { + MgpUniquePtr<mgp_edge> e{nullptr, mgp_edge_destroy}; + // Copy the edge and pass the ownership to the created mgp_value. + + if (const auto err = CreateMgpObject(e, mgp_edge_copy, reinterpret_cast<PyEdge *>(o)->edge, memory); + err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error during copying mgp_edge"}; + } + if (const auto err = mgp_value_make_edge(e.get(), &mgp_v); err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error during copying mgp_edge"}; + } + static_cast<void>(e.release()); + } else if (Py_TYPE(o) == &PyPathType) { + MgpUniquePtr<mgp_path> p{nullptr, mgp_path_destroy}; + // Copy the edge and pass the ownership to the created mgp_value. + + if (const auto err = CreateMgpObject(p, mgp_path_copy, reinterpret_cast<PyPath *>(o)->path, memory); + err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error during copying mgp_path"}; + } + if (const auto err = mgp_value_make_path(p.get(), &mgp_v); err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error during copying mgp_path"}; + } + static_cast<void>(p.release()); + } else if (Py_TYPE(o) == &PyVertexType) { + MgpUniquePtr<mgp_vertex> v{nullptr, mgp_vertex_destroy}; + // Copy the edge and pass the ownership to the created mgp_value. + + if (const auto err = CreateMgpObject(v, mgp_vertex_copy, reinterpret_cast<PyVertex *>(o)->vertex, memory); + err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error during copying mgp_vertex"}; + } + if (const auto err = mgp_value_make_vertex(v.get(), &mgp_v); err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error during copying mgp_vertex"}; + } + static_cast<void>(v.release()); + } else if (is_mgp_instance(o, "Edge")) { + py::Object edge(PyObject_GetAttrString(o, "_edge")); + if (!edge) { + PyErr_Clear(); + throw std::invalid_argument("'mgp.Edge' is missing '_edge' attribute"); + } + return PyObjectToMgpValue(edge.Ptr(), memory); + } else if (is_mgp_instance(o, "Vertex")) { + py::Object vertex(PyObject_GetAttrString(o, "_vertex")); + if (!vertex) { + PyErr_Clear(); + throw std::invalid_argument("'mgp.Vertex' is missing '_vertex' attribute"); + } + return PyObjectToMgpValue(vertex.Ptr(), memory); + } else if (is_mgp_instance(o, "Path")) { + py::Object path(PyObject_GetAttrString(o, "_path")); + if (!path) { + PyErr_Clear(); + throw std::invalid_argument("'mgp.Path' is missing '_path' attribute"); + } + return PyObjectToMgpValue(path.Ptr(), memory); + } else if (PyDate_CheckExact(o)) { + mgp_date_parameters parameters{ + .year = PyDateTime_GET_YEAR(o), // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + .month = PyDateTime_GET_MONTH(o), // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + .day = PyDateTime_GET_DAY(o)}; // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + MgpUniquePtr<mgp_date> date{nullptr, mgp_date_destroy}; + + if (const auto err = CreateMgpObject(date, mgp_date_from_parameters, ¶meters, memory); + err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error while creating mgp_date"}; + } + if (const auto err = mgp_value_make_date(date.get(), &mgp_v); err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error while creating mgp_value"}; + } + static_cast<void>(date.release()); + } else if (PyTime_CheckExact(o)) { + mgp_local_time_parameters parameters{ + .hour = PyDateTime_TIME_GET_HOUR(o), // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + .minute = PyDateTime_TIME_GET_MINUTE(o), // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + .second = PyDateTime_TIME_GET_SECOND(o), // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + .millisecond = + PyDateTime_TIME_GET_MICROSECOND(o) / // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + 1000, + .microsecond = + PyDateTime_TIME_GET_MICROSECOND(o) % // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + 1000}; + MgpUniquePtr<mgp_local_time> local_time{nullptr, mgp_local_time_destroy}; + + if (const auto err = CreateMgpObject(local_time, mgp_local_time_from_parameters, ¶meters, memory); + err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error while creating mgp_local_time"}; + } + if (const auto err = mgp_value_make_local_time(local_time.get(), &mgp_v); + err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error while creating mgp_value"}; + } + static_cast<void>(local_time.release()); + } else if (PyDateTime_CheckExact(o)) { + mgp_date_parameters date_parameters{ + .year = PyDateTime_GET_YEAR(o), // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + .month = PyDateTime_GET_MONTH(o), // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + .day = PyDateTime_GET_DAY(o)}; // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + mgp_local_time_parameters local_time_parameters{ + .hour = PyDateTime_DATE_GET_HOUR(o), // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + .minute = PyDateTime_DATE_GET_MINUTE(o), // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + .second = PyDateTime_DATE_GET_SECOND(o), // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + .millisecond = + PyDateTime_DATE_GET_MICROSECOND(o) / // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + 1000, + .microsecond = + PyDateTime_DATE_GET_MICROSECOND(o) % // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + 1000}; + + mgp_local_date_time_parameters parameters{&date_parameters, &local_time_parameters}; + + MgpUniquePtr<mgp_local_date_time> local_date_time{nullptr, mgp_local_date_time_destroy}; + + if (const auto err = CreateMgpObject(local_date_time, mgp_local_date_time_from_parameters, ¶meters, memory); + err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error while creating mgp_local_date_time"}; + } + if (const auto err = mgp_value_make_local_date_time(local_date_time.get(), &mgp_v); + err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error while creating mgp_value"}; + } + static_cast<void>(local_date_time.release()); + } else if (PyDelta_CheckExact(o)) { + static constexpr int64_t microseconds_in_days = + static_cast<std::chrono::microseconds>(std::chrono::days{1}).count(); + const auto days = + PyDateTime_DELTA_GET_DAYS(o); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + auto microseconds = + std::abs(days) * microseconds_in_days + + PyDateTime_DELTA_GET_SECONDS(o) * 1000 * // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + 1000 + + PyDateTime_DELTA_GET_MICROSECONDS(o); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast,hicpp-signed-bitwise) + microseconds *= days < 0 ? -1 : 1; + + MgpUniquePtr<mgp_duration> duration{nullptr, mgp_duration_destroy}; + + if (const auto err = CreateMgpObject(duration, mgp_duration_from_microseconds, microseconds, memory); + err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error while creating mgp_duration"}; + } + if (const auto err = mgp_value_make_duration(duration.get(), &mgp_v); + err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } else if (err != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error while creating mgp_value"}; + } + static_cast<void>(duration.release()); + } else { + throw std::invalid_argument("Unsupported PyObject conversion"); + } + + if (last_error == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + throw std::bad_alloc{}; + } + if (last_error != mgp_error::MGP_ERROR_NO_ERROR) { + throw std::runtime_error{"Unexpected error while creating mgp_value"}; + } + + return mgp_v; +} + +PyObject *PyGraphCreateEdge(PyGraph *self, PyObject *args) { + MG_ASSERT(PyGraphIsValidImpl(*self)); + MG_ASSERT(self->memory); + PyVertex *from{nullptr}; + PyVertex *to{nullptr}; + const char *edge_type{nullptr}; + if (!PyArg_ParseTuple(args, "O!O!s", &PyVertexType, &from, &PyVertexType, &to, &edge_type)) { + return nullptr; + } + MgpUniquePtr<mgp_edge> new_edge{nullptr, mgp_edge_destroy}; + if (RaiseExceptionFromErrorCode(CreateMgpObject(new_edge, mgp_graph_create_edge, self->graph, from->vertex, + to->vertex, mgp_edge_type{edge_type}, self->memory))) { + return nullptr; + } + auto *py_edge = MakePyEdgeWithoutCopy(*new_edge, self); + if (py_edge != nullptr) { + static_cast<void>(new_edge.release()); + } + return py_edge; +} + +PyObject *PyGraphDeleteVertex(PyGraph *self, PyObject *args) { + MG_ASSERT(PyGraphIsValidImpl(*self)); + MG_ASSERT(self->memory); + PyVertex *vertex{nullptr}; + if (!PyArg_ParseTuple(args, "O!", &PyVertexType, &vertex)) { + return nullptr; + } + if (RaiseExceptionFromErrorCode(mgp_graph_delete_vertex(self->graph, vertex->vertex))) { + return nullptr; + } + Py_RETURN_NONE; +} + +PyObject *PyGraphDetachDeleteVertex(PyGraph *self, PyObject *args) { + MG_ASSERT(PyGraphIsValidImpl(*self)); + MG_ASSERT(self->memory); + PyVertex *vertex{nullptr}; + if (!PyArg_ParseTuple(args, "O!", &PyVertexType, &vertex)) { + return nullptr; + } + if (RaiseExceptionFromErrorCode(mgp_graph_detach_delete_vertex(self->graph, vertex->vertex))) { + return nullptr; + } + Py_RETURN_NONE; +} + +PyObject *PyGraphDeleteEdge(PyGraph *self, PyObject *args) { + MG_ASSERT(PyGraphIsValidImpl(*self)); + MG_ASSERT(self->memory); + PyEdge *edge{nullptr}; + if (!PyArg_ParseTuple(args, "O!", &PyEdgeType, &edge)) { + return nullptr; + } + if (RaiseExceptionFromErrorCode(mgp_graph_delete_edge(self->graph, edge->edge))) { + return nullptr; + } + Py_RETURN_NONE; +} + +} // namespace memgraph::query::v2::procedure diff --git a/src/query/v2/procedure/py_module.hpp b/src/query/v2/procedure/py_module.hpp new file mode 100644 index 000000000..b037f7797 --- /dev/null +++ b/src/query/v2/procedure/py_module.hpp @@ -0,0 +1,82 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +/// Functions and types for loading Query Modules written in Python. +#pragma once + +#include "py/py.hpp" + +struct mgp_graph; +struct mgp_memory; +struct mgp_module; +struct mgp_value; + +namespace memgraph::query::v2::procedure { + +struct PyGraph; + +/// Convert an `mgp_value` into a Python object, referencing the given `PyGraph` +/// instance and using the same allocator as the graph. +/// +/// Values of type `MGP_VALUE_TYPE_VERTEX`, `MGP_VALUE_TYPE_EDGE` and +/// `MGP_VALUE_TYPE_PATH` are returned as `mgp.Vertex`, `mgp.Edge` and +/// `mgp.Path` respectively, and *not* their internal `_mgp` +/// representations. Other value types are converted to equivalent builtin +/// Python objects. +/// +/// Return a non-null `py::Object` instance on success. Otherwise, return a null +/// `py::Object` instance and set the appropriate Python exception. +py::Object MgpValueToPyObject(const mgp_value &value, PyGraph *py_graph); + +py::Object MgpValueToPyObject(const mgp_value &value, PyObject *py_graph); + +/// Convert a Python object into `mgp_value`, constructing it using the given +/// `mgp_memory` allocator. +/// +/// If the user-facing 'mgp' module can be imported, this function will handle +/// conversion of 'mgp.Vertex', 'mgp.Edge' and 'mgp.Path' values. +/// +/// @throw std::bad_alloc +/// @throw std::overflow_error if attempting to convert a Python integer which +/// too large to fit into int64_t. +/// @throw std::invalid_argument if the given Python object cannot be converted +/// to an mgp_value (e.g. a dictionary whose keys aren't strings or an object +/// of unsupported type). +mgp_value *PyObjectToMgpValue(PyObject *, mgp_memory *); + +/// Create the _mgp module for use in embedded Python. +/// +/// The function is to be used before Py_Initialize via the following code. +/// +/// PyImport_AppendInittab("_mgp", &query::v2::procedure::PyInitMgpModule); +PyObject *PyInitMgpModule(); + +/// Create an instance of _mgp.Graph class. +PyObject *MakePyGraph(mgp_graph *, mgp_memory *); + +/// Import a module with given name in the context of mgp_module. +/// +/// This function can only be called when '_mgp' module has been initialized in +/// Python. +/// +/// Return nullptr and set appropriate Python exception on failure. +py::Object ImportPyModule(const char *, mgp_module *); + +/// Reload already loaded Python module in the context of mgp_module. +/// +/// This function can only be called when '_mgp' module has been initialized in +/// Python. +/// +/// Return nullptr and set appropriate Python exception on failure. +py::Object ReloadPyModule(PyObject *, mgp_module *); + +} // namespace memgraph::query::v2::procedure diff --git a/src/query/v2/serialization/property_value.cpp b/src/query/v2/serialization/property_value.cpp new file mode 100644 index 000000000..d78aa7021 --- /dev/null +++ b/src/query/v2/serialization/property_value.cpp @@ -0,0 +1,127 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/serialization/property_value.hpp" +#include "storage/v3/property_value.hpp" +#include "utils/logging.hpp" + +namespace memgraph::query::v2::serialization { + +namespace { +enum class ObjectType : uint8_t { MAP, TEMPORAL_DATA }; +} // namespace + +nlohmann::json SerializePropertyValue(const storage::v3::PropertyValue &property_value) { + using Type = storage::v3::PropertyValue::Type; + switch (property_value.type()) { + case Type::Null: + return {}; + case Type::Bool: + return property_value.ValueBool(); + case Type::Int: + return property_value.ValueInt(); + case Type::Double: + return property_value.ValueDouble(); + case Type::String: + return property_value.ValueString(); + case Type::List: + return SerializePropertyValueVector(property_value.ValueList()); + case Type::Map: + return SerializePropertyValueMap(property_value.ValueMap()); + case Type::TemporalData: + const auto temporal_data = property_value.ValueTemporalData(); + auto data = nlohmann::json::object(); + data.emplace("type", static_cast<uint64_t>(ObjectType::TEMPORAL_DATA)); + data.emplace("value", nlohmann::json::object({{"type", static_cast<uint64_t>(temporal_data.type)}, + {"microseconds", temporal_data.microseconds}})); + return data; + } +} + +nlohmann::json SerializePropertyValueVector(const std::vector<storage::v3::PropertyValue> &values) { + nlohmann::json array = nlohmann::json::array(); + for (const auto &value : values) { + array.push_back(SerializePropertyValue(value)); + } + return array; +} + +nlohmann::json SerializePropertyValueMap(const std::map<std::string, storage::v3::PropertyValue> ¶meters) { + nlohmann::json data = nlohmann::json::object(); + data.emplace("type", static_cast<uint64_t>(ObjectType::MAP)); + data.emplace("value", nlohmann::json::object()); + + for (const auto &[key, value] : parameters) { + data["value"][key] = SerializePropertyValue(value); + } + + return data; +}; + +storage::v3::PropertyValue DeserializePropertyValue(const nlohmann::json &data) { + if (data.is_null()) { + return storage::v3::PropertyValue(); + } + + if (data.is_boolean()) { + return storage::v3::PropertyValue(data.get<bool>()); + } + + if (data.is_number_integer()) { + return storage::v3::PropertyValue(data.get<int64_t>()); + } + + if (data.is_number_float()) { + return storage::v3::PropertyValue(data.get<double>()); + } + + if (data.is_string()) { + return storage::v3::PropertyValue(data.get<std::string>()); + } + + if (data.is_array()) { + return storage::v3::PropertyValue(DeserializePropertyValueList(data)); + } + + MG_ASSERT(data.is_object(), "Unknown type found in the trigger storage"); + + switch (data["type"].get<ObjectType>()) { + case ObjectType::MAP: + return storage::v3::PropertyValue(DeserializePropertyValueMap(data)); + case ObjectType::TEMPORAL_DATA: + return storage::v3::PropertyValue(storage::v3::TemporalData{ + data["value"]["type"].get<storage::v3::TemporalType>(), data["value"]["microseconds"].get<int64_t>()}); + } +} + +std::vector<storage::v3::PropertyValue> DeserializePropertyValueList(const nlohmann::json::array_t &data) { + std::vector<storage::v3::PropertyValue> property_values; + property_values.reserve(data.size()); + for (const auto &value : data) { + property_values.emplace_back(DeserializePropertyValue(value)); + } + + return property_values; +} + +std::map<std::string, storage::v3::PropertyValue> DeserializePropertyValueMap(const nlohmann::json::object_t &data) { + MG_ASSERT(data.at("type").get<ObjectType>() == ObjectType::MAP, "Invalid map serialization"); + std::map<std::string, storage::v3::PropertyValue> property_values; + + const nlohmann::json::object_t &values = data.at("value"); + for (const auto &[key, value] : values) { + property_values.emplace(key, DeserializePropertyValue(value)); + } + + return property_values; +} + +} // namespace memgraph::query::v2::serialization diff --git a/src/query/v2/serialization/property_value.hpp b/src/query/v2/serialization/property_value.hpp new file mode 100644 index 000000000..85f35c043 --- /dev/null +++ b/src/query/v2/serialization/property_value.hpp @@ -0,0 +1,32 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <json/json.hpp> + +#include "storage/v3/property_value.hpp" + +namespace memgraph::query::v2::serialization { + +nlohmann::json SerializePropertyValue(const storage::v3::PropertyValue &property_value); + +nlohmann::json SerializePropertyValueVector(const std::vector<storage::v3::PropertyValue> &values); + +nlohmann::json SerializePropertyValueMap(const std::map<std::string, storage::v3::PropertyValue> ¶meters); + +storage::v3::PropertyValue DeserializePropertyValue(const nlohmann::json &data); + +std::vector<storage::v3::PropertyValue> DeserializePropertyValueList(const nlohmann::json::array_t &data); + +std::map<std::string, storage::v3::PropertyValue> DeserializePropertyValueMap(const nlohmann::json::object_t &data); + +} // namespace memgraph::query::v2::serialization diff --git a/src/query/v2/stream.hpp b/src/query/v2/stream.hpp new file mode 100644 index 000000000..cc27c3daf --- /dev/null +++ b/src/query/v2/stream.hpp @@ -0,0 +1,63 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <memory> +#include <vector> + +#include "query/v2/typed_value.hpp" +#include "utils/memory.hpp" + +namespace memgraph::query::v2 { + +/** + * `AnyStream` can wrap *any* type implementing the `Stream` concept into a + * single type. + * + * The type erasure technique is used. The original type which an `AnyStream` + * was constructed from is "erased", as `AnyStream` is not a class template and + * doesn't use the type in any way. Client code can then program just for + * `AnyStream`, rather than using static polymorphism to handle any type + * implementing the `Stream` concept. + */ +class AnyStream final { + public: + template <class TStream> + AnyStream(TStream *stream, utils::MemoryResource *memory_resource) + : content_{ + utils::Allocator<GenericWrapper<TStream>>{memory_resource}.template new_object<GenericWrapper<TStream>>( + stream), + [memory_resource](Wrapper *ptr) { + utils::Allocator<GenericWrapper<TStream>>{memory_resource} + .template delete_object<GenericWrapper<TStream>>(static_cast<GenericWrapper<TStream> *>(ptr)); + }} {} + + void Result(const std::vector<TypedValue> &values) { content_->Result(values); } + + private: + struct Wrapper { + virtual void Result(const std::vector<TypedValue> &values) = 0; + }; + + template <class TStream> + struct GenericWrapper final : public Wrapper { + explicit GenericWrapper(TStream *stream) : stream_{stream} {} + + void Result(const std::vector<TypedValue> &values) override { stream_->Result(values); } + + TStream *stream_; + }; + + std::unique_ptr<Wrapper, std::function<void(Wrapper *)>> content_; +}; + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/stream/common.cpp b/src/query/v2/stream/common.cpp new file mode 100644 index 000000000..268899b9a --- /dev/null +++ b/src/query/v2/stream/common.cpp @@ -0,0 +1,45 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/stream/common.hpp" + +#include <json/json.hpp> + +namespace memgraph::query::v2::stream { +namespace { +const std::string kBatchIntervalKey{"batch_interval"}; +const std::string kBatchSizeKey{"batch_size"}; +const std::string kTransformationName{"transformation_name"}; +} // namespace + +void to_json(nlohmann::json &data, CommonStreamInfo &&common_info) { + data[kBatchIntervalKey] = common_info.batch_interval.count(); + data[kBatchSizeKey] = common_info.batch_size; + data[kTransformationName] = common_info.transformation_name; +} + +void from_json(const nlohmann::json &data, CommonStreamInfo &common_info) { + if (const auto batch_interval = data.at(kBatchIntervalKey); !batch_interval.is_null()) { + using BatchInterval = decltype(common_info.batch_interval); + common_info.batch_interval = BatchInterval{batch_interval.get<typename BatchInterval::rep>()}; + } else { + common_info.batch_interval = kDefaultBatchInterval; + } + + if (const auto batch_size = data.at(kBatchSizeKey); !batch_size.is_null()) { + common_info.batch_size = batch_size.get<decltype(common_info.batch_size)>(); + } else { + common_info.batch_size = kDefaultBatchSize; + } + + data.at(kTransformationName).get_to(common_info.transformation_name); +} +} // namespace memgraph::query::v2::stream diff --git a/src/query/v2/stream/common.hpp b/src/query/v2/stream/common.hpp new file mode 100644 index 000000000..ee8761545 --- /dev/null +++ b/src/query/v2/stream/common.hpp @@ -0,0 +1,87 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <chrono> +#include <cstdint> +#include <functional> +#include <optional> +#include <string> + +#include <json/json.hpp> + +#include "query/v2/procedure/mg_procedure_impl.hpp" + +namespace memgraph::query::v2::stream { + +inline constexpr std::chrono::milliseconds kDefaultBatchInterval{100}; +inline constexpr int64_t kDefaultBatchSize{1000}; + +template <typename TMessage> +using ConsumerFunction = std::function<void(const std::vector<TMessage> &)>; + +struct CommonStreamInfo { + std::chrono::milliseconds batch_interval; + int64_t batch_size; + std::string transformation_name; +}; + +template <typename T> +concept ConvertableToJson = requires(T value, nlohmann::json data) { + { to_json(data, std::move(value)) } -> std::same_as<void>; + { from_json(data, value) } -> std::same_as<void>; +}; + +template <typename T> +concept ConvertableToMgpMessage = requires(T value) { + mgp_message{value}; +}; + +template <typename TStream> +concept Stream = requires(TStream stream) { + typename TStream::StreamInfo; + typename TStream::Message; + TStream{std::string{""}, typename TStream::StreamInfo{}, ConsumerFunction<typename TStream::Message>{}}; + { stream.Start() } -> std::same_as<void>; + { stream.StartWithLimit(uint64_t{}, std::optional<std::chrono::milliseconds>{}) } -> std::same_as<void>; + { stream.Stop() } -> std::same_as<void>; + { stream.IsRunning() } -> std::same_as<bool>; + { + stream.Check(std::optional<std::chrono::milliseconds>{}, std::optional<uint64_t>{}, + ConsumerFunction<typename TStream::Message>{}) + } -> std::same_as<void>; + requires std::same_as<std::decay_t<decltype(std::declval<typename TStream::StreamInfo>().common_info)>, + CommonStreamInfo>; + + requires ConvertableToMgpMessage<typename TStream::Message>; + requires ConvertableToJson<typename TStream::StreamInfo>; +}; + +enum class StreamSourceType : uint8_t { KAFKA, PULSAR }; + +constexpr std::string_view StreamSourceTypeToString(StreamSourceType type) { + switch (type) { + case StreamSourceType::KAFKA: + return "kafka"; + case StreamSourceType::PULSAR: + return "pulsar"; + } +} + +template <Stream T> +StreamSourceType StreamType(const T & /*stream*/); + +const std::string kCommonInfoKey = "common_info"; + +void to_json(nlohmann::json &data, CommonStreamInfo &&info); +void from_json(const nlohmann::json &data, CommonStreamInfo &common_info); +} // namespace memgraph::query::v2::stream diff --git a/src/query/v2/stream/sources.cpp b/src/query/v2/stream/sources.cpp new file mode 100644 index 000000000..686bc4d56 --- /dev/null +++ b/src/query/v2/stream/sources.cpp @@ -0,0 +1,137 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/stream/sources.hpp" + +#include <json/json.hpp> + +#include "integrations/constants.hpp" + +namespace memgraph::query::v2::stream { +KafkaStream::KafkaStream(std::string stream_name, StreamInfo stream_info, + ConsumerFunction<integrations::kafka::Message> consumer_function) { + integrations::kafka::ConsumerInfo consumer_info{ + .consumer_name = std::move(stream_name), + .topics = std::move(stream_info.topics), + .consumer_group = std::move(stream_info.consumer_group), + .bootstrap_servers = std::move(stream_info.bootstrap_servers), + .batch_interval = stream_info.common_info.batch_interval, + .batch_size = stream_info.common_info.batch_size, + .public_configs = std::move(stream_info.configs), + .private_configs = std::move(stream_info.credentials), + }; + consumer_.emplace(std::move(consumer_info), std::move(consumer_function)); +}; + +KafkaStream::StreamInfo KafkaStream::Info(std::string transformation_name) const { + const auto &info = consumer_->Info(); + return {{.batch_interval = info.batch_interval, + .batch_size = info.batch_size, + .transformation_name = std::move(transformation_name)}, + .topics = info.topics, + .consumer_group = info.consumer_group, + .bootstrap_servers = info.bootstrap_servers, + .configs = info.public_configs, + .credentials = info.private_configs}; +} + +void KafkaStream::Start() { consumer_->Start(); } +void KafkaStream::StartWithLimit(uint64_t batch_limit, std::optional<std::chrono::milliseconds> timeout) const { + consumer_->StartWithLimit(batch_limit, timeout); +} +void KafkaStream::Stop() { consumer_->Stop(); } +bool KafkaStream::IsRunning() const { return consumer_->IsRunning(); } + +void KafkaStream::Check(std::optional<std::chrono::milliseconds> timeout, std::optional<uint64_t> batch_limit, + const ConsumerFunction<integrations::kafka::Message> &consumer_function) const { + consumer_->Check(timeout, batch_limit, consumer_function); +} + +utils::BasicResult<std::string> KafkaStream::SetStreamOffset(const int64_t offset) { + return consumer_->SetConsumerOffsets(offset); +} + +namespace { +const std::string kTopicsKey{"topics"}; +const std::string kConsumerGroupKey{"consumer_group"}; +const std::string kBoostrapServers{"bootstrap_servers"}; +const std::string kConfigs{"configs"}; +const std::string kCredentials{"credentials"}; + +const std::unordered_map<std::string, std::string> kDefaultConfigsMap; +} // namespace + +void to_json(nlohmann::json &data, KafkaStream::StreamInfo &&info) { + data[kCommonInfoKey] = std::move(info.common_info); + data[kTopicsKey] = std::move(info.topics); + data[kConsumerGroupKey] = info.consumer_group; + data[kBoostrapServers] = std::move(info.bootstrap_servers); + data[kConfigs] = std::move(info.configs); + data[kCredentials] = std::move(info.credentials); +} + +void from_json(const nlohmann::json &data, KafkaStream::StreamInfo &info) { + data.at(kCommonInfoKey).get_to(info.common_info); + data.at(kTopicsKey).get_to(info.topics); + data.at(kConsumerGroupKey).get_to(info.consumer_group); + data.at(kBoostrapServers).get_to(info.bootstrap_servers); + // These values might not be present in the persisted JSON object + info.configs = data.value(kConfigs, kDefaultConfigsMap); + info.credentials = data.value(kCredentials, kDefaultConfigsMap); +} + +PulsarStream::PulsarStream(std::string stream_name, StreamInfo stream_info, + ConsumerFunction<integrations::pulsar::Message> consumer_function) { + integrations::pulsar::ConsumerInfo consumer_info{.batch_size = stream_info.common_info.batch_size, + .batch_interval = stream_info.common_info.batch_interval, + .topics = std::move(stream_info.topics), + .consumer_name = std::move(stream_name), + .service_url = std::move(stream_info.service_url)}; + + consumer_.emplace(std::move(consumer_info), std::move(consumer_function)); +}; + +PulsarStream::StreamInfo PulsarStream::Info(std::string transformation_name) const { + const auto &info = consumer_->Info(); + return {{.batch_interval = info.batch_interval, + .batch_size = info.batch_size, + .transformation_name = std::move(transformation_name)}, + .topics = info.topics, + .service_url = info.service_url}; +} + +void PulsarStream::Start() { consumer_->Start(); } +void PulsarStream::StartWithLimit(uint64_t batch_limit, std::optional<std::chrono::milliseconds> timeout) const { + consumer_->StartWithLimit(batch_limit, timeout); +} +void PulsarStream::Stop() { consumer_->Stop(); } +bool PulsarStream::IsRunning() const { return consumer_->IsRunning(); } +void PulsarStream::Check(std::optional<std::chrono::milliseconds> timeout, std::optional<uint64_t> batch_limit, + const ConsumerFunction<Message> &consumer_function) const { + consumer_->Check(timeout, batch_limit, consumer_function); +} + +namespace { +const std::string kServiceUrl{"service_url"}; +} // namespace + +void to_json(nlohmann::json &data, PulsarStream::StreamInfo &&info) { + data[kCommonInfoKey] = std::move(info.common_info); + data[kTopicsKey] = std::move(info.topics); + data[kServiceUrl] = std::move(info.service_url); +} + +void from_json(const nlohmann::json &data, PulsarStream::StreamInfo &info) { + data.at(kCommonInfoKey).get_to(info.common_info); + data.at(kTopicsKey).get_to(info.topics); + data.at(kServiceUrl).get_to(info.service_url); +} +} // namespace memgraph::query::v2::stream diff --git a/src/query/v2/stream/sources.hpp b/src/query/v2/stream/sources.hpp new file mode 100644 index 000000000..ba5b75dc0 --- /dev/null +++ b/src/query/v2/stream/sources.hpp @@ -0,0 +1,95 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "query/v2/stream/common.hpp" + +#include "integrations/kafka/consumer.hpp" +#include "integrations/pulsar/consumer.hpp" + +namespace memgraph::query::v2::stream { + +struct KafkaStream { + struct StreamInfo { + CommonStreamInfo common_info; + std::vector<std::string> topics; + std::string consumer_group; + std::string bootstrap_servers; + std::unordered_map<std::string, std::string> configs; + std::unordered_map<std::string, std::string> credentials; + }; + + using Message = integrations::kafka::Message; + + KafkaStream(std::string stream_name, StreamInfo stream_info, + ConsumerFunction<integrations::kafka::Message> consumer_function); + + StreamInfo Info(std::string transformation_name) const; + + void Start(); + void StartWithLimit(uint64_t batch_limit, std::optional<std::chrono::milliseconds> timeout) const; + void Stop(); + bool IsRunning() const; + + void Check(std::optional<std::chrono::milliseconds> timeout, std::optional<uint64_t> batch_limit, + const ConsumerFunction<Message> &consumer_function) const; + + utils::BasicResult<std::string> SetStreamOffset(int64_t offset); + + private: + using Consumer = integrations::kafka::Consumer; + std::optional<Consumer> consumer_; +}; + +void to_json(nlohmann::json &data, KafkaStream::StreamInfo &&info); +void from_json(const nlohmann::json &data, KafkaStream::StreamInfo &info); + +template <> +inline StreamSourceType StreamType(const KafkaStream & /*stream*/) { + return StreamSourceType::KAFKA; +} + +struct PulsarStream { + struct StreamInfo { + CommonStreamInfo common_info; + std::vector<std::string> topics; + std::string service_url; + }; + + using Message = integrations::pulsar::Message; + + PulsarStream(std::string stream_name, StreamInfo stream_info, ConsumerFunction<Message> consumer_function); + + StreamInfo Info(std::string transformation_name) const; + + void Start(); + void StartWithLimit(uint64_t batch_limit, std::optional<std::chrono::milliseconds> timeout) const; + void Stop(); + bool IsRunning() const; + + void Check(std::optional<std::chrono::milliseconds> timeout, std::optional<uint64_t> batch_limit, + const ConsumerFunction<Message> &consumer_function) const; + + private: + using Consumer = integrations::pulsar::Consumer; + std::optional<Consumer> consumer_; +}; + +void to_json(nlohmann::json &data, PulsarStream::StreamInfo &&info); +void from_json(const nlohmann::json &data, PulsarStream::StreamInfo &info); + +template <> +inline StreamSourceType StreamType(const PulsarStream & /*stream*/) { + return StreamSourceType::PULSAR; +} + +} // namespace memgraph::query::v2::stream diff --git a/src/query/v2/stream/streams.cpp b/src/query/v2/stream/streams.cpp new file mode 100644 index 000000000..563e1401f --- /dev/null +++ b/src/query/v2/stream/streams.cpp @@ -0,0 +1,772 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/stream/streams.hpp" + +#include <shared_mutex> +#include <string_view> +#include <utility> + +#include <spdlog/spdlog.h> +#include <json/json.hpp> + +#include "integrations/constants.hpp" +#include "mg_procedure.h" +#include "query/v2/db_accessor.hpp" +#include "query/v2/discard_value_stream.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/interpreter.hpp" +#include "query/v2/procedure/mg_procedure_helpers.hpp" +#include "query/v2/procedure/mg_procedure_impl.hpp" +#include "query/v2/procedure/module.hpp" +#include "query/v2/stream/sources.hpp" +#include "query/v2/typed_value.hpp" +#include "utils/event_counter.hpp" +#include "utils/logging.hpp" +#include "utils/memory.hpp" +#include "utils/on_scope_exit.hpp" +#include "utils/pmr/string.hpp" +#include "utils/variant_helpers.hpp" + +namespace EventCounter { +extern const Event MessagesConsumed; +} // namespace EventCounter + +namespace memgraph::query::v2::stream { +namespace { +inline constexpr auto kExpectedTransformationResultSize = 2; +inline constexpr auto kCheckStreamResultSize = 2; +const utils::pmr::string query_param_name{"query", utils::NewDeleteResource()}; +const utils::pmr::string params_param_name{"parameters", utils::NewDeleteResource()}; + +const std::map<std::string, storage::v3::PropertyValue> empty_parameters{}; + +auto GetStream(auto &map, const std::string &stream_name) { + if (auto it = map.find(stream_name); it != map.end()) { + return it; + } + throw StreamsException("Couldn't find stream '{}'", stream_name); +} + +std::pair<TypedValue /*query*/, TypedValue /*parameters*/> ExtractTransformationResult( + const utils::pmr::map<utils::pmr::string, TypedValue> &values, const std::string_view transformation_name, + const std::string_view stream_name) { + if (values.size() != kExpectedTransformationResultSize) { + throw StreamsException( + "Transformation '{}' in stream '{}' did not yield all fields (query, parameters) as required.", + transformation_name, stream_name); + } + + auto get_value = [&](const utils::pmr::string &field_name) mutable -> const TypedValue & { + auto it = values.find(field_name); + if (it == values.end()) { + throw StreamsException{"Transformation '{}' in stream '{}' did not yield a record with '{}' field.", + transformation_name, stream_name, field_name}; + }; + return it->second; + }; + + const auto &query_value = get_value(query_param_name); + MG_ASSERT(query_value.IsString()); + const auto ¶ms_value = get_value(params_param_name); + MG_ASSERT(params_value.IsNull() || params_value.IsMap()); + return {query_value, params_value}; +} + +template <typename TMessage> +void CallCustomTransformation(const std::string &transformation_name, const std::vector<TMessage> &messages, + mgp_result &result, storage::v3::Storage::Accessor &storage_accessor, + utils::MemoryResource &memory_resource, const std::string &stream_name) { + DbAccessor db_accessor{&storage_accessor}; + { + auto maybe_transformation = + procedure::FindTransformation(procedure::gModuleRegistry, transformation_name, utils::NewDeleteResource()); + + if (!maybe_transformation) { + throw StreamsException("Couldn't find transformation {} for stream '{}'", transformation_name, stream_name); + }; + const auto &trans = *maybe_transformation->second; + mgp_messages mgp_messages{mgp_messages::storage_type{&memory_resource}}; + std::transform(messages.begin(), messages.end(), std::back_inserter(mgp_messages.messages), + [](const TMessage &message) { return mgp_message{message}; }); + mgp_graph graph{&db_accessor, storage::v3::View::OLD, nullptr}; + mgp_memory memory{&memory_resource}; + result.rows.clear(); + result.error_msg.reset(); + result.signature = &trans.results; + + MG_ASSERT(result.signature->size() == kExpectedTransformationResultSize); + MG_ASSERT(result.signature->contains(query_param_name)); + MG_ASSERT(result.signature->contains(params_param_name)); + + spdlog::trace("Calling transformation in stream '{}'", stream_name); + trans.cb(&mgp_messages, &graph, &result, &memory); + } + if (result.error_msg.has_value()) { + throw StreamsException(result.error_msg->c_str()); + } +} + +template <Stream TStream> +StreamStatus<TStream> CreateStatus(std::string stream_name, std::string transformation_name, + std::optional<std::string> owner, const TStream &stream) { + return {.name = std::move(stream_name), + .type = StreamType(stream), + .is_running = stream.IsRunning(), + .info = stream.Info(std::move(transformation_name)), + .owner = std::move(owner)}; +} + +// nlohmann::json doesn't support string_view access yet +const std::string kStreamName{"name"}; +const std::string kIsRunningKey{"is_running"}; +const std::string kOwner{"owner"}; +const std::string kType{"type"}; +} // namespace + +template <Stream TStream> +void to_json(nlohmann::json &data, StreamStatus<TStream> &&status) { + data[kStreamName] = std::move(status.name); + data[kType] = status.type; + data[kIsRunningKey] = status.is_running; + + if (status.owner.has_value()) { + data[kOwner] = std::move(*status.owner); + } else { + data[kOwner] = nullptr; + } + + to_json(data, std::move(status.info)); +} + +template <Stream TStream> +void from_json(const nlohmann::json &data, StreamStatus<TStream> &status) { + data.at(kStreamName).get_to(status.name); + data.at(kIsRunningKey).get_to(status.is_running); + + if (const auto &owner = data.at(kOwner); !owner.is_null()) { + status.owner = owner.get<typename decltype(status.owner)::value_type>(); + } else { + status.owner = {}; + } + + from_json(data, status.info); +} + +Streams::Streams(InterpreterContext *interpreter_context, std::filesystem::path directory) + : interpreter_context_(interpreter_context), storage_(std::move(directory)) { + RegisterProcedures(); +} + +void Streams::RegisterProcedures() { + RegisterKafkaProcedures(); + RegisterPulsarProcedures(); +} + +void Streams::RegisterKafkaProcedures() { + { + static constexpr std::string_view proc_name = "kafka_set_stream_offset"; + auto set_stream_offset = [this](mgp_list *args, mgp_graph * /*graph*/, mgp_result *result, + mgp_memory * /*memory*/) { + auto *arg_stream_name = procedure::Call<mgp_value *>(mgp_list_at, args, 0); + const auto *stream_name = procedure::Call<const char *>(mgp_value_get_string, arg_stream_name); + auto *arg_offset = procedure::Call<mgp_value *>(mgp_list_at, args, 1); + const auto offset = procedure::Call<int64_t>(mgp_value_get_int, arg_offset); + auto lock_ptr = streams_.Lock(); + auto it = GetStream(*lock_ptr, std::string(stream_name)); + std::visit(utils::Overloaded{[&](StreamData<KafkaStream> &kafka_stream) { + auto stream_source_ptr = kafka_stream.stream_source->Lock(); + const auto error = stream_source_ptr->SetStreamOffset(offset); + if (error.HasError()) { + MG_ASSERT(mgp_result_set_error_msg(result, error.GetError().c_str()) == + mgp_error::MGP_ERROR_NO_ERROR, + "Unable to set procedure error message of procedure: {}", proc_name); + } + }, + [](auto && /*other*/) { + throw QueryRuntimeException("'{}' can be only used for Kafka stream sources", + proc_name); + }}, + it->second); + }; + + mgp_proc proc(proc_name, set_stream_offset, utils::NewDeleteResource()); + MG_ASSERT(mgp_proc_add_arg(&proc, "stream_name", procedure::Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_arg(&proc, "offset", procedure::Call<mgp_type *>(mgp_type_int)) == + mgp_error::MGP_ERROR_NO_ERROR); + + procedure::gModuleRegistry.RegisterMgProcedure(proc_name, std::move(proc)); + } + + { + static constexpr std::string_view proc_name = "kafka_stream_info"; + + static constexpr std::string_view consumer_group_result_name = "consumer_group"; + static constexpr std::string_view topics_result_name = "topics"; + static constexpr std::string_view bootstrap_servers_result_name = "bootstrap_servers"; + static constexpr std::string_view configs_result_name = "configs"; + static constexpr std::string_view credentials_result_name = "credentials"; + + auto get_stream_info = [this](mgp_list *args, mgp_graph * /*graph*/, mgp_result *result, mgp_memory *memory) { + auto *arg_stream_name = procedure::Call<mgp_value *>(mgp_list_at, args, 0); + const auto *stream_name = procedure::Call<const char *>(mgp_value_get_string, arg_stream_name); + auto lock_ptr = streams_.Lock(); + auto it = GetStream(*lock_ptr, std::string(stream_name)); + std::visit( + utils::Overloaded{ + [&](StreamData<KafkaStream> &kafka_stream) { + auto stream_source_ptr = kafka_stream.stream_source->Lock(); + const auto info = stream_source_ptr->Info(kafka_stream.transformation_name); + mgp_result_record *record{nullptr}; + if (!procedure::TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; + } + + const auto consumer_group_value = + procedure::GetStringValueOrSetError(info.consumer_group.c_str(), memory, result); + if (!consumer_group_value) { + return; + } + + procedure::MgpUniquePtr<mgp_list> topic_names{nullptr, mgp_list_destroy}; + if (!procedure::TryOrSetError( + [&] { + return procedure::CreateMgpObject(topic_names, mgp_list_make_empty, info.topics.size(), + memory); + }, + result)) { + return; + } + + for (const auto &topic : info.topics) { + auto topic_value = procedure::GetStringValueOrSetError(topic.c_str(), memory, result); + if (!topic_value) { + return; + } + topic_names->elems.push_back(std::move(*topic_value)); + } + + procedure::MgpUniquePtr<mgp_value> topics_value{nullptr, mgp_value_destroy}; + if (!procedure::TryOrSetError( + [&] { + return procedure::CreateMgpObject(topics_value, mgp_value_make_list, topic_names.get()); + }, + result)) { + return; + } + static_cast<void>(topic_names.release()); + + const auto bootstrap_servers_value = + procedure::GetStringValueOrSetError(info.bootstrap_servers.c_str(), memory, result); + if (!bootstrap_servers_value) { + return; + } + + const auto convert_config_map = + [result, memory](const std::unordered_map<std::string, std::string> &configs_to_convert) + -> procedure::MgpUniquePtr<mgp_value> { + procedure::MgpUniquePtr<mgp_value> configs_value{nullptr, mgp_value_destroy}; + procedure::MgpUniquePtr<mgp_map> configs{nullptr, mgp_map_destroy}; + if (!procedure::TryOrSetError( + [&] { return procedure::CreateMgpObject(configs, mgp_map_make_empty, memory); }, result)) { + return configs_value; + } + + for (const auto &[key, value] : configs_to_convert) { + auto value_value = procedure::GetStringValueOrSetError(value.c_str(), memory, result); + if (!value_value) { + return configs_value; + } + configs->items.emplace(key, std::move(*value_value)); + } + + if (!procedure::TryOrSetError( + [&] { return procedure::CreateMgpObject(configs_value, mgp_value_make_map, configs.get()); }, + result)) { + return configs_value; + } + static_cast<void>(configs.release()); + return configs_value; + }; + + const auto configs_value = convert_config_map(info.configs); + if (configs_value == nullptr) { + return; + } + + using CredentialsType = decltype(KafkaStream::StreamInfo::credentials); + CredentialsType reducted_credentials; + std::transform(info.credentials.begin(), info.credentials.end(), + std::inserter(reducted_credentials, reducted_credentials.end()), + [](const auto &pair) -> CredentialsType::value_type { + return {pair.first, integrations::kReducted}; + }); + + const auto credentials_value = convert_config_map(reducted_credentials); + if (credentials_value == nullptr) { + return; + } + + if (!procedure::InsertResultOrSetError(result, record, consumer_group_result_name.data(), + consumer_group_value.get())) { + return; + } + + if (!procedure::InsertResultOrSetError(result, record, topics_result_name.data(), topics_value.get())) { + return; + } + + if (!procedure::InsertResultOrSetError(result, record, bootstrap_servers_result_name.data(), + bootstrap_servers_value.get())) { + return; + } + + if (!procedure::InsertResultOrSetError(result, record, configs_result_name.data(), + configs_value.get())) { + return; + } + + if (!procedure::InsertResultOrSetError(result, record, credentials_result_name.data(), + credentials_value.get())) { + return; + } + }, + [](auto && /*other*/) { + throw QueryRuntimeException("'{}' can be only used for Kafka stream sources", proc_name); + }}, + it->second); + }; + + mgp_proc proc(proc_name, get_stream_info, utils::NewDeleteResource()); + MG_ASSERT(mgp_proc_add_arg(&proc, "stream_name", procedure::Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&proc, consumer_group_result_name.data(), + procedure::Call<mgp_type *>(mgp_type_string)) == mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT( + mgp_proc_add_result(&proc, topics_result_name.data(), + procedure::Call<mgp_type *>(mgp_type_list, procedure::Call<mgp_type *>(mgp_type_string))) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&proc, bootstrap_servers_result_name.data(), + procedure::Call<mgp_type *>(mgp_type_string)) == mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&proc, configs_result_name.data(), procedure::Call<mgp_type *>(mgp_type_map)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&proc, credentials_result_name.data(), procedure::Call<mgp_type *>(mgp_type_map)) == + mgp_error::MGP_ERROR_NO_ERROR); + + procedure::gModuleRegistry.RegisterMgProcedure(proc_name, std::move(proc)); + } +} + +void Streams::RegisterPulsarProcedures() { + { + static constexpr std::string_view proc_name = "pulsar_stream_info"; + static constexpr std::string_view service_url_result_name = "service_url"; + static constexpr std::string_view topics_result_name = "topics"; + auto get_stream_info = [this](mgp_list *args, mgp_graph * /*graph*/, mgp_result *result, mgp_memory *memory) { + auto *arg_stream_name = procedure::Call<mgp_value *>(mgp_list_at, args, 0); + const auto *stream_name = procedure::Call<const char *>(mgp_value_get_string, arg_stream_name); + auto lock_ptr = streams_.Lock(); + auto it = GetStream(*lock_ptr, std::string(stream_name)); + std::visit( + utils::Overloaded{ + [&](StreamData<PulsarStream> &pulsar_stream) { + auto stream_source_ptr = pulsar_stream.stream_source->Lock(); + const auto info = stream_source_ptr->Info(pulsar_stream.transformation_name); + mgp_result_record *record{nullptr}; + if (!procedure::TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; + } + + auto service_url_value = procedure::GetStringValueOrSetError(info.service_url.c_str(), memory, result); + if (!service_url_value) { + return; + } + + procedure::MgpUniquePtr<mgp_list> topic_names{nullptr, mgp_list_destroy}; + if (!procedure::TryOrSetError( + [&] { + return procedure::CreateMgpObject(topic_names, mgp_list_make_empty, info.topics.size(), + memory); + }, + result)) { + return; + } + + for (const auto &topic : info.topics) { + auto topic_value = procedure::GetStringValueOrSetError(topic.c_str(), memory, result); + if (!topic_value) { + return; + } + topic_names->elems.push_back(std::move(*topic_value)); + } + + procedure::MgpUniquePtr<mgp_value> topics_value{nullptr, mgp_value_destroy}; + if (!procedure::TryOrSetError( + [&] { + return procedure::CreateMgpObject(topics_value, mgp_value_make_list, topic_names.release()); + }, + result)) { + return; + } + + if (!procedure::InsertResultOrSetError(result, record, topics_result_name.data(), topics_value.get())) { + return; + } + + if (!procedure::InsertResultOrSetError(result, record, service_url_result_name.data(), + service_url_value.get())) { + return; + } + }, + [](auto && /*other*/) { + throw QueryRuntimeException("'{}' can be only used for Pulsar stream sources", proc_name); + }}, + it->second); + }; + + mgp_proc proc(proc_name, get_stream_info, utils::NewDeleteResource()); + MG_ASSERT(mgp_proc_add_arg(&proc, "stream_name", procedure::Call<mgp_type *>(mgp_type_string)) == + mgp_error::MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&proc, service_url_result_name.data(), + procedure::Call<mgp_type *>(mgp_type_string)) == mgp_error::MGP_ERROR_NO_ERROR); + + MG_ASSERT( + mgp_proc_add_result(&proc, topics_result_name.data(), + procedure::Call<mgp_type *>(mgp_type_list, procedure::Call<mgp_type *>(mgp_type_string))) == + mgp_error::MGP_ERROR_NO_ERROR); + + procedure::gModuleRegistry.RegisterMgProcedure(proc_name, std::move(proc)); + } +} + +template <Stream TStream> +void Streams::Create(const std::string &stream_name, typename TStream::StreamInfo info, + std::optional<std::string> owner) { + auto locked_streams = streams_.Lock(); + auto it = CreateConsumer<TStream>(*locked_streams, stream_name, std::move(info), std::move(owner)); + + try { + std::visit( + [&](const auto &stream_data) { + const auto stream_source_ptr = stream_data.stream_source->ReadLock(); + Persist(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *stream_source_ptr)); + }, + it->second); + } catch (...) { + locked_streams->erase(it); + throw; + } +} + +template void Streams::Create<KafkaStream>(const std::string &stream_name, KafkaStream::StreamInfo info, + std::optional<std::string> owner); +template void Streams::Create<PulsarStream>(const std::string &stream_name, PulsarStream::StreamInfo info, + std::optional<std::string> owner); + +template <Stream TStream> +Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std::string &stream_name, + typename TStream::StreamInfo stream_info, + std::optional<std::string> owner) { + if (map.contains(stream_name)) { + throw StreamsException{"Stream already exists with name '{}'", stream_name}; + } + + auto *memory_resource = utils::NewDeleteResource(); + + auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, stream_name, + transformation_name = stream_info.common_info.transformation_name, owner = owner, + interpreter = std::make_shared<Interpreter>(interpreter_context_), + result = mgp_result{nullptr, memory_resource}, + total_retries = interpreter_context_->config.stream_transaction_conflict_retries, + retry_interval = interpreter_context_->config.stream_transaction_retry_interval]( + const std::vector<typename TStream::Message> &messages) mutable { + auto accessor = interpreter_context->db->Access(); + EventCounter::IncrementCounter(EventCounter::MessagesConsumed, messages.size()); + CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name); + + DiscardValueResultStream stream; + + spdlog::trace("Start transaction in stream '{}'", stream_name); + utils::OnScopeExit cleanup{[&interpreter, &result]() { + result.rows.clear(); + interpreter->Abort(); + }}; + + const static std::map<std::string, storage::v3::PropertyValue> empty_parameters{}; + uint32_t i = 0; + while (true) { + try { + interpreter->BeginTransaction(); + for (auto &row : result.rows) { + spdlog::trace("Processing row in stream '{}'", stream_name); + auto [query_value, params_value] = ExtractTransformationResult(row.values, transformation_name, stream_name); + storage::v3::PropertyValue params_prop{params_value}; + + std::string query{query_value.ValueString()}; + spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name); + auto prepare_result = + interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap(), nullptr); + if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges)) { + throw StreamsException{ + "Couldn't execute query '{}' for stream '{}' because the owner is not authorized to execute the " + "query!", + query, stream_name}; + } + interpreter->PullAll(&stream); + } + + spdlog::trace("Commit transaction in stream '{}'", stream_name); + interpreter->CommitTransaction(); + result.rows.clear(); + break; + } catch (const query::v2::TransactionSerializationException &e) { + interpreter->Abort(); + if (i == total_retries) { + throw; + } + ++i; + std::this_thread::sleep_for(retry_interval); + } + } + }; + + auto insert_result = map.try_emplace( + stream_name, StreamData<TStream>{std::move(stream_info.common_info.transformation_name), std::move(owner), + std::make_unique<SynchronizedStreamSource<TStream>>( + stream_name, std::move(stream_info), std::move(consumer_function))}); + MG_ASSERT(insert_result.second, "Unexpected error during storing consumer '{}'", stream_name); + return insert_result.first; +} + +void Streams::RestoreStreams() { + spdlog::info("Loading streams..."); + auto locked_streams_map = streams_.Lock(); + MG_ASSERT(locked_streams_map->empty(), "Cannot restore streams when some streams already exist!"); + + for (const auto &[stream_name, stream_data] : storage_) { + const auto get_failed_message = [&stream_name = stream_name](const std::string_view message, + const std::string_view nested_message) { + return fmt::format("Failed to load stream '{}', because: {} caused by {}", stream_name, message, nested_message); + }; + + const auto create_consumer = [&, &stream_name = stream_name, this]<typename T>(StreamStatus<T> status, + auto &&stream_json_data) { + try { + stream_json_data.get_to(status); + } catch (const nlohmann::json::type_error &exception) { + spdlog::warn(get_failed_message("invalid type conversion", exception.what())); + return; + } catch (const nlohmann::json::out_of_range &exception) { + spdlog::warn(get_failed_message("non existing field", exception.what())); + return; + } + MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", status.name, stream_name); + + try { + auto it = CreateConsumer<T>(*locked_streams_map, stream_name, std::move(status.info), std::move(status.owner)); + if (status.is_running) { + std::visit( + [&](const auto &stream_data) { + auto stream_source_ptr = stream_data.stream_source->Lock(); + stream_source_ptr->Start(); + }, + it->second); + } + spdlog::info("Stream '{}' is loaded", stream_name); + } catch (const utils::BasicException &exception) { + spdlog::warn(get_failed_message("unexpected error", exception.what())); + } + }; + + auto stream_json_data = nlohmann::json::parse(stream_data); + if (const auto it = stream_json_data.find(kType); it != stream_json_data.end()) { + const auto stream_type = static_cast<StreamSourceType>(*it); + switch (stream_type) { + case StreamSourceType::KAFKA: + create_consumer(StreamStatus<KafkaStream>{}, std::move(stream_json_data)); + break; + case StreamSourceType::PULSAR: + create_consumer(StreamStatus<PulsarStream>{}, std::move(stream_json_data)); + break; + } + } else { + spdlog::warn( + "Unable to load stream '{}', because it does not contain the type of the stream. Most probably the stream " + "was saved before Memgraph 2.1. Please recreate the stream manually to make it work. For more information " + "please check https://memgraph.com/docs/memgraph/changelog#v210---nov-22-2021 .", + stream_json_data.value(kStreamName, "<invalid format>")); + } + } +} + +void Streams::Drop(const std::string &stream_name) { + auto locked_streams = streams_.Lock(); + + auto it = GetStream(*locked_streams, stream_name); + + // streams_ is write locked, which means there is no access to it outside of this function, thus only the Test + // function can be executing with the consumer, nothing else. + // By acquiring the write lock here for the consumer, we make sure there is + // no running Test function for this consumer, therefore it can be erased. + std::visit([&](const auto &stream_data) { stream_data.stream_source->Lock(); }, it->second); + + locked_streams->erase(it); + if (!storage_.Delete(stream_name)) { + throw StreamsException("Couldn't delete stream '{}' from persistent store!", stream_name); + } + + // TODO(antaljanosbenjamin) Release the transformation +} + +void Streams::Start(const std::string &stream_name) { + auto locked_streams = streams_.Lock(); + auto it = GetStream(*locked_streams, stream_name); + + std::visit( + [&, this](const auto &stream_data) { + auto stream_source_ptr = stream_data.stream_source->Lock(); + stream_source_ptr->Start(); + Persist(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *stream_source_ptr)); + }, + it->second); +} + +void Streams::StartWithLimit(const std::string &stream_name, uint64_t batch_limit, + std::optional<std::chrono::milliseconds> timeout) const { + std::optional locked_streams{streams_.ReadLock()}; + auto it = GetStream(**locked_streams, stream_name); + + std::visit( + [&](const auto &stream_data) { + const auto locked_stream_source = stream_data.stream_source->ReadLock(); + locked_streams.reset(); + + locked_stream_source->StartWithLimit(batch_limit, timeout); + }, + it->second); +} + +void Streams::Stop(const std::string &stream_name) { + auto locked_streams = streams_.Lock(); + auto it = GetStream(*locked_streams, stream_name); + + std::visit( + [&, this](const auto &stream_data) { + auto stream_source_ptr = stream_data.stream_source->Lock(); + stream_source_ptr->Stop(); + + Persist(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *stream_source_ptr)); + }, + it->second); +} + +void Streams::StartAll() { + for (auto locked_streams = streams_.Lock(); auto &[stream_name, stream_data] : *locked_streams) { + std::visit( + [&stream_name = stream_name, this](const auto &stream_data) { + auto locked_stream_source = stream_data.stream_source->Lock(); + if (!locked_stream_source->IsRunning()) { + locked_stream_source->Start(); + Persist( + CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *locked_stream_source)); + } + }, + stream_data); + } +} + +void Streams::StopAll() { + for (auto locked_streams = streams_.Lock(); auto &[stream_name, stream_data] : *locked_streams) { + std::visit( + [&stream_name = stream_name, this](const auto &stream_data) { + auto locked_stream_source = stream_data.stream_source->Lock(); + if (locked_stream_source->IsRunning()) { + locked_stream_source->Stop(); + Persist( + CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *locked_stream_source)); + } + }, + stream_data); + } +} + +std::vector<StreamStatus<>> Streams::GetStreamInfo() const { + std::vector<StreamStatus<>> result; + { + for (auto locked_streams = streams_.ReadLock(); const auto &[stream_name, stream_data] : *locked_streams) { + std::visit( + [&, &stream_name = stream_name](const auto &stream_data) { + auto locked_stream_source = stream_data.stream_source->ReadLock(); + auto info = locked_stream_source->Info(stream_data.transformation_name); + result.emplace_back(StreamStatus<>{stream_name, StreamType(*locked_stream_source), + locked_stream_source->IsRunning(), std::move(info.common_info), + stream_data.owner}); + }, + stream_data); + } + } + return result; +} + +TransformationResult Streams::Check(const std::string &stream_name, std::optional<std::chrono::milliseconds> timeout, + std::optional<uint64_t> batch_limit) const { + std::optional locked_streams{streams_.ReadLock()}; + auto it = GetStream(**locked_streams, stream_name); + + return std::visit( + [&](const auto &stream_data) { + // This depends on the fact that Drop will first acquire a write lock to the consumer, and erase it only after + // that + const auto locked_stream_source = stream_data.stream_source->ReadLock(); + const auto transformation_name = stream_data.transformation_name; + locked_streams.reset(); + + auto *memory_resource = utils::NewDeleteResource(); + mgp_result result{nullptr, memory_resource}; + TransformationResult test_result; + + auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, &stream_name, + &transformation_name = transformation_name, &result, + &test_result]<typename T>(const std::vector<T> &messages) mutable { + auto accessor = interpreter_context->db->Access(); + CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name); + + auto result_row = std::vector<TypedValue>(); + result_row.reserve(kCheckStreamResultSize); + + auto queries_and_parameters = std::vector<TypedValue>(result.rows.size()); + std::transform( + result.rows.cbegin(), result.rows.cend(), queries_and_parameters.begin(), [&](const auto &row) { + auto [query, parameters] = ExtractTransformationResult(row.values, transformation_name, stream_name); + + return std::map<std::string, TypedValue>{{"query", std::move(query)}, + {"parameters", std::move(parameters)}}; + }); + result_row.emplace_back(std::move(queries_and_parameters)); + + auto messages_list = std::vector<TypedValue>(messages.size()); + std::transform(messages.cbegin(), messages.cend(), messages_list.begin(), [](const auto &message) { + return std::string_view(message.Payload().data(), message.Payload().size()); + }); + + result_row.emplace_back(std::move(messages_list)); + + test_result.emplace_back(std::move(result_row)); + }; + + locked_stream_source->Check(timeout, batch_limit, consumer_function); + return test_result; + }, + it->second); +} + +} // namespace memgraph::query::v2::stream diff --git a/src/query/v2/stream/streams.hpp b/src/query/v2/stream/streams.hpp new file mode 100644 index 000000000..4fc7fc33c --- /dev/null +++ b/src/query/v2/stream/streams.hpp @@ -0,0 +1,206 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <concepts> +#include <functional> +#include <map> +#include <optional> +#include <type_traits> +#include <unordered_map> + +#include <json/json.hpp> + +#include "integrations/kafka/consumer.hpp" +#include "kvstore/kvstore.hpp" +#include "query/v2/stream/common.hpp" +#include "query/v2/stream/sources.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/property_value.hpp" +#include "utils/event_counter.hpp" +#include "utils/exceptions.hpp" +#include "utils/rw_lock.hpp" +#include "utils/synchronized.hpp" + +class StreamsTest; +namespace memgraph::query::v2 { + +struct InterpreterContext; + +namespace stream { + +class StreamsException : public utils::BasicException { + public: + using BasicException::BasicException; +}; + +template <typename T> +struct StreamInfo; + +template <> +struct StreamInfo<void> { + using Type = CommonStreamInfo; +}; + +template <Stream TStream> +struct StreamInfo<TStream> { + using Type = typename TStream::StreamInfo; +}; + +template <typename T> +using StreamInfoType = typename StreamInfo<T>::Type; + +template <typename T = void> +struct StreamStatus { + std::string name; + StreamSourceType type; + bool is_running; + StreamInfoType<T> info; + std::optional<std::string> owner; +}; + +using TransformationResult = std::vector<std::vector<TypedValue>>; + +/// Manages Kafka consumers. +/// +/// This class is responsible for all query supported actions to happen. +class Streams final { + friend StreamsTest; + + public: + /// Initializes the streams. + /// + /// @param interpreter_context context to use to run the result of transformations + /// @param directory a directory path to store the persisted streams metadata + Streams(InterpreterContext *interpreter_context, std::filesystem::path directory); + + /// Restores the streams from the persisted metadata. + /// The restoration is done in a best effort manner, therefore no exception is thrown on failure, but the error is + /// logged. If a stream was running previously, then after restoration it will be started. + /// This function should only be called when there are no existing streams. + void RestoreStreams(); + + /// Creates a new import stream. + /// The create implies connecting to the server to get metadata necessary to initialize the stream. This + /// method assures there is no other stream with the same name. + /// + /// @param stream_name the name of the stream which can be used to uniquely identify the stream + /// @param stream_info the necessary informations needed to create the Kafka consumer and transform the messages + /// + /// @throws StreamsException if the stream with the same name exists or if the creation of Kafka consumer fails + template <Stream TStream> + void Create(const std::string &stream_name, typename TStream::StreamInfo info, std::optional<std::string> owner); + + /// Deletes an existing stream and all the data that was persisted. + /// + /// @param stream_name name of the stream that needs to be deleted. + /// + /// @throws StreamsException if the stream doesn't exist or if the persisted metadata can't be deleted. + void Drop(const std::string &stream_name); + + /// Start consuming from a stream. + /// + /// @param stream_name name of the stream that needs to be started + /// + /// @throws StreamsException if the stream doesn't exist or if the metadata cannot be persisted + /// @throws ConsumerRunningException if the consumer is already running + void Start(const std::string &stream_name); + + /// Start consuming from a stream. + /// + /// @param stream_name name of the stream that needs to be started + /// @param batch_limit number of batches we want to consume before stopping + /// @param timeout the maximum duration during which the command should run. + /// + /// @throws StreamsException if the stream doesn't exist + /// @throws ConsumerRunningException if the consumer is already running + void StartWithLimit(const std::string &stream_name, uint64_t batch_limit, + std::optional<std::chrono::milliseconds> timeout) const; + + /// Stop consuming from a stream. + /// + /// @param stream_name name of the stream that needs to be stopped + /// + /// @throws StreamsException if the stream doesn't exist or if the metadata cannot be persisted + /// @throws ConsumerStoppedException if the consumer is already stopped + void Stop(const std::string &stream_name); + + /// Start consuming from all streams that are stopped. + /// + /// @throws StreamsException if the metadata cannot be persisted + void StartAll(); + + /// Stop consuming from all streams that are running. + /// + /// @throws StreamsException if the metadata cannot be persisted + void StopAll(); + + /// Return current status for all streams. + /// It might happend that the is_running field is out of date if the one of the streams stops during the invocation of + /// this function because of an error. + std::vector<StreamStatus<>> GetStreamInfo() const; + + /// Do a dry-run consume from a stream. + /// + /// @param stream_name name of the stream we want to test + /// @param batch_limit number of batches we want to test before stopping + /// @param timeout the maximum duration during which the command should run. + /// + /// @returns A vector of vectors of TypedValue. Each subvector contains two elements, the query string and the + /// nullable parameters map. + /// + /// @throws StreamsException if the stream doesn't exist + /// @throws ConsumerRunningException if the consumer is alredy running + /// @throws ConsumerCheckFailedException if the transformation function throws any std::exception during processing + TransformationResult Check(const std::string &stream_name, + std::optional<std::chrono::milliseconds> timeout = std::nullopt, + std::optional<uint64_t> batch_limit = std::nullopt) const; + + private: + template <Stream TStream> + using SynchronizedStreamSource = utils::Synchronized<TStream, utils::WritePrioritizedRWLock>; + + template <Stream TStream> + struct StreamData { + std::string transformation_name; + std::optional<std::string> owner; + std::unique_ptr<SynchronizedStreamSource<TStream>> stream_source; + }; + + using StreamDataVariant = std::variant<StreamData<KafkaStream>, StreamData<PulsarStream>>; + using StreamsMap = std::unordered_map<std::string, StreamDataVariant>; + using SynchronizedStreamsMap = utils::Synchronized<StreamsMap, utils::WritePrioritizedRWLock>; + + template <Stream TStream> + StreamsMap::iterator CreateConsumer(StreamsMap &map, const std::string &stream_name, + typename TStream::StreamInfo stream_info, std::optional<std::string> owner); + + template <Stream TStream> + void Persist(StreamStatus<TStream> &&status) { + const std::string stream_name = status.name; + if (!storage_.Put(stream_name, nlohmann::json(std::move(status)).dump())) { + throw StreamsException{"Couldn't persist steam data for stream '{}'", stream_name}; + } + } + + void RegisterProcedures(); + void RegisterKafkaProcedures(); + void RegisterPulsarProcedures(); + + InterpreterContext *interpreter_context_; + kvstore::KVStore storage_; + + SynchronizedStreamsMap streams_; +}; + +} // namespace stream +} // namespace memgraph::query::v2 diff --git a/src/query/v2/trigger.cpp b/src/query/v2/trigger.cpp new file mode 100644 index 000000000..08f9ace9f --- /dev/null +++ b/src/query/v2/trigger.cpp @@ -0,0 +1,441 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/trigger.hpp" + +#include <concepts> + +#include "query/v2/config.hpp" +#include "query/v2/context.hpp" +#include "query/v2/cypher_query_interpreter.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/interpret/frame.hpp" +#include "query/v2/serialization/property_value.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/property_value.hpp" +#include "utils/event_counter.hpp" +#include "utils/memory.hpp" + +namespace EventCounter { +extern const Event TriggersExecuted; +} // namespace EventCounter + +namespace memgraph::query::v2 { +namespace { +auto IdentifierString(const TriggerIdentifierTag tag) noexcept { + switch (tag) { + case TriggerIdentifierTag::CREATED_VERTICES: + return "createdVertices"; + + case TriggerIdentifierTag::CREATED_EDGES: + return "createdEdges"; + + case TriggerIdentifierTag::CREATED_OBJECTS: + return "createdObjects"; + + case TriggerIdentifierTag::DELETED_VERTICES: + return "deletedVertices"; + + case TriggerIdentifierTag::DELETED_EDGES: + return "deletedEdges"; + + case TriggerIdentifierTag::DELETED_OBJECTS: + return "deletedObjects"; + + case TriggerIdentifierTag::SET_VERTEX_PROPERTIES: + return "setVertexProperties"; + + case TriggerIdentifierTag::SET_EDGE_PROPERTIES: + return "setEdgeProperties"; + + case TriggerIdentifierTag::REMOVED_VERTEX_PROPERTIES: + return "removedVertexProperties"; + + case TriggerIdentifierTag::REMOVED_EDGE_PROPERTIES: + return "removedEdgeProperties"; + + case TriggerIdentifierTag::SET_VERTEX_LABELS: + return "setVertexLabels"; + + case TriggerIdentifierTag::REMOVED_VERTEX_LABELS: + return "removedVertexLabels"; + + case TriggerIdentifierTag::UPDATED_VERTICES: + return "updatedVertices"; + + case TriggerIdentifierTag::UPDATED_EDGES: + return "updatedEdges"; + + case TriggerIdentifierTag::UPDATED_OBJECTS: + return "updatedObjects"; + } +} + +template <typename T> +concept SameAsIdentifierTag = std::same_as<T, TriggerIdentifierTag>; + +template <SameAsIdentifierTag... TArgs> +std::vector<std::pair<Identifier, TriggerIdentifierTag>> TagsToIdentifiers(const TArgs &...args) { + std::vector<std::pair<Identifier, TriggerIdentifierTag>> identifiers; + identifiers.reserve(sizeof...(args)); + + auto add_identifier = [&identifiers](const auto tag) { + identifiers.emplace_back(Identifier{IdentifierString(tag), false}, tag); + }; + + (add_identifier(args), ...); + + return identifiers; +}; + +std::vector<std::pair<Identifier, TriggerIdentifierTag>> GetPredefinedIdentifiers(const TriggerEventType event_type) { + using IdentifierTag = TriggerIdentifierTag; + using EventType = TriggerEventType; + + switch (event_type) { + case EventType::ANY: + return TagsToIdentifiers( + IdentifierTag::CREATED_VERTICES, IdentifierTag::CREATED_EDGES, IdentifierTag::CREATED_OBJECTS, + IdentifierTag::DELETED_VERTICES, IdentifierTag::DELETED_EDGES, IdentifierTag::DELETED_OBJECTS, + IdentifierTag::SET_VERTEX_PROPERTIES, IdentifierTag::REMOVED_VERTEX_PROPERTIES, + IdentifierTag::SET_VERTEX_LABELS, IdentifierTag::REMOVED_VERTEX_LABELS, IdentifierTag::UPDATED_VERTICES, + IdentifierTag::SET_EDGE_PROPERTIES, IdentifierTag::REMOVED_EDGE_PROPERTIES, IdentifierTag::UPDATED_EDGES, + IdentifierTag::UPDATED_OBJECTS); + + case EventType::CREATE: + return TagsToIdentifiers(IdentifierTag::CREATED_VERTICES, IdentifierTag::CREATED_EDGES, + IdentifierTag::CREATED_OBJECTS); + + case EventType::VERTEX_CREATE: + return TagsToIdentifiers(IdentifierTag::CREATED_VERTICES); + + case EventType::EDGE_CREATE: + return TagsToIdentifiers(IdentifierTag::CREATED_EDGES); + + case EventType::DELETE: + return TagsToIdentifiers(IdentifierTag::DELETED_VERTICES, IdentifierTag::DELETED_EDGES, + IdentifierTag::DELETED_OBJECTS); + + case EventType::VERTEX_DELETE: + return TagsToIdentifiers(IdentifierTag::DELETED_VERTICES); + + case EventType::EDGE_DELETE: + return TagsToIdentifiers(IdentifierTag::DELETED_EDGES); + + case EventType::UPDATE: + return TagsToIdentifiers(IdentifierTag::SET_VERTEX_PROPERTIES, IdentifierTag::REMOVED_VERTEX_PROPERTIES, + IdentifierTag::SET_VERTEX_LABELS, IdentifierTag::REMOVED_VERTEX_LABELS, + IdentifierTag::UPDATED_VERTICES, IdentifierTag::SET_EDGE_PROPERTIES, + IdentifierTag::REMOVED_EDGE_PROPERTIES, IdentifierTag::UPDATED_EDGES, + IdentifierTag::UPDATED_OBJECTS); + + case EventType::VERTEX_UPDATE: + return TagsToIdentifiers(IdentifierTag::SET_VERTEX_PROPERTIES, IdentifierTag::REMOVED_VERTEX_PROPERTIES, + IdentifierTag::SET_VERTEX_LABELS, IdentifierTag::REMOVED_VERTEX_LABELS, + IdentifierTag::UPDATED_VERTICES); + + case EventType::EDGE_UPDATE: + return TagsToIdentifiers(IdentifierTag::SET_EDGE_PROPERTIES, IdentifierTag::REMOVED_EDGE_PROPERTIES, + IdentifierTag::UPDATED_EDGES); + } +} +} // namespace + +Trigger::Trigger(std::string name, const std::string &query, + const std::map<std::string, storage::v3::PropertyValue> &user_parameters, + const TriggerEventType event_type, utils::SkipList<QueryCacheEntry> *query_cache, + DbAccessor *db_accessor, utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config, + std::optional<std::string> owner, const query::v2::AuthChecker *auth_checker) + : name_{std::move(name)}, + parsed_statements_{ParseQuery(query, user_parameters, query_cache, antlr_lock, query_config)}, + event_type_{event_type}, + owner_{std::move(owner)} { + // We check immediately if the query is valid by trying to create a plan. + GetPlan(db_accessor, auth_checker); +} + +Trigger::TriggerPlan::TriggerPlan(std::unique_ptr<LogicalPlan> logical_plan, std::vector<IdentifierInfo> identifiers) + : cached_plan(std::move(logical_plan)), identifiers(std::move(identifiers)) {} + +std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor, + const query::v2::AuthChecker *auth_checker) const { + std::lock_guard plan_guard{plan_lock_}; + if (!parsed_statements_.is_cacheable || !trigger_plan_ || trigger_plan_->cached_plan.IsExpired()) { + auto identifiers = GetPredefinedIdentifiers(event_type_); + + AstStorage ast_storage; + ast_storage.properties_ = parsed_statements_.ast_storage.properties_; + ast_storage.labels_ = parsed_statements_.ast_storage.labels_; + ast_storage.edge_types_ = parsed_statements_.ast_storage.edge_types_; + + std::vector<Identifier *> predefined_identifiers; + predefined_identifiers.reserve(identifiers.size()); + std::transform(identifiers.begin(), identifiers.end(), std::back_inserter(predefined_identifiers), + [](auto &identifier) { return &identifier.first; }); + + auto logical_plan = MakeLogicalPlan(std::move(ast_storage), utils::Downcast<CypherQuery>(parsed_statements_.query), + parsed_statements_.parameters, db_accessor, predefined_identifiers); + + trigger_plan_ = std::make_shared<TriggerPlan>(std::move(logical_plan), std::move(identifiers)); + } + if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges)) { + throw utils::BasicException("The owner of trigger '{}' is not authorized to execute the query!", name_); + } + return trigger_plan_; +} + +void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, + const double max_execution_time_sec, std::atomic<bool> *is_shutting_down, + const TriggerContext &context, const AuthChecker *auth_checker) const { + if (!context.ShouldEventTrigger(event_type_)) { + return; + } + + spdlog::debug("Executing trigger '{}'", name_); + auto trigger_plan = GetPlan(dba, auth_checker); + MG_ASSERT(trigger_plan, "Invalid trigger plan received"); + auto &[plan, identifiers] = *trigger_plan; + + ExecutionContext ctx; + ctx.db_accessor = dba; + ctx.symbol_table = plan.symbol_table(); + ctx.evaluation_context.timestamp = QueryTimestamp(); + ctx.evaluation_context.parameters = parsed_statements_.parameters; + ctx.evaluation_context.properties = NamesToProperties(plan.ast_storage().properties_, dba); + ctx.evaluation_context.labels = NamesToLabels(plan.ast_storage().labels_, dba); + ctx.timer = utils::AsyncTimer(max_execution_time_sec); + ctx.is_shutting_down = is_shutting_down; + ctx.is_profile_query = false; + + // Set up temporary memory for a single Pull. Initial memory comes from the + // stack. 256 KiB should fit on the stack and should be more than enough for a + // single `Pull`. + static constexpr size_t stack_size = 256UL * 1024UL; + char stack_data[stack_size]; + + // We can throw on every query because a simple queries for deleting will use only + // the stack allocated buffer. + // Also, we want to throw only when the query engine requests more memory and not the storage + // so we add the exception to the allocator. + utils::ResourceWithOutOfMemoryException resource_with_exception; + utils::MonotonicBufferResource monotonic_memory(&stack_data[0], stack_size, &resource_with_exception); + // TODO (mferencevic): Tune the parameters accordingly. + utils::PoolResource pool_memory(128, 1024, &monotonic_memory); + ctx.evaluation_context.memory = &pool_memory; + + auto cursor = plan.plan().MakeCursor(execution_memory); + Frame frame{plan.symbol_table().max_position(), execution_memory}; + for (const auto &[identifier, tag] : identifiers) { + if (identifier.symbol_pos_ == -1) { + continue; + } + + frame[plan.symbol_table().at(identifier)] = context.GetTypedValue(tag, dba); + } + + while (cursor->Pull(frame, ctx)) + ; + + cursor->Shutdown(); + EventCounter::IncrementCounter(EventCounter::TriggersExecuted); +} + +namespace { +// When the format of the persisted trigger is changed, increase this version +inline constexpr uint64_t kVersion{2}; +} // namespace + +TriggerStore::TriggerStore(std::filesystem::path directory) : storage_{std::move(directory)} {} + +void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, + utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config, + const query::v2::AuthChecker *auth_checker) { + MG_ASSERT(before_commit_triggers_.size() == 0 && after_commit_triggers_.size() == 0, + "Cannot restore trigger when some triggers already exist!"); + spdlog::info("Loading triggers..."); + + for (const auto &[trigger_name, trigger_data] : storage_) { + const auto get_failed_message = [&trigger_name = trigger_name](const std::string_view message) { + return fmt::format("Failed to load trigger '{}'. {}", trigger_name, message); + }; + + const auto invalid_state_message = get_failed_message("Invalid state of the trigger data."); + + spdlog::debug("Loading trigger '{}'", trigger_name); + auto json_trigger_data = nlohmann::json::parse(trigger_data); + + if (!json_trigger_data["version"].is_number_unsigned()) { + spdlog::warn(invalid_state_message); + continue; + } + if (json_trigger_data["version"] != kVersion) { + spdlog::warn(get_failed_message("Invalid version of the trigger data.")); + continue; + } + + if (!json_trigger_data["statement"].is_string()) { + spdlog::warn(invalid_state_message); + continue; + } + auto statement = json_trigger_data["statement"].get<std::string>(); + + if (!json_trigger_data["phase"].is_number_integer()) { + spdlog::warn(invalid_state_message); + continue; + } + const auto phase = json_trigger_data["phase"].get<TriggerPhase>(); + + if (!json_trigger_data["event_type"].is_number_integer()) { + spdlog::warn(invalid_state_message); + continue; + } + const auto event_type = json_trigger_data["event_type"].get<TriggerEventType>(); + + if (!json_trigger_data["user_parameters"].is_object()) { + spdlog::warn(invalid_state_message); + continue; + } + const auto user_parameters = serialization::DeserializePropertyValueMap(json_trigger_data["user_parameters"]); + + const auto owner_json = json_trigger_data["owner"]; + std::optional<std::string> owner{}; + if (owner_json.is_string()) { + owner.emplace(owner_json.get<std::string>()); + } else if (!owner_json.is_null()) { + spdlog::warn(invalid_state_message); + continue; + } + + std::optional<Trigger> trigger; + try { + trigger.emplace(trigger_name, statement, user_parameters, event_type, query_cache, db_accessor, antlr_lock, + query_config, std::move(owner), auth_checker); + } catch (const utils::BasicException &e) { + spdlog::warn("Failed to create trigger '{}' because: {}", trigger_name, e.what()); + continue; + } + + auto triggers_acc = + phase == TriggerPhase::BEFORE_COMMIT ? before_commit_triggers_.access() : after_commit_triggers_.access(); + triggers_acc.insert(std::move(*trigger)); + + spdlog::debug("Trigger loaded successfully!"); + } +} + +void TriggerStore::AddTrigger(std::string name, const std::string &query, + const std::map<std::string, storage::v3::PropertyValue> &user_parameters, + TriggerEventType event_type, TriggerPhase phase, + utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, + utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config, + std::optional<std::string> owner, const query::v2::AuthChecker *auth_checker) { + std::unique_lock store_guard{store_lock_}; + if (storage_.Get(name)) { + throw utils::BasicException("Trigger with the same name already exists."); + } + + std::optional<Trigger> trigger; + try { + trigger.emplace(std::move(name), query, user_parameters, event_type, query_cache, db_accessor, antlr_lock, + query_config, std::move(owner), auth_checker); + } catch (const utils::BasicException &e) { + const auto identifiers = GetPredefinedIdentifiers(event_type); + std::stringstream identifier_names_stream; + utils::PrintIterable(identifier_names_stream, identifiers, ", ", + [](auto &stream, const auto &identifier) { stream << identifier.first.name_; }); + + throw utils::BasicException( + "Failed creating the trigger.\nError message: '{}'\nThe error was mostly likely generated because of the wrong " + "statement that this trigger executes.\nMake sure all predefined variables used are present for the specified " + "event.\nAllowed variables for event '{}' are: {}", + e.what(), TriggerEventTypeToString(event_type), identifier_names_stream.str()); + } + + // When the format of the persisted trigger is changed, update the kVersion + nlohmann::json data = nlohmann::json::object(); + data["statement"] = query; + data["user_parameters"] = serialization::SerializePropertyValueMap(user_parameters); + data["event_type"] = event_type; + data["phase"] = phase; + data["version"] = kVersion; + + if (const auto &owner_from_trigger = trigger->Owner(); owner_from_trigger.has_value()) { + data["owner"] = *owner_from_trigger; + } else { + data["owner"] = nullptr; + } + storage_.Put(trigger->Name(), data.dump()); + store_guard.unlock(); + + auto triggers_acc = + phase == TriggerPhase::BEFORE_COMMIT ? before_commit_triggers_.access() : after_commit_triggers_.access(); + triggers_acc.insert(std::move(*trigger)); +} + +void TriggerStore::DropTrigger(const std::string &name) { + std::unique_lock store_guard{store_lock_}; + const auto maybe_trigger_data = storage_.Get(name); + if (!maybe_trigger_data) { + throw utils::BasicException("Trigger with name '{}' doesn't exist", name); + } + + nlohmann::json data; + try { + data = nlohmann::json::parse(*maybe_trigger_data); + } catch (const nlohmann::json::parse_error &e) { + throw utils::BasicException("Couldn't load trigger data!"); + } + + if (!data.is_object()) { + throw utils::BasicException("Couldn't load trigger data!"); + } + + if (!data["phase"].is_number_integer()) { + throw utils::BasicException("Invalid type loaded inside the trigger data!"); + } + + auto triggers_acc = + data["phase"] == TriggerPhase::BEFORE_COMMIT ? before_commit_triggers_.access() : after_commit_triggers_.access(); + triggers_acc.remove(name); + storage_.Delete(name); +} + +std::vector<TriggerStore::TriggerInfo> TriggerStore::GetTriggerInfo() const { + std::vector<TriggerInfo> info; + info.reserve(before_commit_triggers_.size() + after_commit_triggers_.size()); + + const auto add_info = [&](const utils::SkipList<Trigger> &trigger_list, const TriggerPhase phase) { + for (const auto &trigger : trigger_list.access()) { + info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase, trigger.Owner()}); + } + }; + + add_info(before_commit_triggers_, TriggerPhase::BEFORE_COMMIT); + add_info(after_commit_triggers_, TriggerPhase::AFTER_COMMIT); + + return info; +} + +std::unordered_set<TriggerEventType> TriggerStore::GetEventTypes() const { + std::unordered_set<TriggerEventType> event_types; + + const auto add_event_types = [&](const utils::SkipList<Trigger> &trigger_list) { + for (const auto &trigger : trigger_list.access()) { + event_types.insert(trigger.EventType()); + } + }; + + add_event_types(before_commit_triggers_); + add_event_types(after_commit_triggers_); + return event_types; +} +} // namespace memgraph::query::v2 diff --git a/src/query/v2/trigger.hpp b/src/query/v2/trigger.hpp new file mode 100644 index 000000000..5ceaaa63e --- /dev/null +++ b/src/query/v2/trigger.hpp @@ -0,0 +1,119 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <atomic> +#include <filesystem> +#include <map> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "kvstore/kvstore.hpp" +#include "query/v2/auth_checker.hpp" +#include "query/v2/config.hpp" +#include "query/v2/cypher_query_interpreter.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/trigger_context.hpp" +#include "storage/v3/property_value.hpp" +#include "utils/skip_list.hpp" +#include "utils/spin_lock.hpp" + +namespace memgraph::query::v2 { +struct Trigger { + explicit Trigger(std::string name, const std::string &query, + const std::map<std::string, storage::v3::PropertyValue> &user_parameters, + TriggerEventType event_type, utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, + utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config, + std::optional<std::string> owner, const AuthChecker *auth_checker); + + void Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, double max_execution_time_sec, + std::atomic<bool> *is_shutting_down, const TriggerContext &context, + const AuthChecker *auth_checker) const; + + bool operator==(const Trigger &other) const { return name_ == other.name_; } + // NOLINTNEXTLINE (modernize-use-nullptr) + bool operator<(const Trigger &other) const { return name_ < other.name_; } + bool operator==(const std::string &other) const { return name_ == other; } + // NOLINTNEXTLINE (modernize-use-nullptr) + bool operator<(const std::string &other) const { return name_ < other; } + + const auto &Name() const noexcept { return name_; } + const auto &OriginalStatement() const noexcept { return parsed_statements_.query_string; } + const auto &Owner() const noexcept { return owner_; } + auto EventType() const noexcept { return event_type_; } + + private: + struct TriggerPlan { + using IdentifierInfo = std::pair<Identifier, TriggerIdentifierTag>; + + explicit TriggerPlan(std::unique_ptr<LogicalPlan> logical_plan, std::vector<IdentifierInfo> identifiers); + + CachedPlan cached_plan; + std::vector<IdentifierInfo> identifiers; + }; + std::shared_ptr<TriggerPlan> GetPlan(DbAccessor *db_accessor, const AuthChecker *auth_checker) const; + + std::string name_; + ParsedQuery parsed_statements_; + + TriggerEventType event_type_; + + mutable utils::SpinLock plan_lock_; + mutable std::shared_ptr<TriggerPlan> trigger_plan_; + std::optional<std::string> owner_; +}; + +enum class TriggerPhase : uint8_t { BEFORE_COMMIT, AFTER_COMMIT }; + +struct TriggerStore { + explicit TriggerStore(std::filesystem::path directory); + + void RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, + utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config, + const query::v2::AuthChecker *auth_checker); + + void AddTrigger(std::string name, const std::string &query, + const std::map<std::string, storage::v3::PropertyValue> &user_parameters, TriggerEventType event_type, + TriggerPhase phase, utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, + utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config, + std::optional<std::string> owner, const query::v2::AuthChecker *auth_checker); + + void DropTrigger(const std::string &name); + + struct TriggerInfo { + std::string name; + std::string statement; + TriggerEventType event_type; + TriggerPhase phase; + std::optional<std::string> owner; + }; + + std::vector<TriggerInfo> GetTriggerInfo() const; + + const auto &BeforeCommitTriggers() const noexcept { return before_commit_triggers_; } + const auto &AfterCommitTriggers() const noexcept { return after_commit_triggers_; } + + bool HasTriggers() const noexcept { return before_commit_triggers_.size() > 0 || after_commit_triggers_.size() > 0; } + std::unordered_set<TriggerEventType> GetEventTypes() const; + + private: + utils::SpinLock store_lock_; + kvstore::KVStore storage_; + + utils::SkipList<Trigger> before_commit_triggers_; + utils::SkipList<Trigger> after_commit_triggers_; +}; + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/trigger_context.cpp b/src/query/v2/trigger_context.cpp new file mode 100644 index 000000000..f69a48ed4 --- /dev/null +++ b/src/query/v2/trigger_context.cpp @@ -0,0 +1,557 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/trigger.hpp" + +#include <concepts> + +#include "query/v2/context.hpp" +#include "query/v2/cypher_query_interpreter.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/interpret/frame.hpp" +#include "query/v2/serialization/property_value.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/property_value.hpp" +#include "utils/memory.hpp" + +namespace memgraph::query::v2 { +namespace { +template <typename T> +concept WithToMap = requires(const T value, DbAccessor *dba) { + { value.ToMap(dba) } -> std::same_as<std::map<std::string, TypedValue>>; +}; + +template <WithToMap T> +TypedValue ToTypedValue(const T &value, DbAccessor *dba) { + return TypedValue{value.ToMap(dba)}; +} + +template <detail::ObjectAccessor TAccessor> +TypedValue ToTypedValue(const detail::CreatedObject<TAccessor> &created_object, [[maybe_unused]] DbAccessor *dba) { + return TypedValue{created_object.object}; +} + +template <detail::ObjectAccessor TAccessor> +TypedValue ToTypedValue(const detail::DeletedObject<TAccessor> &deleted_object, [[maybe_unused]] DbAccessor *dba) { + return TypedValue{deleted_object.object}; +} + +template <typename T> +concept WithIsValid = requires(const T value) { + { value.IsValid() } -> std::same_as<bool>; +}; + +template <typename T> +concept ConvertableToTypedValue = requires(T value, DbAccessor *dba) { + { ToTypedValue(value, dba) } -> std::same_as<TypedValue>; +} +&&WithIsValid<T>; + +template <typename T> +concept LabelUpdateContext = utils::SameAsAnyOf<T, detail::SetVertexLabel, detail::RemovedVertexLabel>; + +template <LabelUpdateContext TContext> +TypedValue ToTypedValue(const std::vector<TContext> &values, DbAccessor *dba) { + std::unordered_map<storage::v3::LabelId, std::vector<TypedValue>> vertices_by_labels; + + for (const auto &value : values) { + if (value.IsValid()) { + vertices_by_labels[value.label_id].emplace_back(value.object); + } + } + + TypedValue result{std::vector<TypedValue>{}}; + auto &typed_values = result.ValueList(); + for (auto &[label_id, vertices] : vertices_by_labels) { + typed_values.emplace_back(std::map<std::string, TypedValue>{ + {std::string{"label"}, TypedValue(dba->LabelToName(label_id))}, + {std::string{"vertices"}, TypedValue(std::move(vertices))}, + }); + } + + return result; +} + +template <ConvertableToTypedValue T> +TypedValue ToTypedValue(const std::vector<T> &values, DbAccessor *dba) requires(!LabelUpdateContext<T>) { + TypedValue result{std::vector<TypedValue>{}}; + auto &typed_values = result.ValueList(); + typed_values.reserve(values.size()); + + for (const auto &value : values) { + if (value.IsValid()) { + typed_values.push_back(ToTypedValue(value, dba)); + } + } + + return result; +} + +template <typename T> +const char *TypeToString() { + if constexpr (std::same_as<T, detail::CreatedObject<VertexAccessor>>) { + return "created_vertex"; + } else if constexpr (std::same_as<T, detail::CreatedObject<EdgeAccessor>>) { + return "created_edge"; + } else if constexpr (std::same_as<T, detail::DeletedObject<VertexAccessor>>) { + return "deleted_vertex"; + } else if constexpr (std::same_as<T, detail::DeletedObject<EdgeAccessor>>) { + return "deleted_edge"; + } else if constexpr (std::same_as<T, detail::SetObjectProperty<VertexAccessor>>) { + return "set_vertex_property"; + } else if constexpr (std::same_as<T, detail::SetObjectProperty<EdgeAccessor>>) { + return "set_edge_property"; + } else if constexpr (std::same_as<T, detail::RemovedObjectProperty<VertexAccessor>>) { + return "removed_vertex_property"; + } else if constexpr (std::same_as<T, detail::RemovedObjectProperty<EdgeAccessor>>) { + return "removed_edge_property"; + } else if constexpr (std::same_as<T, detail::SetVertexLabel>) { + return "set_vertex_label"; + } else if constexpr (std::same_as<T, detail::RemovedVertexLabel>) { + return "removed_vertex_label"; + } +} + +template <typename T> +concept ContextInfo = WithToMap<T> && WithIsValid<T>; + +template <ContextInfo... Args> +TypedValue Concatenate(DbAccessor *dba, const std::vector<Args> &...args) { + const auto size = (args.size() + ...); + TypedValue result{std::vector<TypedValue>{}}; + auto &concatenated = result.ValueList(); + concatenated.reserve(size); + + const auto add_to_concatenated = [&]<ContextInfo T>(const std::vector<T> &values) { + for (const auto &value : values) { + if (value.IsValid()) { + auto map = value.ToMap(dba); + map["event_type"] = TypeToString<T>(); + concatenated.emplace_back(std::move(map)); + } + } + }; + + (add_to_concatenated(args), ...); + + return result; +} + +template <typename T> +concept WithEmpty = requires(const T value) { + { value.empty() } -> std::same_as<bool>; +}; + +template <WithEmpty... TContainer> +bool AnyContainsValue(const TContainer &...value_containers) { + return (!value_containers.empty() || ...); +} + +template <detail::ObjectAccessor TAccessor> +using ChangesSummary = + std::tuple<std::vector<detail::CreatedObject<TAccessor>>, std::vector<detail::DeletedObject<TAccessor>>, + std::vector<detail::SetObjectProperty<TAccessor>>, + std::vector<detail::RemovedObjectProperty<TAccessor>>>; + +template <detail::ObjectAccessor TAccessor> +using PropertyChangesLists = + std::pair<std::vector<detail::SetObjectProperty<TAccessor>>, std::vector<detail::RemovedObjectProperty<TAccessor>>>; + +template <detail::ObjectAccessor TAccessor> +[[nodiscard]] PropertyChangesLists<TAccessor> PropertyMapToList( + query::v2::TriggerContextCollector::PropertyChangesMap<TAccessor> &&map) { + std::vector<detail::SetObjectProperty<TAccessor>> set_object_properties; + std::vector<detail::RemovedObjectProperty<TAccessor>> removed_object_properties; + + for (auto it = map.begin(); it != map.end(); it = map.erase(it)) { + const auto &[key, property_change_info] = *it; + if (property_change_info.old_value.IsNull() && property_change_info.new_value.IsNull()) { + // no change happened on the transaction level + continue; + } + + if (const auto is_equal = property_change_info.old_value == property_change_info.new_value; + is_equal.IsBool() && is_equal.ValueBool()) { + // no change happened on the transaction level + continue; + } + + if (property_change_info.new_value.IsNull()) { + removed_object_properties.emplace_back(key.first, key.second /* property_id */, + std::move(property_change_info.old_value)); + } else { + set_object_properties.emplace_back(key.first, key.second, std::move(property_change_info.old_value), + std::move(property_change_info.new_value)); + } + } + + return PropertyChangesLists<TAccessor>{std::move(set_object_properties), std::move(removed_object_properties)}; +} + +template <detail::ObjectAccessor TAccessor> +[[nodiscard]] ChangesSummary<TAccessor> Summarize(query::v2::TriggerContextCollector::Registry<TAccessor> &®istry) { + auto [set_object_properties, removed_object_properties] = PropertyMapToList(std::move(registry.property_changes)); + std::vector<detail::CreatedObject<TAccessor>> created_objects_vec; + created_objects_vec.reserve(registry.created_objects.size()); + std::transform(registry.created_objects.begin(), registry.created_objects.end(), + std::back_inserter(created_objects_vec), + [](const auto &gid_and_created_object) { return gid_and_created_object.second; }); + registry.created_objects.clear(); + + return {std::move(created_objects_vec), std::move(registry.deleted_objects), std::move(set_object_properties), + std::move(removed_object_properties)}; +} +} // namespace + +namespace detail { +bool SetVertexLabel::IsValid() const { return object.IsVisible(storage::v3::View::OLD); } + +std::map<std::string, TypedValue> SetVertexLabel::ToMap(DbAccessor *dba) const { + return {{"vertex", TypedValue{object}}, {"label", TypedValue{dba->LabelToName(label_id)}}}; +} + +bool RemovedVertexLabel::IsValid() const { return object.IsVisible(storage::v3::View::OLD); } + +std::map<std::string, TypedValue> RemovedVertexLabel::ToMap(DbAccessor *dba) const { + return {{"vertex", TypedValue{object}}, {"label", TypedValue{dba->LabelToName(label_id)}}}; +} +} // namespace detail + +const char *TriggerEventTypeToString(const TriggerEventType event_type) { + switch (event_type) { + case TriggerEventType::ANY: + return "ANY"; + + case TriggerEventType::CREATE: + return "CREATE"; + + case TriggerEventType::VERTEX_CREATE: + return "() CREATE"; + + case TriggerEventType::EDGE_CREATE: + return "--> CREATE"; + + case TriggerEventType::DELETE: + return "DELETE"; + + case TriggerEventType::VERTEX_DELETE: + return "() DELETE"; + + case TriggerEventType::EDGE_DELETE: + return "--> DELETE"; + + case TriggerEventType::UPDATE: + return "UPDATE"; + + case TriggerEventType::VERTEX_UPDATE: + return "() UPDATE"; + + case TriggerEventType::EDGE_UPDATE: + return "--> UPDATE"; + } +} + +void TriggerContext::AdaptForAccessor(DbAccessor *accessor) { + { + // adapt created_vertices_ + auto it = created_vertices_.begin(); + for (auto &created_vertex : created_vertices_) { + if (auto maybe_vertex = accessor->FindVertex(created_vertex.object.Gid(), storage::v3::View::OLD); maybe_vertex) { + *it = detail::CreatedObject{*maybe_vertex}; + ++it; + } + } + created_vertices_.erase(it, created_vertices_.end()); + } + + // deleted_vertices_ should keep the transaction context of the transaction which deleted it + // because no other transaction can modify an object after it's deleted so it should be the + // latest state of the object + + const auto adapt_context_with_vertex = [accessor](auto *values) { + auto it = values->begin(); + for (auto &value : *values) { + if (auto maybe_vertex = accessor->FindVertex(value.object.Gid(), storage::v3::View::OLD); maybe_vertex) { + *it = std::move(value); + it->object = *maybe_vertex; + ++it; + } + } + values->erase(it, values->end()); + }; + + adapt_context_with_vertex(&set_vertex_properties_); + adapt_context_with_vertex(&removed_vertex_properties_); + adapt_context_with_vertex(&set_vertex_labels_); + adapt_context_with_vertex(&removed_vertex_labels_); + + { + // adapt created_edges + auto it = created_edges_.begin(); + for (auto &created_edge : created_edges_) { + const auto maybe_from_vertex = accessor->FindVertex(created_edge.object.From().Gid(), storage::v3::View::OLD); + if (!maybe_from_vertex) { + continue; + } + auto maybe_out_edges = maybe_from_vertex->OutEdges(storage::v3::View::OLD); + MG_ASSERT(maybe_out_edges.HasValue()); + const auto edge_gid = created_edge.object.Gid(); + for (const auto &edge : *maybe_out_edges) { + if (edge.Gid() == edge_gid) { + *it = detail::CreatedObject{edge}; + ++it; + } + } + } + created_edges_.erase(it, created_edges_.end()); + } + + // deleted_edges_ should keep the transaction context of the transaction which deleted it + // because no other transaction can modify an object after it's deleted so it should be the + // latest state of the object + + const auto adapt_context_with_edge = [accessor](auto *values) { + auto it = values->begin(); + for (const auto &value : *values) { + if (auto maybe_vertex = accessor->FindVertex(value.object.From().Gid(), storage::v3::View::OLD); maybe_vertex) { + auto maybe_out_edges = maybe_vertex->OutEdges(storage::v3::View::OLD); + MG_ASSERT(maybe_out_edges.HasValue()); + for (const auto &edge : *maybe_out_edges) { + if (edge.Gid() == value.object.Gid()) { + *it = std::move(value); + it->object = edge; + ++it; + break; + } + } + } + } + values->erase(it, values->end()); + }; + + adapt_context_with_edge(&set_edge_properties_); + adapt_context_with_edge(&removed_edge_properties_); +} + +TypedValue TriggerContext::GetTypedValue(const TriggerIdentifierTag tag, DbAccessor *dba) const { + switch (tag) { + case TriggerIdentifierTag::CREATED_VERTICES: + return ToTypedValue(created_vertices_, dba); + + case TriggerIdentifierTag::CREATED_EDGES: + return ToTypedValue(created_edges_, dba); + + case TriggerIdentifierTag::CREATED_OBJECTS: + return Concatenate(dba, created_vertices_, created_edges_); + + case TriggerIdentifierTag::DELETED_VERTICES: + return ToTypedValue(deleted_vertices_, dba); + + case TriggerIdentifierTag::DELETED_EDGES: + return ToTypedValue(deleted_edges_, dba); + + case TriggerIdentifierTag::DELETED_OBJECTS: + return Concatenate(dba, deleted_vertices_, deleted_edges_); + + case TriggerIdentifierTag::SET_VERTEX_PROPERTIES: + return ToTypedValue(set_vertex_properties_, dba); + + case TriggerIdentifierTag::SET_EDGE_PROPERTIES: + return ToTypedValue(set_edge_properties_, dba); + + case TriggerIdentifierTag::REMOVED_VERTEX_PROPERTIES: + return ToTypedValue(removed_vertex_properties_, dba); + + case TriggerIdentifierTag::REMOVED_EDGE_PROPERTIES: + return ToTypedValue(removed_edge_properties_, dba); + + case TriggerIdentifierTag::SET_VERTEX_LABELS: + return ToTypedValue(set_vertex_labels_, dba); + + case TriggerIdentifierTag::REMOVED_VERTEX_LABELS: + return ToTypedValue(removed_vertex_labels_, dba); + + case TriggerIdentifierTag::UPDATED_VERTICES: + return Concatenate(dba, set_vertex_properties_, removed_vertex_properties_, set_vertex_labels_, + removed_vertex_labels_); + + case TriggerIdentifierTag::UPDATED_EDGES: + return Concatenate(dba, set_edge_properties_, removed_edge_properties_); + + case TriggerIdentifierTag::UPDATED_OBJECTS: + return Concatenate(dba, set_vertex_properties_, set_edge_properties_, removed_vertex_properties_, + removed_edge_properties_, set_vertex_labels_, removed_vertex_labels_); + } +} + +bool TriggerContext::ShouldEventTrigger(const TriggerEventType event_type) const { + using EventType = TriggerEventType; + switch (event_type) { + case EventType::ANY: + return AnyContainsValue(created_vertices_, created_edges_, deleted_vertices_, deleted_edges_, + set_vertex_properties_, set_edge_properties_, removed_vertex_properties_, + removed_edge_properties_, set_vertex_labels_, removed_vertex_labels_); + + case EventType::CREATE: + return AnyContainsValue(created_vertices_, created_edges_); + + case EventType::VERTEX_CREATE: + return AnyContainsValue(created_vertices_); + + case EventType::EDGE_CREATE: + return AnyContainsValue(created_edges_); + + case EventType::DELETE: + return AnyContainsValue(deleted_vertices_, deleted_edges_); + + case EventType::VERTEX_DELETE: + return AnyContainsValue(deleted_vertices_); + + case EventType::EDGE_DELETE: + return AnyContainsValue(deleted_edges_); + + case EventType::UPDATE: + return AnyContainsValue(set_vertex_properties_, set_edge_properties_, removed_vertex_properties_, + removed_edge_properties_, set_vertex_labels_, removed_vertex_labels_); + + case EventType::VERTEX_UPDATE: + return AnyContainsValue(set_vertex_properties_, removed_vertex_properties_, set_vertex_labels_, + removed_vertex_labels_); + + case EventType::EDGE_UPDATE: + return AnyContainsValue(set_edge_properties_, removed_edge_properties_); + } +} + +void TriggerContextCollector::UpdateLabelMap(const VertexAccessor vertex, const storage::v3::LabelId label_id, + const LabelChange change) { + auto ®istry = GetRegistry<VertexAccessor>(); + if (!registry.should_register_updated_objects || registry.created_objects.count(vertex.Gid())) { + return; + } + + if (auto it = label_changes_.find({vertex, label_id}); it != label_changes_.end()) { + it->second = std::clamp(it->second + LabelChangeToInt(change), -1, 1); + return; + } + + label_changes_.emplace(std::make_pair(vertex, label_id), LabelChangeToInt(change)); +} + +TriggerContextCollector::TriggerContextCollector(const std::unordered_set<TriggerEventType> &event_types) { + for (const auto event_type : event_types) { + switch (event_type) { + case TriggerEventType::ANY: + vertex_registry_.should_register_created_objects = true; + edge_registry_.should_register_created_objects = true; + vertex_registry_.should_register_deleted_objects = true; + edge_registry_.should_register_deleted_objects = true; + vertex_registry_.should_register_updated_objects = true; + edge_registry_.should_register_updated_objects = true; + break; + case TriggerEventType::VERTEX_CREATE: + vertex_registry_.should_register_created_objects = true; + break; + case TriggerEventType::EDGE_CREATE: + edge_registry_.should_register_created_objects = true; + break; + case TriggerEventType::CREATE: + vertex_registry_.should_register_created_objects = true; + edge_registry_.should_register_created_objects = true; + break; + case TriggerEventType::VERTEX_DELETE: + vertex_registry_.should_register_deleted_objects = true; + break; + case TriggerEventType::EDGE_DELETE: + edge_registry_.should_register_deleted_objects = true; + break; + case TriggerEventType::DELETE: + vertex_registry_.should_register_deleted_objects = true; + edge_registry_.should_register_deleted_objects = true; + break; + case TriggerEventType::VERTEX_UPDATE: + vertex_registry_.should_register_updated_objects = true; + break; + case TriggerEventType::EDGE_UPDATE: + edge_registry_.should_register_updated_objects = true; + break; + case TriggerEventType::UPDATE: + vertex_registry_.should_register_updated_objects = true; + edge_registry_.should_register_updated_objects = true; + break; + } + } + + const auto deduce_if_should_register_created = [](auto ®istry) { + // Registering the created objects is necessary to: + // - eliminate deleted objects that were created in the same transaction + // - eliminate set/removed properties and labels of newly created objects + // because those changes are only relevant for objects that have existed before the transaction. + registry.should_register_created_objects |= + registry.should_register_updated_objects || registry.should_register_deleted_objects; + }; + + deduce_if_should_register_created(vertex_registry_); + deduce_if_should_register_created(edge_registry_); +} + +bool TriggerContextCollector::ShouldRegisterVertexLabelChange() const { + return vertex_registry_.should_register_updated_objects; +} + +void TriggerContextCollector::RegisterSetVertexLabel(const VertexAccessor &vertex, + const storage::v3::LabelId label_id) { + UpdateLabelMap(vertex, label_id, LabelChange::ADD); +} + +void TriggerContextCollector::RegisterRemovedVertexLabel(const VertexAccessor &vertex, + const storage::v3::LabelId label_id) { + UpdateLabelMap(vertex, label_id, LabelChange::REMOVE); +} + +int8_t TriggerContextCollector::LabelChangeToInt(LabelChange change) { + static_assert(std::is_same_v<std::underlying_type_t<LabelChange>, int8_t>, + "The underlying type of LabelChange doesn't match the return type!"); + return static_cast<int8_t>(change); +} + +TriggerContext TriggerContextCollector::TransformToTriggerContext() && { + auto [created_vertices, deleted_vertices, set_vertex_properties, removed_vertex_properties] = + Summarize(std::move(vertex_registry_)); + auto [set_vertex_labels, removed_vertex_labels] = LabelMapToList(std::move(label_changes_)); + auto [created_edges, deleted_edges, set_edge_properties, removed_edge_properties] = + Summarize(std::move(edge_registry_)); + + return {std::move(created_vertices), std::move(deleted_vertices), + std::move(set_vertex_properties), std::move(removed_vertex_properties), + std::move(set_vertex_labels), std::move(removed_vertex_labels), + std::move(created_edges), std::move(deleted_edges), + std::move(set_edge_properties), std::move(removed_edge_properties)}; +} + +TriggerContextCollector::LabelChangesLists TriggerContextCollector::LabelMapToList(LabelChangesMap &&label_changes) { + std::vector<detail::SetVertexLabel> set_vertex_labels; + std::vector<detail::RemovedVertexLabel> removed_vertex_labels; + + for (const auto &[key, label_state] : label_changes) { + if (label_state == LabelChangeToInt(LabelChange::ADD)) { + set_vertex_labels.emplace_back(key.first, key.second); + } else if (label_state == LabelChangeToInt(LabelChange::REMOVE)) { + removed_vertex_labels.emplace_back(key.first, key.second); + } + } + + label_changes.clear(); + + return {std::move(set_vertex_labels), std::move(removed_vertex_labels)}; +} +} // namespace memgraph::query::v2 diff --git a/src/query/v2/trigger_context.hpp b/src/query/v2/trigger_context.hpp new file mode 100644 index 000000000..5f4e62b40 --- /dev/null +++ b/src/query/v2/trigger_context.hpp @@ -0,0 +1,365 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <cstdint> +#include <map> +#include <string> +#include <string_view> +#include <type_traits> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "query/v2/db_accessor.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/view.hpp" +#include "utils/concepts.hpp" +#include "utils/fnv.hpp" + +namespace memgraph::query::v2 { +namespace detail { +template <typename T> +concept ObjectAccessor = utils::SameAsAnyOf<T, VertexAccessor, EdgeAccessor>; + +template <ObjectAccessor TAccessor> +const char *ObjectString() { + if constexpr (std::same_as<TAccessor, VertexAccessor>) { + return "vertex"; + } else { + return "edge"; + } +} + +template <ObjectAccessor TAccessor> +struct CreatedObject { + explicit CreatedObject(const TAccessor &object) : object{object} {} + + bool IsValid() const { return object.IsVisible(storage::v3::View::OLD); } + std::map<std::string, TypedValue> ToMap([[maybe_unused]] DbAccessor *dba) const { + return {{ObjectString<TAccessor>(), TypedValue{object}}}; + } + + TAccessor object; +}; + +template <ObjectAccessor TAccessor> +struct DeletedObject { + explicit DeletedObject(const TAccessor &object) : object{object} {} + + bool IsValid() const { return object.IsVisible(storage::v3::View::OLD); } + std::map<std::string, TypedValue> ToMap([[maybe_unused]] DbAccessor *dba) const { + return {{ObjectString<TAccessor>(), TypedValue{object}}}; + } + + TAccessor object; +}; + +template <ObjectAccessor TAccessor> +struct SetObjectProperty { + explicit SetObjectProperty(const TAccessor &object, storage::v3::PropertyId key, TypedValue old_value, + TypedValue new_value) + : object{object}, key{key}, old_value{std::move(old_value)}, new_value{std::move(new_value)} {} + + std::map<std::string, TypedValue> ToMap(DbAccessor *dba) const { + return {{ObjectString<TAccessor>(), TypedValue{object}}, + {"key", TypedValue{dba->PropertyToName(key)}}, + {"old", old_value}, + {"new", new_value}}; + } + + bool IsValid() const { return object.IsVisible(storage::v3::View::OLD); } + + TAccessor object; + storage::v3::PropertyId key; + TypedValue old_value; + TypedValue new_value; +}; + +template <ObjectAccessor TAccessor> +struct RemovedObjectProperty { + explicit RemovedObjectProperty(const TAccessor &object, storage::v3::PropertyId key, TypedValue old_value) + : object{object}, key{key}, old_value{std::move(old_value)} {} + + std::map<std::string, TypedValue> ToMap(DbAccessor *dba) const { + return {{ObjectString<TAccessor>(), TypedValue{object}}, + {"key", TypedValue{dba->PropertyToName(key)}}, + {"old", old_value}}; + } + + bool IsValid() const { return object.IsVisible(storage::v3::View::OLD); } + + TAccessor object; + storage::v3::PropertyId key; + TypedValue old_value; +}; + +struct SetVertexLabel { + explicit SetVertexLabel(const VertexAccessor &vertex, const storage::v3::LabelId label_id) + : object{vertex}, label_id{label_id} {} + + std::map<std::string, TypedValue> ToMap(DbAccessor *dba) const; + bool IsValid() const; + + VertexAccessor object; + storage::v3::LabelId label_id; +}; + +struct RemovedVertexLabel { + explicit RemovedVertexLabel(const VertexAccessor &vertex, const storage::v3::LabelId label_id) + : object{vertex}, label_id{label_id} {} + + std::map<std::string, TypedValue> ToMap(DbAccessor *dba) const; + bool IsValid() const; + + VertexAccessor object; + storage::v3::LabelId label_id; +}; +} // namespace detail + +enum class TriggerIdentifierTag : uint8_t { + CREATED_VERTICES, + CREATED_EDGES, + CREATED_OBJECTS, + DELETED_VERTICES, + DELETED_EDGES, + DELETED_OBJECTS, + SET_VERTEX_PROPERTIES, + SET_EDGE_PROPERTIES, + REMOVED_VERTEX_PROPERTIES, + REMOVED_EDGE_PROPERTIES, + SET_VERTEX_LABELS, + REMOVED_VERTEX_LABELS, + UPDATED_VERTICES, + UPDATED_EDGES, + UPDATED_OBJECTS +}; + +enum class TriggerEventType : uint8_t { + ANY, // Triggers on any change + VERTEX_CREATE, + EDGE_CREATE, + CREATE, + VERTEX_DELETE, + EDGE_DELETE, + DELETE, + VERTEX_UPDATE, + EDGE_UPDATE, + UPDATE +}; + +const char *TriggerEventTypeToString(TriggerEventType event_type); + +static_assert(std::is_trivially_copy_constructible_v<VertexAccessor>, + "VertexAccessor is not trivially copy constructible, move it where possible and remove this assert"); +static_assert(std::is_trivially_copy_constructible_v<EdgeAccessor>, + "EdgeAccessor is not trivially copy constructible, move it where possible and remove this asssert"); + +// Holds the information necessary for triggers +class TriggerContext { + public: + TriggerContext() = default; + TriggerContext(std::vector<detail::CreatedObject<VertexAccessor>> created_vertices, + std::vector<detail::DeletedObject<VertexAccessor>> deleted_vertices, + std::vector<detail::SetObjectProperty<VertexAccessor>> set_vertex_properties, + std::vector<detail::RemovedObjectProperty<VertexAccessor>> removed_vertex_properties, + std::vector<detail::SetVertexLabel> set_vertex_labels, + std::vector<detail::RemovedVertexLabel> removed_vertex_labels, + std::vector<detail::CreatedObject<EdgeAccessor>> created_edges, + std::vector<detail::DeletedObject<EdgeAccessor>> deleted_edges, + std::vector<detail::SetObjectProperty<EdgeAccessor>> set_edge_properties, + std::vector<detail::RemovedObjectProperty<EdgeAccessor>> removed_edge_properties) + : created_vertices_{std::move(created_vertices)}, + deleted_vertices_{std::move(deleted_vertices)}, + set_vertex_properties_{std::move(set_vertex_properties)}, + removed_vertex_properties_{std::move(removed_vertex_properties)}, + set_vertex_labels_{std::move(set_vertex_labels)}, + removed_vertex_labels_{std::move(removed_vertex_labels)}, + created_edges_{std::move(created_edges)}, + deleted_edges_{std::move(deleted_edges)}, + set_edge_properties_{std::move(set_edge_properties)}, + removed_edge_properties_{std::move(removed_edge_properties)} {} + TriggerContext(const TriggerContext &) = default; + TriggerContext(TriggerContext &&) = default; + TriggerContext &operator=(const TriggerContext &) = default; + TriggerContext &operator=(TriggerContext &&) = default; + + // Adapt the TriggerContext object inplace for a different DbAccessor + // (each derived accessor, e.g. VertexAccessor, gets adapted + // to the sent DbAccessor so they can be used safely) + void AdaptForAccessor(DbAccessor *accessor); + + // Get TypedValue for the identifier defined with tag + TypedValue GetTypedValue(TriggerIdentifierTag tag, DbAccessor *dba) const; + bool ShouldEventTrigger(TriggerEventType) const; + + private: + std::vector<detail::CreatedObject<VertexAccessor>> created_vertices_; + std::vector<detail::DeletedObject<VertexAccessor>> deleted_vertices_; + std::vector<detail::SetObjectProperty<VertexAccessor>> set_vertex_properties_; + std::vector<detail::RemovedObjectProperty<VertexAccessor>> removed_vertex_properties_; + std::vector<detail::SetVertexLabel> set_vertex_labels_; + std::vector<detail::RemovedVertexLabel> removed_vertex_labels_; + + std::vector<detail::CreatedObject<EdgeAccessor>> created_edges_; + std::vector<detail::DeletedObject<EdgeAccessor>> deleted_edges_; + std::vector<detail::SetObjectProperty<EdgeAccessor>> set_edge_properties_; + std::vector<detail::RemovedObjectProperty<EdgeAccessor>> removed_edge_properties_; +}; + +// Collects the information necessary for triggers during a single transaction run. +class TriggerContextCollector { + public: + struct HashPairWithAccessor { + template <detail::ObjectAccessor TAccessor, typename T2> + size_t operator()(const std::pair<TAccessor, T2> &pair) const { + using GidType = decltype(std::declval<TAccessor>().Gid()); + return utils::HashCombine<GidType, T2>{}(pair.first.Gid(), pair.second); + } + }; + + struct PropertyChangeInfo { + TypedValue old_value; + TypedValue new_value; + }; + + template <detail::ObjectAccessor TAccessor> + using PropertyChangesMap = + std::unordered_map<std::pair<TAccessor, storage::v3::PropertyId>, PropertyChangeInfo, HashPairWithAccessor>; + + template <detail::ObjectAccessor TAccessor> + struct Registry { + bool should_register_created_objects{false}; + bool should_register_deleted_objects{false}; + bool should_register_updated_objects{false}; // Set/removed properties (and labels for vertices) + std::unordered_map<storage::v3::Gid, detail::CreatedObject<TAccessor>> created_objects; + std::vector<detail::DeletedObject<TAccessor>> deleted_objects; + // During the transaction, a single property on a single object could be changed multiple times. + // We want to register only the global change, at the end of the transaction. The change consists of + // the value before the transaction start, and the latest value assigned throughout the transaction. + PropertyChangesMap<TAccessor> property_changes; + }; + + explicit TriggerContextCollector(const std::unordered_set<TriggerEventType> &event_types); + TriggerContextCollector(const TriggerContextCollector &) = default; + TriggerContextCollector(TriggerContextCollector &&) = default; + TriggerContextCollector &operator=(const TriggerContextCollector &) = default; + TriggerContextCollector &operator=(TriggerContextCollector &&) = default; + ~TriggerContextCollector() = default; + + template <detail::ObjectAccessor TAccessor> + bool ShouldRegisterCreatedObject() const { + return GetRegistry<TAccessor>().should_register_created_objects; + } + + template <detail::ObjectAccessor TAccessor> + void RegisterCreatedObject(const TAccessor &created_object) { + auto ®istry = GetRegistry<TAccessor>(); + if (!registry.should_register_created_objects) { + return; + } + registry.created_objects.emplace(created_object.Gid(), detail::CreatedObject{created_object}); + } + + template <detail::ObjectAccessor TAccessor> + bool ShouldRegisterDeletedObject() const { + return GetRegistry<TAccessor>().should_register_deleted_objects; + } + + template <detail::ObjectAccessor TAccessor> + void RegisterDeletedObject(const TAccessor &deleted_object) { + auto ®istry = GetRegistry<TAccessor>(); + if (!registry.should_register_deleted_objects || registry.created_objects.count(deleted_object.Gid())) { + return; + } + + registry.deleted_objects.emplace_back(deleted_object); + } + + template <detail::ObjectAccessor TAccessor> + bool ShouldRegisterObjectPropertyChange() const { + return GetRegistry<TAccessor>().should_register_updated_objects; + } + + template <detail::ObjectAccessor TAccessor> + void RegisterSetObjectProperty(const TAccessor &object, const storage::v3::PropertyId key, TypedValue old_value, + TypedValue new_value) { + auto ®istry = GetRegistry<TAccessor>(); + if (!registry.should_register_updated_objects) { + return; + } + + if (registry.created_objects.count(object.Gid())) { + return; + } + + if (auto it = registry.property_changes.find({object, key}); it != registry.property_changes.end()) { + it->second.new_value = std::move(new_value); + return; + } + + registry.property_changes.emplace(std::make_pair(object, key), + PropertyChangeInfo{std::move(old_value), std::move(new_value)}); + } + + template <detail::ObjectAccessor TAccessor> + void RegisterRemovedObjectProperty(const TAccessor &object, const storage::v3::PropertyId key, TypedValue old_value) { + // property is already removed + if (old_value.IsNull()) { + return; + } + + RegisterSetObjectProperty(object, key, std::move(old_value), TypedValue()); + } + + bool ShouldRegisterVertexLabelChange() const; + void RegisterSetVertexLabel(const VertexAccessor &vertex, storage::v3::LabelId label_id); + void RegisterRemovedVertexLabel(const VertexAccessor &vertex, storage::v3::LabelId label_id); + [[nodiscard]] TriggerContext TransformToTriggerContext() &&; + + private: + template <detail::ObjectAccessor TAccessor> + const Registry<TAccessor> &GetRegistry() const { + if constexpr (std::same_as<TAccessor, VertexAccessor>) { + return vertex_registry_; + } else { + return edge_registry_; + } + } + + template <detail::ObjectAccessor TAccessor> + Registry<TAccessor> &GetRegistry() { + return const_cast<Registry<TAccessor> &>( + const_cast<const TriggerContextCollector *>(this)->GetRegistry<TAccessor>()); + } + + using LabelChangesMap = + std::unordered_map<std::pair<VertexAccessor, storage::v3::LabelId>, int8_t, HashPairWithAccessor>; + using LabelChangesLists = std::pair<std::vector<detail::SetVertexLabel>, std::vector<detail::RemovedVertexLabel>>; + + enum class LabelChange : int8_t { REMOVE = -1, ADD = 1 }; + + static int8_t LabelChangeToInt(LabelChange change); + + [[nodiscard]] static LabelChangesLists LabelMapToList(LabelChangesMap &&label_changes); + + void UpdateLabelMap(VertexAccessor vertex, storage::v3::LabelId label_id, LabelChange change); + + Registry<VertexAccessor> vertex_registry_; + Registry<EdgeAccessor> edge_registry_; + // During the transaction, a single label on a single vertex could be added and removed multiple times. + // We want to register only the global change, at the end of the transaction. The change consists of + // the state of the label before the transaction start, and the latest state assigned throughout the transaction. + LabelChangesMap label_changes_; +}; +} // namespace memgraph::query::v2 diff --git a/src/query/v2/typed_value.cpp b/src/query/v2/typed_value.cpp new file mode 100644 index 000000000..b6ded09d8 --- /dev/null +++ b/src/query/v2/typed_value.cpp @@ -0,0 +1,1108 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/v2/typed_value.hpp" + +#include <fmt/format.h> +#include <chrono> +#include <cmath> +#include <iostream> +#include <memory> +#include <string_view> +#include <utility> + +#include "storage/v3/temporal.hpp" +#include "utils/exceptions.hpp" +#include "utils/fnv.hpp" + +namespace memgraph::query::v2 { + +TypedValue::TypedValue(const storage::v3::PropertyValue &value) + // TODO: MemoryResource in storage::v3::PropertyValue + : TypedValue(value, utils::NewDeleteResource()) {} + +TypedValue::TypedValue(const storage::v3::PropertyValue &value, utils::MemoryResource *memory) : memory_(memory) { + switch (value.type()) { + case storage::v3::PropertyValue::Type::Null: + type_ = Type::Null; + return; + case storage::v3::PropertyValue::Type::Bool: + type_ = Type::Bool; + bool_v = value.ValueBool(); + return; + case storage::v3::PropertyValue::Type::Int: + type_ = Type::Int; + int_v = value.ValueInt(); + return; + case storage::v3::PropertyValue::Type::Double: + type_ = Type::Double; + double_v = value.ValueDouble(); + return; + case storage::v3::PropertyValue::Type::String: + type_ = Type::String; + new (&string_v) TString(value.ValueString(), memory_); + return; + case storage::v3::PropertyValue::Type::List: { + type_ = Type::List; + const auto &vec = value.ValueList(); + new (&list_v) TVector(memory_); + list_v.reserve(vec.size()); + for (const auto &v : vec) list_v.emplace_back(v); + return; + } + case storage::v3::PropertyValue::Type::Map: { + type_ = Type::Map; + const auto &map = value.ValueMap(); + new (&map_v) TMap(memory_); + for (const auto &kv : map) map_v.emplace(kv.first, kv.second); + return; + } + case storage::v3::PropertyValue::Type::TemporalData: { + const auto &temporal_data = value.ValueTemporalData(); + switch (temporal_data.type) { + case storage::v3::TemporalType::Date: { + type_ = Type::Date; + new (&date_v) utils::Date(temporal_data.microseconds); + break; + } + case storage::v3::TemporalType::LocalTime: { + type_ = Type::LocalTime; + new (&local_time_v) utils::LocalTime(temporal_data.microseconds); + break; + } + case storage::v3::TemporalType::LocalDateTime: { + type_ = Type::LocalDateTime; + new (&local_date_time_v) utils::LocalDateTime(temporal_data.microseconds); + break; + } + case storage::v3::TemporalType::Duration: { + type_ = Type::Duration; + new (&duration_v) utils::Duration(temporal_data.microseconds); + break; + } + } + return; + } + } + LOG_FATAL("Unsupported type"); +} + +TypedValue::TypedValue(storage::v3::PropertyValue &&other) /* noexcept */ + // TODO: MemoryResource in storage::v3::PropertyValue, so this can be noexcept + : TypedValue(std::move(other), utils::NewDeleteResource()) {} + +TypedValue::TypedValue(storage::v3::PropertyValue &&other, utils::MemoryResource *memory) : memory_(memory) { + switch (other.type()) { + case storage::v3::PropertyValue::Type::Null: + type_ = Type::Null; + break; + case storage::v3::PropertyValue::Type::Bool: + type_ = Type::Bool; + bool_v = other.ValueBool(); + break; + case storage::v3::PropertyValue::Type::Int: + type_ = Type::Int; + int_v = other.ValueInt(); + break; + case storage::v3::PropertyValue::Type::Double: + type_ = Type::Double; + double_v = other.ValueDouble(); + break; + case storage::v3::PropertyValue::Type::String: + type_ = Type::String; + new (&string_v) TString(other.ValueString(), memory_); + break; + case storage::v3::PropertyValue::Type::List: { + type_ = Type::List; + auto &vec = other.ValueList(); + new (&list_v) TVector(memory_); + list_v.reserve(vec.size()); + for (auto &v : vec) list_v.emplace_back(std::move(v)); + break; + } + case storage::v3::PropertyValue::Type::Map: { + type_ = Type::Map; + auto &map = other.ValueMap(); + new (&map_v) TMap(memory_); + for (auto &kv : map) map_v.emplace(kv.first, std::move(kv.second)); + break; + } + case storage::v3::PropertyValue::Type::TemporalData: { + const auto &temporal_data = other.ValueTemporalData(); + switch (temporal_data.type) { + case storage::v3::TemporalType::Date: { + type_ = Type::Date; + new (&date_v) utils::Date(temporal_data.microseconds); + break; + } + case storage::v3::TemporalType::LocalTime: { + type_ = Type::LocalTime; + new (&local_time_v) utils::LocalTime(temporal_data.microseconds); + break; + } + case storage::v3::TemporalType::LocalDateTime: { + type_ = Type::LocalDateTime; + new (&local_date_time_v) utils::LocalDateTime(temporal_data.microseconds); + break; + } + case storage::v3::TemporalType::Duration: { + type_ = Type::Duration; + new (&duration_v) utils::Duration(temporal_data.microseconds); + break; + } + } + break; + } + } + + other = storage::v3::PropertyValue(); +} + +TypedValue::TypedValue(const TypedValue &other) + : TypedValue(other, std::allocator_traits<utils::Allocator<TypedValue>>::select_on_container_copy_construction( + other.memory_) + .GetMemoryResource()) {} + +TypedValue::TypedValue(const TypedValue &other, utils::MemoryResource *memory) : memory_(memory), type_(other.type_) { + switch (other.type_) { + case TypedValue::Type::Null: + return; + case TypedValue::Type::Bool: + this->bool_v = other.bool_v; + return; + case Type::Int: + this->int_v = other.int_v; + return; + case Type::Double: + this->double_v = other.double_v; + return; + case TypedValue::Type::String: + new (&string_v) TString(other.string_v, memory_); + return; + case Type::List: + new (&list_v) TVector(other.list_v, memory_); + return; + case Type::Map: + new (&map_v) TMap(other.map_v, memory_); + return; + case Type::Vertex: + new (&vertex_v) VertexAccessor(other.vertex_v); + return; + case Type::Edge: + new (&edge_v) EdgeAccessor(other.edge_v); + return; + case Type::Path: + new (&path_v) Path(other.path_v, memory_); + return; + case Type::Date: + new (&date_v) utils::Date(other.date_v); + return; + case Type::LocalTime: + new (&local_time_v) utils::LocalTime(other.local_time_v); + return; + case Type::LocalDateTime: + new (&local_date_time_v) utils::LocalDateTime(other.local_date_time_v); + return; + case Type::Duration: + new (&duration_v) utils::Duration(other.duration_v); + return; + } + LOG_FATAL("Unsupported TypedValue::Type"); +} + +TypedValue::TypedValue(TypedValue &&other) noexcept : TypedValue(std::move(other), other.memory_) {} + +TypedValue::TypedValue(TypedValue &&other, utils::MemoryResource *memory) : memory_(memory), type_(other.type_) { + switch (other.type_) { + case TypedValue::Type::Null: + break; + case TypedValue::Type::Bool: + this->bool_v = other.bool_v; + break; + case Type::Int: + this->int_v = other.int_v; + break; + case Type::Double: + this->double_v = other.double_v; + break; + case TypedValue::Type::String: + new (&string_v) TString(std::move(other.string_v), memory_); + break; + case Type::List: + new (&list_v) TVector(std::move(other.list_v), memory_); + break; + case Type::Map: + new (&map_v) TMap(std::move(other.map_v), memory_); + break; + case Type::Vertex: + new (&vertex_v) VertexAccessor(std::move(other.vertex_v)); + break; + case Type::Edge: + new (&edge_v) EdgeAccessor(std::move(other.edge_v)); + break; + case Type::Path: + new (&path_v) Path(std::move(other.path_v), memory_); + break; + case Type::Date: + new (&date_v) utils::Date(other.date_v); + break; + case Type::LocalTime: + new (&local_time_v) utils::LocalTime(other.local_time_v); + break; + case Type::LocalDateTime: + new (&local_date_time_v) utils::LocalDateTime(other.local_date_time_v); + break; + case Type::Duration: + new (&duration_v) utils::Duration(other.duration_v); + break; + } + other.DestroyValue(); +} + +TypedValue::operator storage::v3::PropertyValue() const { + switch (type_) { + case TypedValue::Type::Null: + return storage::v3::PropertyValue(); + case TypedValue::Type::Bool: + return storage::v3::PropertyValue(bool_v); + case TypedValue::Type::Int: + return storage::v3::PropertyValue(int_v); + case TypedValue::Type::Double: + return storage::v3::PropertyValue(double_v); + case TypedValue::Type::String: + return storage::v3::PropertyValue(std::string(string_v)); + case TypedValue::Type::List: + return storage::v3::PropertyValue(std::vector<storage::v3::PropertyValue>(list_v.begin(), list_v.end())); + case TypedValue::Type::Map: { + std::map<std::string, storage::v3::PropertyValue> map; + for (const auto &kv : map_v) map.emplace(kv.first, kv.second); + return storage::v3::PropertyValue(std::move(map)); + } + case Type::Date: + return storage::v3::PropertyValue( + storage::v3::TemporalData{storage::v3::TemporalType::Date, date_v.MicrosecondsSinceEpoch()}); + case Type::LocalTime: + return storage::v3::PropertyValue( + storage::v3::TemporalData{storage::v3::TemporalType::LocalTime, local_time_v.MicrosecondsSinceEpoch()}); + case Type::LocalDateTime: + return storage::v3::PropertyValue(storage::v3::TemporalData{storage::v3::TemporalType::LocalDateTime, + local_date_time_v.MicrosecondsSinceEpoch()}); + case Type::Duration: + return storage::v3::PropertyValue( + storage::v3::TemporalData{storage::v3::TemporalType::Duration, duration_v.microseconds}); + default: + break; + } + throw TypedValueException("Unsupported conversion from TypedValue to PropertyValue"); +} + +#define DEFINE_VALUE_AND_TYPE_GETTERS(type_param, type_enum, field) \ + type_param &TypedValue::Value##type_enum() { \ + if (type_ != Type::type_enum) \ + throw TypedValueException("TypedValue is of type '{}', not '{}'", type_, Type::type_enum); \ + return field; \ + } \ + \ + const type_param &TypedValue::Value##type_enum() const { \ + if (type_ != Type::type_enum) \ + throw TypedValueException("TypedValue is of type '{}', not '{}'", type_, Type::type_enum); \ + return field; \ + } \ + \ + bool TypedValue::Is##type_enum() const { return type_ == Type::type_enum; } + +DEFINE_VALUE_AND_TYPE_GETTERS(bool, Bool, bool_v) +DEFINE_VALUE_AND_TYPE_GETTERS(int64_t, Int, int_v) +DEFINE_VALUE_AND_TYPE_GETTERS(double, Double, double_v) +DEFINE_VALUE_AND_TYPE_GETTERS(TypedValue::TString, String, string_v) +DEFINE_VALUE_AND_TYPE_GETTERS(TypedValue::TVector, List, list_v) +DEFINE_VALUE_AND_TYPE_GETTERS(TypedValue::TMap, Map, map_v) +DEFINE_VALUE_AND_TYPE_GETTERS(VertexAccessor, Vertex, vertex_v) +DEFINE_VALUE_AND_TYPE_GETTERS(EdgeAccessor, Edge, edge_v) +DEFINE_VALUE_AND_TYPE_GETTERS(Path, Path, path_v) +DEFINE_VALUE_AND_TYPE_GETTERS(utils::Date, Date, date_v) +DEFINE_VALUE_AND_TYPE_GETTERS(utils::LocalTime, LocalTime, local_time_v) +DEFINE_VALUE_AND_TYPE_GETTERS(utils::LocalDateTime, LocalDateTime, local_date_time_v) +DEFINE_VALUE_AND_TYPE_GETTERS(utils::Duration, Duration, duration_v) + +#undef DEFINE_VALUE_AND_TYPE_GETTERS + +bool TypedValue::IsNull() const { return type_ == Type::Null; } + +bool TypedValue::IsNumeric() const { return IsInt() || IsDouble(); } + +bool TypedValue::IsPropertyValue() const { + switch (type_) { + case Type::Null: + case Type::Bool: + case Type::Int: + case Type::Double: + case Type::String: + case Type::List: + case Type::Map: + case Type::Date: + case Type::LocalTime: + case Type::LocalDateTime: + case Type::Duration: + return true; + default: + return false; + } +} + +std::ostream &operator<<(std::ostream &os, const TypedValue::Type &type) { + switch (type) { + case TypedValue::Type::Null: + return os << "null"; + case TypedValue::Type::Bool: + return os << "bool"; + case TypedValue::Type::Int: + return os << "int"; + case TypedValue::Type::Double: + return os << "double"; + case TypedValue::Type::String: + return os << "string"; + case TypedValue::Type::List: + return os << "list"; + case TypedValue::Type::Map: + return os << "map"; + case TypedValue::Type::Vertex: + return os << "vertex"; + case TypedValue::Type::Edge: + return os << "edge"; + case TypedValue::Type::Path: + return os << "path"; + case TypedValue::Type::Date: + return os << "date"; + case TypedValue::Type::LocalTime: + return os << "local_time"; + case TypedValue::Type::LocalDateTime: + return os << "local_date_time"; + case TypedValue::Type::Duration: + return os << "duration"; + } + LOG_FATAL("Unsupported TypedValue::Type"); +} + +#define DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(type_param, typed_value_type, member) \ + TypedValue &TypedValue::operator=(type_param other) { \ + if (this->type_ == TypedValue::Type::typed_value_type) { \ + this->member = other; \ + } else { \ + *this = TypedValue(other, memory_); \ + } \ + \ + return *this; \ + } + +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(const char *, String, string_v) +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(int, Int, int_v) +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(bool, Bool, bool_v) +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(int64_t, Int, int_v) +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(double, Double, double_v) +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(const std::string_view, String, string_v) +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(const TypedValue::TVector &, List, list_v) + +TypedValue &TypedValue::operator=(const std::vector<TypedValue> &other) { + if (type_ == Type::List) { + list_v.reserve(other.size()); + list_v.assign(other.begin(), other.end()); + } else { + *this = TypedValue(other, memory_); + } + return *this; +} + +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(const TypedValue::TMap &, Map, map_v) + +TypedValue &TypedValue::operator=(const std::map<std::string, TypedValue> &other) { + if (type_ == Type::Map) { + map_v.clear(); + for (const auto &kv : other) map_v.emplace(kv.first, kv.second); + } else { + *this = TypedValue(other, memory_); + } + return *this; +} + +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(const VertexAccessor &, Vertex, vertex_v) +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(const EdgeAccessor &, Edge, edge_v) +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(const Path &, Path, path_v) +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(const utils::Date &, Date, date_v) +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(const utils::LocalTime &, LocalTime, local_time_v) +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(const utils::LocalDateTime &, LocalDateTime, local_date_time_v) +DEFINE_TYPED_VALUE_COPY_ASSIGNMENT(const utils::Duration &, Duration, duration_v) + +#undef DEFINE_TYPED_VALUE_COPY_ASSIGNMENT + +#define DEFINE_TYPED_VALUE_MOVE_ASSIGNMENT(type_param, typed_value_type, member) \ + TypedValue &TypedValue::operator=(type_param &&other) { \ + if (this->type_ == TypedValue::Type::typed_value_type) { \ + this->member = std::move(other); \ + } else { \ + *this = TypedValue(std::move(other), memory_); \ + } \ + return *this; \ + } + +DEFINE_TYPED_VALUE_MOVE_ASSIGNMENT(TypedValue::TString, String, string_v) +DEFINE_TYPED_VALUE_MOVE_ASSIGNMENT(TypedValue::TVector, List, list_v) + +TypedValue &TypedValue::operator=(std::vector<TypedValue> &&other) { + if (type_ == Type::List) { + list_v.reserve(other.size()); + list_v.assign(std::make_move_iterator(other.begin()), std::make_move_iterator(other.end())); + } else { + *this = TypedValue(std::move(other), memory_); + } + return *this; +} + +DEFINE_TYPED_VALUE_MOVE_ASSIGNMENT(TMap, Map, map_v) + +TypedValue &TypedValue::operator=(std::map<std::string, TypedValue> &&other) { + if (type_ == Type::Map) { + map_v.clear(); + for (auto &kv : other) map_v.emplace(kv.first, std::move(kv.second)); + } else { + *this = TypedValue(std::move(other), memory_); + } + return *this; +} + +DEFINE_TYPED_VALUE_MOVE_ASSIGNMENT(Path, Path, path_v) + +#undef DEFINE_TYPED_VALUE_MOVE_ASSIGNMENT + +TypedValue &TypedValue::operator=(const TypedValue &other) { + if (this != &other) { + // NOTE: STL uses + // std::allocator_traits<>::propagate_on_container_copy_assignment to + // determine whether to take the allocator from `other`, or use the one in + // `this`. Our utils::Allocator never propagates, so we use the allocator + // from `this`. + static_assert(!std::allocator_traits<utils::Allocator<TypedValue>>::propagate_on_container_copy_assignment::value, + "Allocator propagation not implemented"); + DestroyValue(); + type_ = other.type_; + switch (other.type_) { + case TypedValue::Type::Null: + return *this; + case TypedValue::Type::Bool: + this->bool_v = other.bool_v; + return *this; + case TypedValue::Type::Int: + this->int_v = other.int_v; + return *this; + case TypedValue::Type::Double: + this->double_v = other.double_v; + return *this; + case TypedValue::Type::String: + new (&string_v) TString(other.string_v, memory_); + return *this; + case TypedValue::Type::List: + new (&list_v) TVector(other.list_v, memory_); + return *this; + case TypedValue::Type::Map: + new (&map_v) TMap(other.map_v, memory_); + return *this; + case TypedValue::Type::Vertex: + new (&vertex_v) VertexAccessor(other.vertex_v); + return *this; + case TypedValue::Type::Edge: + new (&edge_v) EdgeAccessor(other.edge_v); + return *this; + case TypedValue::Type::Path: + new (&path_v) Path(other.path_v, memory_); + return *this; + case Type::Date: + new (&date_v) utils::Date(other.date_v); + return *this; + case Type::LocalTime: + new (&local_time_v) utils::LocalTime(other.local_time_v); + return *this; + case Type::LocalDateTime: + new (&local_date_time_v) utils::LocalDateTime(other.local_date_time_v); + return *this; + case Type::Duration: + new (&duration_v) utils::Duration(other.duration_v); + return *this; + } + LOG_FATAL("Unsupported TypedValue::Type"); + } + return *this; +} + +TypedValue &TypedValue::operator=(TypedValue &&other) noexcept(false) { + if (this != &other) { + DestroyValue(); + // NOTE: STL uses + // std::allocator_traits<>::propagate_on_container_move_assignment to + // determine whether to take the allocator from `other`, or use the one in + // `this`. Our utils::Allocator never propagates, so we use the allocator + // from `this`. + static_assert(!std::allocator_traits<utils::Allocator<TypedValue>>::propagate_on_container_move_assignment::value, + "Allocator propagation not implemented"); + type_ = other.type_; + switch (other.type_) { + case TypedValue::Type::Null: + break; + case TypedValue::Type::Bool: + this->bool_v = other.bool_v; + break; + case TypedValue::Type::Int: + this->int_v = other.int_v; + break; + case TypedValue::Type::Double: + this->double_v = other.double_v; + break; + case TypedValue::Type::String: + new (&string_v) TString(std::move(other.string_v), memory_); + break; + case TypedValue::Type::List: + new (&list_v) TVector(std::move(other.list_v), memory_); + break; + case TypedValue::Type::Map: + new (&map_v) TMap(std::move(other.map_v), memory_); + break; + case TypedValue::Type::Vertex: + new (&vertex_v) VertexAccessor(std::move(other.vertex_v)); + break; + case TypedValue::Type::Edge: + new (&edge_v) EdgeAccessor(std::move(other.edge_v)); + break; + case TypedValue::Type::Path: + new (&path_v) Path(std::move(other.path_v), memory_); + break; + case Type::Date: + new (&date_v) utils::Date(other.date_v); + break; + case Type::LocalTime: + new (&local_time_v) utils::LocalTime(other.local_time_v); + break; + case Type::LocalDateTime: + new (&local_date_time_v) utils::LocalDateTime(other.local_date_time_v); + break; + case Type::Duration: + new (&duration_v) utils::Duration(other.duration_v); + break; + } + other.DestroyValue(); + } + return *this; +} + +void TypedValue::DestroyValue() { + switch (type_) { + // destructor for primitive types does nothing + case Type::Null: + case Type::Bool: + case Type::Int: + case Type::Double: + break; + + // we need to call destructors for non primitive types since we used + // placement new + case Type::String: + string_v.~TString(); + break; + case Type::List: + list_v.~TVector(); + break; + case Type::Map: + map_v.~TMap(); + break; + case Type::Vertex: + vertex_v.~VertexAccessor(); + break; + case Type::Edge: + edge_v.~EdgeAccessor(); + break; + case Type::Path: + path_v.~Path(); + break; + case Type::Date: + case Type::LocalTime: + case Type::LocalDateTime: + case Type::Duration: + break; + } + + type_ = TypedValue::Type::Null; +} + +TypedValue::~TypedValue() { DestroyValue(); } + +/** + * Returns the double value of a value. + * The value MUST be either Double or Int. + * + * @param value + * @return + */ +double ToDouble(const TypedValue &value) { + switch (value.type()) { + case TypedValue::Type::Int: + return (double)value.ValueInt(); + case TypedValue::Type::Double: + return value.ValueDouble(); + default: + throw TypedValueException("Unsupported TypedValue::Type conversion to double"); + } +} + +namespace { +bool IsTemporalType(const TypedValue::Type type) { + static constexpr std::array temporal_types{TypedValue::Type::Date, TypedValue::Type::LocalTime, + TypedValue::Type::LocalDateTime, TypedValue::Type::Duration}; + return std::any_of(temporal_types.begin(), temporal_types.end(), + [type](const auto temporal_type) { return temporal_type == type; }); +}; +} // namespace + +TypedValue operator<(const TypedValue &a, const TypedValue &b) { + auto is_legal = [](TypedValue::Type type) { + switch (type) { + case TypedValue::Type::Null: + case TypedValue::Type::Int: + case TypedValue::Type::Double: + case TypedValue::Type::String: + case TypedValue::Type::Date: + case TypedValue::Type::LocalTime: + case TypedValue::Type::LocalDateTime: + case TypedValue::Type::Duration: + return true; + default: + return false; + } + }; + if (!is_legal(a.type()) || !is_legal(b.type())) + throw TypedValueException("Invalid 'less' operand types({} + {})", a.type(), b.type()); + + if (a.IsNull() || b.IsNull()) return TypedValue(a.GetMemoryResource()); + + if (a.IsString() || b.IsString()) { + if (a.type() != b.type()) { + throw TypedValueException("Invalid 'less' operand types({} + {})", a.type(), b.type()); + } else { + return TypedValue(a.ValueString() < b.ValueString(), a.GetMemoryResource()); + } + } + + if (IsTemporalType(a.type()) || IsTemporalType(b.type())) { + if (a.type() != b.type()) { + throw TypedValueException("Invalid 'less' operand types({} + {})", a.type(), b.type()); + } + + switch (a.type()) { + case TypedValue::Type::Date: + // NOLINTNEXTLINE(modernize-use-nullptr) + return TypedValue(a.ValueDate() < b.ValueDate(), a.GetMemoryResource()); + case TypedValue::Type::LocalTime: + // NOLINTNEXTLINE(modernize-use-nullptr) + return TypedValue(a.ValueLocalTime() < b.ValueLocalTime(), a.GetMemoryResource()); + case TypedValue::Type::LocalDateTime: + // NOLINTNEXTLINE(modernize-use-nullptr) + return TypedValue(a.ValueLocalDateTime() < b.ValueLocalDateTime(), a.GetMemoryResource()); + case TypedValue::Type::Duration: + // NOLINTNEXTLINE(modernize-use-nullptr) + return TypedValue(a.ValueDuration() < b.ValueDuration(), a.GetMemoryResource()); + default: + LOG_FATAL("Invalid temporal type"); + } + } + + // at this point we only have int and double + if (a.IsDouble() || b.IsDouble()) { + return TypedValue(ToDouble(a) < ToDouble(b), a.GetMemoryResource()); + } else { + return TypedValue(a.ValueInt() < b.ValueInt(), a.GetMemoryResource()); + } +} + +TypedValue operator==(const TypedValue &a, const TypedValue &b) { + if (a.IsNull() || b.IsNull()) return TypedValue(a.GetMemoryResource()); + + // check we have values that can be compared + // this means that either they're the same type, or (int, double) combo + if ((a.type() != b.type() && !(a.IsNumeric() && b.IsNumeric()))) return TypedValue(false, a.GetMemoryResource()); + + switch (a.type()) { + case TypedValue::Type::Bool: + return TypedValue(a.ValueBool() == b.ValueBool(), a.GetMemoryResource()); + case TypedValue::Type::Int: + if (b.IsDouble()) + return TypedValue(ToDouble(a) == ToDouble(b), a.GetMemoryResource()); + else + return TypedValue(a.ValueInt() == b.ValueInt(), a.GetMemoryResource()); + case TypedValue::Type::Double: + return TypedValue(ToDouble(a) == ToDouble(b), a.GetMemoryResource()); + case TypedValue::Type::String: + return TypedValue(a.ValueString() == b.ValueString(), a.GetMemoryResource()); + case TypedValue::Type::Vertex: + return TypedValue(a.ValueVertex() == b.ValueVertex(), a.GetMemoryResource()); + case TypedValue::Type::Edge: + return TypedValue(a.ValueEdge() == b.ValueEdge(), a.GetMemoryResource()); + case TypedValue::Type::List: { + // We are not compatible with neo4j at this point. In neo4j 2 = [2] + // compares + // to true. That is not the end of unselfishness of developers at neo4j so + // they allow us to use as many braces as we want to get to the truth in + // list comparison, so [[2]] = [[[[[[2]]]]]] compares to true in neo4j as + // well. Because, why not? + // At memgraph we prefer sanity so [1,2] = [1,2] compares to true and + // 2 = [2] compares to false. + const auto &list_a = a.ValueList(); + const auto &list_b = b.ValueList(); + if (list_a.size() != list_b.size()) return TypedValue(false, a.GetMemoryResource()); + // two arrays are considered equal (by neo) if all their + // elements are bool-equal. this means that: + // [1] == [null] -> false + // [null] == [null] -> true + // in that sense array-comparison never results in Null + return TypedValue(std::equal(list_a.begin(), list_a.end(), list_b.begin(), TypedValue::BoolEqual{}), + a.GetMemoryResource()); + } + case TypedValue::Type::Map: { + const auto &map_a = a.ValueMap(); + const auto &map_b = b.ValueMap(); + if (map_a.size() != map_b.size()) return TypedValue(false, a.GetMemoryResource()); + for (const auto &kv_a : map_a) { + auto found_b_it = map_b.find(kv_a.first); + if (found_b_it == map_b.end()) return TypedValue(false, a.GetMemoryResource()); + TypedValue comparison = kv_a.second == found_b_it->second; + if (comparison.IsNull() || !comparison.ValueBool()) return TypedValue(false, a.GetMemoryResource()); + } + return TypedValue(true, a.GetMemoryResource()); + } + case TypedValue::Type::Path: + return TypedValue(a.ValuePath() == b.ValuePath(), a.GetMemoryResource()); + case TypedValue::Type::Date: + return TypedValue(a.ValueDate() == b.ValueDate(), a.GetMemoryResource()); + case TypedValue::Type::LocalTime: + return TypedValue(a.ValueLocalTime() == b.ValueLocalTime(), a.GetMemoryResource()); + case TypedValue::Type::LocalDateTime: + return TypedValue(a.ValueLocalDateTime() == b.ValueLocalDateTime(), a.GetMemoryResource()); + case TypedValue::Type::Duration: + return TypedValue(a.ValueDuration() == b.ValueDuration(), a.GetMemoryResource()); + default: + LOG_FATAL("Unhandled comparison for types"); + } +} + +TypedValue operator!(const TypedValue &a) { + if (a.IsNull()) return TypedValue(a.GetMemoryResource()); + if (a.IsBool()) return TypedValue(!a.ValueBool(), a.GetMemoryResource()); + throw TypedValueException("Invalid logical not operand type (!{})", a.type()); +} + +/** + * Turns a numeric or string value into a string. + * + * @param value a value. + * @return A string. + */ +std::string ValueToString(const TypedValue &value) { + // TODO: Should this allocate a string through value.GetMemoryResource()? + if (value.IsString()) return std::string(value.ValueString()); + if (value.IsInt()) return std::to_string(value.ValueInt()); + if (value.IsDouble()) return fmt::format("{}", value.ValueDouble()); + // unsupported situations + throw TypedValueException("Unsupported TypedValue::Type conversion to string"); +} + +TypedValue operator-(const TypedValue &a) { + if (a.IsNull()) return TypedValue(a.GetMemoryResource()); + if (a.IsInt()) return TypedValue(-a.ValueInt(), a.GetMemoryResource()); + if (a.IsDouble()) return TypedValue(-a.ValueDouble(), a.GetMemoryResource()); + if (a.IsDuration()) return TypedValue(-a.ValueDuration(), a.GetMemoryResource()); + throw TypedValueException("Invalid unary minus operand type (-{})", a.type()); +} + +TypedValue operator+(const TypedValue &a) { + if (a.IsNull()) return TypedValue(a.GetMemoryResource()); + if (a.IsInt()) return TypedValue(+a.ValueInt(), a.GetMemoryResource()); + if (a.IsDouble()) return TypedValue(+a.ValueDouble(), a.GetMemoryResource()); + throw TypedValueException("Invalid unary plus operand type (+{})", a.type()); +} + +/** + * Raises a TypedValueException if the given values do not support arithmetic + * operations. If they do, nothing happens. + * + * @param a First value. + * @param b Second value. + * @param string_ok If or not for the given operation it's valid to work with + * String values (typically it's OK only for sum). + * @param op_name Name of the operation, used only for exception description, + * if raised. + */ +inline void EnsureArithmeticallyOk(const TypedValue &a, const TypedValue &b, bool string_ok, + const std::string &op_name) { + auto is_legal = [string_ok](const TypedValue &value) { + return value.IsNumeric() || (string_ok && value.type() == TypedValue::Type::String); + }; + + // Note that List and Null can also be valid in arithmetic ops. They are not + // checked here because they are handled before this check is performed in + // arithmetic op implementations. + + if (!is_legal(a) || !is_legal(b)) + throw TypedValueException("Invalid {} operand types {}, {}", op_name, a.type(), b.type()); +} + +namespace { + +std::optional<TypedValue> MaybeDoTemporalTypeAddition(const TypedValue &a, const TypedValue &b) { + // Duration + if (a.IsDuration() && b.IsDuration()) { + return TypedValue(a.ValueDuration() + b.ValueDuration()); + } + // Date + if (a.IsDate() && b.IsDuration()) { + return TypedValue(a.ValueDate() + b.ValueDuration()); + } + if (a.IsDuration() && b.IsDate()) { + return TypedValue(a.ValueDuration() + b.ValueDate()); + } + // LocalTime + if (a.IsLocalTime() && b.IsDuration()) { + return TypedValue(a.ValueLocalTime() + b.ValueDuration()); + } + if (a.IsDuration() && b.IsLocalTime()) { + return TypedValue(a.ValueDuration() + b.ValueLocalTime()); + } + // LocalDateTime + if (a.IsLocalDateTime() && b.IsDuration()) { + return TypedValue(a.ValueLocalDateTime() + b.ValueDuration()); + } + if (a.IsDuration() && b.IsLocalDateTime()) { + return TypedValue(a.ValueDuration() + b.ValueLocalDateTime()); + } + return std::nullopt; +} + +std::optional<TypedValue> MaybeDoTemporalTypeSubtraction(const TypedValue &a, const TypedValue &b) { + // Duration + if (a.IsDuration() && b.IsDuration()) { + return TypedValue(a.ValueDuration() - b.ValueDuration()); + } + // Date + if (a.IsDate() && b.IsDuration()) { + return TypedValue(a.ValueDate() - b.ValueDuration()); + } + if (a.IsDate() && b.IsDate()) { + return TypedValue(a.ValueDate() - b.ValueDate()); + } + // LocalTime + if (a.IsLocalTime() && b.IsDuration()) { + return TypedValue(a.ValueLocalTime() - b.ValueDuration()); + } + if (a.IsLocalTime() && b.IsLocalTime()) { + return TypedValue(a.ValueLocalTime() - b.ValueLocalTime()); + } + // LocalDateTime + if (a.IsLocalDateTime() && b.IsDuration()) { + return TypedValue(a.ValueLocalDateTime() - b.ValueDuration()); + } + if (a.IsLocalDateTime() && b.IsLocalDateTime()) { + return TypedValue(a.ValueLocalDateTime() - b.ValueLocalDateTime()); + } + return std::nullopt; +} +} // namespace + +TypedValue operator+(const TypedValue &a, const TypedValue &b) { + if (a.IsNull() || b.IsNull()) return TypedValue(a.GetMemoryResource()); + + if (a.IsList() || b.IsList()) { + TypedValue::TVector list(a.GetMemoryResource()); + auto append_list = [&list](const TypedValue &v) { + if (v.IsList()) { + auto list2 = v.ValueList(); + list.insert(list.end(), list2.begin(), list2.end()); + } else { + list.push_back(v); + } + }; + append_list(a); + append_list(b); + return TypedValue(std::move(list), a.GetMemoryResource()); + } + + if (const auto maybe_add = MaybeDoTemporalTypeAddition(a, b); maybe_add) { + return *maybe_add; + } + + EnsureArithmeticallyOk(a, b, true, "addition"); + // no more Bool nor Null, summing works on anything from here onward + + if (a.IsString() || b.IsString()) return TypedValue(ValueToString(a) + ValueToString(b), a.GetMemoryResource()); + + // at this point we only have int and double + if (a.IsDouble() || b.IsDouble()) { + return TypedValue(ToDouble(a) + ToDouble(b), a.GetMemoryResource()); + } + return TypedValue(a.ValueInt() + b.ValueInt(), a.GetMemoryResource()); +} + +TypedValue operator-(const TypedValue &a, const TypedValue &b) { + if (a.IsNull() || b.IsNull()) return TypedValue(a.GetMemoryResource()); + if (const auto maybe_sub = MaybeDoTemporalTypeSubtraction(a, b); maybe_sub) { + return *maybe_sub; + } + EnsureArithmeticallyOk(a, b, true, "subraction"); + // at this point we only have int and double + if (a.IsDouble() || b.IsDouble()) { + return TypedValue(ToDouble(a) - ToDouble(b), a.GetMemoryResource()); + } + return TypedValue(a.ValueInt() - b.ValueInt(), a.GetMemoryResource()); +} + +TypedValue operator/(const TypedValue &a, const TypedValue &b) { + if (a.IsNull() || b.IsNull()) return TypedValue(a.GetMemoryResource()); + EnsureArithmeticallyOk(a, b, false, "division"); + + // at this point we only have int and double + if (a.IsDouble() || b.IsDouble()) { + return TypedValue(ToDouble(a) / ToDouble(b), a.GetMemoryResource()); + } else { + if (b.ValueInt() == 0LL) throw TypedValueException("Division by zero"); + return TypedValue(a.ValueInt() / b.ValueInt(), a.GetMemoryResource()); + } +} + +TypedValue operator*(const TypedValue &a, const TypedValue &b) { + if (a.IsNull() || b.IsNull()) return TypedValue(a.GetMemoryResource()); + EnsureArithmeticallyOk(a, b, false, "multiplication"); + + // at this point we only have int and double + if (a.IsDouble() || b.IsDouble()) { + return TypedValue(ToDouble(a) * ToDouble(b), a.GetMemoryResource()); + } else { + return TypedValue(a.ValueInt() * b.ValueInt(), a.GetMemoryResource()); + } +} + +TypedValue operator%(const TypedValue &a, const TypedValue &b) { + if (a.IsNull() || b.IsNull()) return TypedValue(a.GetMemoryResource()); + EnsureArithmeticallyOk(a, b, false, "modulo"); + + // at this point we only have int and double + if (a.IsDouble() || b.IsDouble()) { + return TypedValue(static_cast<double>(fmod(ToDouble(a), ToDouble(b))), a.GetMemoryResource()); + } else { + if (b.ValueInt() == 0LL) throw TypedValueException("Mod with zero"); + return TypedValue(a.ValueInt() % b.ValueInt(), a.GetMemoryResource()); + } +} + +inline void EnsureLogicallyOk(const TypedValue &a, const TypedValue &b, const std::string &op_name) { + if (!((a.IsBool() || a.IsNull()) && (b.IsBool() || b.IsNull()))) + throw TypedValueException("Invalid {} operand types({} && {})", op_name, a.type(), b.type()); +} + +TypedValue operator&&(const TypedValue &a, const TypedValue &b) { + EnsureLogicallyOk(a, b, "logical AND"); + // at this point we only have null and bool + // if either operand is false, the result is false + if (a.IsBool() && !a.ValueBool()) return TypedValue(false, a.GetMemoryResource()); + if (b.IsBool() && !b.ValueBool()) return TypedValue(false, a.GetMemoryResource()); + if (a.IsNull() || b.IsNull()) return TypedValue(a.GetMemoryResource()); + // neither is false, neither is null, thus both are true + return TypedValue(true, a.GetMemoryResource()); +} + +TypedValue operator||(const TypedValue &a, const TypedValue &b) { + EnsureLogicallyOk(a, b, "logical OR"); + // at this point we only have null and bool + // if either operand is true, the result is true + if (a.IsBool() && a.ValueBool()) return TypedValue(true, a.GetMemoryResource()); + if (b.IsBool() && b.ValueBool()) return TypedValue(true, a.GetMemoryResource()); + if (a.IsNull() || b.IsNull()) return TypedValue(a.GetMemoryResource()); + // neither is true, neither is null, thus both are false + return TypedValue(false, a.GetMemoryResource()); +} + +TypedValue operator^(const TypedValue &a, const TypedValue &b) { + EnsureLogicallyOk(a, b, "logical XOR"); + // at this point we only have null and bool + if (a.IsNull() || b.IsNull()) + return TypedValue(a.GetMemoryResource()); + else + return TypedValue(static_cast<bool>(a.ValueBool() ^ b.ValueBool()), a.GetMemoryResource()); +} + +bool TypedValue::BoolEqual::operator()(const TypedValue &lhs, const TypedValue &rhs) const { + if (lhs.IsNull() && rhs.IsNull()) return true; + TypedValue equality_result = lhs == rhs; + switch (equality_result.type()) { + case TypedValue::Type::Bool: + return equality_result.ValueBool(); + case TypedValue::Type::Null: + return false; + default: + LOG_FATAL( + "Equality between two TypedValues resulted in something other " + "than Null or bool"); + } +} + +size_t TypedValue::Hash::operator()(const TypedValue &value) const { + switch (value.type()) { + case TypedValue::Type::Null: + return 31; + case TypedValue::Type::Bool: + return std::hash<bool>{}(value.ValueBool()); + case TypedValue::Type::Int: + // we cast int to double for hashing purposes + // to be consistent with TypedValue equality + // in which (2.0 == 2) returns true + return std::hash<double>{}((double)value.ValueInt()); + case TypedValue::Type::Double: + return std::hash<double>{}(value.ValueDouble()); + case TypedValue::Type::String: + return std::hash<std::string_view>{}(value.ValueString()); + case TypedValue::Type::List: { + return utils::FnvCollection<TypedValue::TVector, TypedValue, Hash>{}(value.ValueList()); + } + case TypedValue::Type::Map: { + size_t hash = 6543457; + for (const auto &kv : value.ValueMap()) { + hash ^= std::hash<std::string_view>{}(kv.first); + hash ^= this->operator()(kv.second); + } + return hash; + } + case TypedValue::Type::Vertex: + return value.ValueVertex().Gid().AsUint(); + case TypedValue::Type::Edge: + return value.ValueEdge().Gid().AsUint(); + case TypedValue::Type::Path: { + const auto &vertices = value.ValuePath().vertices(); + const auto &edges = value.ValuePath().edges(); + return utils::FnvCollection<decltype(vertices), VertexAccessor>{}(vertices) ^ + utils::FnvCollection<decltype(edges), EdgeAccessor>{}(edges); + } + case TypedValue::Type::Date: + return utils::DateHash{}(value.ValueDate()); + case TypedValue::Type::LocalTime: + return utils::LocalTimeHash{}(value.ValueLocalTime()); + case TypedValue::Type::LocalDateTime: + return utils::LocalDateTimeHash{}(value.ValueLocalDateTime()); + case TypedValue::Type::Duration: + return utils::DurationHash{}(value.ValueDuration()); + break; + } + LOG_FATAL("Unhandled TypedValue.type() in hash function"); +} + +} // namespace memgraph::query::v2 diff --git a/src/query/v2/typed_value.hpp b/src/query/v2/typed_value.hpp new file mode 100644 index 000000000..0bd8e9695 --- /dev/null +++ b/src/query/v2/typed_value.hpp @@ -0,0 +1,739 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <cstdint> +#include <iostream> +#include <map> +#include <memory> +#include <string> +#include <string_view> +#include <utility> +#include <vector> + +#include "query/v2/db_accessor.hpp" +#include "query/v2/path.hpp" +#include "utils/exceptions.hpp" +#include "utils/memory.hpp" +#include "utils/pmr/map.hpp" +#include "utils/pmr/string.hpp" +#include "utils/pmr/vector.hpp" +#include "utils/temporal.hpp" + +namespace memgraph::query::v2 { + +// TODO: Neo4j does overflow checking. Should we also implement it? +/** + * Stores a query runtime value and its type. + * + * Values can be of a number of predefined types that are enumerated in + * TypedValue::Type. Each such type corresponds to exactly one C++ type. + * + * Non-primitive value types perform additional memory allocations. To tune the + * allocation scheme, each TypedValue stores a MemoryResource for said + * allocations. When copying and moving TypedValue instances, take care that the + * appropriate MemoryResource is used. + */ +class TypedValue { + public: + /** Custom TypedValue equality function that returns a bool + * (as opposed to returning TypedValue as the default equality does). + * This implementation treats two nulls as being equal and null + * not being equal to everything else. + */ + struct BoolEqual { + bool operator()(const TypedValue &left, const TypedValue &right) const; + }; + + /** Hash operator for TypedValue. + * + * Not injecting into std + * due to linking problems. If the implementation is in this header, + * then it implicitly instantiates TypedValue::Value<T> before + * explicit instantiation in .cpp file. If the implementation is in + * the .cpp file, it won't link. + * TODO: No longer the case as Value<T> was removed. + */ + struct Hash { + size_t operator()(const TypedValue &value) const; + }; + + /** A value type. Each type corresponds to exactly one C++ type */ + enum class Type : unsigned { + Null, + Bool, + Int, + Double, + String, + List, + Map, + Vertex, + Edge, + Path, + Date, + LocalTime, + LocalDateTime, + Duration + }; + + // TypedValue at this exact moment of compilation is an incomplete type, and + // the standard says that instantiating a container with an incomplete type + // invokes undefined behaviour. The libstdc++-8.3.0 we are using supports + // std::map with incomplete type, but this is still murky territory. Note that + // since C++17, std::vector is explicitly said to support incomplete types. + + using TString = utils::pmr::string; + using TVector = utils::pmr::vector<TypedValue>; + using TMap = utils::pmr::map<utils::pmr::string, TypedValue>; + + /** Allocator type so that STL containers are aware that we need one */ + using allocator_type = utils::Allocator<TypedValue>; + + /** Construct a Null value with default utils::NewDeleteResource(). */ + TypedValue() : type_(Type::Null) {} + + /** Construct a Null value with given utils::MemoryResource. */ + explicit TypedValue(utils::MemoryResource *memory) : memory_(memory), type_(Type::Null) {} + + /** + * Construct a copy of other. + * utils::MemoryResource is obtained by calling + * std::allocator_traits<>::select_on_container_copy_construction(other.memory_). + * Since we use utils::Allocator, which does not propagate, this means that + * memory_ will be the default utils::NewDeleteResource(). + */ + TypedValue(const TypedValue &other); + + /** Construct a copy using the given utils::MemoryResource */ + TypedValue(const TypedValue &other, utils::MemoryResource *memory); + + /** + * Construct with the value of other. + * utils::MemoryResource is obtained from other. After the move, other will be + * set to Null. + */ + TypedValue(TypedValue &&other) noexcept; + + /** + * Construct with the value of other, but use the given utils::MemoryResource. + * After the move, other will be set to Null. + * If `*memory != *other.GetMemoryResource()`, then a copy is made instead of + * a move. + */ + TypedValue(TypedValue &&other, utils::MemoryResource *memory); + + explicit TypedValue(bool value, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::Bool) { + bool_v = value; + } + + explicit TypedValue(int value, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::Int) { + int_v = value; + } + + explicit TypedValue(int64_t value, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::Int) { + int_v = value; + } + + explicit TypedValue(double value, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::Double) { + double_v = value; + } + + explicit TypedValue(const utils::Date &value, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::Date) { + date_v = value; + } + + explicit TypedValue(const utils::LocalTime &value, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::LocalTime) { + local_time_v = value; + } + + explicit TypedValue(const utils::LocalDateTime &value, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::LocalDateTime) { + local_date_time_v = value; + } + + explicit TypedValue(const utils::Duration &value, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::Duration) { + duration_v = value; + } + + // conversion function to storage::v3::PropertyValue + explicit operator storage::v3::PropertyValue() const; + + // copy constructors for non-primitive types + explicit TypedValue(const std::string &value, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::String) { + new (&string_v) TString(value, memory_); + } + + explicit TypedValue(const char *value, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::String) { + new (&string_v) TString(value, memory_); + } + + explicit TypedValue(const std::string_view value, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::String) { + new (&string_v) TString(value, memory_); + } + + /** + * Construct a copy of other. + * utils::MemoryResource is obtained by calling + * std::allocator_traits<>:: + * select_on_container_copy_construction(other.get_allocator()). + * Since we use utils::Allocator, which does not propagate, this means that + * memory_ will be the default utils::NewDeleteResource(). + */ + explicit TypedValue(const TString &other) + : TypedValue(other, std::allocator_traits<utils::Allocator<TypedValue>>::select_on_container_copy_construction( + other.get_allocator()) + .GetMemoryResource()) {} + + /** Construct a copy using the given utils::MemoryResource */ + TypedValue(const TString &other, utils::MemoryResource *memory) : memory_(memory), type_(Type::String) { + new (&string_v) TString(other, memory_); + } + + /** Construct a copy using the given utils::MemoryResource */ + explicit TypedValue(const std::vector<TypedValue> &value, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::List) { + new (&list_v) TVector(memory_); + list_v.reserve(value.size()); + list_v.assign(value.begin(), value.end()); + } + + /** + * Construct a copy of other. + * utils::MemoryResource is obtained by calling + * std::allocator_traits<>:: + * select_on_container_copy_construction(other.get_allocator()). + * Since we use utils::Allocator, which does not propagate, this means that + * memory_ will be the default utils::NewDeleteResource(). + */ + explicit TypedValue(const TVector &other) + : TypedValue(other, std::allocator_traits<utils::Allocator<TypedValue>>::select_on_container_copy_construction( + other.get_allocator()) + .GetMemoryResource()) {} + + /** Construct a copy using the given utils::MemoryResource */ + TypedValue(const TVector &value, utils::MemoryResource *memory) : memory_(memory), type_(Type::List) { + new (&list_v) TVector(value, memory_); + } + + /** Construct a copy using the given utils::MemoryResource */ + explicit TypedValue(const std::map<std::string, TypedValue> &value, + utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::Map) { + new (&map_v) TMap(memory_); + for (const auto &kv : value) map_v.emplace(kv.first, kv.second); + } + + /** + * Construct a copy of other. + * utils::MemoryResource is obtained by calling + * std::allocator_traits<>:: + * select_on_container_copy_construction(other.get_allocator()). + * Since we use utils::Allocator, which does not propagate, this means that + * memory_ will be the default utils::NewDeleteResource(). + */ + explicit TypedValue(const TMap &other) + : TypedValue(other, std::allocator_traits<utils::Allocator<TypedValue>>::select_on_container_copy_construction( + other.get_allocator()) + .GetMemoryResource()) {} + + /** Construct a copy using the given utils::MemoryResource */ + TypedValue(const TMap &value, utils::MemoryResource *memory) : memory_(memory), type_(Type::Map) { + new (&map_v) TMap(value, memory_); + } + + explicit TypedValue(const VertexAccessor &vertex, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::Vertex) { + new (&vertex_v) VertexAccessor(vertex); + } + + explicit TypedValue(const EdgeAccessor &edge, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::Edge) { + new (&edge_v) EdgeAccessor(edge); + } + + explicit TypedValue(const Path &path, utils::MemoryResource *memory = utils::NewDeleteResource()) + : memory_(memory), type_(Type::Path) { + new (&path_v) Path(path, memory_); + } + + /** Construct a copy using default utils::NewDeleteResource() */ + explicit TypedValue(const storage::v3::PropertyValue &value); + + /** Construct a copy using the given utils::MemoryResource */ + TypedValue(const storage::v3::PropertyValue &value, utils::MemoryResource *memory); + + // move constructors for non-primitive types + + /** + * Construct with the value of other. + * utils::MemoryResource is obtained from other. After the move, other will be + * left in unspecified state. + */ + explicit TypedValue(TString &&other) noexcept + : TypedValue(std::move(other), other.get_allocator().GetMemoryResource()) {} + + /** + * Construct with the value of other and use the given MemoryResource + * After the move, other will be left in unspecified state. + */ + TypedValue(TString &&other, utils::MemoryResource *memory) : memory_(memory), type_(Type::String) { + new (&string_v) TString(std::move(other), memory_); + } + + /** + * Perform an element-wise move using default utils::NewDeleteResource(). + * Other will be not be empty, though elements may be Null. + */ + explicit TypedValue(std::vector<TypedValue> &&other) : TypedValue(std::move(other), utils::NewDeleteResource()) {} + + /** + * Perform an element-wise move of the other and use the given MemoryResource. + * Other will be not be left empty, though elements may be Null. + */ + TypedValue(std::vector<TypedValue> &&other, utils::MemoryResource *memory) : memory_(memory), type_(Type::List) { + new (&list_v) TVector(memory_); + list_v.reserve(other.size()); + // std::vector<TypedValue> has std::allocator and there's no move + // constructor for std::vector using different allocator types. Since + // std::allocator is not propagated to elements, it is possible that some + // TypedValue elements have a MemoryResource that is the same as the one we + // are given. In such a case we would like to move those TypedValue + // instances, so we use move_iterator. + list_v.assign(std::make_move_iterator(other.begin()), std::make_move_iterator(other.end())); + } + + /** + * Construct with the value of other. + * utils::MemoryResource is obtained from other. After the move, other will be + * left empty. + */ + explicit TypedValue(TVector &&other) noexcept + : TypedValue(std::move(other), other.get_allocator().GetMemoryResource()) {} + + /** + * Construct with the value of other and use the given MemoryResource. + * If `other.get_allocator() != *memory`, this call will perform an + * element-wise move and other is not guaranteed to be empty. + */ + TypedValue(TVector &&other, utils::MemoryResource *memory) : memory_(memory), type_(Type::List) { + new (&list_v) TVector(std::move(other), memory_); + } + + /** + * Perform an element-wise move using default utils::NewDeleteResource(). + * Other will not be left empty, i.e. keys will exist but their values may + * be Null. + */ + explicit TypedValue(std::map<std::string, TypedValue> &&other) + : TypedValue(std::move(other), utils::NewDeleteResource()) {} + + /** + * Perform an element-wise move using the given MemoryResource. + * Other will not be left empty, i.e. keys will exist but their values may + * be Null. + */ + TypedValue(std::map<std::string, TypedValue> &&other, utils::MemoryResource *memory) + : memory_(memory), type_(Type::Map) { + new (&map_v) TMap(memory_); + for (auto &kv : other) map_v.emplace(kv.first, std::move(kv.second)); + } + + /** + * Construct with the value of other. + * utils::MemoryResource is obtained from other. After the move, other will be + * left empty. + */ + explicit TypedValue(TMap &&other) noexcept + : TypedValue(std::move(other), other.get_allocator().GetMemoryResource()) {} + + /** + * Construct with the value of other and use the given MemoryResource. + * If `other.get_allocator() != *memory`, this call will perform an + * element-wise move and other is not guaranteed to be empty, i.e. keys may + * exist but their values may be Null. + */ + TypedValue(TMap &&other, utils::MemoryResource *memory) : memory_(memory), type_(Type::Map) { + new (&map_v) TMap(std::move(other), memory_); + } + + explicit TypedValue(VertexAccessor &&vertex, utils::MemoryResource *memory = utils::NewDeleteResource()) noexcept + : memory_(memory), type_(Type::Vertex) { + new (&vertex_v) VertexAccessor(std::move(vertex)); + } + + explicit TypedValue(EdgeAccessor &&edge, utils::MemoryResource *memory = utils::NewDeleteResource()) noexcept + : memory_(memory), type_(Type::Edge) { + new (&edge_v) EdgeAccessor(std::move(edge)); + } + + /** + * Construct with the value of path. + * utils::MemoryResource is obtained from path. After the move, path will be + * left empty. + */ + explicit TypedValue(Path &&path) noexcept : TypedValue(std::move(path), path.GetMemoryResource()) {} + + /** + * Construct with the value of path and use the given MemoryResource. + * If `*path.GetMemoryResource() != *memory`, this call will perform an + * element-wise move and path is not guaranteed to be empty. + */ + TypedValue(Path &&path, utils::MemoryResource *memory) : memory_(memory), type_(Type::Path) { + new (&path_v) Path(std::move(path), memory_); + } + + /** + * Construct with the value of other. + * Default utils::NewDeleteResource() is used for allocations. After the move, + * other will be set to Null. + */ + explicit TypedValue(storage::v3::PropertyValue &&other); + + /** + * Construct with the value of other, but use the given utils::MemoryResource. + * After the move, other will be set to Null. + */ + TypedValue(storage::v3::PropertyValue &&other, utils::MemoryResource *memory); + + // copy assignment operators + TypedValue &operator=(const char *); + TypedValue &operator=(int); + TypedValue &operator=(bool); + TypedValue &operator=(int64_t); + TypedValue &operator=(double); + TypedValue &operator=(std::string_view); + TypedValue &operator=(const TVector &); + TypedValue &operator=(const std::vector<TypedValue> &); + TypedValue &operator=(const TMap &); + TypedValue &operator=(const std::map<std::string, TypedValue> &); + TypedValue &operator=(const VertexAccessor &); + TypedValue &operator=(const EdgeAccessor &); + TypedValue &operator=(const Path &); + TypedValue &operator=(const utils::Date &); + TypedValue &operator=(const utils::LocalTime &); + TypedValue &operator=(const utils::LocalDateTime &); + TypedValue &operator=(const utils::Duration &); + + /** Copy assign other, utils::MemoryResource of `this` is used */ + TypedValue &operator=(const TypedValue &other); + + /** Move assign other, utils::MemoryResource of `this` is used. */ + TypedValue &operator=(TypedValue &&other) noexcept(false); + + // move assignment operators + TypedValue &operator=(TString &&); + TypedValue &operator=(TVector &&); + TypedValue &operator=(std::vector<TypedValue> &&); + TypedValue &operator=(TMap &&); + TypedValue &operator=(std::map<std::string, TypedValue> &&); + TypedValue &operator=(Path &&); + + ~TypedValue(); + + Type type() const { return type_; } + + // TODO consider adding getters for primitives by value (and not by ref) + +#define DECLARE_VALUE_AND_TYPE_GETTERS(type_param, field) \ + /** Gets the value of type field. Throws if value is not field*/ \ + type_param &Value##field(); \ + /** Gets the value of type field. Throws if value is not field*/ \ + const type_param &Value##field() const; \ + /** Checks if it's the value is of the given type */ \ + bool Is##field() const; + + DECLARE_VALUE_AND_TYPE_GETTERS(bool, Bool) + DECLARE_VALUE_AND_TYPE_GETTERS(int64_t, Int) + DECLARE_VALUE_AND_TYPE_GETTERS(double, Double) + DECLARE_VALUE_AND_TYPE_GETTERS(TString, String) + + /** + * Get the list value. + * @throw TypedValueException if stored value is not a list. + */ + TVector &ValueList(); + + const TVector &ValueList() const; + + /** Check if the stored value is a list value */ + bool IsList() const; + + DECLARE_VALUE_AND_TYPE_GETTERS(TMap, Map) + DECLARE_VALUE_AND_TYPE_GETTERS(VertexAccessor, Vertex) + DECLARE_VALUE_AND_TYPE_GETTERS(EdgeAccessor, Edge) + DECLARE_VALUE_AND_TYPE_GETTERS(Path, Path) + + DECLARE_VALUE_AND_TYPE_GETTERS(utils::Date, Date) + DECLARE_VALUE_AND_TYPE_GETTERS(utils::LocalTime, LocalTime) + DECLARE_VALUE_AND_TYPE_GETTERS(utils::LocalDateTime, LocalDateTime) + DECLARE_VALUE_AND_TYPE_GETTERS(utils::Duration, Duration) + +#undef DECLARE_VALUE_AND_TYPE_GETTERS + + /** Checks if value is a TypedValue::Null. */ + bool IsNull() const; + + /** Convenience function for checking if this TypedValue is either + * an integer or double */ + bool IsNumeric() const; + + /** Convenience function for checking if this TypedValue can be converted into + * storage::v3::PropertyValue */ + bool IsPropertyValue() const; + + utils::MemoryResource *GetMemoryResource() const { return memory_; } + + private: + void DestroyValue(); + + // Memory resource for allocations of non primitive values + utils::MemoryResource *memory_{utils::NewDeleteResource()}; + + // storage for the value of the property + union { + bool bool_v; + int64_t int_v; + double double_v; + // Since this is used in query runtime, size of union is not critical so + // string and vector are used instead of pointers. It requires copy of data, + // but most of algorithms (concatenations, serialisation...) has linear time + // complexity so it shouldn't be a problem. This is maybe even faster + // because of data locality. + TString string_v; + TVector list_v; + TMap map_v; + VertexAccessor vertex_v; + EdgeAccessor edge_v; + Path path_v; + utils::Date date_v; + utils::LocalTime local_time_v; + utils::LocalDateTime local_date_time_v; + utils::Duration duration_v; + }; + + /** + * The Type of property. + */ + Type type_; +}; + +/** + * An exception raised by the TypedValue system. Typically when + * trying to perform operations (such as addition) on TypedValues + * of incompatible Types. + */ +class TypedValueException : public utils::BasicException { + public: + using utils::BasicException::BasicException; +}; + +// binary bool operators + +/** + * Perform logical 'and' on TypedValues. + * + * If any of the values is false, return false. Otherwise checks if any value is + * Null and return Null. All other cases return true. The resulting value uses + * the same MemoryResource as the left hand side arguments. + * + * @throw TypedValueException if arguments are not boolean or Null. + */ +TypedValue operator&&(const TypedValue &a, const TypedValue &b); + +/** + * Perform logical 'or' on TypedValues. + * + * If any of the values is true, return true. Otherwise checks if any value is + * Null and return Null. All other cases return false. The resulting value uses + * the same MemoryResource as the left hand side arguments. + * + * @throw TypedValueException if arguments are not boolean or Null. + */ +TypedValue operator||(const TypedValue &a, const TypedValue &b); + +/** + * Logically negate a TypedValue. + * + * Negating Null value returns Null. Values other than null raise an exception. + * The resulting value uses the same MemoryResource as the argument. + * + * @throw TypedValueException if TypedValue is not a boolean or Null. + */ +TypedValue operator!(const TypedValue &a); + +// binary bool xor, not power operator +// Be careful: since ^ is binary operator and || and && are logical operators +// they have different priority in c++. +TypedValue operator^(const TypedValue &a, const TypedValue &b); + +// comparison operators + +/** + * Compare TypedValues and return true, false or Null. + * + * Null is returned if either of the two values is Null. + * Since each TypedValue may have a different MemoryResource for allocations, + * the results is allocated using MemoryResource obtained from the left hand + * side. + */ +TypedValue operator==(const TypedValue &a, const TypedValue &b); + +/** + * Compare TypedValues and return true, false or Null. + * + * Null is returned if either of the two values is Null. + * Since each TypedValue may have a different MemoryResource for allocations, + * the results is allocated using MemoryResource obtained from the left hand + * side. + */ +inline TypedValue operator!=(const TypedValue &a, const TypedValue &b) { return !(a == b); } + +/** + * Compare TypedValues and return true, false or Null. + * + * Null is returned if either of the two values is Null. + * The resulting value uses the same MemoryResource as the left hand side + * argument. + * + * @throw TypedValueException if the values cannot be compared, i.e. they are + * not either Null, numeric or a character string type. + */ +TypedValue operator<(const TypedValue &a, const TypedValue &b); + +/** + * Compare TypedValues and return true, false or Null. + * + * Null is returned if either of the two values is Null. + * The resulting value uses the same MemoryResource as the left hand side + * argument. + * + * @throw TypedValueException if the values cannot be compared, i.e. they are + * not either Null, numeric or a character string type. + */ +inline TypedValue operator<=(const TypedValue &a, const TypedValue &b) { return a < b || a == b; } + +/** + * Compare TypedValues and return true, false or Null. + * + * Null is returned if either of the two values is Null. + * The resulting value uses the same MemoryResource as the left hand side + * argument. + * + * @throw TypedValueException if the values cannot be compared, i.e. they are + * not either Null, numeric or a character string type. + */ +inline TypedValue operator>(const TypedValue &a, const TypedValue &b) { return !(a <= b); } + +/** + * Compare TypedValues and return true, false or Null. + * + * Null is returned if either of the two values is Null. + * The resulting value uses the same MemoryResource as the left hand side + * argument. + * + * @throw TypedValueException if the values cannot be compared, i.e. they are + * not either Null, numeric or a character string type. + */ +inline TypedValue operator>=(const TypedValue &a, const TypedValue &b) { return !(a < b); } + +// arithmetic operators + +/** + * Arithmetically negate a value. + * + * If the value is Null, then Null is returned. + * The resulting value uses the same MemoryResource as the argument. + * + * @throw TypedValueException if the value is not numeric or Null. + */ +TypedValue operator-(const TypedValue &a); + +/** + * Apply the unary plus operator to a value. + * + * If the value is Null, then Null is returned. + * The resulting value uses the same MemoryResource as the argument. + * + * @throw TypedValueException if the value is not numeric or Null. + */ +TypedValue operator+(const TypedValue &a); + +/** + * Perform addition or concatenation on two values. + * + * Numeric values are summed, while lists and character strings are + * concatenated. If either value is Null, then Null is returned. The resulting + * value uses the same MemoryResource as the left hand side argument. + * + * @throw TypedValueException if values cannot be summed or concatenated. + */ +TypedValue operator+(const TypedValue &a, const TypedValue &b); + +/** + * Subtract two values. + * + * If any of the values is Null, then Null is returned. + * The resulting value uses the same MemoryResource as the left hand side + * argument. + * + * @throw TypedValueException if the values are not numeric or Null. + */ +TypedValue operator-(const TypedValue &a, const TypedValue &b); + +/** + * Divide two values. + * + * If any of the values is Null, then Null is returned. + * The resulting value uses the same MemoryResource as the left hand side + * argument. + * + * @throw TypedValueException if the values are not numeric or Null, or if + * dividing two integer values by zero. + */ +TypedValue operator/(const TypedValue &a, const TypedValue &b); + +/** + * Multiply two values. + * + * If any of the values is Null, then Null is returned. + * The resulting value uses the same MemoryResource as the left hand side + * argument. + * + * @throw TypedValueException if the values are not numeric or Null. + */ +TypedValue operator*(const TypedValue &a, const TypedValue &b); + +/** + * Perform modulo operation on two values. + * + * If any of the values is Null, then Null is returned. + * The resulting value uses the same MemoryResource as the left hand side + * argument. + * + * @throw TypedValueException if the values are not numeric or Null. + */ +TypedValue operator%(const TypedValue &a, const TypedValue &b); + +/** Output the TypedValue::Type value as a string */ +std::ostream &operator<<(std::ostream &os, const TypedValue::Type &type); + +} // namespace memgraph::query::v2