diff --git a/.github/workflows/diff.yaml b/.github/workflows/diff.yaml index 8b47b3bb3..a7faa7d22 100644 --- a/.github/workflows/diff.yaml +++ b/.github/workflows/diff.yaml @@ -99,7 +99,7 @@ jobs: echo ${file} if [[ ${file} == *.py ]]; then python3 -m black --check --diff ${file} - python3 -m isort --check-only --diff ${file} + python3 -m isort --check-only --profile "black" --diff ${file} fi done diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5abd746c9..26b7c8e05 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,6 +14,7 @@ repos: hooks: - id: isort name: isort (python) + args: ["--profile", "black"] - repo: https://github.com/pre-commit/mirrors-clang-format rev: v13.0.0 hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index 916389424..36e648e8b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -182,7 +182,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) # c99-designator is disabled because of required mixture of designated and # non-designated initializers in Python Query Module code (`py_module.cpp`). set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall \ - -Werror=switch -Werror=switch-bool -Werror=return-type \ + -Werror=switch -Werror=switch-bool -Werror=implicit-fallthrough \ + -Werror=return-type \ -Werror=return-stack-address \ -Wno-c99-designator \ -DBOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8b84bb26a..81fc61836 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(auth) add_subdirectory(parser) add_subdirectory(expr) add_subdirectory(coordinator) +add_subdirectory(functions) if (MG_ENTERPRISE) add_subdirectory(audit) diff --git a/src/communication/bolt/v1/states/executing.hpp b/src/communication/bolt/v1/states/executing.hpp index 54504985b..de4b2a00e 100644 --- a/src/communication/bolt/v1/states/executing.hpp +++ b/src/communication/bolt/v1/states/executing.hpp @@ -74,7 +74,7 @@ State RunHandlerV4(Signature signature, TSession &session, State state, Marker m } case Signature::Route: { if constexpr (bolt_minor >= 3) { - if (signature == Signature::Route) return HandleRoute(session); + return HandleRoute(session); } else { spdlog::trace("Supported only in bolt v4.3"); return State::Close; diff --git a/src/expr/CMakeLists.txt b/src/expr/CMakeLists.txt index e529512b7..31cbfa493 100644 --- a/src/expr/CMakeLists.txt +++ b/src/expr/CMakeLists.txt @@ -17,4 +17,4 @@ target_include_directories(mg-expr PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(mg-expr PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/ast) target_include_directories(mg-expr PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/interpret) target_include_directories(mg-expr PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/semantic) -target_link_libraries(mg-expr cppitertools Boost::headers mg-utils mg-parser) +target_link_libraries(mg-expr cppitertools Boost::headers mg-utils mg-parser mg-functions) diff --git a/src/expr/interpret/eval.hpp b/src/expr/interpret/eval.hpp index a7d027ede..8fb300a83 100644 --- a/src/expr/interpret/eval.hpp +++ b/src/expr/interpret/eval.hpp @@ -24,6 +24,7 @@ #include "expr/exceptions.hpp" #include "expr/interpret/frame.hpp" #include "expr/semantic/symbol_table.hpp" +#include "functions/awesome_memgraph_functions.hpp" #include "utils/exceptions.hpp" namespace memgraph::expr { @@ -427,8 +428,7 @@ class ExpressionEvaluator : public ExpressionVisitor { typename TReturnType = std::enable_if_t, bool>> TReturnType HasLabelImpl(const VertexAccessor &vertex, const LabelIx &label_ix, QueryEngineTag /*tag*/) { auto label = typename VertexAccessor::Label{LabelId::FromUint(label_ix.ix)}; - auto has_label = vertex.HasLabel(label); - return !has_label; + return vertex.HasLabel(label); } TypedValue Visit(LabelsTest &labels_test) override { @@ -491,7 +491,7 @@ class ExpressionEvaluator : public ExpressionVisitor { } TypedValue Visit(Function &function) override { - FunctionContext function_ctx{dba_, ctx_->memory, ctx_->timestamp, &ctx_->counters, view_}; + functions::FunctionContext function_ctx{dba_, ctx_->memory, ctx_->timestamp, &ctx_->counters, view_}; // Stack allocate evaluated arguments when there's a small number of them. if (function.arguments_.size() <= 8) { TypedValue arguments[8] = {TypedValue(ctx_->memory), TypedValue(ctx_->memory), TypedValue(ctx_->memory), diff --git a/src/functions/CMakeLists.txt b/src/functions/CMakeLists.txt new file mode 100644 index 000000000..3a3d430cd --- /dev/null +++ b/src/functions/CMakeLists.txt @@ -0,0 +1 @@ +add_library(mg-functions INTERFACE) diff --git a/src/functions/awesome_memgraph_functions.hpp b/src/functions/awesome_memgraph_functions.hpp new file mode 100644 index 000000000..7e716d970 --- /dev/null +++ b/src/functions/awesome_memgraph_functions.hpp @@ -0,0 +1,1423 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include + +#include "storage/v3/result.hpp" +#include "storage/v3/shard.hpp" +#include "storage/v3/view.hpp" +#include "utils/algorithm.hpp" +#include "utils/cast.hpp" +#include "utils/concepts.hpp" +#include "utils/memory.hpp" +#include "utils/pmr/string.hpp" +#include "utils/string.hpp" +#include "utils/temporal.hpp" + +namespace memgraph::functions { + +class FunctionRuntimeException : public utils::BasicException { + using utils::BasicException::BasicException; +}; + +template +struct FunctionContext { + TAccessor *db_accessor; + utils::MemoryResource *memory; + int64_t timestamp; + std::unordered_map *counters; + storage::v3::View view; +}; + +// Tags for the NameToFunction() function template +struct StorageEngineTag {}; +struct QueryEngineTag {}; + +/// Return the function implementation with the given name. +/// +/// Note, returned function signature uses C-style access to an array to allow +/// having an array stored anywhere the caller likes, as long as it is +/// contiguous in memory. Since most functions don't take many arguments, it's +/// convenient to have them stored in the calling stack frame. +template +std::function +NameToFunction(const std::string &function_name); + +inline constexpr char kStartsWith[] = "STARTSWITH"; +inline constexpr char kEndsWith[] = "ENDSWITH"; +inline constexpr char kContains[] = "CONTAINS"; +inline constexpr char kId[] = "ID"; + +} // namespace memgraph::functions + +namespace memgraph::functions::impl { + +//////////////////////////////////////////////////////////////////////////////// +// eDSL using template magic for describing a type of an awesome memgraph +// function and checking if the passed in arguments match the description. +// +// To use the type checking eDSL, you should put a `FType` invocation in the +// body of your awesome Memgraph function. `FType` takes type arguments as the +// description of the function type signature. Each runtime argument will be +// checked in order corresponding to given compile time type arguments. These +// type arguments can come in two forms: +// +// * final, primitive type descriptor and +// * combinator type descriptor. +// +// The primitive type descriptors are defined as empty structs, they are right +// below this documentation. +// +// Combinator type descriptors are defined as structs taking additional type +// parameters, you can find these further below in the implementation. Of +// primary interest are `Or` and `Optional` type combinators. +// +// With `Or` you can describe that an argument can be any of the types listed in +// `Or`. For example, `Or` allows an argument to be either +// `Null` or a boolean or an integer. +// +// The `Optional` combinator is used to define optional arguments to a function. +// These must come as the last positional arguments. Naturally, you can use `Or` +// inside `Optional`. So for example, `Optional, Integer>` +// describes that a function takes 2 optional arguments. The 1st one must be +// either a `Null` or a boolean, while the 2nd one must be an integer. The type +// signature check will succeed in the following cases. +// +// * No optional arguments were supplied. +// * One argument was supplied and it passes `Or` check. +// * Two arguments were supplied, the 1st one passes `Or` check +// and the 2nd one passes `Integer` check. +// +// Runtime arguments to `FType` are: function name, pointer to arguments and the +// number of received arguments. +// +// Full example. +// +// FType, NonNegativeInteger, +// Optional>("substring", args, nargs); +// +// The above will check that `substring` function received the 2 required +// arguments. Optionally, the function may take a 3rd argument. The 1st argument +// must be either a `Null` or a character string. The 2nd argument is required +// to be a non-negative integer. If the 3rd argument was supplied, it will also +// be checked that it is a non-negative integer. If any of these checks fail, +// `FType` will throw a `FunctionRuntimeException` with an appropriate error +// message. +//////////////////////////////////////////////////////////////////////////////// + +struct Null {}; +struct Bool {}; +struct Integer {}; +struct PositiveInteger {}; +struct NonZeroInteger {}; +struct NonNegativeInteger {}; +struct Double {}; +struct Number {}; +struct List {}; +struct String {}; +struct Map {}; +struct Edge {}; +struct Vertex {}; +struct Path {}; +struct Date {}; +struct LocalTime {}; +struct LocalDateTime {}; +struct Duration {}; + +template +bool ArgIsType(const TypedValueT &arg) { + if constexpr (std::is_same_v) { + return arg.IsNull(); + } else if constexpr (std::is_same_v) { + return arg.IsBool(); + } else if constexpr (std::is_same_v) { + return arg.IsInt(); + } else if constexpr (std::is_same_v) { + return arg.IsInt() && arg.ValueInt() > 0; + } else if constexpr (std::is_same_v) { + return arg.IsInt() && arg.ValueInt() != 0; + } else if constexpr (std::is_same_v) { + return arg.IsInt() && arg.ValueInt() >= 0; + } else if constexpr (std::is_same_v) { + return arg.IsDouble(); + } else if constexpr (std::is_same_v) { + return arg.IsNumeric(); + } else if constexpr (std::is_same_v) { + return arg.IsList(); + } else if constexpr (std::is_same_v) { + return arg.IsString(); + } else if constexpr (std::is_same_v) { + return arg.IsMap(); + } else if constexpr (std::is_same_v) { + return arg.IsVertex(); + } else if constexpr (std::is_same_v) { + return arg.IsEdge(); + } else if constexpr (std::is_same_v) { + return arg.IsPath(); + } else if constexpr (std::is_same_v) { + return arg.IsDate(); + } else if constexpr (std::is_same_v) { + return arg.IsLocalTime(); + } else if constexpr (std::is_same_v) { + return arg.IsLocalDateTime(); + } else if constexpr (std::is_same_v) { + return arg.IsDuration(); + } else if constexpr (std::is_same_v) { + return true; + } else { + static_assert(std::is_same_v, "Unknown ArgType"); + } + return false; +} + +template +constexpr const char *ArgTypeName() { + // The type names returned should be standardized openCypher type names. + // https://github.com/opencypher/openCypher/blob/master/docs/openCypher9.pdf + if constexpr (std::is_same_v) { + return "null"; + } else if constexpr (std::is_same_v) { + return "boolean"; + } else if constexpr (std::is_same_v) { + return "integer"; + } else if constexpr (std::is_same_v) { + return "positive integer"; + } else if constexpr (std::is_same_v) { + return "non-zero integer"; + } else if constexpr (std::is_same_v) { + return "non-negative integer"; + } else if constexpr (std::is_same_v) { + return "float"; + } else if constexpr (std::is_same_v) { + return "number"; + } else if constexpr (std::is_same_v) { + return "list"; + } else if constexpr (std::is_same_v) { + return "string"; + } else if constexpr (std::is_same_v) { + return "map"; + } else if constexpr (std::is_same_v) { + return "node"; + } else if constexpr (std::is_same_v) { + return "relationship"; + } else if constexpr (std::is_same_v) { + return "path"; + } else if constexpr (std::is_same_v) { + return "void"; + } else if constexpr (std::is_same_v) { + return "Date"; + } else if constexpr (std::is_same_v) { + return "LocalTime"; + } else if constexpr (std::is_same_v) { + return "LocalDateTime"; + } else if constexpr (std::is_same_v) { + return "Duration"; + } else { + static_assert(std::is_same_v, "Unknown ArgType"); + } + return ""; +} + +template +struct Or; + +template +struct Or { + template + static bool Check(const TypedValueT &arg) { + return ArgIsType(arg); + } + + static std::string TypeNames() { return ArgTypeName(); } +}; + +template +struct Or { + template + static bool Check(const TypedValueT &arg) { + if (ArgIsType(arg)) return true; + return Or::Check(arg); + } + + static std::string TypeNames() { + if constexpr (sizeof...(ArgTypes) > 1) { + return fmt::format("'{}', {}", ArgTypeName(), Or::TypeNames()); + } else { + return fmt::format("'{}' or '{}'", ArgTypeName(), Or::TypeNames()); + } + } +}; + +template +struct IsOrType { + static constexpr bool value = false; +}; + +template +struct IsOrType> { + static constexpr bool value = true; +}; + +template +struct Optional; + +template +struct Optional { + static constexpr size_t size = 1; + + template + static void Check(const char *name, const TypedValueT *args, int64_t nargs, int64_t pos) { + if (nargs == 0) return; + const TypedValueT &arg = args[0]; + if constexpr (IsOrType::value) { + if (!ArgType::Check(arg)) { + throw FunctionRuntimeException("Optional '{}' argument at position {} must be either {}.", name, pos, + ArgType::TypeNames()); + } + } else { + if (!ArgIsType(arg)) + throw FunctionRuntimeException("Optional '{}' argument at position {} must be '{}'.", name, pos, + ArgTypeName()); + } + } +}; + +template +struct Optional { + static constexpr size_t size = 1 + sizeof...(ArgTypes); + + template + static void Check(const char *name, const TypedValueT *args, int64_t nargs, int64_t pos) { + if (nargs == 0) return; + Optional::Check(name, args, nargs, pos); + Optional::Check(name, args + 1, nargs - 1, pos + 1); + } +}; + +template +struct IsOptional { + static constexpr bool value = false; +}; + +template +struct IsOptional> { + static constexpr bool value = true; +}; + +template +constexpr size_t FTypeRequiredArgs() { + if constexpr (IsOptional::value) { + static_assert(sizeof...(ArgTypes) == 0, "Optional arguments must be last!"); + return 0; + } else if constexpr (sizeof...(ArgTypes) == 0) { + return 1; + } else { + return 1U + FTypeRequiredArgs(); + } +} + +template +constexpr size_t FTypeOptionalArgs() { + if constexpr (IsOptional::value) { + static_assert(sizeof...(ArgTypes) == 0, "Optional arguments must be last!"); + return ArgType::size; + } else if constexpr (sizeof...(ArgTypes) == 0) { + return 0; + } else { + return FTypeOptionalArgs(); + } +} + +template +void FType(const char *name, const TypedValueT *args, int64_t nargs, int64_t pos = 1) { + if constexpr (std::is_same_v) { + if (nargs != 0) { + throw FunctionRuntimeException("'{}' requires no arguments.", name); + } + return; + } + static constexpr int64_t required_args = FTypeRequiredArgs(); + static constexpr int64_t optional_args = FTypeOptionalArgs(); + static constexpr int64_t total_args = required_args + optional_args; + if constexpr (optional_args > 0) { + if (nargs < required_args || nargs > total_args) { + throw FunctionRuntimeException("'{}' requires between {} and {} arguments.", name, required_args, total_args); + } + } else { + if (nargs != required_args) { + throw FunctionRuntimeException("'{}' requires exactly {} {}.", name, required_args, + required_args == 1 ? "argument" : "arguments"); + } + } + const TypedValueT &arg = args[0]; + if constexpr (IsOrType::value) { + if (!ArgType::Check(arg)) { + throw FunctionRuntimeException("'{}' argument at position {} must be either {}.", name, pos, + ArgType::TypeNames()); + } + } else if constexpr (IsOptional::value) { + static_assert(sizeof...(ArgTypes) == 0, "Optional arguments must be last!"); + ArgType::Check(name, args, nargs, pos); + } else { + if (!ArgIsType(arg)) { + throw FunctionRuntimeException("'{}' argument at position {} must be '{}'", name, pos, ArgTypeName()); + } + } + if constexpr (sizeof...(ArgTypes) > 0) { + FType(name, args + 1, nargs - 1, pos + 1); + } +} + +//////////////////////////////////////////////////////////////////////////////// +// END function type description eDSL +//////////////////////////////////////////////////////////////////////////////// + +// Predicate functions. +// Neo4j has all, any, exists, none, single +// Those functions are a little bit different since they take a filterExpression +// as an argument. +// There is all, any, none and single productions in opencypher grammar, but it +// will be trivial to also add exists. +// TODO: Implement this. + +// Scalar functions. +// We don't have a way to implement id function since we don't store any. If it +// is really neccessary we could probably map vlist* to id. +// TODO: Implement length (it works on a path, but we didn't define path +// structure yet). +// TODO: Implement size(pattern), for example size((a)-[:X]-()) should return +// number of results of this pattern. I don't think we will ever do this. +// TODO: Implement rest of the list functions. +// TODO: Implement degrees, haversin, radians +// TODO: Implement spatial functions + +template +TypedValueT EndNode(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("endNode", args, nargs); + if (args[0].IsNull()) { + return TypedValueT(ctx.memory); + } + return TypedValueT(args[0].ValueEdge().To(), ctx.memory); +} + +template +TypedValueT Head(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("head", args, nargs); + if (args[0].IsNull()) return TypedValueT(ctx.memory); + const auto &list = args[0].ValueList(); + if (list.empty()) return TypedValueT(ctx.memory); + return TypedValueT(list[0], ctx.memory); +} + +template +TypedValueT Last(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("last", args, nargs); + if (args[0].IsNull()) return TypedValueT(ctx.memory); + const auto &list = args[0].ValueList(); + if (list.empty()) return TypedValueT(ctx.memory); + return TypedValueT(list.back(), ctx.memory); +} + +template +TypedValueT Properties(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("properties", args, nargs); + auto *dba = ctx.db_accessor; + auto get_properties = [&](const auto &record_accessor) { + typename TypedValueT::TMap properties(ctx.memory); + Conv conv; + if constexpr (std::is_same_v) { + auto maybe_props = record_accessor.Properties(ctx.view); + if (maybe_props.HasError()) { + switch (maybe_props.GetError().code) { + case common::ErrorCode::DELETED_OBJECT: + throw functions::FunctionRuntimeException("Trying to get properties from a deleted object."); + case common::ErrorCode::NONEXISTENT_OBJECT: + throw functions::FunctionRuntimeException("Trying to get properties from an object that doesn't exist."); + case common::ErrorCode::SERIALIZATION_ERROR: + case common::ErrorCode::VERTEX_HAS_EDGES: + case common::ErrorCode::PROPERTIES_DISABLED: + case common::ErrorCode::VERTEX_ALREADY_INSERTED: + case common::ErrorCode::SCHEMA_NO_SCHEMA_DEFINED_FOR_LABEL: + case common::ErrorCode::SCHEMA_VERTEX_PROPERTY_WRONG_TYPE: + case common::ErrorCode::SCHEMA_VERTEX_UPDATE_PRIMARY_KEY: + case common::ErrorCode::SCHEMA_VERTEX_UPDATE_PRIMARY_LABEL: + case common::ErrorCode::SCHEMA_VERTEX_SECONDARY_LABEL_IS_PRIMARY: + case common::ErrorCode::SCHEMA_VERTEX_PRIMARY_PROPERTIES_UNDEFINED: + case common::ErrorCode::OBJECT_NOT_FOUND: + throw functions::FunctionRuntimeException("Unexpected error when getting properties."); + } + } + for (const auto &property : *maybe_props) { + properties.emplace(dba->PropertyToName(property.first), conv(property.second)); + } + } else { + for (const auto &property : record_accessor.Properties()) { + properties.emplace(utils::pmr::string(dba->PropertyToName(property.first), ctx.memory), + conv(property.second, dba)); + } + } + return TypedValueT(std::move(properties)); + }; + + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValueT(ctx.memory); + } + if (value.IsVertex()) { + return get_properties(value.ValueVertex()); + } + return get_properties(value.ValueEdge()); +} + +template +TypedValueT Size(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("size", args, nargs); + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValueT(ctx.memory); + } + if (value.IsList()) { + return TypedValueT(static_cast(value.ValueList().size()), ctx.memory); + } + if (value.IsString()) { + return TypedValueT(static_cast(value.ValueString().size()), ctx.memory); + } + if (value.IsMap()) { + // neo4j doesn't implement size for map, but I don't see a good reason not + // to do it. + return TypedValueT(static_cast(value.ValueMap().size()), ctx.memory); + } + return TypedValueT(static_cast(value.ValuePath().edges().size()), ctx.memory); +} + +template +TypedValueT StartNode(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("startNode", args, nargs); + if (args[0].IsNull()) { + return TypedValueT(ctx.memory); + } + return TypedValueT(args[0].ValueEdge().From(), ctx.memory); +} + +// This is needed because clang-tidy fails to identify the use of this function in the if-constexpr branch +// NOLINTNEXTLINE(clang-diagnostic-unused-function) +inline size_t UnwrapDegreeResult(storage::v3::ShardResult maybe_degree) { + if (maybe_degree.HasError()) { + switch (maybe_degree.GetError().code) { + case common::ErrorCode::DELETED_OBJECT: + throw functions::FunctionRuntimeException("Trying to get degree of a deleted node."); + case common::ErrorCode::NONEXISTENT_OBJECT: + throw functions::FunctionRuntimeException("Trying to get degree of a node that doesn't exist."); + case common::ErrorCode::SERIALIZATION_ERROR: + case common::ErrorCode::VERTEX_HAS_EDGES: + case common::ErrorCode::PROPERTIES_DISABLED: + case common::ErrorCode::VERTEX_ALREADY_INSERTED: + case common::ErrorCode::SCHEMA_NO_SCHEMA_DEFINED_FOR_LABEL: + case common::ErrorCode::SCHEMA_VERTEX_PROPERTY_WRONG_TYPE: + case common::ErrorCode::SCHEMA_VERTEX_UPDATE_PRIMARY_KEY: + case common::ErrorCode::SCHEMA_VERTEX_UPDATE_PRIMARY_LABEL: + case common::ErrorCode::SCHEMA_VERTEX_SECONDARY_LABEL_IS_PRIMARY: + case common::ErrorCode::SCHEMA_VERTEX_PRIMARY_PROPERTIES_UNDEFINED: + case common::ErrorCode::OBJECT_NOT_FOUND: + throw functions::FunctionRuntimeException("Unexpected error when getting node degree."); + } + } + return *maybe_degree; +} + +template +TypedValueT Degree(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("degree", args, nargs); + if (args[0].IsNull()) return TypedValueT(ctx.memory); + const auto &vertex = args[0].ValueVertex(); + size_t out_degree = 0; + size_t in_degree = 0; + if constexpr (std::same_as) { + out_degree = UnwrapDegreeResult(vertex.OutDegree(ctx.view)); + in_degree = UnwrapDegreeResult(vertex.InDegree(ctx.view)); + } else { + out_degree = vertex.OutDegree(); + in_degree = vertex.InDegree(); + } + return TypedValueT(static_cast(out_degree + in_degree), ctx.memory); +} + +template +TypedValueT InDegree(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("inDegree", args, nargs); + if (args[0].IsNull()) return TypedValueT(ctx.memory); + const auto &vertex = args[0].ValueVertex(); + size_t in_degree = 0; + if constexpr (std::same_as) { + in_degree = UnwrapDegreeResult(vertex.InDegree(ctx.view)); + } else { + in_degree = vertex.InDegree(); + } + return TypedValueT(static_cast(in_degree), ctx.memory); +} + +template +TypedValueT OutDegree(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("outDegree", args, nargs); + if (args[0].IsNull()) return TypedValueT(ctx.memory); + const auto &vertex = args[0].ValueVertex(); + size_t out_degree = 0; + if constexpr (std::same_as) { + out_degree = UnwrapDegreeResult(vertex.OutDegree(ctx.view)); + } else { + out_degree = vertex.OutDegree(); + } + return TypedValueT(static_cast(out_degree), ctx.memory); +} + +template +TypedValueT ToBoolean(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("toBoolean", args, nargs); + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValueT(ctx.memory); + } + if (value.IsBool()) { + return TypedValueT(value.ValueBool(), ctx.memory); + } + if (value.IsInt()) { + return TypedValueT(value.ValueInt() != 0L, ctx.memory); + } + auto s = utils::ToUpperCase(utils::Trim(value.ValueString())); + if (s == "TRUE") return TypedValueT(true, ctx.memory); + if (s == "FALSE") return TypedValueT(false, ctx.memory); + // I think this is just stupid and that exception should be thrown, but + // neo4j does it this way... + return TypedValueT(ctx.memory); +} + +template +TypedValueT ToFloat(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("toFloat", args, nargs); + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValueT(ctx.memory); + } + if (value.IsInt()) { + return TypedValueT(static_cast(value.ValueInt()), ctx.memory); + } + if (value.IsDouble()) { + return TypedValueT(value, ctx.memory); + } + try { + return TypedValueT(utils::ParseDouble(utils::Trim(value.ValueString())), ctx.memory); + } catch (const utils::BasicException &) { + return TypedValueT(ctx.memory); + } +} + +template +TypedValueT ToInteger(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("toInteger", args, nargs); + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValueT(ctx.memory); + } + if (value.IsBool()) { + return TypedValueT(value.ValueBool() ? 1L : 0L, ctx.memory); + } + if (value.IsInt()) { + return TypedValueT(value, ctx.memory); + } + if (value.IsDouble()) { + return TypedValueT(static_cast(value.ValueDouble()), ctx.memory); + } + try { + // Yup, this is correct. String is valid if it has floating point + // number, then it is parsed and converted to int. + return TypedValueT(static_cast(utils::ParseDouble(utils::Trim(value.ValueString()))), ctx.memory); + } catch (const utils::BasicException &) { + return TypedValueT(ctx.memory); + } +} + +template +TypedValueT Type(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("type", args, nargs); + auto *dba = ctx.db_accessor; + if (args[0].IsNull()) return TypedValueT(ctx.memory); + return TypedValueT(dba->EdgeTypeToName(args[0].ValueEdge().EdgeType()), ctx.memory); +} + +template +TypedValueT ValueType(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("type", args, nargs); + // The type names returned should be standardized openCypher type names. + // https://github.com/opencypher/openCypher/blob/master/docs/openCypher9.pdf + switch (args[0].type()) { + case TypedValueT::Type::Null: + return TypedValueT("NULL", ctx.memory); + case TypedValueT::Type::Bool: + return TypedValueT("BOOLEAN", ctx.memory); + case TypedValueT::Type::Int: + return TypedValueT("INTEGER", ctx.memory); + case TypedValueT::Type::Double: + return TypedValueT("FLOAT", ctx.memory); + case TypedValueT::Type::String: + return TypedValueT("STRING", ctx.memory); + case TypedValueT::Type::List: + return TypedValueT("LIST", ctx.memory); + case TypedValueT::Type::Map: + return TypedValueT("MAP", ctx.memory); + case TypedValueT::Type::Vertex: + return TypedValueT("NODE", ctx.memory); + case TypedValueT::Type::Edge: + return TypedValueT("RELATIONSHIP", ctx.memory); + case TypedValueT::Type::Path: + return TypedValueT("PATH", ctx.memory); + case TypedValueT::Type::Date: + return TypedValueT("DATE", ctx.memory); + case TypedValueT::Type::LocalTime: + return TypedValueT("LOCAL_TIME", ctx.memory); + case TypedValueT::Type::LocalDateTime: + return TypedValueT("LOCAL_DATE_TIME", ctx.memory); + case TypedValueT::Type::Duration: + return TypedValueT("DURATION", ctx.memory); + } +} + +template +TypedValueT Labels(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("labels", args, nargs); + auto *dba = ctx.db_accessor; + if (args[0].IsNull()) return TypedValueT(ctx.memory); + typename TypedValueT::TVector labels(ctx.memory); + if constexpr (std::is_same_v) { + auto maybe_labels = args[0].ValueVertex().Labels(ctx.view); + if (maybe_labels.HasError()) { + switch (maybe_labels.GetError().code) { + case common::ErrorCode::DELETED_OBJECT: + throw functions::FunctionRuntimeException("Trying to get labels from a deleted node."); + case common::ErrorCode::NONEXISTENT_OBJECT: + throw functions::FunctionRuntimeException("Trying to get labels from a node that doesn't exist."); + case common::ErrorCode::SERIALIZATION_ERROR: + case common::ErrorCode::VERTEX_HAS_EDGES: + case common::ErrorCode::PROPERTIES_DISABLED: + case common::ErrorCode::VERTEX_ALREADY_INSERTED: + case common::ErrorCode::SCHEMA_NO_SCHEMA_DEFINED_FOR_LABEL: + case common::ErrorCode::SCHEMA_VERTEX_PROPERTY_WRONG_TYPE: + case common::ErrorCode::SCHEMA_VERTEX_UPDATE_PRIMARY_KEY: + case common::ErrorCode::SCHEMA_VERTEX_UPDATE_PRIMARY_LABEL: + case common::ErrorCode::SCHEMA_VERTEX_SECONDARY_LABEL_IS_PRIMARY: + case common::ErrorCode::SCHEMA_VERTEX_PRIMARY_PROPERTIES_UNDEFINED: + case common::ErrorCode::OBJECT_NOT_FOUND: + throw functions::FunctionRuntimeException("Unexpected error when getting labels."); + } + } + for (const auto &label : *maybe_labels) { + labels.emplace_back(dba->LabelToName(label)); + } + } else { + auto vertex = args[0].ValueVertex(); + for (const auto &label : vertex.Labels()) { + labels.emplace_back(dba->LabelToName(label.id)); + } + } + return TypedValueT(std::move(labels)); +} + +template +TypedValueT Nodes(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("nodes", args, nargs); + if (args[0].IsNull()) return TypedValueT(ctx.memory); + const auto &vertices = args[0].ValuePath().vertices(); + typename TypedValueT::TVector values(ctx.memory); + values.reserve(vertices.size()); + for (const auto &v : vertices) values.emplace_back(v); + return TypedValueT(std::move(values)); +} + +template +TypedValueT Relationships(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("relationships", args, nargs); + if (args[0].IsNull()) return TypedValueT(ctx.memory); + const auto &edges = args[0].ValuePath().edges(); + typename TypedValueT::TVector values(ctx.memory); + values.reserve(edges.size()); + for (const auto &e : edges) values.emplace_back(e); + return TypedValueT(std::move(values)); +} + +template +TypedValueT Range(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType, Or, Optional>>("range", args, nargs); + for (int64_t i = 0; i < nargs; ++i) + if (args[i].IsNull()) return TypedValueT(ctx.memory); + auto lbound = args[0].ValueInt(); + auto rbound = args[1].ValueInt(); + int64_t step = nargs == 3 ? args[2].ValueInt() : 1; + typename TypedValueT::TVector list(ctx.memory); + if (lbound <= rbound && step > 0) { + for (auto i = lbound; i <= rbound; i += step) { + list.emplace_back(i); + } + } else if (lbound >= rbound && step < 0) { + for (auto i = lbound; i >= rbound; i += step) { + list.emplace_back(i); + } + } + return TypedValueT(std::move(list)); +} + +template +TypedValueT Tail(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("tail", args, nargs); + if (args[0].IsNull()) return TypedValueT(ctx.memory); + typename TypedValueT::TVector list(args[0].ValueList(), ctx.memory); + if (list.empty()) return TypedValueT(std::move(list)); + list.erase(list.begin()); + return TypedValueT(std::move(list)); +} + +template +TypedValueT UniformSample(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType, Or>("uniformSample", args, nargs); + static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()}; + if (args[0].IsNull() || args[1].IsNull()) return TypedValueT(ctx.memory); + const auto &population = args[0].ValueList(); + auto population_size = population.size(); + if (population_size == 0) return TypedValueT(ctx.memory); + auto desired_length = args[1].ValueInt(); + std::uniform_int_distribution rand_dist{0, population_size - 1}; + typename TypedValueT::TVector sampled(ctx.memory); + sampled.reserve(desired_length); + for (int64_t i = 0; i < desired_length; ++i) { + sampled.emplace_back(population[rand_dist(pseudo_rand_gen_)]); + } + return TypedValueT(std::move(sampled)); +} + +template +TypedValueT Abs(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("abs", args, nargs); + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValueT(ctx.memory); + } + if (value.IsInt()) { + return TypedValueT(std::abs(value.ValueInt()), ctx.memory); + } + return TypedValueT(std::abs(value.ValueDouble()), ctx.memory); +} + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define WRAP_CMATH_FLOAT_FUNCTION(name, lowercased_name) \ + template \ + TypedValueT name(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { \ + FType>(#lowercased_name, args, nargs); \ + const auto &value = args[0]; \ + if (value.IsNull()) { \ + return TypedValueT(ctx.memory); \ + } \ + if (value.IsInt()) { \ + return TypedValueT(lowercased_name(value.ValueInt()), ctx.memory); \ + } \ + return TypedValueT(lowercased_name(value.ValueDouble()), ctx.memory); \ + } + +WRAP_CMATH_FLOAT_FUNCTION(Ceil, ceil) +WRAP_CMATH_FLOAT_FUNCTION(Floor, floor) +// We are not completely compatible with neoj4 in this function because, +// neo4j rounds -0.5, -1.5, -2.5... to 0, -1, -2... +WRAP_CMATH_FLOAT_FUNCTION(Round, round) +WRAP_CMATH_FLOAT_FUNCTION(Exp, exp) +WRAP_CMATH_FLOAT_FUNCTION(Log, log) +WRAP_CMATH_FLOAT_FUNCTION(Log10, log10) +WRAP_CMATH_FLOAT_FUNCTION(Sqrt, sqrt) +WRAP_CMATH_FLOAT_FUNCTION(Acos, acos) +WRAP_CMATH_FLOAT_FUNCTION(Asin, asin) +WRAP_CMATH_FLOAT_FUNCTION(Atan, atan) +WRAP_CMATH_FLOAT_FUNCTION(Cos, cos) +WRAP_CMATH_FLOAT_FUNCTION(Sin, sin) +WRAP_CMATH_FLOAT_FUNCTION(Tan, tan) + +#undef WRAP_CMATH_FLOAT_FUNCTION + +template +TypedValueT Atan2(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType, Or>("atan2", args, nargs); + if (args[0].IsNull() || args[1].IsNull()) return TypedValueT(ctx.memory); + auto to_double = [](const TypedValueT &t) -> double { + if (t.IsInt()) { + return t.ValueInt(); + } + return t.ValueDouble(); + }; + double y = to_double(args[0]); + double x = to_double(args[1]); + return TypedValueT(atan2(y, x), ctx.memory); +} + +template +TypedValueT Sign(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("sign", args, nargs); + auto sign = [&](auto x) { return TypedValueT((0 < x) - (x < 0), ctx.memory); }; + const auto &value = args[0]; + if (value.IsNull()) { + return TypedValueT(ctx.memory); + } + if (value.IsInt()) { + return sign(value.ValueInt()); + } + return sign(value.ValueDouble()); +} + +template +TypedValueT E(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType("e", args, nargs); + return TypedValueT(M_E, ctx.memory); +} + +template +TypedValueT Pi(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType("pi", args, nargs); + return TypedValueT(M_PI, ctx.memory); +} + +template +TypedValueT Rand(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType("rand", args, nargs); + static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()}; + static thread_local std::uniform_real_distribution<> rand_dist_{0, 1}; + return TypedValueT(rand_dist_(pseudo_rand_gen_), ctx.memory); +} + +template +TypedValueT StringMatchOperator(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType, Or>(TPredicate::name, args, nargs); + if (args[0].IsNull() || args[1].IsNull()) return TypedValueT(ctx.memory); + const auto &s1 = args[0].ValueString(); + const auto &s2 = args[1].ValueString(); + return TypedValueT(TPredicate{}(s1, s2), ctx.memory); +} + +// Check if s1 starts with s2. +template +struct StartsWithPredicate { + static constexpr const char *name = "startsWith"; + bool operator()(const typename TypedValueT::TString &s1, const typename TypedValueT::TString &s2) const { + if (s1.size() < s2.size()) return false; + return std::equal(s2.begin(), s2.end(), s1.begin()); + } +}; + +template +inline const auto StartsWith = StringMatchOperator, TypedValueT, FunctionContextT>; + +// Check if s1 ends with s2. +template +struct EndsWithPredicate { + static constexpr const char *name = "endsWith"; + bool operator()(const typename TypedValueT::TString &s1, const typename TypedValueT::TString &s2) const { + if (s1.size() < s2.size()) return false; + return std::equal(s2.rbegin(), s2.rend(), s1.rbegin()); + } +}; + +template +inline const auto EndsWith = StringMatchOperator, TypedValueT, FunctionContextT>; + +// Check if s1 contains s2. +template +struct ContainsPredicate { + static constexpr const char *name = "contains"; + bool operator()(const typename TypedValueT::TString &s1, const typename TypedValueT::TString &s2) const { + if (s1.size() < s2.size()) return false; + return s1.find(s2) != std::string::npos; + } +}; + +template +inline const auto Contains = StringMatchOperator, TypedValueT, FunctionContextT>; + +template +TypedValueT Assert(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("assert", args, nargs); + if (!args[0].ValueBool()) { + std::string message("Assertion failed"); + if (nargs == 2) { + message += ": "; + message += args[1].ValueString(); + } + message += "."; + throw FunctionRuntimeException(message); + } + return TypedValueT(args[0], ctx.memory); +} + +template +TypedValueT Counter(const TypedValueT *args, int64_t nargs, const FunctionContextT &context) { + FType>("counter", args, nargs); + int64_t step = 1; + if (nargs == 3) { + step = args[2].ValueInt(); + } + + auto [it, inserted] = context.counters->emplace(args[0].ValueString(), args[1].ValueInt()); + auto value = it->second; + it->second += step; + + return TypedValueT(value, context.memory); +} + +template +TypedValueT Id(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("id", args, nargs); + const auto &arg = args[0]; + if (arg.IsNull()) { + return TypedValueT(ctx.memory); + } + return TypedValueT(static_cast(arg.ValueEdge().CypherId()), ctx.memory); +} + +template +TypedValueT ToString(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("toString", args, nargs); + const auto &arg = args[0]; + if (arg.IsNull()) { + return TypedValueT(ctx.memory); + } + if (arg.IsString()) { + return TypedValueT(arg, ctx.memory); + } + if (arg.IsInt()) { + // TODO: This is making a pointless copy of std::string, we may want to + // use a different conversion to string + return TypedValueT(std::to_string(arg.ValueInt()), ctx.memory); + } + if (arg.IsDouble()) { + return TypedValueT(std::to_string(arg.ValueDouble()), ctx.memory); + } + if (arg.IsDate()) { + return TypedValueT(arg.ValueDate().ToString(), ctx.memory); + } + if (arg.IsLocalTime()) { + return TypedValueT(arg.ValueLocalTime().ToString(), ctx.memory); + } + if (arg.IsLocalDateTime()) { + return TypedValueT(arg.ValueLocalDateTime().ToString(), ctx.memory); + } + if (arg.IsDuration()) { + return TypedValueT(arg.ValueDuration().ToString(), ctx.memory); + } + + return TypedValueT(arg.ValueBool() ? "true" : "false", ctx.memory); +} + +template +TypedValueT Timestamp(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>>("timestamp", args, nargs); + const auto &arg = *args; + if (arg.IsDate()) { + return TypedValueT(arg.ValueDate().MicrosecondsSinceEpoch(), ctx.memory); + } + if (arg.IsLocalTime()) { + return TypedValueT(arg.ValueLocalTime().MicrosecondsSinceEpoch(), ctx.memory); + } + if (arg.IsLocalDateTime()) { + return TypedValueT(arg.ValueLocalDateTime().MicrosecondsSinceEpoch(), ctx.memory); + } + if (arg.IsDuration()) { + return TypedValueT(arg.ValueDuration().microseconds, ctx.memory); + } + return TypedValueT(ctx.timestamp, ctx.memory); +} + +template +TypedValueT Left(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType, Or>("left", args, nargs); + if (args[0].IsNull() || args[1].IsNull()) return TypedValueT(ctx.memory); + return TypedValueT(utils::Substr(args[0].ValueString(), 0, args[1].ValueInt()), ctx.memory); +} + +template +TypedValueT Right(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType, Or>("right", args, nargs); + if (args[0].IsNull() || args[1].IsNull()) return TypedValueT(ctx.memory); + const auto &str = args[0].ValueString(); + auto len = args[1].ValueInt(); + return len <= str.size() ? TypedValueT(utils::Substr(str, str.size() - len, len), ctx.memory) + : TypedValueT(str, ctx.memory); +} + +template +TypedValueT CallStringFunction( + const TypedValueT *args, int64_t nargs, utils::MemoryResource *memory, const char *name, + std::function fun) { + FType>(name, args, nargs); + if (args[0].IsNull()) return TypedValueT(memory); + return TypedValueT(fun(args[0].ValueString()), memory); +} + +template +TypedValueT LTrim(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + return CallStringFunction(args, nargs, ctx.memory, "lTrim", [&](const auto &str) { + return typename TypedValueT::TString(utils::LTrim(str), ctx.memory); + }); +} + +template +TypedValueT RTrim(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + return CallStringFunction(args, nargs, ctx.memory, "rTrim", [&](const auto &str) { + return typename TypedValueT::TString(utils::RTrim(str), ctx.memory); + }); +} + +template +TypedValueT Trim(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + return CallStringFunction(args, nargs, ctx.memory, "trim", [&](const auto &str) { + return typename TypedValueT::TString(utils::Trim(str), ctx.memory); + }); +} + +template +TypedValueT Reverse(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + return CallStringFunction( + args, nargs, ctx.memory, "reverse", [&](const auto &str) { return utils::Reversed(str, ctx.memory); }); +} + +template +TypedValueT ToLower(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + return CallStringFunction(args, nargs, ctx.memory, "toLower", [&](const auto &str) { + typename TypedValueT::TString res(ctx.memory); + utils::ToLowerCase(&res, str); + return res; + }); +} + +template +TypedValueT ToUpper(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + return CallStringFunction(args, nargs, ctx.memory, "toUpper", [&](const auto &str) { + typename TypedValueT::TString res(ctx.memory); + utils::ToUpperCase(&res, str); + return res; + }); +} + +template +TypedValueT Replace(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType, Or, Or>("replace", args, nargs); + if (args[0].IsNull() || args[1].IsNull() || args[2].IsNull()) { + return TypedValueT(ctx.memory); + } + typename TypedValueT::TString replaced(ctx.memory); + utils::Replace(&replaced, args[0].ValueString(), args[1].ValueString(), args[2].ValueString()); + return TypedValueT(std::move(replaced)); +} + +template +TypedValueT Split(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType, Or>("split", args, nargs); + if (args[0].IsNull() || args[1].IsNull()) { + return TypedValueT(ctx.memory); + } + typename TypedValueT::TVector result(ctx.memory); + utils::Split(&result, args[0].ValueString(), args[1].ValueString()); + return TypedValueT(std::move(result)); +} + +template +TypedValueT Substring(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType, NonNegativeInteger, Optional>("substring", args, nargs); + if (args[0].IsNull()) return TypedValueT(ctx.memory); + const auto &str = args[0].ValueString(); + auto start = args[1].ValueInt(); + if (nargs == 2) return TypedValueT(utils::Substr(str, start), ctx.memory); + auto len = args[2].ValueInt(); + return TypedValueT(utils::Substr(str, start, len), ctx.memory); +} + +template +TypedValueT ToByteString(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType("toByteString", args, nargs); + const auto &str = args[0].ValueString(); + if (str.empty()) return TypedValueT("", ctx.memory); + if (!utils::StartsWith(str, "0x") && !utils::StartsWith(str, "0X")) { + throw FunctionRuntimeException("'toByteString' argument must start with '0x'"); + } + const auto &hex_str = utils::Substr(str, 2); + auto read_hex = [](const char ch) -> unsigned char { + if (ch >= '0' && ch <= '9') return ch - '0'; + if (ch >= 'a' && ch <= 'f') return ch - 'a' + 10; + if (ch >= 'A' && ch <= 'F') return ch - 'A' + 10; + throw FunctionRuntimeException("'toByteString' argument has an invalid character '{}'", ch); + }; + utils::pmr::string bytes(ctx.memory); + bytes.reserve((1 + hex_str.size()) / 2); + size_t i = 0; + // Treat odd length hex string as having a leading zero. + if (hex_str.size() % 2) bytes.append(1, read_hex(hex_str[i++])); + for (; i < hex_str.size(); i += 2) { + unsigned char byte = read_hex(hex_str[i]) * 16U + read_hex(hex_str[i + 1]); + // MemcpyCast in case we are converting to a signed value, so as to avoid + // undefined behaviour. + bytes.append(1, utils::MemcpyCast(byte)); + } + return TypedValueT(std::move(bytes)); +} + +template +TypedValueT FromByteString(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("fromByteString", args, nargs); + const auto &bytes = args[0].ValueString(); + if (bytes.empty()) return TypedValueT("", ctx.memory); + size_t min_length = bytes.size(); + if (nargs == 2) min_length = std::max(min_length, static_cast(args[1].ValueInt())); + utils::pmr::string str(ctx.memory); + str.reserve(min_length * 2 + 2); + str.append("0x"); + for (size_t pad = 0; pad < min_length - bytes.size(); ++pad) str.append(2, '0'); + // Convert the bytes to a character string in hex representation. + // Unfortunately, we don't know whether the default `char` is signed or + // unsigned, so we have to work around any potential undefined behaviour when + // conversions between the 2 occur. That's why this function is more + // complicated than it should be. + auto to_hex = [](const unsigned char val) -> char { + unsigned char ch = val < 10U ? static_cast('0') + val : static_cast('a') + val - 10U; + return utils::MemcpyCast(ch); + }; + for (unsigned char byte : bytes) { + str.append(1, to_hex(byte / 16U)); + str.append(1, to_hex(byte % 16U)); + } + return TypedValueT(std::move(str)); +} + +template +concept IsNumberOrInteger = utils::SameAsAnyOf; + +template +void MapNumericParameters(auto ¶meter_mappings, const auto &input_parameters) { + for (const auto &[key, value] : input_parameters) { + if (auto it = parameter_mappings.find(key); it != parameter_mappings.end()) { + if (value.IsInt()) { + *it->second = value.ValueInt(); + } else if (std::is_same_v && value.IsDouble()) { + *it->second = value.ValueDouble(); + } else { + std::string_view error = std::is_same_v ? "an integer." : "a numeric value."; + throw FunctionRuntimeException("Invalid value for key '{}'. Expected {}", key, error); + } + } else { + throw FunctionRuntimeException("Unknown key '{}'.", key); + } + } +} + +template +TypedValueT Date(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>>("date", args, nargs); + if (nargs == 0) { + return TypedValueT(utils::LocalDateTime(ctx.timestamp).date, ctx.memory); + } + + if (args[0].IsString()) { + const auto &[date_parameters, is_extended] = utils::ParseDateParameters(args[0].ValueString()); + return TypedValueT(utils::Date(date_parameters), ctx.memory); + } + + utils::DateParameters date_parameters; + + using namespace std::literals; + std::unordered_map parameter_mappings = {std::pair{"year"sv, &date_parameters.year}, + std::pair{"month"sv, &date_parameters.month}, + std::pair{"day"sv, &date_parameters.day}}; + + MapNumericParameters(parameter_mappings, args[0].ValueMap()); + return TypedValueT(utils::Date(date_parameters), ctx.memory); +} + +template +TypedValueT LocalTime(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>>("localtime", args, nargs); + + if (nargs == 0) { + return TypedValueT(utils::LocalDateTime(ctx.timestamp).local_time, ctx.memory); + } + + if (args[0].IsString()) { + const auto &[local_time_parameters, is_extended] = utils::ParseLocalTimeParameters(args[0].ValueString()); + return TypedValueT(utils::LocalTime(local_time_parameters), ctx.memory); + } + + utils::LocalTimeParameters local_time_parameters; + + using namespace std::literals; + std::unordered_map parameter_mappings{ + std::pair{"hour"sv, &local_time_parameters.hour}, + std::pair{"minute"sv, &local_time_parameters.minute}, + std::pair{"second"sv, &local_time_parameters.second}, + std::pair{"millisecond"sv, &local_time_parameters.millisecond}, + std::pair{"microsecond"sv, &local_time_parameters.microsecond}, + }; + + MapNumericParameters(parameter_mappings, args[0].ValueMap()); + return TypedValueT(utils::LocalTime(local_time_parameters), ctx.memory); +} + +template +TypedValueT LocalDateTime(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>>("localdatetime", args, nargs); + + if (nargs == 0) { + return TypedValueT(utils::LocalDateTime(ctx.timestamp), ctx.memory); + } + + if (args[0].IsString()) { + const auto &[date_parameters, local_time_parameters] = ParseLocalDateTimeParameters(args[0].ValueString()); + return TypedValueT(utils::LocalDateTime(date_parameters, local_time_parameters), ctx.memory); + } + + utils::DateParameters date_parameters; + utils::LocalTimeParameters local_time_parameters; + using namespace std::literals; + std::unordered_map parameter_mappings{ + std::pair{"year"sv, &date_parameters.year}, + std::pair{"month"sv, &date_parameters.month}, + std::pair{"day"sv, &date_parameters.day}, + std::pair{"hour"sv, &local_time_parameters.hour}, + std::pair{"minute"sv, &local_time_parameters.minute}, + std::pair{"second"sv, &local_time_parameters.second}, + std::pair{"millisecond"sv, &local_time_parameters.millisecond}, + std::pair{"microsecond"sv, &local_time_parameters.microsecond}, + }; + + MapNumericParameters(parameter_mappings, args[0].ValueMap()); + return TypedValueT(utils::LocalDateTime(date_parameters, local_time_parameters), ctx.memory); +} + +template +TypedValueT Duration(const TypedValueT *args, int64_t nargs, const FunctionContextT &ctx) { + FType>("duration", args, nargs); + + if (args[0].IsString()) { + return TypedValueT(utils::Duration(ParseDurationParameters(args[0].ValueString())), ctx.memory); + } + + utils::DurationParameters duration_parameters; + using namespace std::literals; + std::unordered_map parameter_mappings{std::pair{"day"sv, &duration_parameters.day}, + std::pair{"hour"sv, &duration_parameters.hour}, + std::pair{"minute"sv, &duration_parameters.minute}, + std::pair{"second"sv, &duration_parameters.second}, + std::pair{"millisecond"sv, &duration_parameters.millisecond}, + std::pair{"microsecond"sv, &duration_parameters.microsecond}}; + MapNumericParameters(parameter_mappings, args[0].ValueMap()); + return TypedValueT(utils::Duration(duration_parameters), ctx.memory); +} + +} // namespace memgraph::functions::impl + +namespace memgraph::functions { + +template +std::function +NameToFunction(const std::string &function_name) { + // Scalar functions + if (function_name == "DEGREE") return functions::impl::Degree; + if (function_name == "INDEGREE") return functions::impl::InDegree; + if (function_name == "OUTDEGREE") return functions::impl::OutDegree; + if (function_name == "HEAD") return functions::impl::Head; + if (function_name == kId) return functions::impl::Id; + if (function_name == "LAST") return functions::impl::Last; + if (function_name == "PROPERTIES") return functions::impl::Properties; + if (function_name == "SIZE") return functions::impl::Size; + if (function_name == "TIMESTAMP") return functions::impl::Timestamp; + if (function_name == "TOBOOLEAN") return functions::impl::ToBoolean; + if (function_name == "TOFLOAT") return functions::impl::ToFloat; + if (function_name == "TOINTEGER") return functions::impl::ToInteger; + if (function_name == "TYPE") return functions::impl::Type; + if (function_name == "VALUETYPE") return functions::impl::ValueType; + // Only on QE + if constexpr (std::is_same_v) { + if (function_name == "STARTNODE") return functions::impl::StartNode; + if (function_name == "ENDNODE") return functions::impl::EndNode; + } + + // List functions + if (function_name == "LABELS") return functions::impl::Labels; + if (function_name == "NODES") return functions::impl::Nodes; + if (function_name == "RANGE") return functions::impl::Range; + if (function_name == "RELATIONSHIPS") return functions::impl::Relationships; + if (function_name == "TAIL") return functions::impl::Tail; + if (function_name == "UNIFORMSAMPLE") return functions::impl::UniformSample; + + // Mathematical functions - numeric + if (function_name == "ABS") return functions::impl::Abs; + if (function_name == "CEIL") return functions::impl::Ceil; + if (function_name == "FLOOR") return functions::impl::Floor; + if (function_name == "RAND") return functions::impl::Rand; + if (function_name == "ROUND") return functions::impl::Round; + if (function_name == "SIGN") return functions::impl::Sign; + + // Mathematical functions - logarithmic + if (function_name == "E") return functions::impl::E; + if (function_name == "EXP") return functions::impl::Exp; + if (function_name == "LOG") return functions::impl::Log; + if (function_name == "LOG10") return functions::impl::Log10; + if (function_name == "SQRT") return functions::impl::Sqrt; + + // Mathematical functions - trigonometric + if (function_name == "ACOS") return functions::impl::Acos; + if (function_name == "ASIN") return functions::impl::Asin; + if (function_name == "ATAN") return functions::impl::Atan; + if (function_name == "ATAN2") return functions::impl::Atan2; + if (function_name == "COS") return functions::impl::Cos; + if (function_name == "PI") return functions::impl::Pi; + if (function_name == "SIN") return functions::impl::Sin; + if (function_name == "TAN") return functions::impl::Tan; + + // String functions + if (function_name == kContains) return functions::impl::Contains; + if (function_name == kEndsWith) return functions::impl::EndsWith; + if (function_name == "LEFT") return functions::impl::Left; + if (function_name == "LTRIM") return functions::impl::LTrim; + if (function_name == "REPLACE") return functions::impl::Replace; + if (function_name == "REVERSE") return functions::impl::Reverse; + if (function_name == "RIGHT") return functions::impl::Right; + if (function_name == "RTRIM") return functions::impl::RTrim; + if (function_name == "SPLIT") return functions::impl::Split; + if (function_name == kStartsWith) return functions::impl::StartsWith; + if (function_name == "SUBSTRING") return functions::impl::Substring; + if (function_name == "TOLOWER") return functions::impl::ToLower; + if (function_name == "TOSTRING") return functions::impl::ToString; + if (function_name == "TOUPPER") return functions::impl::ToUpper; + if (function_name == "TRIM") return functions::impl::Trim; + + // Memgraph specific functions + if (function_name == "ASSERT") return functions::impl::Assert; + if (function_name == "TOBYTESTRING") return functions::impl::ToByteString; + if (function_name == "FROMBYTESTRING") return functions::impl::FromByteString; + // Only on QE + if constexpr (std::is_same_v) { + if (function_name == "COUNTER") return functions::impl::Counter; + } + + // Functions for temporal types + if (function_name == "DATE") return functions::impl::Date; + if (function_name == "LOCALTIME") return functions::impl::LocalTime; + if (function_name == "LOCALDATETIME") return functions::impl::LocalDateTime; + if (function_name == "DURATION") return functions::impl::Duration; + + return nullptr; +} + +} // namespace memgraph::functions diff --git a/src/glue/v2/communication.cpp b/src/glue/v2/communication.cpp index eedf699bb..42228652f 100644 --- a/src/glue/v2/communication.cpp +++ b/src/glue/v2/communication.cpp @@ -18,8 +18,8 @@ #include "common/errors.hpp" #include "coordinator/shard_map.hpp" #include "query/v2/accessors.hpp" +#include "query/v2/request_router.hpp" #include "query/v2/requests.hpp" -#include "query/v2/shard_request_manager.hpp" #include "storage/v3/edge_accessor.hpp" #include "storage/v3/id_types.hpp" #include "storage/v3/shard.hpp" @@ -72,7 +72,7 @@ query::v2::TypedValue ToTypedValue(const Value &value) { } communication::bolt::Vertex ToBoltVertex(const query::v2::accessors::VertexAccessor &vertex, - const msgs::ShardRequestManagerInterface *shard_request_manager, + const query::v2::RequestRouterInterface *request_router, storage::v3::View /*view*/) { auto id = communication::bolt::Id::FromUint(0); @@ -80,43 +80,43 @@ communication::bolt::Vertex ToBoltVertex(const query::v2::accessors::VertexAcces std::vector new_labels; new_labels.reserve(labels.size()); for (const auto &label : labels) { - new_labels.push_back(shard_request_manager->LabelToName(label.id)); + new_labels.push_back(request_router->LabelToName(label.id)); } auto properties = vertex.Properties(); std::map new_properties; for (const auto &[prop, property_value] : properties) { - new_properties[shard_request_manager->PropertyToName(prop)] = ToBoltValue(property_value); + new_properties[request_router->PropertyToName(prop)] = ToBoltValue(property_value); } return communication::bolt::Vertex{id, new_labels, new_properties}; } communication::bolt::Edge ToBoltEdge(const query::v2::accessors::EdgeAccessor &edge, - const msgs::ShardRequestManagerInterface *shard_request_manager, + const query::v2::RequestRouterInterface *request_router, storage::v3::View /*view*/) { // TODO(jbajic) Fix bolt communication auto id = communication::bolt::Id::FromUint(0); auto from = communication::bolt::Id::FromUint(0); auto to = communication::bolt::Id::FromUint(0); - const auto &type = shard_request_manager->EdgeTypeToName(edge.EdgeType()); + const auto &type = request_router->EdgeTypeToName(edge.EdgeType()); auto properties = edge.Properties(); std::map new_properties; for (const auto &[prop, property_value] : properties) { - new_properties[shard_request_manager->PropertyToName(prop)] = ToBoltValue(property_value); + new_properties[request_router->PropertyToName(prop)] = ToBoltValue(property_value); } return communication::bolt::Edge{id, from, to, type, new_properties}; } communication::bolt::Path ToBoltPath(const query::v2::accessors::Path & /*edge*/, - const msgs::ShardRequestManagerInterface * /*shard_request_manager*/, + const query::v2::RequestRouterInterface * /*request_router*/, storage::v3::View /*view*/) { // TODO(jbajic) Fix bolt communication MG_ASSERT(false, "Path is unimplemented!"); return {}; } -Value ToBoltValue(const query::v2::TypedValue &value, const msgs::ShardRequestManagerInterface *shard_request_manager, +Value ToBoltValue(const query::v2::TypedValue &value, const query::v2::RequestRouterInterface *request_router, storage::v3::View view) { switch (value.type()) { case query::v2::TypedValue::Type::Null: @@ -133,7 +133,7 @@ Value ToBoltValue(const query::v2::TypedValue &value, const msgs::ShardRequestMa std::vector values; values.reserve(value.ValueList().size()); for (const auto &v : value.ValueList()) { - auto value = ToBoltValue(v, shard_request_manager, view); + auto value = ToBoltValue(v, request_router, view); values.emplace_back(std::move(value)); } return {std::move(values)}; @@ -141,21 +141,21 @@ Value ToBoltValue(const query::v2::TypedValue &value, const msgs::ShardRequestMa case query::v2::TypedValue::Type::Map: { std::map map; for (const auto &kv : value.ValueMap()) { - auto value = ToBoltValue(kv.second, shard_request_manager, view); + auto value = ToBoltValue(kv.second, request_router, view); map.emplace(kv.first, std::move(value)); } return {std::move(map)}; } case query::v2::TypedValue::Type::Vertex: { - auto vertex = ToBoltVertex(value.ValueVertex(), shard_request_manager, view); + auto vertex = ToBoltVertex(value.ValueVertex(), request_router, view); return {std::move(vertex)}; } case query::v2::TypedValue::Type::Edge: { - auto edge = ToBoltEdge(value.ValueEdge(), shard_request_manager, view); + auto edge = ToBoltEdge(value.ValueEdge(), request_router, view); return {std::move(edge)}; } case query::v2::TypedValue::Type::Path: { - auto path = ToBoltPath(value.ValuePath(), shard_request_manager, view); + auto path = ToBoltPath(value.ValuePath(), request_router, view); return {std::move(path)}; } case query::v2::TypedValue::Type::Date: diff --git a/src/glue/v2/communication.hpp b/src/glue/v2/communication.hpp index a20162176..c1661b521 100644 --- a/src/glue/v2/communication.hpp +++ b/src/glue/v2/communication.hpp @@ -15,7 +15,7 @@ #include "communication/bolt/v1/value.hpp" #include "coordinator/shard_map.hpp" #include "query/v2/bindings/typed_value.hpp" -#include "query/v2/shard_request_manager.hpp" +#include "query/v2/request_router.hpp" #include "storage/v3/property_value.hpp" #include "storage/v3/result.hpp" #include "storage/v3/shard.hpp" @@ -32,40 +32,37 @@ namespace memgraph::glue::v2 { /// @param storage::v3::VertexAccessor for converting to /// communication::bolt::Vertex. -/// @param msgs::ShardRequestManagerInterface *shard_request_manager getting label and property names. +/// @param query::v2::RequestRouterInterface *request_router getting label and property names. /// @param storage::v3::View for deciding which vertex attributes are visible. /// /// @throw std::bad_alloc communication::bolt::Vertex ToBoltVertex(const storage::v3::VertexAccessor &vertex, - const msgs::ShardRequestManagerInterface *shard_request_manager, + const query::v2::RequestRouterInterface *request_router, storage::v3::View view); /// @param storage::v3::EdgeAccessor for converting to communication::bolt::Edge. -/// @param msgs::ShardRequestManagerInterface *shard_request_manager getting edge type and property names. +/// @param query::v2::RequestRouterInterface *request_router getting edge type and property names. /// @param storage::v3::View for deciding which edge attributes are visible. /// /// @throw std::bad_alloc communication::bolt::Edge ToBoltEdge(const storage::v3::EdgeAccessor &edge, - const msgs::ShardRequestManagerInterface *shard_request_manager, - storage::v3::View view); + const query::v2::RequestRouterInterface *request_router, storage::v3::View view); /// @param query::v2::Path for converting to communication::bolt::Path. -/// @param msgs::ShardRequestManagerInterface *shard_request_manager ToBoltVertex and ToBoltEdge. +/// @param query::v2::RequestRouterInterface *request_router ToBoltVertex and ToBoltEdge. /// @param storage::v3::View for ToBoltVertex and ToBoltEdge. /// /// @throw std::bad_alloc communication::bolt::Path ToBoltPath(const query::v2::accessors::Path &path, - const msgs::ShardRequestManagerInterface *shard_request_manager, - storage::v3::View view); + const query::v2::RequestRouterInterface *request_router, storage::v3::View view); /// @param query::v2::TypedValue for converting to communication::bolt::Value. -/// @param msgs::ShardRequestManagerInterface *shard_request_manager ToBoltVertex and ToBoltEdge. +/// @param query::v2::RequestRouterInterface *request_router ToBoltVertex and ToBoltEdge. /// @param storage::v3::View for ToBoltVertex and ToBoltEdge. /// /// @throw std::bad_alloc communication::bolt::Value ToBoltValue(const query::v2::TypedValue &value, - const msgs::ShardRequestManagerInterface *shard_request_manager, - storage::v3::View view); + const query::v2::RequestRouterInterface *request_router, storage::v3::View view); query::v2::TypedValue ToTypedValue(const communication::bolt::Value &value); @@ -75,8 +72,7 @@ storage::v3::PropertyValue ToPropertyValue(const communication::bolt::Value &val communication::bolt::Value ToBoltValue(msgs::Value value); -communication::bolt::Value ToBoltValue(msgs::Value value, - const msgs::ShardRequestManagerInterface *shard_request_manager, +communication::bolt::Value ToBoltValue(msgs::Value value, const query::v2::RequestRouterInterface *request_router, storage::v3::View view); } // namespace memgraph::glue::v2 diff --git a/src/io/rsm/rsm_client.hpp b/src/io/rsm/rsm_client.hpp index b60380b08..920866c7a 100644 --- a/src/io/rsm/rsm_client.hpp +++ b/src/io/rsm/rsm_client.hpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include "io/address.hpp" @@ -36,6 +37,21 @@ using memgraph::io::rsm::WriteRequest; using memgraph::io::rsm::WriteResponse; using memgraph::utils::BasicResult; +class AsyncRequestToken { + size_t id_; + + public: + explicit AsyncRequestToken(size_t id) : id_(id) {} + size_t GetId() const { return id_; } +}; + +template +struct AsyncRequest { + Time start_time; + RequestT request; + ResponseFuture future; +}; + template class RsmClient { @@ -47,23 +63,17 @@ class RsmClient { /// State for single async read/write operations. In the future this could become a map /// of async operations that can be accessed via an ID etc... - std::optional