diff --git a/src/expr/interpret/frame.hpp b/src/expr/interpret/frame.hpp index 1cd6a99ce..457806680 100644 --- a/src/expr/interpret/frame.hpp +++ b/src/expr/interpret/frame.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -23,9 +23,9 @@ namespace memgraph::expr { class Frame { public: /// Create a Frame of given size backed by a utils::NewDeleteResource() - explicit Frame(int64_t size) : elems_(size, utils::NewDeleteResource()) { MG_ASSERT(size >= 0); } + explicit Frame(size_t size) : elems_(size, utils::NewDeleteResource()) { MG_ASSERT(size >= 0); } - Frame(int64_t size, utils::MemoryResource *memory) : elems_(size, memory) { MG_ASSERT(size >= 0); } + Frame(size_t size, utils::MemoryResource *memory) : elems_(size, memory) { MG_ASSERT(size >= 0); } TypedValue &operator[](const Symbol &symbol) { return elems_[symbol.position()]; } const TypedValue &operator[](const Symbol &symbol) const { return elems_[symbol.position()]; } @@ -34,6 +34,7 @@ class Frame { const TypedValue &at(const Symbol &symbol) const { return elems_.at(symbol.position()); } auto &elems() { return elems_; } + const auto &elems() const { return elems_; } utils::MemoryResource *GetMemoryResource() const { return elems_.get_allocator().GetMemoryResource(); } @@ -43,9 +44,9 @@ class Frame { class FrameWithValidity final : public Frame { public: - explicit FrameWithValidity(int64_t size) : Frame(size), is_valid_(false) {} + explicit FrameWithValidity(size_t size) : Frame(size), is_valid_(false) {} - FrameWithValidity(int64_t size, utils::MemoryResource *memory) : Frame(size, memory), is_valid_(false) {} + FrameWithValidity(size_t size, utils::MemoryResource *memory) : Frame(size, memory), is_valid_(false) {} bool IsValid() const noexcept { return is_valid_; } void MakeValid() noexcept { is_valid_ = true; } diff --git a/src/io/local_transport/local_system.hpp b/src/io/local_transport/local_system.hpp index 2e54f8d75..7b0cda537 100644 --- a/src/io/local_transport/local_system.hpp +++ b/src/io/local_transport/local_system.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/src/io/simulator/simulator.hpp b/src/io/simulator/simulator.hpp index 8afc073af..4ff8e7ae0 100644 --- a/src/io/simulator/simulator.hpp +++ b/src/io/simulator/simulator.hpp @@ -41,7 +41,7 @@ class Simulator { Io<SimulatorTransport> Register(Address address) { std::uniform_int_distribution<uint64_t> seed_distrib; uint64_t seed = seed_distrib(rng_); - return Io{SimulatorTransport{simulator_handle_, address, seed}, address}; + return Io{SimulatorTransport(simulator_handle_, address, seed), address}; } void IncrementServerCountAndWaitForQuiescentState(Address address) { @@ -50,8 +50,12 @@ class Simulator { SimulatorStats Stats() { return simulator_handle_->Stats(); } + std::shared_ptr<SimulatorHandle> GetSimulatorHandle() const { return simulator_handle_; } + std::function<bool()> GetSimulatorTickClosure() { - std::function<bool()> tick_closure = [handle_copy = simulator_handle_] { return handle_copy->MaybeTickSimulator(); }; + std::function<bool()> tick_closure = [handle_copy = simulator_handle_] { + return handle_copy->MaybeTickSimulator(); + }; return tick_closure; } }; diff --git a/src/io/simulator/simulator_transport.hpp b/src/io/simulator/simulator_transport.hpp index 038cfeb03..1272a04a1 100644 --- a/src/io/simulator/simulator_transport.hpp +++ b/src/io/simulator/simulator_transport.hpp @@ -26,7 +26,7 @@ using memgraph::io::Time; class SimulatorTransport { std::shared_ptr<SimulatorHandle> simulator_handle_; - const Address address_; + Address address_; std::mt19937 rng_; public: @@ -36,7 +36,9 @@ class SimulatorTransport { template <Message RequestT, Message ResponseT> ResponseFuture<ResponseT> Request(Address to_address, Address from_address, RequestT request, std::function<void()> notification, Duration timeout) { - std::function<bool()> tick_simulator = [handle_copy = simulator_handle_] { return handle_copy->MaybeTickSimulator(); }; + std::function<bool()> tick_simulator = [handle_copy = simulator_handle_] { + return handle_copy->MaybeTickSimulator(); + }; return simulator_handle_->template SubmitRequest<RequestT, ResponseT>( to_address, from_address, std::move(request), timeout, std::move(tick_simulator), std::move(notification)); diff --git a/src/memgraph.cpp b/src/memgraph.cpp index d825cc0e7..35fd20ad7 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -640,6 +640,8 @@ int main(int argc, char **argv) { memgraph::machine_manager::MachineManager<memgraph::io::local_transport::LocalTransport> mm{io, config, coordinator}; std::jthread mm_thread([&mm] { mm.Run(); }); + auto rr_factory = std::make_unique<memgraph::query::v2::LocalRequestRouterFactory>(io); + memgraph::query::v2::InterpreterContext interpreter_context{ (memgraph::storage::v3::Shard *)(nullptr), {.query = {.allow_load_csv = FLAGS_allow_load_csv}, @@ -650,7 +652,7 @@ int main(int argc, char **argv) { .stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries, .stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)}, FLAGS_data_directory, - std::move(io), + std::move(rr_factory), mm.CoordinatorAddress()}; SessionData session_data{&interpreter_context}; diff --git a/src/query/v2/context.hpp b/src/query/v2/context.hpp index cb30a9ced..58f5ada97 100644 --- a/src/query/v2/context.hpp +++ b/src/query/v2/context.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/src/query/v2/interpreter.cpp b/src/query/v2/interpreter.cpp index fde11ac00..aa220d764 100644 --- a/src/query/v2/interpreter.cpp +++ b/src/query/v2/interpreter.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -704,7 +704,6 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par ctx_.request_router = request_router; ctx_.edge_ids_alloc = &interpreter_context->edge_ids_alloc; } - std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::PullMultiple(AnyStream *stream, std::optional<int> n, const std::vector<Symbol> &output_symbols, std::map<std::string, TypedValue> *summary) { @@ -732,10 +731,7 @@ std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::PullMultiple(AnyStrea } // Returns true if a result was pulled. - const auto pull_result = [&]() -> bool { - cursor_->PullMultiple(multi_frame_, ctx_); - return multi_frame_.HasValidFrame(); - }; + const auto pull_result = [&]() -> bool { return cursor_->PullMultiple(multi_frame_, ctx_); }; const auto stream_values = [&output_symbols, &stream](const Frame &frame) { // TODO: The streamed values should also probably use the above memory. @@ -755,13 +751,14 @@ std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::PullMultiple(AnyStrea int i = 0; if (has_unsent_results_ && !output_symbols.empty()) { // stream unsent results from previous pull - - auto iterator_for_valid_frame_only = multi_frame_.GetValidFramesReader(); - for (const auto &frame : iterator_for_valid_frame_only) { + for (auto &frame : multi_frame_.GetValidFramesConsumer()) { stream_values(frame); + frame.MakeInvalid(); ++i; + if (i == n) { + break; + } } - multi_frame_.MakeAllFramesInvalid(); } for (; !n || i < n;) { @@ -770,13 +767,17 @@ std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::PullMultiple(AnyStrea } if (!output_symbols.empty()) { - auto iterator_for_valid_frame_only = multi_frame_.GetValidFramesReader(); - for (const auto &frame : iterator_for_valid_frame_only) { + for (auto &frame : multi_frame_.GetValidFramesConsumer()) { stream_values(frame); + frame.MakeInvalid(); ++i; + if (i == n) { + break; + } } + } else { + multi_frame_.MakeAllFramesInvalid(); } - multi_frame_.MakeAllFramesInvalid(); } // If we finished because we streamed the requested n results, @@ -906,34 +907,24 @@ using RWType = plan::ReadWriteTypeChecker::RWType; InterpreterContext::InterpreterContext(storage::v3::Shard *db, const InterpreterConfig config, const std::filesystem::path & /*data_directory*/, - io::Io<io::local_transport::LocalTransport> io, + std::unique_ptr<RequestRouterFactory> request_router_factory, coordinator::Address coordinator_addr) - : db(db), config(config), io{std::move(io)}, coordinator_address{coordinator_addr} {} + : db(db), + config(config), + coordinator_address{coordinator_addr}, + request_router_factory_{std::move(request_router_factory)} {} Interpreter::Interpreter(InterpreterContext *interpreter_context) : interpreter_context_(interpreter_context) { MG_ASSERT(interpreter_context_, "Interpreter context must not be NULL"); - // TODO(tyler) make this deterministic so that it can be tested. - auto random_uuid = boost::uuids::uuid{boost::uuids::random_generator()()}; - auto query_io = interpreter_context_->io.ForkLocal(random_uuid); + request_router_ = + interpreter_context_->request_router_factory_->CreateRequestRouter(interpreter_context_->coordinator_address); - request_router_ = std::make_unique<RequestRouter<io::local_transport::LocalTransport>>( - coordinator::CoordinatorClient<io::local_transport::LocalTransport>( - query_io, interpreter_context_->coordinator_address, std::vector{interpreter_context_->coordinator_address}), - std::move(query_io)); // Get edge ids - coordinator::CoordinatorWriteRequests requests{coordinator::AllocateEdgeIdBatchRequest{.batch_size = 1000000}}; - io::rsm::WriteRequest<coordinator::CoordinatorWriteRequests> ww; - ww.operation = requests; - auto resp = interpreter_context_->io - .Request<io::rsm::WriteRequest<coordinator::CoordinatorWriteRequests>, - io::rsm::WriteResponse<coordinator::CoordinatorWriteResponses>>( - interpreter_context_->coordinator_address, ww) - .Wait(); - if (resp.HasValue()) { - const auto alloc_edge_id_reps = - std::get<coordinator::AllocateEdgeIdBatchResponse>(resp.GetValue().message.write_return); - interpreter_context_->edge_ids_alloc = {alloc_edge_id_reps.low, alloc_edge_id_reps.high}; + const auto edge_ids_alloc_min_max_pair = + request_router_->AllocateInitialEdgeIds(interpreter_context_->coordinator_address); + if (edge_ids_alloc_min_max_pair) { + interpreter_context_->edge_ids_alloc = {edge_ids_alloc_min_max_pair->first, edge_ids_alloc_min_max_pair->second}; } } diff --git a/src/query/v2/interpreter.hpp b/src/query/v2/interpreter.hpp index 985c9a90c..4efc85c22 100644 --- a/src/query/v2/interpreter.hpp +++ b/src/query/v2/interpreter.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -16,7 +16,6 @@ #include "coordinator/coordinator.hpp" #include "coordinator/coordinator_client.hpp" -#include "io/local_transport/local_transport.hpp" #include "io/transport.hpp" #include "query/v2/auth_checker.hpp" #include "query/v2/bindings/cypher_main_visitor.hpp" @@ -172,7 +171,8 @@ struct PreparedQuery { struct InterpreterContext { explicit InterpreterContext(storage::v3::Shard *db, InterpreterConfig config, const std::filesystem::path &data_directory, - io::Io<io::local_transport::LocalTransport> io, coordinator::Address coordinator_addr); + std::unique_ptr<RequestRouterFactory> request_router_factory, + coordinator::Address coordinator_addr); storage::v3::Shard *db; @@ -188,26 +188,24 @@ struct InterpreterContext { const InterpreterConfig config; IdAllocator edge_ids_alloc; - // TODO (antaljanosbenjamin) Figure out an abstraction for io::Io to make it possible to construct an interpreter - // context with a simulator transport without templatizing it. - io::Io<io::local_transport::LocalTransport> io; coordinator::Address coordinator_address; + std::unique_ptr<RequestRouterFactory> request_router_factory_; storage::v3::LabelId NameToLabelId(std::string_view label_name) { - return storage::v3::LabelId::FromUint(query_id_mapper.NameToId(label_name)); + return storage::v3::LabelId::FromUint(query_id_mapper_.NameToId(label_name)); } storage::v3::PropertyId NameToPropertyId(std::string_view property_name) { - return storage::v3::PropertyId::FromUint(query_id_mapper.NameToId(property_name)); + return storage::v3::PropertyId::FromUint(query_id_mapper_.NameToId(property_name)); } storage::v3::EdgeTypeId NameToEdgeTypeId(std::string_view edge_type_name) { - return storage::v3::EdgeTypeId::FromUint(query_id_mapper.NameToId(edge_type_name)); + return storage::v3::EdgeTypeId::FromUint(query_id_mapper_.NameToId(edge_type_name)); } private: // TODO Replace with local map of labels, properties and edge type ids - storage::v3::NameIdMapper query_id_mapper; + storage::v3::NameIdMapper query_id_mapper_; }; /// Function that is used to tell all active interpreters that they should stop @@ -297,12 +295,15 @@ class Interpreter final { void Abort(); const RequestRouterInterface *GetRequestRouter() const { return request_router_.get(); } + void InstallSimulatorTicker(std::function<bool()> &&tick_simulator) { + request_router_->InstallSimulatorTicker(tick_simulator); + } private: struct QueryExecution { - std::optional<PreparedQuery> prepared_query; utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize}; utils::ResourceWithOutOfMemoryException execution_memory_with_exception{&execution_memory}; + std::optional<PreparedQuery> prepared_query; std::map<std::string, TypedValue> summary; std::vector<Notification> notifications; diff --git a/src/query/v2/multiframe.cpp b/src/query/v2/multiframe.cpp index 2cb591153..477ef6c0c 100644 --- a/src/query/v2/multiframe.cpp +++ b/src/query/v2/multiframe.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -45,20 +45,26 @@ void MultiFrame::MakeAllFramesInvalid() noexcept { } bool MultiFrame::HasValidFrame() const noexcept { - return std::any_of(frames_.begin(), frames_.end(), [](auto &frame) { return frame.IsValid(); }); + return std::any_of(frames_.begin(), frames_.end(), [](const auto &frame) { return frame.IsValid(); }); +} + +bool MultiFrame::HasInvalidFrame() const noexcept { + return std::any_of(frames_.rbegin(), frames_.rend(), [](const auto &frame) { return !frame.IsValid(); }); } // NOLINTNEXTLINE (bugprone-exception-escape) void MultiFrame::DefragmentValidFrames() noexcept { - /* - from: https://en.cppreference.com/w/cpp/algorithm/remove - "Removing is done by shifting (by means of copy assignment (until C++11)move assignment (since C++11)) the elements - in the range in such a way that the elements that are not to be removed appear in the beginning of the range. - Relative order of the elements that remain is preserved and the physical size of the container is unchanged." - */ - - // NOLINTNEXTLINE (bugprone-unused-return-value) - std::remove_if(frames_.begin(), frames_.end(), [](auto &frame) { return !frame.IsValid(); }); + static constexpr auto kIsValid = [](const FrameWithValidity &frame) { return frame.IsValid(); }; + static constexpr auto kIsInvalid = [](const FrameWithValidity &frame) { return !frame.IsValid(); }; + auto first_invalid_frame = std::find_if(frames_.begin(), frames_.end(), kIsInvalid); + auto following_first_valid = std::find_if(first_invalid_frame, frames_.end(), kIsValid); + while (first_invalid_frame != frames_.end() && following_first_valid != frames_.end()) { + std::swap(*first_invalid_frame, *following_first_valid); + first_invalid_frame++; + first_invalid_frame = std::find_if(first_invalid_frame, frames_.end(), kIsInvalid); + following_first_valid++; + following_first_valid = std::find_if(following_first_valid, frames_.end(), kIsValid); + } } ValidFramesReader MultiFrame::GetValidFramesReader() { return ValidFramesReader{*this}; } diff --git a/src/query/v2/multiframe.hpp b/src/query/v2/multiframe.hpp index 0365b449f..6958ffbe8 100644 --- a/src/query/v2/multiframe.hpp +++ b/src/query/v2/multiframe.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -33,6 +33,7 @@ class MultiFrame { MultiFrame(size_t size_of_frame, size_t number_of_frames, utils::MemoryResource *execution_memory); ~MultiFrame() = default; + // Assigning and moving the MultiFrame is not allowed if any accessor from the above ones are alive. MultiFrame(const MultiFrame &other); MultiFrame(MultiFrame &&other) noexcept; MultiFrame &operator=(const MultiFrame &other) = delete; @@ -81,6 +82,7 @@ class MultiFrame { void MakeAllFramesInvalid() noexcept; bool HasValidFrame() const noexcept; + bool HasInvalidFrame() const noexcept; inline utils::MemoryResource *GetMemoryResource() { return frames_[0].GetMemoryResource(); } @@ -96,9 +98,9 @@ class ValidFramesReader { ~ValidFramesReader() = default; ValidFramesReader(const ValidFramesReader &other) = delete; - ValidFramesReader(ValidFramesReader &&other) noexcept = delete; + ValidFramesReader(ValidFramesReader &&other) noexcept = default; ValidFramesReader &operator=(const ValidFramesReader &other) = delete; - ValidFramesReader &operator=(ValidFramesReader &&other) noexcept = delete; + ValidFramesReader &operator=(ValidFramesReader &&other) noexcept = default; struct Iterator { using iterator_category = std::forward_iterator_tag; @@ -146,9 +148,9 @@ class ValidFramesModifier { ~ValidFramesModifier() = default; ValidFramesModifier(const ValidFramesModifier &other) = delete; - ValidFramesModifier(ValidFramesModifier &&other) noexcept = delete; + ValidFramesModifier(ValidFramesModifier &&other) noexcept = default; ValidFramesModifier &operator=(const ValidFramesModifier &other) = delete; - ValidFramesModifier &operator=(ValidFramesModifier &&other) noexcept = delete; + ValidFramesModifier &operator=(ValidFramesModifier &&other) noexcept = default; struct Iterator { using iterator_category = std::forward_iterator_tag; @@ -200,10 +202,10 @@ class ValidFramesConsumer { explicit ValidFramesConsumer(MultiFrame &multiframe); ~ValidFramesConsumer() noexcept; - ValidFramesConsumer(const ValidFramesConsumer &other) = default; + ValidFramesConsumer(const ValidFramesConsumer &other) = delete; ValidFramesConsumer(ValidFramesConsumer &&other) noexcept = default; - ValidFramesConsumer &operator=(const ValidFramesConsumer &other) = default; - ValidFramesConsumer &operator=(ValidFramesConsumer &&other) noexcept = delete; + ValidFramesConsumer &operator=(const ValidFramesConsumer &other) = delete; + ValidFramesConsumer &operator=(ValidFramesConsumer &&other) noexcept = default; struct Iterator { using iterator_category = std::forward_iterator_tag; @@ -255,9 +257,9 @@ class InvalidFramesPopulator { ~InvalidFramesPopulator() = default; InvalidFramesPopulator(const InvalidFramesPopulator &other) = delete; - InvalidFramesPopulator(InvalidFramesPopulator &&other) noexcept = delete; + InvalidFramesPopulator(InvalidFramesPopulator &&other) noexcept = default; InvalidFramesPopulator &operator=(const InvalidFramesPopulator &other) = delete; - InvalidFramesPopulator &operator=(InvalidFramesPopulator &&other) noexcept = delete; + InvalidFramesPopulator &operator=(InvalidFramesPopulator &&other) noexcept = default; struct Iterator { using iterator_category = std::forward_iterator_tag; diff --git a/src/query/v2/plan/cost_estimator.hpp b/src/query/v2/plan/cost_estimator.hpp index 8cba2505f..f497d14d5 100644 --- a/src/query/v2/plan/cost_estimator.hpp +++ b/src/query/v2/plan/cost_estimator.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -149,8 +149,6 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { return true; } - // TODO: Cost estimate ScanAllById? - // For the given op first increments the cardinality and then cost. #define POST_VISIT_CARD_FIRST(NAME) \ bool PostVisit(NAME &) override { \ diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp index a048f799d..95a7f6c40 100644 --- a/src/query/v2/plan/operator.cpp +++ b/src/query/v2/plan/operator.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -34,6 +34,7 @@ #include "query/v2/bindings/eval.hpp" #include "query/v2/bindings/symbol_table.hpp" #include "query/v2/context.hpp" +#include "query/v2/conversions.hpp" #include "query/v2/db_accessor.hpp" #include "query/v2/exceptions.hpp" #include "query/v2/frontend/ast/ast.hpp" @@ -93,7 +94,7 @@ extern const Event ScanAllByLabelOperator; extern const Event ScanAllByLabelPropertyRangeOperator; extern const Event ScanAllByLabelPropertyValueOperator; extern const Event ScanAllByLabelPropertyOperator; -extern const Event ScanAllByIdOperator; +extern const Event ScanByPrimaryKeyOperator; extern const Event ExpandOperator; extern const Event ExpandVariableOperator; extern const Event ConstructNamedPathOperator; @@ -171,9 +172,8 @@ uint64_t ComputeProfilingKey(const T *obj) { class DistributedCreateNodeCursor : public Cursor { public: using InputOperator = std::shared_ptr<memgraph::query::v2::plan::LogicalOperator>; - DistributedCreateNodeCursor(const InputOperator &op, utils::MemoryResource *mem, - std::vector<const NodeCreationInfo *> nodes_info) - : input_cursor_(op->MakeCursor(mem)), nodes_info_(std::move(nodes_info)) {} + DistributedCreateNodeCursor(const InputOperator &op, utils::MemoryResource *mem, const NodeCreationInfo &node_info) + : input_cursor_(op->MakeCursor(mem)), node_info_(node_info) {} bool Pull(Frame &frame, ExecutionContext &context) override { SCOPED_PROFILE_OP("CreateNode"); @@ -190,15 +190,19 @@ class DistributedCreateNodeCursor : public Cursor { return false; } - void PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) override { + bool PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) override { SCOPED_PROFILE_OP("CreateNodeMF"); - input_cursor_->PullMultiple(multi_frame, context); - auto &request_router = context.request_router; + + auto *request_router = context.request_router; + if (!input_cursor_->PullMultiple(multi_frame, context)) { + return false; + } { SCOPED_REQUEST_WAIT_PROFILE; request_router->CreateVertices(NodeCreationInfoToRequests(context, multi_frame)); } PlaceNodesOnTheMultiFrame(multi_frame, context); + return false; } void Shutdown() override { input_cursor_->Shutdown(); } @@ -207,27 +211,75 @@ class DistributedCreateNodeCursor : public Cursor { void PlaceNodeOnTheFrame(Frame &frame, ExecutionContext &context) { // TODO(kostasrim) Make this work with batching - const auto primary_label = msgs::Label{.id = nodes_info_[0]->labels[0]}; + const auto primary_label = msgs::Label{.id = node_info_.labels[0]}; msgs::Vertex v{.id = std::make_pair(primary_label, primary_keys_[0])}; - frame[nodes_info_.front()->symbol] = + frame[node_info_.symbol] = TypedValue(query::v2::accessors::VertexAccessor(std::move(v), src_vertex_props_[0], context.request_router)); } std::vector<msgs::NewVertex> NodeCreationInfoToRequest(ExecutionContext &context, Frame &frame) { std::vector<msgs::NewVertex> requests; - // TODO(kostasrim) this assertion should be removed once we support multiple vertex creation - MG_ASSERT(nodes_info_.size() == 1); msgs::PrimaryKey pk; - for (const auto &node_info : nodes_info_) { + msgs::NewVertex rqst; + MG_ASSERT(!node_info_.labels.empty(), "Cannot determine primary label"); + const auto primary_label = node_info_.labels[0]; + // TODO(jbajic) Send also the properties that are not part of primary key + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, nullptr, + storage::v3::View::NEW); + if (const auto *node_info_properties = std::get_if<PropertiesMapList>(&node_info_.properties)) { + for (const auto &[key, value_expression] : *node_info_properties) { + TypedValue val = value_expression->Accept(evaluator); + if (context.request_router->IsPrimaryKey(primary_label, key)) { + rqst.primary_key.push_back(TypedValueToValue(val)); + pk.push_back(TypedValueToValue(val)); + } + } + } else { + auto property_map = evaluator.Visit(*std::get<ParameterLookup *>(node_info_.properties)).ValueMap(); + for (const auto &[key, value] : property_map) { + auto key_str = std::string(key); + auto property_id = context.request_router->NameToProperty(key_str); + if (context.request_router->IsPrimaryKey(primary_label, property_id)) { + rqst.primary_key.push_back(TypedValueToValue(value)); + pk.push_back(TypedValueToValue(value)); + } + } + } + + // TODO(kostasrim) Copy non primary labels as well + rqst.label_ids.push_back(msgs::Label{.id = primary_label}); + src_vertex_props_.push_back(rqst.properties); + requests.push_back(std::move(rqst)); + + primary_keys_.push_back(std::move(pk)); + return requests; + } + + void PlaceNodesOnTheMultiFrame(MultiFrame &multi_frame, ExecutionContext &context) { + auto multi_frame_modifier = multi_frame.GetValidFramesModifier(); + size_t i = 0; + MG_ASSERT(std::distance(multi_frame_modifier.begin(), multi_frame_modifier.end())); + for (auto &frame : multi_frame_modifier) { + const auto primary_label = msgs::Label{.id = node_info_.labels[0]}; + msgs::Vertex v{.id = std::make_pair(primary_label, primary_keys_[i])}; + frame[node_info_.symbol] = TypedValue( + query::v2::accessors::VertexAccessor(std::move(v), src_vertex_props_[i++], context.request_router)); + } + } + + std::vector<msgs::NewVertex> NodeCreationInfoToRequests(ExecutionContext &context, MultiFrame &multi_frame) { + std::vector<msgs::NewVertex> requests; + auto multi_frame_modifier = multi_frame.GetValidFramesModifier(); + for (auto &frame : multi_frame_modifier) { + msgs::PrimaryKey pk; msgs::NewVertex rqst; - MG_ASSERT(!node_info->labels.empty(), "Cannot determine primary label"); - const auto primary_label = node_info->labels[0]; - // TODO(jbajic) Fix properties not send, - // suggestion: ignore distinction between properties and primary keys - // since schema validation is done on storage side + MG_ASSERT(!node_info_.labels.empty(), "Cannot determine primary label"); + const auto primary_label = node_info_.labels[0]; + MG_ASSERT(context.request_router->IsPrimaryLabel(primary_label), "First label has to be a primary label!"); + // TODO(jbajic) Send also the properties that are not part of primary key ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, nullptr, storage::v3::View::NEW); - if (const auto *node_info_properties = std::get_if<PropertiesMapList>(&node_info->properties)) { + if (const auto *node_info_properties = std::get_if<PropertiesMapList>(&node_info_.properties)) { for (const auto &[key, value_expression] : *node_info_properties) { TypedValue val = value_expression->Accept(evaluator); if (context.request_router->IsPrimaryKey(primary_label, key)) { @@ -236,7 +288,7 @@ class DistributedCreateNodeCursor : public Cursor { } } } else { - auto property_map = evaluator.Visit(*std::get<ParameterLookup *>(node_info->properties)).ValueMap(); + auto property_map = evaluator.Visit(*std::get<ParameterLookup *>(node_info_.properties)).ValueMap(); for (const auto &[key, value] : property_map) { auto key_str = std::string(key); auto property_id = context.request_router->NameToProperty(key_str); @@ -247,80 +299,19 @@ class DistributedCreateNodeCursor : public Cursor { } } - if (node_info->labels.empty()) { - throw QueryRuntimeException("Primary label must be defined!"); - } // TODO(kostasrim) Copy non primary labels as well rqst.label_ids.push_back(msgs::Label{.id = primary_label}); src_vertex_props_.push_back(rqst.properties); requests.push_back(std::move(rqst)); - } - primary_keys_.push_back(std::move(pk)); - return requests; - } - - void PlaceNodesOnTheMultiFrame(MultiFrame &multi_frame, ExecutionContext &context) { - auto multi_frame_reader = multi_frame.GetValidFramesConsumer(); - size_t i = 0; - MG_ASSERT(std::distance(multi_frame_reader.begin(), multi_frame_reader.end())); - for (auto &frame : multi_frame_reader) { - const auto primary_label = msgs::Label{.id = nodes_info_[0]->labels[0]}; - msgs::Vertex v{.id = std::make_pair(primary_label, primary_keys_[i])}; - frame[nodes_info_.front()->symbol] = TypedValue( - query::v2::accessors::VertexAccessor(std::move(v), src_vertex_props_[i++], context.request_router)); - } - } - - std::vector<msgs::NewVertex> NodeCreationInfoToRequests(ExecutionContext &context, MultiFrame &multi_frame) { - std::vector<msgs::NewVertex> requests; - auto multi_frame_reader = multi_frame.GetValidFramesConsumer(); - for (auto &frame : multi_frame_reader) { - msgs::PrimaryKey pk; - for (const auto &node_info : nodes_info_) { - msgs::NewVertex rqst; - MG_ASSERT(!node_info->labels.empty(), "Cannot determine primary label"); - const auto primary_label = node_info->labels[0]; - // TODO(jbajic) Fix properties not send, - // suggestion: ignore distinction between properties and primary keys - // since schema validation is done on storage side - ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, nullptr, - storage::v3::View::NEW); - if (const auto *node_info_properties = std::get_if<PropertiesMapList>(&node_info->properties)) { - for (const auto &[key, value_expression] : *node_info_properties) { - TypedValue val = value_expression->Accept(evaluator); - if (context.request_router->IsPrimaryKey(primary_label, key)) { - rqst.primary_key.push_back(TypedValueToValue(val)); - pk.push_back(TypedValueToValue(val)); - } - } - } else { - auto property_map = evaluator.Visit(*std::get<ParameterLookup *>(node_info->properties)).ValueMap(); - for (const auto &[key, value] : property_map) { - auto key_str = std::string(key); - auto property_id = context.request_router->NameToProperty(key_str); - if (context.request_router->IsPrimaryKey(primary_label, property_id)) { - rqst.primary_key.push_back(TypedValueToValue(value)); - pk.push_back(TypedValueToValue(value)); - } - } - } - - if (node_info->labels.empty()) { - throw QueryRuntimeException("Primary label must be defined!"); - } - // TODO(kostasrim) Copy non primary labels as well - rqst.label_ids.push_back(msgs::Label{.id = primary_label}); - src_vertex_props_.push_back(rqst.properties); - requests.push_back(std::move(rqst)); - } primary_keys_.push_back(std::move(pk)); } + return requests; } private: const UniqueCursorPtr input_cursor_; - std::vector<const NodeCreationInfo *> nodes_info_; + NodeCreationInfo node_info_; std::vector<std::vector<std::pair<storage::v3::PropertyId, msgs::Value>>> src_vertex_props_; std::vector<msgs::PrimaryKey> primary_keys_; }; @@ -335,14 +326,16 @@ bool Once::OnceCursor::Pull(Frame &, ExecutionContext &context) { return false; } -void Once::OnceCursor::PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) { +bool Once::OnceCursor::PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) { SCOPED_PROFILE_OP("OnceMF"); if (!did_pull_) { auto &first_frame = multi_frame.GetFirstFrame(); first_frame.MakeValid(); did_pull_ = true; + return true; } + return false; } UniqueCursorPtr Once::MakeCursor(utils::MemoryResource *mem) const { @@ -365,7 +358,7 @@ ACCEPT_WITH_INPUT(CreateNode) UniqueCursorPtr CreateNode::MakeCursor(utils::MemoryResource *mem) const { EventCounter::IncrementCounter(EventCounter::CreateNodeOperator); - return MakeUniqueCursorPtr<DistributedCreateNodeCursor>(mem, input_, mem, std::vector{&this->node_info_}); + return MakeUniqueCursorPtr<DistributedCreateNodeCursor>(mem, input_, mem, this->node_info_); } std::vector<Symbol> CreateNode::ModifiedSymbols(const SymbolTable &table) const { @@ -463,114 +456,109 @@ class DistributedScanAllAndFilterCursor : public Cursor { ResetExecutionState(); } - enum class State : int8_t { INITIALIZING, COMPLETED }; - using VertexAccessor = accessors::VertexAccessor; - bool MakeRequest(RequestRouterInterface &request_router, ExecutionContext &context) { + bool MakeRequest(ExecutionContext &context) { { SCOPED_REQUEST_WAIT_PROFILE; std::optional<std::string> request_label = std::nullopt; if (label_.has_value()) { - request_label = request_router.LabelToName(*label_); + request_label = context.request_router->LabelToName(*label_); } - current_batch_ = request_router.ScanVertices(request_label); + current_batch_ = context.request_router->ScanVertices(request_label); } current_vertex_it_ = current_batch_.begin(); - request_state_ = State::COMPLETED; return !current_batch_.empty(); } bool Pull(Frame &frame, ExecutionContext &context) override { SCOPED_PROFILE_OP(op_name_); - auto &request_router = *context.request_router; while (true) { if (MustAbort(context)) { throw HintedAbortError(); } - if (request_state_ == State::INITIALIZING) { - if (!input_cursor_->Pull(frame, context)) { + if (current_vertex_it_ == current_batch_.end()) { + ResetExecutionState(); + if (!input_cursor_->Pull(frame, context) || !MakeRequest(context)) { return false; } } - if (current_vertex_it_ == current_batch_.end() && - (request_state_ == State::COMPLETED || !MakeRequest(request_router, context))) { - ResetExecutionState(); - continue; - } - frame[output_symbol_] = TypedValue(std::move(*current_vertex_it_)); ++current_vertex_it_; return true; } } - void PrepareNextFrames(ExecutionContext &context) { - auto &request_router = *context.request_router; - - input_cursor_->PullMultiple(*own_multi_frames_, context); - valid_frames_consumer_ = own_multi_frames_->GetValidFramesConsumer(); - valid_frames_it_ = valid_frames_consumer_->begin(); - - MakeRequest(request_router, context); - } - - inline bool HasNextFrame() { - return current_vertex_it_ != current_batch_.end() && valid_frames_it_ != valid_frames_consumer_->end(); - } - - FrameWithValidity GetNextFrame(ExecutionContext &context) { - MG_ASSERT(HasNextFrame()); - - auto frame = *valid_frames_it_; - frame[output_symbol_] = TypedValue(*current_vertex_it_); - - ++current_vertex_it_; - if (current_vertex_it_ == current_batch_.end()) { - valid_frames_it_->MakeInvalid(); - ++valid_frames_it_; - - if (valid_frames_it_ == valid_frames_consumer_->end()) { - PrepareNextFrames(context); - } else { - current_vertex_it_ = current_batch_.begin(); - } - }; - - return frame; - } - - void PullMultiple(MultiFrame &input_multi_frame, ExecutionContext &context) override { + bool PullMultiple(MultiFrame &output_multi_frame, ExecutionContext &context) override { SCOPED_PROFILE_OP(op_name_); - if (!own_multi_frames_.has_value()) { - own_multi_frames_.emplace(MultiFrame(input_multi_frame.GetFirstFrame().elems().size(), - kNumberOfFramesInMultiframe, input_multi_frame.GetMemoryResource())); - PrepareNextFrames(context); + if (!own_multi_frame_.has_value()) { + own_multi_frame_.emplace(MultiFrame(output_multi_frame.GetFirstFrame().elems().size(), + kNumberOfFramesInMultiframe, output_multi_frame.GetMemoryResource())); + own_frames_consumer_.emplace(own_multi_frame_->GetValidFramesConsumer()); + own_frames_it_ = own_frames_consumer_->begin(); } + auto output_frames_populator = output_multi_frame.GetInvalidFramesPopulator(); + auto populated_any = false; + while (true) { - if (MustAbort(context)) { - throw HintedAbortError(); - } + switch (state_) { + case State::PullInput: { + if (!input_cursor_->PullMultiple(*own_multi_frame_, context)) { + state_ = State::Exhausted; + return populated_any; + } + own_frames_consumer_.emplace(own_multi_frame_->GetValidFramesConsumer()); + own_frames_it_ = own_frames_consumer_->begin(); + state_ = State::FetchVertices; + break; + } + case State::FetchVertices: { + if (own_frames_it_ == own_frames_consumer_->end()) { + state_ = State::PullInput; + continue; + } + if (!filter_expressions_->empty() || property_expression_pair_.has_value() || current_batch_.empty()) { + MakeRequest(context); + } else { + // We can reuse the vertices as they don't depend on any value from the frames + current_vertex_it_ = current_batch_.begin(); + } + state_ = State::PopulateOutput; + break; + } + case State::PopulateOutput: { + if (!output_multi_frame.HasInvalidFrame()) { + return populated_any; + } + if (current_vertex_it_ == current_batch_.end()) { + own_frames_it_->MakeInvalid(); + ++own_frames_it_; + state_ = State::FetchVertices; + continue; + } - auto invalid_frames_populator = input_multi_frame.GetInvalidFramesPopulator(); - auto invalid_frame_it = invalid_frames_populator.begin(); - auto has_modified_at_least_one_frame = false; - - while (invalid_frames_populator.end() != invalid_frame_it && HasNextFrame()) { - has_modified_at_least_one_frame = true; - *invalid_frame_it = GetNextFrame(context); - ++invalid_frame_it; - } - - if (!has_modified_at_least_one_frame) { - return; + for (auto output_frame_it = output_frames_populator.begin(); + output_frame_it != output_frames_populator.end() && current_vertex_it_ != current_batch_.end(); + ++output_frame_it) { + auto &output_frame = *output_frame_it; + output_frame = *own_frames_it_; + output_frame[output_symbol_] = TypedValue(*current_vertex_it_); + current_vertex_it_++; + populated_any = true; + } + break; + } + case State::Exhausted: { + return populated_any; + } } } + return populated_any; }; void Shutdown() override { input_cursor_->Shutdown(); } @@ -578,7 +566,6 @@ class DistributedScanAllAndFilterCursor : public Cursor { void ResetExecutionState() { current_batch_.clear(); current_vertex_it_ = current_batch_.end(); - request_state_ = State::INITIALIZING; } void Reset() override { @@ -587,19 +574,98 @@ class DistributedScanAllAndFilterCursor : public Cursor { } private: + enum class State { PullInput, FetchVertices, PopulateOutput, Exhausted }; + + State state_{State::PullInput}; const Symbol output_symbol_; const UniqueCursorPtr input_cursor_; const char *op_name_; std::vector<VertexAccessor> current_batch_; - std::vector<VertexAccessor>::iterator current_vertex_it_; - State request_state_ = State::INITIALIZING; + std::vector<VertexAccessor>::iterator current_vertex_it_{current_batch_.begin()}; std::optional<storage::v3::LabelId> label_; std::optional<std::pair<storage::v3::PropertyId, Expression *>> property_expression_pair_; std::optional<std::vector<Expression *>> filter_expressions_; - std::optional<MultiFrame> own_multi_frames_; - std::optional<ValidFramesConsumer> valid_frames_consumer_; - ValidFramesConsumer::Iterator valid_frames_it_; - std::queue<FrameWithValidity> frames_buffer_; + std::optional<MultiFrame> own_multi_frame_; + std::optional<ValidFramesConsumer> own_frames_consumer_; + ValidFramesConsumer::Iterator own_frames_it_; +}; + +class DistributedScanByPrimaryKeyCursor : public Cursor { + public: + explicit DistributedScanByPrimaryKeyCursor(Symbol output_symbol, UniqueCursorPtr input_cursor, const char *op_name, + storage::v3::LabelId label, + std::optional<std::vector<Expression *>> filter_expressions, + std::vector<Expression *> primary_key) + : output_symbol_(output_symbol), + input_cursor_(std::move(input_cursor)), + op_name_(op_name), + label_(label), + filter_expressions_(filter_expressions), + primary_key_(primary_key) {} + + enum class State : int8_t { INITIALIZING, COMPLETED }; + + using VertexAccessor = accessors::VertexAccessor; + + std::optional<VertexAccessor> MakeRequestSingleFrame(Frame &frame, RequestRouterInterface &request_router, + ExecutionContext &context) { + // Evaluate the expressions that hold the PrimaryKey. + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.request_router, + storage::v3::View::NEW); + + std::vector<msgs::Value> pk; + for (auto *primary_property : primary_key_) { + pk.push_back(TypedValueToValue(primary_property->Accept(evaluator))); + } + + msgs::Label label = {.id = msgs::LabelId::FromUint(label_.AsUint())}; + + msgs::GetPropertiesRequest req = {.vertex_ids = {std::make_pair(label, pk)}}; + auto get_prop_result = std::invoke([&context, &request_router, &req]() mutable { + SCOPED_REQUEST_WAIT_PROFILE; + return request_router.GetProperties(req); + }); + MG_ASSERT(get_prop_result.size() <= 1); + + if (get_prop_result.empty()) { + return std::nullopt; + } + auto properties = get_prop_result[0].props; + // TODO (gvolfing) figure out labels when relevant. + msgs::Vertex vertex = {.id = get_prop_result[0].vertex, .labels = {}}; + + return VertexAccessor(vertex, properties, &request_router); + } + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP(op_name_); + + if (MustAbort(context)) { + throw HintedAbortError(); + } + + while (input_cursor_->Pull(frame, context)) { + auto &request_router = *context.request_router; + auto vertex = MakeRequestSingleFrame(frame, request_router, context); + if (vertex) { + frame[output_symbol_] = TypedValue(std::move(*vertex)); + return true; + } + } + return false; + } + + void Reset() override { input_cursor_->Reset(); } + + void Shutdown() override { input_cursor_->Shutdown(); } + + private: + const Symbol output_symbol_; + const UniqueCursorPtr input_cursor_; + const char *op_name_; + storage::v3::LabelId label_; + std::optional<std::vector<Expression *>> filter_expressions_; + std::vector<Expression *> primary_key_; }; ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::v3::View view) @@ -607,8 +673,6 @@ ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_sy ACCEPT_WITH_INPUT(ScanAll) -class DistributedScanAllCursor; - UniqueCursorPtr ScanAll::MakeCursor(utils::MemoryResource *mem) const { EventCounter::IncrementCounter(EventCounter::ScanAllOperator); @@ -698,22 +762,21 @@ UniqueCursorPtr ScanAllByLabelProperty::MakeCursor(utils::MemoryResource *mem) c throw QueryRuntimeException("ScanAllByLabelProperty is not supported"); } -ScanAllById::ScanAllById(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, Expression *expression, - storage::v3::View view) - : ScanAll(input, output_symbol, view), expression_(expression) { - MG_ASSERT(expression); +ScanByPrimaryKey::ScanByPrimaryKey(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, + storage::v3::LabelId label, std::vector<query::v2::Expression *> primary_key, + storage::v3::View view) + : ScanAll(input, output_symbol, view), label_(label), primary_key_(primary_key) { + MG_ASSERT(primary_key.front()); } -ACCEPT_WITH_INPUT(ScanAllById) +ACCEPT_WITH_INPUT(ScanByPrimaryKey) -UniqueCursorPtr ScanAllById::MakeCursor(utils::MemoryResource *mem) const { - EventCounter::IncrementCounter(EventCounter::ScanAllByIdOperator); - // TODO Reimplement when we have reliable conversion between hash value and pk - auto vertices = [](Frame & /*frame*/, ExecutionContext & /*context*/) -> std::optional<std::vector<VertexAccessor>> { - return std::nullopt; - }; - return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem), - std::move(vertices), "ScanAllById"); +UniqueCursorPtr ScanByPrimaryKey::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ScanByPrimaryKeyOperator); + + return MakeUniqueCursorPtr<DistributedScanByPrimaryKeyCursor>(mem, output_symbol_, input_->MakeCursor(mem), + "ScanByPrimaryKey", label_, + std::nullopt /*filter_expressions*/, primary_key_); } Expand::Expand(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Symbol node_symbol, @@ -856,6 +919,27 @@ bool Filter::FilterCursor::Pull(Frame &frame, ExecutionContext &context) { return false; } +bool Filter::FilterCursor::PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("Filter"); + auto populated_any = false; + + while (multi_frame.HasInvalidFrame()) { + if (!input_cursor_->PullMultiple(multi_frame, context)) { + return populated_any; + } + for (auto &frame : multi_frame.GetValidFramesConsumer()) { + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.request_router, + storage::v3::View::OLD); + if (!EvaluateFilter(evaluator, self_.expression_)) { + frame.MakeInvalid(); + } else { + populated_any = true; + } + } + } + return populated_any; +} + void Filter::FilterCursor::Shutdown() { input_cursor_->Shutdown(); } void Filter::FilterCursor::Reset() { input_cursor_->Reset(); } @@ -891,19 +975,22 @@ bool Produce::ProduceCursor::Pull(Frame &frame, ExecutionContext &context) { // Produce should always yield the latest results. ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.request_router, storage::v3::View::NEW); - for (auto named_expr : self_.named_expressions_) named_expr->Accept(evaluator); + for (auto *named_expr : self_.named_expressions_) named_expr->Accept(evaluator); return true; } return false; } -void Produce::ProduceCursor::PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) { +bool Produce::ProduceCursor::PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) { SCOPED_PROFILE_OP("ProduceMF"); - input_cursor_->PullMultiple(multi_frame, context); + if (!input_cursor_->PullMultiple(multi_frame, context)) { + return false; + } auto iterator_for_valid_frame_only = multi_frame.GetValidFramesModifier(); + for (auto &frame : iterator_for_valid_frame_only) { // Produce should always yield the latest results. ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.request_router, @@ -913,7 +1000,9 @@ void Produce::ProduceCursor::PullMultiple(MultiFrame &multi_frame, ExecutionCont named_expr->Accept(evaluator); } } -}; + + return true; +} void Produce::ProduceCursor::Shutdown() { input_cursor_->Shutdown(); } @@ -938,6 +1027,8 @@ Delete::DeleteCursor::DeleteCursor(const Delete &self, utils::MemoryResource *me bool Delete::DeleteCursor::Pull(Frame & /*frame*/, ExecutionContext & /*context*/) { return false; } +bool Delete::DeleteCursor::PullMultiple(MultiFrame & /*multi_frame*/, ExecutionContext & /*context*/) { return false; } + void Delete::DeleteCursor::Shutdown() { input_cursor_->Shutdown(); } void Delete::DeleteCursor::Reset() { input_cursor_->Reset(); } @@ -2632,9 +2723,11 @@ class DistributedCreateExpandCursor : public Cursor { return true; } - void PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) override { + bool PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) override { SCOPED_PROFILE_OP("CreateExpandMF"); - input_cursor_->PullMultiple(multi_frame, context); + if (!input_cursor_->PullMultiple(multi_frame, context)) { + return false; + } auto request_vertices = ExpandCreationInfoToRequests(multi_frame, context); { SCOPED_REQUEST_WAIT_PROFILE; @@ -2646,6 +2739,7 @@ class DistributedCreateExpandCursor : public Cursor { } } } + return true; } void Shutdown() override { input_cursor_->Shutdown(); } @@ -2759,8 +2853,20 @@ class DistributedCreateExpandCursor : public Cursor { // Set src and dest vertices // TODO(jbajic) Currently we are only handling scenario where vertices // are matched - request.src_vertex = v1.Id(); - request.dest_vertex = v2.Id(); + switch (edge_info.direction) { + case EdgeAtom::Direction::IN: { + request.src_vertex = v2.Id(); + request.dest_vertex = v1.Id(); + break; + } + case EdgeAtom::Direction::OUT: { + request.src_vertex = v1.Id(); + request.dest_vertex = v2.Id(); + break; + } + case EdgeAtom::Direction::BOTH: + LOG_FATAL("Must indicate exact expansion direction here"); + } edge_requests.push_back(std::move(request)); } @@ -2776,7 +2882,7 @@ class DistributedCreateExpandCursor : public Cursor { class DistributedExpandCursor : public Cursor { public: - explicit DistributedExpandCursor(const Expand &self, utils::MemoryResource *mem) + DistributedExpandCursor(const Expand &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)), current_in_edge_it_(current_in_edges_.begin()), @@ -2813,16 +2919,15 @@ class DistributedExpandCursor : public Cursor { throw std::runtime_error("EdgeDirection Both not implemented"); } }; - msgs::ExpandOneRequest request; + + msgs::GetPropertiesRequest request; // to not fetch any properties of the edges - request.edge_properties.emplace(); - request.src_vertices.push_back(get_dst_vertex(edge, direction)); - request.direction = (direction == EdgeAtom::Direction::IN) ? msgs::EdgeDirection::OUT : msgs::EdgeDirection::IN; - auto result_rows = context.request_router->ExpandOne(std::move(request)); + request.vertex_ids.push_back(get_dst_vertex(edge, direction)); + auto result_rows = context.request_router->GetProperties(std::move(request)); MG_ASSERT(result_rows.size() == 1); auto &result_row = result_rows.front(); - frame[self_.common_.node_symbol] = accessors::VertexAccessor( - msgs::Vertex{result_row.src_vertex}, result_row.src_vertex_properties, context.request_router); + frame[self_.common_.node_symbol] = + accessors::VertexAccessor(msgs::Vertex{result_row.vertex}, result_row.props, context.request_router); } bool InitEdges(Frame &frame, ExecutionContext &context) { @@ -2931,19 +3036,245 @@ class DistributedExpandCursor : public Cursor { } } + void InitEdgesMultiple() { + // This function won't work if any vertex id is duplicated in the input, because: + // 1. vertex_id_to_result_row is not a multimap + // 2. if self_.common_.existing_node is true, then we erase edges that might be necessary for the input vertex on a + // later frame + const auto &frame = (*own_frames_it_); + const auto &vertex_value = frame[self_.input_symbol_]; + + if (vertex_value.IsNull()) { + ResetMultiFrameEdgeIts(); + return; + } + + ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); + const auto &vertex = vertex_value.ValueVertex(); + + current_vertex_ = &vertex; + + auto &ref_counted_result_row = vertex_id_to_result_row.at(vertex.Id()); + auto &result_row = *ref_counted_result_row.result_row; + + current_in_edge_mf_it_ = result_row.in_edges_with_specific_properties.begin(); + in_edges_end_it_ = result_row.in_edges_with_specific_properties.end(); + AdvanceUntilSuitableEdge(current_in_edge_mf_it_, in_edges_end_it_); + current_out_edge_mf_it_ = result_row.out_edges_with_specific_properties.begin(); + out_edges_end_it_ = result_row.out_edges_with_specific_properties.end(); + AdvanceUntilSuitableEdge(current_out_edge_mf_it_, out_edges_end_it_); + + if (ref_counted_result_row.ref_count == 1) { + vertex_id_to_result_row.erase(vertex.Id()); + } else { + ref_counted_result_row.ref_count--; + } + } + + bool PullInputFrames(ExecutionContext &context) { + const auto pulled_any = input_cursor_->PullMultiple(*own_multi_frame_, context); + // These needs to be updated regardless of the result of the pull, otherwise the consumer and iterator might + // get corrupted because of the operations done on our MultiFrame. + own_frames_consumer_ = own_multi_frame_->GetValidFramesConsumer(); + own_frames_it_ = own_frames_consumer_->begin(); + if (!pulled_any) { + return false; + } + + vertex_id_to_result_row.clear(); + + msgs::ExpandOneRequest request; + request.direction = DirectionToMsgsDirection(self_.common_.direction); + // to not fetch any properties of the edges + request.edge_properties.emplace(); + for (const auto &frame : own_multi_frame_->GetValidFramesReader()) { + const auto &vertex_value = frame[self_.input_symbol_]; + + // Null check due to possible failed optional match. + MG_ASSERT(!vertex_value.IsNull()); + + ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); + const auto &vertex = vertex_value.ValueVertex(); + auto [it, inserted] = vertex_id_to_result_row.try_emplace(vertex.Id(), RefCountedResultRow{1U, nullptr}); + + if (inserted) { + request.src_vertices.push_back(vertex.Id()); + } else { + it->second.ref_count++; + } + } + + result_rows_ = std::invoke([&context, &request]() mutable { + SCOPED_REQUEST_WAIT_PROFILE; + return context.request_router->ExpandOne(std::move(request)); + }); + for (auto &row : result_rows_) { + vertex_id_to_result_row[row.src_vertex.id].result_row = &row; + } + + return true; + } + + bool PullMultiple(MultiFrame &output_multi_frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("DistributedExpandMF"); + EnsureOwnMultiFrameIsGood(output_multi_frame); + // A helper function for expanding a node from an edge. + + auto output_frames_populator = output_multi_frame.GetInvalidFramesPopulator(); + auto populated_any = false; + + while (true) { + switch (state_) { + case State::PullInputAndEdges: { + if (!PullInputFrames(context)) { + state_ = State::Exhausted; + return populated_any; + } + state_ = State::InitInOutEdgesIt; + break; + } + case State::InitInOutEdgesIt: { + if (own_frames_it_ == own_frames_consumer_->end()) { + state_ = State::PullInputAndEdges; + } else { + InitEdgesMultiple(); + state_ = State::PopulateOutput; + } + break; + } + case State::PopulateOutput: { + if (!output_multi_frame.HasInvalidFrame()) { + return populated_any; + } + if (current_in_edge_mf_it_ == in_edges_end_it_ && current_out_edge_mf_it_ == out_edges_end_it_) { + own_frames_it_->MakeInvalid(); + ++own_frames_it_; + state_ = State::InitInOutEdgesIt; + continue; + } + auto populate_edges = [this, &context, &output_frames_populator, &populated_any]( + const EdgeAtom::Direction direction, EdgesIterator ¤t, + const EdgesIterator &end) { + for (auto output_frame_it = output_frames_populator.begin(); + output_frame_it != output_frames_populator.end() && current != end; ++output_frame_it) { + auto &edge = *current; + auto &output_frame = *output_frame_it; + output_frame = *own_frames_it_; + switch (direction) { + case EdgeAtom::Direction::IN: { + output_frame[self_.common_.edge_symbol] = + EdgeAccessor{msgs::Edge{edge.other_end, current_vertex_->Id(), {}, {edge.gid}, edge.type}, + context.request_router}; + break; + } + case EdgeAtom::Direction::OUT: { + output_frame[self_.common_.edge_symbol] = + EdgeAccessor{msgs::Edge{current_vertex_->Id(), edge.other_end, {}, {edge.gid}, edge.type}, + context.request_router}; + break; + } + case EdgeAtom::Direction::BOTH: { + LOG_FATAL("Must indicate exact expansion direction here"); + } + }; + PullDstVertex(output_frame, context, direction); + ++current; + AdvanceUntilSuitableEdge(current, end); + populated_any = true; + } + }; + populate_edges(EdgeAtom::Direction::IN, current_in_edge_mf_it_, in_edges_end_it_); + populate_edges(EdgeAtom::Direction::OUT, current_out_edge_mf_it_, out_edges_end_it_); + break; + } + case State::Exhausted: { + return populated_any; + } + } + } + return populated_any; + } + + void EnsureOwnMultiFrameIsGood(MultiFrame &output_multi_frame) { + if (!own_multi_frame_.has_value()) { + own_multi_frame_.emplace(MultiFrame(output_multi_frame.GetFirstFrame().elems().size(), + kNumberOfFramesInMultiframe, output_multi_frame.GetMemoryResource())); + own_frames_consumer_.emplace(own_multi_frame_->GetValidFramesConsumer()); + own_frames_it_ = own_frames_consumer_->begin(); + } + MG_ASSERT(output_multi_frame.GetFirstFrame().elems().size() == own_multi_frame_->GetFirstFrame().elems().size()); + } + void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); + vertex_id_to_result_row.clear(); + result_rows_.clear(); + own_frames_it_ = ValidFramesConsumer::Iterator{}; + own_frames_consumer_.reset(); + own_multi_frame_->MakeAllFramesInvalid(); + state_ = State::PullInputAndEdges; + current_in_edges_.clear(); current_out_edges_.clear(); current_in_edge_it_ = current_in_edges_.end(); current_out_edge_it_ = current_out_edges_.end(); + + ResetMultiFrameEdgeIts(); } private: + enum class State { PullInputAndEdges, InitInOutEdgesIt, PopulateOutput, Exhausted }; + + struct RefCountedResultRow { + size_t ref_count{0U}; + msgs::ExpandOneResultRow *result_row{nullptr}; + }; + + using EdgeWithSpecificProperties = msgs::ExpandOneResultRow::EdgeWithSpecificProperties; + using EdgesVector = std::vector<EdgeWithSpecificProperties>; + using EdgesIterator = EdgesVector::iterator; + + void ResetMultiFrameEdgeIts() { + in_edges_end_it_ = EdgesIterator{}; + current_in_edge_mf_it_ = in_edges_end_it_; + out_edges_end_it_ = EdgesIterator{}; + current_out_edge_mf_it_ = out_edges_end_it_; + } + + void AdvanceUntilSuitableEdge(EdgesIterator ¤t, const EdgesIterator &end) { + if (!self_.common_.existing_node) { + return; + } + + const auto &existing_node_value = (*own_frames_it_)[self_.common_.node_symbol]; + if (existing_node_value.IsNull()) { + current = end; + return; + } + const auto &existing_node = existing_node_value.ValueVertex(); + current = std::find_if(current, end, [&existing_node](const EdgeWithSpecificProperties &edge) { + return edge.other_end == existing_node.Id(); + }); + } + const Expand &self_; const UniqueCursorPtr input_cursor_; + EdgesIterator current_in_edge_mf_it_; + EdgesIterator in_edges_end_it_; + EdgesIterator current_out_edge_mf_it_; + EdgesIterator out_edges_end_it_; + State state_{State::PullInputAndEdges}; + std::optional<MultiFrame> own_multi_frame_; + std::optional<ValidFramesConsumer> own_frames_consumer_; + const VertexAccessor *current_vertex_{nullptr}; + ValidFramesConsumer::Iterator own_frames_it_; + std::vector<msgs::ExpandOneResultRow> result_rows_; + // This won't work if any vertex id is duplicated in the input + std::unordered_map<msgs::VertexId, RefCountedResultRow> vertex_id_to_result_row; + + // TODO(antaljanosbenjamin): Remove when single frame approach is removed std::vector<EdgeAccessor> current_in_edges_; std::vector<EdgeAccessor> current_out_edges_; std::vector<EdgeAccessor>::iterator current_in_edge_it_; diff --git a/src/query/v2/plan/operator.lcp b/src/query/v2/plan/operator.lcp index efa0d5df0..4f34cc061 100644 --- a/src/query/v2/plan/operator.lcp +++ b/src/query/v2/plan/operator.lcp @@ -72,7 +72,21 @@ class Cursor { /// @throws QueryRuntimeException if something went wrong with execution virtual bool Pull(Frame &, ExecutionContext &) = 0; - virtual void PullMultiple(MultiFrame &, ExecutionContext &) { LOG_FATAL("PullMultipleIsNotImplemented"); } + /// Run an iteration of a @c LogicalOperator with MultiFrame. + /// + /// Since operators may be chained, the iteration may pull results from + /// multiple operators. + /// + /// @param MultiFrame May be read from or written to while performing the + /// iteration. + /// @param ExecutionContext Used to get the position of symbols in frame and + /// other information. + /// @return True if the operator was able to populate at least one Frame on the MultiFrame, + /// thus if an operator returns true, that means there is at least one valid Frame in the + /// MultiFrame. + /// + /// @throws QueryRuntimeException if something went wrong with execution + virtual bool PullMultiple(MultiFrame &, ExecutionContext &) {MG_ASSERT(false, "PullMultipleIsNotImplemented"); return false; } /// Resets the Cursor to its initial state. virtual void Reset() = 0; @@ -113,7 +127,7 @@ class ScanAllByLabel; class ScanAllByLabelPropertyRange; class ScanAllByLabelPropertyValue; class ScanAllByLabelProperty; -class ScanAllById; +class ScanByPrimaryKey; class Expand; class ExpandVariable; class ConstructNamedPath; @@ -144,7 +158,7 @@ class Foreach; using LogicalOperatorCompositeVisitor = utils::CompositeVisitor< Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange, ScanAllByLabelPropertyValue, - ScanAllByLabelProperty, ScanAllById, + ScanAllByLabelProperty, ScanByPrimaryKey, Expand, ExpandVariable, ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit, OrderBy, Merge, @@ -335,7 +349,7 @@ and false on every following Pull.") class OnceCursor : public Cursor { public: OnceCursor() {} - void PullMultiple(MultiFrame &, ExecutionContext &) override; + bool PullMultiple(MultiFrame &, ExecutionContext &) override; bool Pull(Frame &, ExecutionContext &) override; void Shutdown() override; void Reset() override; @@ -845,19 +859,21 @@ given label and property. (:serialize (:slk)) (:clone)) - - -(lcp:define-class scan-all-by-id (scan-all) - ((expression "Expression *" :scope :public +(lcp:define-class scan-by-primary-key (scan-all) + ((label "::storage::v3::LabelId" :scope :public) + (primary-key "std::vector<Expression*>" :scope :public) + (expression "Expression *" :scope :public :slk-save #'slk-save-ast-pointer :slk-load (slk-load-ast-pointer "Expression"))) (:documentation - "ScanAll producing a single node with ID equal to evaluated expression") + "ScanAll producing a single node with specified by the label and primary key") (:public #>cpp - ScanAllById() {} - ScanAllById(const std::shared_ptr<LogicalOperator> &input, - Symbol output_symbol, Expression *expression, + ScanByPrimaryKey() {} + ScanByPrimaryKey(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, + storage::v3::LabelId label, + std::vector<query::v2::Expression*> primary_key, storage::v3::View view = storage::v3::View::OLD); bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; @@ -1160,6 +1176,7 @@ a boolean value.") public: FilterCursor(const Filter &, utils::MemoryResource *); bool Pull(Frame &, ExecutionContext &) override; + bool PullMultiple(MultiFrame &, ExecutionContext &) override; void Shutdown() override; void Reset() override; @@ -1211,7 +1228,7 @@ RETURN clause) the Produce's pull succeeds exactly once.") public: ProduceCursor(const Produce &, utils::MemoryResource *); bool Pull(Frame &, ExecutionContext &) override; - void PullMultiple(MultiFrame &, ExecutionContext &) override; + bool PullMultiple(MultiFrame &, ExecutionContext &) override; void Shutdown() override; void Reset() override; @@ -1259,6 +1276,7 @@ Has a flag for using DETACH DELETE when deleting vertices.") public: DeleteCursor(const Delete &, utils::MemoryResource *); bool Pull(Frame &, ExecutionContext &) override; + bool PullMultiple(MultiFrame &, ExecutionContext &) override; void Shutdown() override; void Reset() override; diff --git a/src/query/v2/plan/pretty_print.cpp b/src/query/v2/plan/pretty_print.cpp index bc30a3890..7eb1dd5a9 100644 --- a/src/query/v2/plan/pretty_print.cpp +++ b/src/query/v2/plan/pretty_print.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -86,10 +86,10 @@ bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelProperty &op) { return true; } -bool PlanPrinter::PreVisit(ScanAllById &op) { +bool PlanPrinter::PreVisit(query::v2::plan::ScanByPrimaryKey &op) { WithPrintLn([&](auto &out) { - out << "* ScanAllById" - << " (" << op.output_symbol_.name() << ")"; + out << "* ScanByPrimaryKey" + << " (" << op.output_symbol_.name() << " :" << request_router_->LabelToName(op.label_) << ")"; }); return true; } @@ -487,12 +487,15 @@ bool PlanToJsonVisitor::PreVisit(ScanAllByLabelProperty &op) { return false; } -bool PlanToJsonVisitor::PreVisit(ScanAllById &op) { +bool PlanToJsonVisitor::PreVisit(ScanByPrimaryKey &op) { json self; - self["name"] = "ScanAllById"; + self["name"] = "ScanByPrimaryKey"; + self["label"] = ToJson(op.label_, *request_router_); self["output_symbol"] = ToJson(op.output_symbol_); + op.input_->Accept(*this); self["input"] = PopOutput(); + output_ = std::move(self); return false; } diff --git a/src/query/v2/plan/pretty_print.hpp b/src/query/v2/plan/pretty_print.hpp index 4094d7c81..31ff17b18 100644 --- a/src/query/v2/plan/pretty_print.hpp +++ b/src/query/v2/plan/pretty_print.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -67,7 +67,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelPropertyValue &) override; bool PreVisit(ScanAllByLabelPropertyRange &) override; bool PreVisit(ScanAllByLabelProperty &) override; - bool PreVisit(ScanAllById &) override; + bool PreVisit(ScanByPrimaryKey & /*unused*/) override; bool PreVisit(Expand &) override; bool PreVisit(ExpandVariable &) override; @@ -194,7 +194,7 @@ class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelPropertyRange &) override; bool PreVisit(ScanAllByLabelPropertyValue &) override; bool PreVisit(ScanAllByLabelProperty &) override; - bool PreVisit(ScanAllById &) override; + bool PreVisit(ScanByPrimaryKey & /*unused*/) override; bool PreVisit(Produce &) override; bool PreVisit(Accumulate &) override; diff --git a/src/query/v2/plan/read_write_type_checker.cpp b/src/query/v2/plan/read_write_type_checker.cpp index 6cc38cedf..28c1e8d0a 100644 --- a/src/query/v2/plan/read_write_type_checker.cpp +++ b/src/query/v2/plan/read_write_type_checker.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -11,10 +11,12 @@ #include "query/v2/plan/read_write_type_checker.hpp" -#define PRE_VISIT(TOp, RWType, continue_visiting) \ - bool ReadWriteTypeChecker::PreVisit(TOp &op) { \ - UpdateType(RWType); \ - return continue_visiting; \ +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define PRE_VISIT(TOp, RWType, continue_visiting) \ + /*NOLINTNEXTLINE(bugprone-macro-parentheses)*/ \ + bool ReadWriteTypeChecker::PreVisit(TOp & /*op*/) { \ + UpdateType(RWType); \ + return continue_visiting; \ } namespace memgraph::query::v2::plan { @@ -35,7 +37,7 @@ PRE_VISIT(ScanAllByLabel, RWType::R, true) PRE_VISIT(ScanAllByLabelPropertyRange, RWType::R, true) PRE_VISIT(ScanAllByLabelPropertyValue, RWType::R, true) PRE_VISIT(ScanAllByLabelProperty, RWType::R, true) -PRE_VISIT(ScanAllById, RWType::R, true) +PRE_VISIT(ScanByPrimaryKey, RWType::R, true) PRE_VISIT(Expand, RWType::R, true) PRE_VISIT(ExpandVariable, RWType::R, true) diff --git a/src/query/v2/plan/read_write_type_checker.hpp b/src/query/v2/plan/read_write_type_checker.hpp index a3c2f1a46..626cf7af3 100644 --- a/src/query/v2/plan/read_write_type_checker.hpp +++ b/src/query/v2/plan/read_write_type_checker.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -59,7 +59,7 @@ class ReadWriteTypeChecker : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelPropertyValue &) override; bool PreVisit(ScanAllByLabelPropertyRange &) override; bool PreVisit(ScanAllByLabelProperty &) override; - bool PreVisit(ScanAllById &) override; + bool PreVisit(ScanByPrimaryKey & /*unused*/) override; bool PreVisit(Expand &) override; bool PreVisit(ExpandVariable &) override; diff --git a/src/query/v2/plan/rewrite/index_lookup.hpp b/src/query/v2/plan/rewrite/index_lookup.hpp index 57ddba54e..17996d952 100644 --- a/src/query/v2/plan/rewrite/index_lookup.hpp +++ b/src/query/v2/plan/rewrite/index_lookup.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -25,8 +25,10 @@ #include <gflags/gflags.h> +#include "query/v2/frontend/ast/ast.hpp" #include "query/v2/plan/operator.hpp" #include "query/v2/plan/preprocess.hpp" +#include "storage/v3/id_types.hpp" DECLARE_int64(query_vertex_count_to_expand_existing); @@ -271,11 +273,12 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { return true; } - bool PreVisit(ScanAllById &op) override { + bool PreVisit(ScanByPrimaryKey &op) override { prev_ops_.push_back(&op); return true; } - bool PostVisit(ScanAllById &) override { + + bool PostVisit(ScanByPrimaryKey & /*unused*/) override { prev_ops_.pop_back(); return true; } @@ -487,6 +490,12 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { storage::v3::PropertyId GetProperty(PropertyIx prop) { return db_->NameToProperty(prop.name); } + void EraseLabelFilters(const memgraph::query::v2::Symbol &node_symbol, memgraph::query::v2::LabelIx prim_label) { + std::vector<query::v2::Expression *> removed_expressions; + filters_.EraseLabelFilter(node_symbol, prim_label, &removed_expressions); + filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end()); + } + std::optional<LabelIx> FindBestLabelIndex(const std::unordered_set<LabelIx> &labels) { MG_ASSERT(!labels.empty(), "Trying to find the best label without any labels."); std::optional<LabelIx> best_label; @@ -559,31 +568,81 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { const auto &view = scan.view_; const auto &modified_symbols = scan.ModifiedSymbols(*symbol_table_); std::unordered_set<Symbol> bound_symbols(modified_symbols.begin(), modified_symbols.end()); - auto are_bound = [&bound_symbols](const auto &used_symbols) { - for (const auto &used_symbol : used_symbols) { - if (!utils::Contains(bound_symbols, used_symbol)) { - return false; - } - } - return true; - }; - // First, try to see if we can find a vertex by ID. - if (!max_vertex_count || *max_vertex_count >= 1) { - for (const auto &filter : filters_.IdFilters(node_symbol)) { - if (filter.id_filter->is_symbol_in_value_ || !are_bound(filter.used_symbols)) continue; - auto *value = filter.id_filter->value_; - filter_exprs_for_removal_.insert(filter.expression); - filters_.EraseFilter(filter); - return std::make_unique<ScanAllById>(input, node_symbol, value, view); - } - } - // Now try to see if we can use label+property index. If not, try to use - // just the label index. + + // Try to see if we can use label + primary-key or label + property index. + // If not, try to use just the label index. const auto labels = filters_.FilteredLabels(node_symbol); if (labels.empty()) { // Without labels, we cannot generate any indexed ScanAll. return nullptr; } + + // First, try to see if we can find a vertex based on the possibly + // supplied primary key. + auto property_filters = filters_.PropertyFilters(node_symbol); + query::v2::LabelIx prim_label; + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> primary_key; + + auto extract_primary_key = [this](storage::v3::LabelId label, + std::vector<query::v2::plan::FilterInfo> property_filters) + -> std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> { + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> pk_temp; + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> pk; + std::vector<memgraph::storage::v3::SchemaProperty> schema = db_->GetSchemaForLabel(label); + + std::vector<storage::v3::PropertyId> schema_properties; + schema_properties.reserve(schema.size()); + + std::transform(schema.begin(), schema.end(), std::back_inserter(schema_properties), + [](const auto &schema_elem) { return schema_elem.property_id; }); + + for (const auto &property_filter : property_filters) { + const auto &property_id = db_->NameToProperty(property_filter.property_filter->property_.name); + if (std::find(schema_properties.begin(), schema_properties.end(), property_id) != schema_properties.end()) { + pk_temp.emplace_back(std::make_pair(property_filter.expression, property_filter)); + } + } + + // Make sure pk is in the same order as schema_properties. + for (const auto &schema_prop : schema_properties) { + for (auto &pk_temp_prop : pk_temp) { + const auto &property_id = db_->NameToProperty(pk_temp_prop.second.property_filter->property_.name); + if (schema_prop == property_id) { + pk.push_back(pk_temp_prop); + } + } + } + MG_ASSERT(pk.size() == pk_temp.size(), + "The two vectors should represent the same primary key with a possibly different order of contained " + "elements."); + + return pk.size() == schema_properties.size() + ? pk + : std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>>{}; + }; + + if (!property_filters.empty()) { + for (const auto &label : labels) { + if (db_->PrimaryLabelExists(GetLabel(label))) { + prim_label = label; + primary_key = extract_primary_key(GetLabel(prim_label), property_filters); + break; + } + } + if (!primary_key.empty()) { + // Mark the expressions so they won't be used for an additional, unnecessary filter. + for (const auto &primary_property : primary_key) { + filter_exprs_for_removal_.insert(primary_property.first); + filters_.EraseFilter(primary_property.second); + } + EraseLabelFilters(node_symbol, prim_label); + std::vector<query::v2::Expression *> pk_expressions; + std::transform(primary_key.begin(), primary_key.end(), std::back_inserter(pk_expressions), + [](const auto &exp) { return exp.second.property_filter->value_; }); + return std::make_unique<ScanByPrimaryKey>(input, node_symbol, GetLabel(prim_label), pk_expressions); + } + } + auto found_index = FindBestLabelPropertyIndex(node_symbol, bound_symbols); if (found_index && // Use label+property index if we satisfy max_vertex_count. @@ -597,9 +656,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { filter_exprs_for_removal_.insert(found_index->filter.expression); } filters_.EraseFilter(found_index->filter); - std::vector<Expression *> removed_expressions; - filters_.EraseLabelFilter(node_symbol, found_index->label, &removed_expressions); - filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end()); + EraseLabelFilters(node_symbol, found_index->label); if (prop_filter.lower_bound_ || prop_filter.upper_bound_) { return std::make_unique<ScanAllByLabelPropertyRange>( input, node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_), diff --git a/src/query/v2/plan/vertex_count_cache.hpp b/src/query/v2/plan/vertex_count_cache.hpp index e68ce1220..ae6cdad31 100644 --- a/src/query/v2/plan/vertex_count_cache.hpp +++ b/src/query/v2/plan/vertex_count_cache.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -12,14 +12,17 @@ /// @file #pragma once +#include <iterator> #include <optional> #include "query/v2/bindings/typed_value.hpp" +#include "query/v2/plan/preprocess.hpp" #include "query/v2/request_router.hpp" #include "storage/v3/conversions.hpp" #include "storage/v3/id_types.hpp" #include "storage/v3/property_value.hpp" #include "utils/bound.hpp" +#include "utils/exceptions.hpp" #include "utils/fnv.hpp" namespace memgraph::query::v2::plan { @@ -52,11 +55,16 @@ class VertexCountCache { return 1; } - // For now return true if label is primary label - bool LabelIndexExists(storage::v3::LabelId label) { return request_router_->IsPrimaryLabel(label); } + bool LabelIndexExists(storage::v3::LabelId label) { return PrimaryLabelExists(label); } + + bool PrimaryLabelExists(storage::v3::LabelId label) { return request_router_->IsPrimaryLabel(label); } bool LabelPropertyIndexExists(storage::v3::LabelId /*label*/, storage::v3::PropertyId /*property*/) { return false; } + const std::vector<memgraph::storage::v3::SchemaProperty> &GetSchemaForLabel(storage::v3::LabelId label) { + return request_router_->GetSchemaForLabel(label); + } + RequestRouterInterface *request_router_; }; diff --git a/src/query/v2/request_router.hpp b/src/query/v2/request_router.hpp index 3dd2f164b..bf8c93566 100644 --- a/src/query/v2/request_router.hpp +++ b/src/query/v2/request_router.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -24,14 +24,18 @@ #include <stdexcept> #include <thread> #include <unordered_map> +#include <variant> #include <vector> +#include <boost/uuid/uuid.hpp> + #include "coordinator/coordinator.hpp" #include "coordinator/coordinator_client.hpp" #include "coordinator/coordinator_rsm.hpp" #include "coordinator/shard_map.hpp" #include "io/address.hpp" #include "io/errors.hpp" +#include "io/local_transport/local_transport.hpp" #include "io/notifier.hpp" #include "io/rsm/raft.hpp" #include "io/rsm/rsm_client.hpp" @@ -114,6 +118,10 @@ class RequestRouterInterface { virtual std::optional<storage::v3::LabelId> MaybeNameToLabel(const std::string &name) const = 0; virtual bool IsPrimaryLabel(storage::v3::LabelId label) const = 0; virtual bool IsPrimaryKey(storage::v3::LabelId primary_label, storage::v3::PropertyId property) const = 0; + + virtual std::optional<std::pair<uint64_t, uint64_t>> AllocateInitialEdgeIds(io::Address coordinator_address) = 0; + virtual void InstallSimulatorTicker(std::function<bool()> tick_simulator) = 0; + virtual const std::vector<coordinator::SchemaProperty> &GetSchemaForLabel(storage::v3::LabelId label) const = 0; }; // TODO(kostasrim)rename this class template @@ -138,7 +146,7 @@ class RequestRouter : public RequestRouterInterface { ~RequestRouter() override {} - void InstallSimulatorTicker(std::function<bool()> tick_simulator) { + void InstallSimulatorTicker(std::function<bool()> tick_simulator) override { notifier_.InstallSimulatorTicker(tick_simulator); } @@ -232,12 +240,17 @@ class RequestRouter : public RequestRouterInterface { }) != schema_it->second.end(); } + const std::vector<coordinator::SchemaProperty> &GetSchemaForLabel(storage::v3::LabelId label) const override { + return shards_map_.schemas.at(label); + } + bool IsPrimaryLabel(storage::v3::LabelId label) const override { return shards_map_.label_spaces.contains(label); } // TODO(kostasrim) Simplify return result std::vector<VertexAccessor> ScanVertices(std::optional<std::string> label) override { // create requests - std::vector<ShardRequestState<msgs::ScanVerticesRequest>> requests_to_be_sent = RequestsForScanVertices(label); + auto requests_to_be_sent = RequestsForScanVertices(label); + spdlog::trace("created {} ScanVertices requests", requests_to_be_sent.size()); // begin all requests in parallel @@ -299,7 +312,8 @@ class RequestRouter : public RequestRouterInterface { MG_ASSERT(!new_edges.empty()); // create requests - std::vector<ShardRequestState<msgs::CreateExpandRequest>> requests_to_be_sent = RequestsForCreateExpand(new_edges); + std::vector<ShardRequestState<msgs::CreateExpandRequest>> requests_to_be_sent = + RequestsForCreateExpand(std::move(new_edges)); // begin all requests in parallel RunningRequests<msgs::CreateExpandRequest> running_requests = {}; @@ -359,6 +373,7 @@ class RequestRouter : public RequestRouterInterface { } std::vector<msgs::GetPropertiesResultRow> GetProperties(msgs::GetPropertiesRequest requests) override { + requests.transaction_id = transaction_id_; // create requests std::vector<ShardRequestState<msgs::GetPropertiesRequest>> requests_to_be_sent = RequestsForGetProperties(std::move(requests)); @@ -430,7 +445,7 @@ class RequestRouter : public RequestRouterInterface { } std::vector<ShardRequestState<msgs::CreateExpandRequest>> RequestsForCreateExpand( - const std::vector<msgs::NewExpand> &new_expands) { + std::vector<msgs::NewExpand> new_expands) { std::map<ShardMetadata, msgs::CreateExpandRequest> per_shard_request_table; auto ensure_shard_exists_in_table = [&per_shard_request_table, transaction_id = transaction_id_](const ShardMetadata &shard) { @@ -707,6 +722,23 @@ class RequestRouter : public RequestRouterInterface { edge_types_.StoreMapping(std::move(id_to_name)); } + std::optional<std::pair<uint64_t, uint64_t>> AllocateInitialEdgeIds(io::Address coordinator_address) override { + coordinator::CoordinatorWriteRequests requests{coordinator::AllocateEdgeIdBatchRequest{.batch_size = 1000000}}; + + io::rsm::WriteRequest<coordinator::CoordinatorWriteRequests> ww; + ww.operation = requests; + auto resp = + io_.template Request<io::rsm::WriteRequest<coordinator::CoordinatorWriteRequests>, + io::rsm::WriteResponse<coordinator::CoordinatorWriteResponses>>(coordinator_address, ww) + .Wait(); + if (resp.HasValue()) { + const auto alloc_edge_id_reps = + std::get<coordinator::AllocateEdgeIdBatchResponse>(resp.GetValue().message.write_return); + return std::make_pair(alloc_edge_id_reps.low, alloc_edge_id_reps.high); + } + return {}; + } + ShardMap shards_map_; storage::v3::NameIdMapper properties_; storage::v3::NameIdMapper edge_types_; @@ -718,4 +750,66 @@ class RequestRouter : public RequestRouterInterface { io::Notifier notifier_ = {}; // TODO(kostasrim) Add batch prefetching }; + +class RequestRouterFactory { + public: + RequestRouterFactory() = default; + RequestRouterFactory(const RequestRouterFactory &) = delete; + RequestRouterFactory &operator=(const RequestRouterFactory &) = delete; + RequestRouterFactory(RequestRouterFactory &&) = delete; + RequestRouterFactory &operator=(RequestRouterFactory &&) = delete; + + virtual ~RequestRouterFactory() = default; + + virtual std::unique_ptr<RequestRouterInterface> CreateRequestRouter( + const coordinator::Address &coordinator_address) const = 0; +}; + +class LocalRequestRouterFactory : public RequestRouterFactory { + using LocalTransportIo = io::Io<io::local_transport::LocalTransport>; + LocalTransportIo &io_; + + public: + explicit LocalRequestRouterFactory(LocalTransportIo &io) : io_(io) {} + + std::unique_ptr<RequestRouterInterface> CreateRequestRouter( + const coordinator::Address &coordinator_address) const override { + using TransportType = io::local_transport::LocalTransport; + + auto query_io = io_.ForkLocal(boost::uuids::uuid{boost::uuids::random_generator()()}); + auto local_transport_io = io_.ForkLocal(boost::uuids::uuid{boost::uuids::random_generator()()}); + + return std::make_unique<RequestRouter<TransportType>>( + coordinator::CoordinatorClient<TransportType>(query_io, coordinator_address, {coordinator_address}), + std::move(local_transport_io)); + } +}; + +class SimulatedRequestRouterFactory : public RequestRouterFactory { + io::simulator::Simulator *simulator_; + + public: + explicit SimulatedRequestRouterFactory(io::simulator::Simulator &simulator) : simulator_(&simulator) {} + + std::unique_ptr<RequestRouterInterface> CreateRequestRouter( + const coordinator::Address &coordinator_address) const override { + using TransportType = io::simulator::SimulatorTransport; + auto actual_transport_handle = simulator_->GetSimulatorHandle(); + + boost::uuids::uuid random_uuid; + io::Address unique_local_addr_query; + + // The simulated RR should not introduce stochastic behavior. + random_uuid = boost::uuids::uuid{3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + unique_local_addr_query = {.unique_id = boost::uuids::uuid{4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; + + auto io = simulator_->Register(unique_local_addr_query); + auto query_io = io.ForkLocal(random_uuid); + + return std::make_unique<RequestRouter<TransportType>>( + coordinator::CoordinatorClient<TransportType>(query_io, coordinator_address, {coordinator_address}), + std::move(io)); + } +}; + } // namespace memgraph::query::v2 diff --git a/src/query/v2/requests.hpp b/src/query/v2/requests.hpp index 9ff9a1bae..b2d7f9123 100644 --- a/src/query/v2/requests.hpp +++ b/src/query/v2/requests.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -25,6 +25,7 @@ #include "storage/v3/id_types.hpp" #include "storage/v3/property_value.hpp" #include "storage/v3/result.hpp" +#include "utils/fnv.hpp" namespace memgraph::msgs { @@ -36,6 +37,7 @@ struct Value; struct Label { LabelId id; friend bool operator==(const Label &lhs, const Label &rhs) { return lhs.id == rhs.id; } + friend bool operator==(const Label &lhs, const LabelId &rhs) { return lhs.id == rhs; } }; // TODO(kostasrim) update this with CompoundKey, same for the rest of the file. @@ -578,3 +580,48 @@ using WriteResponses = std::variant<CreateVerticesResponse, DeleteVerticesRespon CreateExpandResponse, DeleteEdgesResponse, UpdateEdgesResponse, CommitResponse>; } // namespace memgraph::msgs + +namespace std { + +template <> +struct hash<memgraph::msgs::Value>; + +template <> +struct hash<memgraph::msgs::VertexId> { + size_t operator()(const memgraph::msgs::VertexId &id) const { + using LabelId = memgraph::storage::v3::LabelId; + using Value = memgraph::msgs::Value; + return memgraph::utils::HashCombine<LabelId, std::vector<Value>, std::hash<LabelId>, + memgraph::utils::FnvCollection<std::vector<Value>, Value>>{}(id.first.id, + id.second); + } +}; + +template <> +struct hash<memgraph::msgs::Value> { + size_t operator()(const memgraph::msgs::Value &value) const { + using Type = memgraph::msgs::Value::Type; + switch (value.type) { + case Type::Null: + return std::hash<size_t>{}(0U); + case Type::Bool: + return std::hash<bool>{}(value.bool_v); + case Type::Int64: + return std::hash<int64_t>{}(value.int_v); + case Type::Double: + return std::hash<double>{}(value.double_v); + case Type::String: + return std::hash<std::string>{}(value.string_v); + case Type::List: + LOG_FATAL("Add hash for lists"); + case Type::Map: + LOG_FATAL("Add hash for maps"); + case Type::Vertex: + LOG_FATAL("Add hash for vertices"); + case Type::Edge: + LOG_FATAL("Add hash for edges"); + } + } +}; + +} // namespace std diff --git a/src/storage/v3/request_helper.cpp b/src/storage/v3/request_helper.cpp index 6b889fe16..f13c5a82e 100644 --- a/src/storage/v3/request_helper.cpp +++ b/src/storage/v3/request_helper.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -321,12 +321,15 @@ EdgeFiller InitializeEdgeFillerFunction(const msgs::ExpandOneRequest &req) { value_properties.insert(std::make_pair(prop_key, FromPropertyValueToValue(std::move(prop_val)))); } using EdgeWithAllProperties = msgs::ExpandOneResultRow::EdgeWithAllProperties; - EdgeWithAllProperties edges{ToMsgsVertexId(edge.From()), msgs::EdgeType{edge.EdgeType()}, edge.Gid().AsUint(), - std::move(value_properties)}; + if (is_in_edge) { - result_row.in_edges_with_all_properties.push_back(std::move(edges)); + result_row.in_edges_with_all_properties.push_back( + EdgeWithAllProperties{ToMsgsVertexId(edge.From()), msgs::EdgeType{edge.EdgeType()}, edge.Gid().AsUint(), + std::move(value_properties)}); } else { - result_row.out_edges_with_all_properties.push_back(std::move(edges)); + result_row.out_edges_with_all_properties.push_back( + EdgeWithAllProperties{ToMsgsVertexId(edge.To()), msgs::EdgeType{edge.EdgeType()}, edge.Gid().AsUint(), + std::move(value_properties)}); } return {}; }; @@ -346,12 +349,15 @@ EdgeFiller InitializeEdgeFillerFunction(const msgs::ExpandOneRequest &req) { value_properties.emplace_back(FromPropertyValueToValue(std::move(property_result.GetValue()))); } using EdgeWithSpecificProperties = msgs::ExpandOneResultRow::EdgeWithSpecificProperties; - EdgeWithSpecificProperties edges{ToMsgsVertexId(edge.From()), msgs::EdgeType{edge.EdgeType()}, - edge.Gid().AsUint(), std::move(value_properties)}; + if (is_in_edge) { - result_row.in_edges_with_specific_properties.push_back(std::move(edges)); + result_row.in_edges_with_specific_properties.push_back( + EdgeWithSpecificProperties{ToMsgsVertexId(edge.From()), msgs::EdgeType{edge.EdgeType()}, + edge.Gid().AsUint(), std::move(value_properties)}); } else { - result_row.out_edges_with_specific_properties.push_back(std::move(edges)); + result_row.out_edges_with_specific_properties.push_back( + EdgeWithSpecificProperties{ToMsgsVertexId(edge.To()), msgs::EdgeType{edge.EdgeType()}, edge.Gid().AsUint(), + std::move(value_properties)}); } return {}; }; diff --git a/src/storage/v3/shard_rsm.cpp b/src/storage/v3/shard_rsm.cpp index b919d217c..881796a70 100644 --- a/src/storage/v3/shard_rsm.cpp +++ b/src/storage/v3/shard_rsm.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -535,13 +535,17 @@ msgs::ReadResponses ShardRsm::HandleRead(msgs::GetPropertiesRequest &&req) { return result; }; - auto collect_props = [&req](const VertexAccessor &v_acc, - const std::optional<EdgeAccessor> &e_acc) -> ShardResult<std::map<PropertyId, Value>> { + auto collect_props = [this, &req]( + const VertexAccessor &v_acc, + const std::optional<EdgeAccessor> &e_acc) -> ShardResult<std::map<PropertyId, Value>> { if (!req.property_ids) { if (e_acc) { return CollectAllPropertiesFromAccessor(*e_acc, view); } - return CollectAllPropertiesFromAccessor(v_acc, view); + const auto *schema = shard_->GetSchema(shard_->PrimaryLabel()); + MG_ASSERT(schema); + + return CollectAllPropertiesFromAccessor(v_acc, view, *schema); } if (e_acc) { diff --git a/src/utils/event_counter.cpp b/src/utils/event_counter.cpp index 634b6eae2..2928027ee 100644 --- a/src/utils/event_counter.cpp +++ b/src/utils/event_counter.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -25,6 +25,7 @@ M(ScanAllByLabelPropertyValueOperator, "Number of times ScanAllByLabelPropertyValue operator was used.") \ M(ScanAllByLabelPropertyOperator, "Number of times ScanAllByLabelProperty operator was used.") \ M(ScanAllByIdOperator, "Number of times ScanAllById operator was used.") \ + M(ScanByPrimaryKeyOperator, "Number of times ScanByPrimaryKey operator was used.") \ M(ExpandOperator, "Number of times Expand operator was used.") \ M(ExpandVariableOperator, "Number of times ExpandVariable operator was used.") \ M(ConstructNamedPathOperator, "Number of times ConstructNamedPath operator was used.") \ diff --git a/tests/mgbench/client.cpp b/tests/mgbench/client.cpp index e4b63d477..000c199fa 100644 --- a/tests/mgbench/client.cpp +++ b/tests/mgbench/client.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/simulation/CMakeLists.txt b/tests/simulation/CMakeLists.txt index cd5fc0a4a..f3d4870f0 100644 --- a/tests/simulation/CMakeLists.txt +++ b/tests/simulation/CMakeLists.txt @@ -17,7 +17,7 @@ function(add_simulation_test test_cpp) # requires unique logical target names set_target_properties(${target_name} PROPERTIES OUTPUT_NAME ${exec_name}) - target_link_libraries(${target_name} mg-storage-v3 mg-communication mg-utils mg-io mg-io-simulator mg-coordinator mg-query-v2) + target_link_libraries(${target_name} mg-communication mg-utils mg-io mg-io-simulator mg-coordinator mg-query-v2 mg-storage-v3) target_link_libraries(${target_name} Boost::headers) target_link_libraries(${target_name} gtest gtest_main gmock rapidcheck rapidcheck_gtest) @@ -32,4 +32,5 @@ add_simulation_test(trial_query_storage/query_storage_test.cpp) add_simulation_test(sharded_map.cpp) add_simulation_test(shard_rsm.cpp) add_simulation_test(cluster_property_test.cpp) +add_simulation_test(cluster_property_test_cypher_queries.cpp) add_simulation_test(request_router.cpp) diff --git a/tests/simulation/cluster_property_test_cypher_queries.cpp b/tests/simulation/cluster_property_test_cypher_queries.cpp new file mode 100644 index 000000000..e35edc033 --- /dev/null +++ b/tests/simulation/cluster_property_test_cypher_queries.cpp @@ -0,0 +1,64 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +// This test serves as an example of a property-based model test. +// It generates a cluster configuration and a set of operations to +// apply against both the real system and a greatly simplified model. + +#include <chrono> + +#include <gtest/gtest.h> +#include <rapidcheck.h> +#include <rapidcheck/gtest.h> +#include <spdlog/cfg/env.h> + +#include "generated_operations.hpp" +#include "io/simulator/simulator_config.hpp" +#include "io/time.hpp" +#include "storage/v3/shard_manager.hpp" +#include "test_cluster.hpp" + +namespace memgraph::tests::simulation { + +using io::Duration; +using io::Time; +using io::simulator::SimulatorConfig; +using storage::v3::kMaximumCronInterval; + +RC_GTEST_PROP(RandomClusterConfig, HappyPath, (ClusterConfig cluster_config, NonEmptyOpVec ops, uint64_t rng_seed)) { + spdlog::cfg::load_env_levels(); + + SimulatorConfig sim_config{ + .drop_percent = 0, + .perform_timeouts = false, + .scramble_messages = true, + .rng_seed = rng_seed, + .start_time = Time::min(), + .abort_time = Time::max(), + }; + + std::vector<std::string> queries = {"CREATE (n:test_label{property_1: 0, property_2: 0});", "MATCH (n) RETURN n;"}; + + auto [sim_stats_1, latency_stats_1] = RunClusterSimulationWithQueries(sim_config, cluster_config, queries); + auto [sim_stats_2, latency_stats_2] = RunClusterSimulationWithQueries(sim_config, cluster_config, queries); + + if (latency_stats_1 != latency_stats_2) { + spdlog::error("simulator stats diverged across runs"); + spdlog::error("run 1 simulator stats: {}", sim_stats_1); + spdlog::error("run 2 simulator stats: {}", sim_stats_2); + spdlog::error("run 1 latency:\n{}", latency_stats_1.SummaryTable()); + spdlog::error("run 2 latency:\n{}", latency_stats_2.SummaryTable()); + RC_ASSERT(latency_stats_1 == latency_stats_2); + RC_ASSERT(sim_stats_1 == sim_stats_2); + } +} + +} // namespace memgraph::tests::simulation diff --git a/tests/simulation/request_router.cpp b/tests/simulation/request_router.cpp index 4248e7876..037674b66 100644 --- a/tests/simulation/request_router.cpp +++ b/tests/simulation/request_router.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/simulation/shard_rsm.cpp b/tests/simulation/shard_rsm.cpp index 768217945..5c35b822a 100644 --- a/tests/simulation/shard_rsm.cpp +++ b/tests/simulation/shard_rsm.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -1305,7 +1305,7 @@ void TestGetProperties(ShardClient &client) { MG_ASSERT(!result.error); MG_ASSERT(result.result_row.size() == 2); for (const auto &elem : result.result_row) { - MG_ASSERT(elem.props.size() == 3); + MG_ASSERT(elem.props.size() == 4); } } { diff --git a/tests/simulation/simulation_interpreter.hpp b/tests/simulation/simulation_interpreter.hpp new file mode 100644 index 000000000..e83980787 --- /dev/null +++ b/tests/simulation/simulation_interpreter.hpp @@ -0,0 +1,93 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "io/simulator/simulator_handle.hpp" +#include "machine_manager/machine_config.hpp" +#include "machine_manager/machine_manager.hpp" +#include "query/v2/config.hpp" +#include "query/v2/discard_value_stream.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/interpreter.hpp" +#include "query/v2/request_router.hpp" + +#include <string> +#include <vector> + +// TODO(gvolfing) +// -How to set up the entire raft cluster with the QE. Also provide abrstraction for that. +// -Pass an argument to the setup to determine, how many times the retry of a query should happen. + +namespace memgraph::io::simulator { + +class SimulatedInterpreter { + using ResultStream = query::v2::DiscardValueResultStream; + + public: + explicit SimulatedInterpreter(std::unique_ptr<query::v2::InterpreterContext> interpreter_context) + : interpreter_context_(std::move(interpreter_context)) { + interpreter_ = std::make_unique<memgraph::query::v2::Interpreter>(interpreter_context_.get()); + } + + SimulatedInterpreter(const SimulatedInterpreter &) = delete; + SimulatedInterpreter &operator=(const SimulatedInterpreter &) = delete; + SimulatedInterpreter(SimulatedInterpreter &&) = delete; + SimulatedInterpreter &operator=(SimulatedInterpreter &&) = delete; + ~SimulatedInterpreter() = default; + + void InstallSimulatorTicker(Simulator &simulator) { + interpreter_->InstallSimulatorTicker(simulator.GetSimulatorTickClosure()); + } + + std::vector<ResultStream> RunQueries(const std::vector<std::string> &queries) { + std::vector<ResultStream> results; + results.reserve(queries.size()); + + for (const auto &query : queries) { + results.emplace_back(RunQuery(query)); + } + return results; + } + + private: + ResultStream RunQuery(const std::string &query) { + ResultStream stream; + + std::map<std::string, memgraph::storage::v3::PropertyValue> params; + const std::string *username = nullptr; + + interpreter_->Prepare(query, params, username); + interpreter_->PullAll(&stream); + + return stream; + } + + std::unique_ptr<query::v2::InterpreterContext> interpreter_context_; + std::unique_ptr<query::v2::Interpreter> interpreter_; +}; + +SimulatedInterpreter SetUpInterpreter(Address coordinator_address, Simulator &simulator) { + auto rr_factory = std::make_unique<memgraph::query::v2::SimulatedRequestRouterFactory>(simulator); + + auto interpreter_context = std::make_unique<memgraph::query::v2::InterpreterContext>( + nullptr, + memgraph::query::v2::InterpreterConfig{.query = {.allow_load_csv = true}, + .execution_timeout_sec = 600, + .replication_replica_check_frequency = std::chrono::seconds(1), + .default_kafka_bootstrap_servers = "", + .default_pulsar_service_url = "", + .stream_transaction_conflict_retries = 30, + .stream_transaction_retry_interval = std::chrono::milliseconds(500)}, + std::filesystem::path("mg_data"), std::move(rr_factory), coordinator_address); + + return SimulatedInterpreter(std::move(interpreter_context)); +} + +} // namespace memgraph::io::simulator diff --git a/tests/simulation/test_cluster.hpp b/tests/simulation/test_cluster.hpp index 2e8bdf92f..f10e88e61 100644 --- a/tests/simulation/test_cluster.hpp +++ b/tests/simulation/test_cluster.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -36,6 +36,8 @@ #include "utils/print_helpers.hpp" #include "utils/variant_helpers.hpp" +#include "simulation_interpreter.hpp" + namespace memgraph::tests::simulation { using coordinator::Coordinator; @@ -279,4 +281,65 @@ std::pair<SimulatorStats, LatencyHistogramSummaries> RunClusterSimulation(const return std::make_pair(stats, histo); } +std::pair<SimulatorStats, LatencyHistogramSummaries> RunClusterSimulationWithQueries( + const SimulatorConfig &sim_config, const ClusterConfig &cluster_config, const std::vector<std::string> &queries) { + spdlog::info("========================== NEW SIMULATION =========================="); + + auto simulator = Simulator(sim_config); + + auto machine_1_addr = Address::TestAddress(1); + auto cli_addr = Address::TestAddress(2); + auto cli_addr_2 = Address::TestAddress(3); + + Io<SimulatorTransport> cli_io = simulator.Register(cli_addr); + Io<SimulatorTransport> cli_io_2 = simulator.Register(cli_addr_2); + + auto coordinator_addresses = std::vector{ + machine_1_addr, + }; + + ShardMap initialization_sm = TestShardMap(cluster_config.shards - 1, cluster_config.replication_factor); + + auto mm_1 = MkMm(simulator, coordinator_addresses, machine_1_addr, initialization_sm); + Address coordinator_address = mm_1.CoordinatorAddress(); + + auto mm_thread_1 = std::jthread(RunMachine, std::move(mm_1)); + simulator.IncrementServerCountAndWaitForQuiescentState(machine_1_addr); + + auto detach_on_error = DetachIfDropped{.handle = mm_thread_1}; + + // TODO(tyler) clarify addresses of coordinator etc... as it's a mess + + CoordinatorClient<SimulatorTransport> coordinator_client(cli_io, coordinator_address, {coordinator_address}); + WaitForShardsToInitialize(coordinator_client); + + auto simulated_interpreter = io::simulator::SetUpInterpreter(coordinator_address, simulator); + simulated_interpreter.InstallSimulatorTicker(simulator); + + auto query_results = simulated_interpreter.RunQueries(queries); + + // We have now completed our workload without failing any assertions, so we can + // disable detaching the worker thread, which will cause the mm_thread_1 jthread + // to be joined when this function returns. + detach_on_error.detach = false; + + simulator.ShutDown(); + + mm_thread_1.join(); + + SimulatorStats stats = simulator.Stats(); + + spdlog::info("total messages: {}", stats.total_messages); + spdlog::info("dropped messages: {}", stats.dropped_messages); + spdlog::info("timed out requests: {}", stats.timed_out_requests); + spdlog::info("total requests: {}", stats.total_requests); + spdlog::info("total responses: {}", stats.total_responses); + spdlog::info("simulator ticks: {}", stats.simulator_ticks); + + auto histo = cli_io_2.ResponseLatencies(); + + spdlog::info("========================== SUCCESS :) =========================="); + return std::make_pair(stats, histo); +} + } // namespace memgraph::tests::simulation diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index e7efb7473..ae747385b 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -93,6 +93,9 @@ target_link_libraries(${test_prefix}query_expression_evaluator mg-query) add_unit_test(query_plan.cpp) target_link_libraries(${test_prefix}query_plan mg-query) +add_unit_test(query_v2_plan.cpp) +target_link_libraries(${test_prefix}query_v2_plan mg-query-v2) + add_unit_test(query_plan_accumulate_aggregate.cpp) target_link_libraries(${test_prefix}query_plan_accumulate_aggregate mg-query) diff --git a/tests/unit/bfs_single_node.cpp b/tests/unit/bfs_single_node.cpp index 93002eef5..eab4f3973 100644 --- a/tests/unit/bfs_single_node.cpp +++ b/tests/unit/bfs_single_node.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -88,14 +88,14 @@ TEST_P(SingleNodeBfsTest, All) { std::unique_ptr<SingleNodeDb> SingleNodeBfsTest::db_{nullptr}; -INSTANTIATE_TEST_CASE_P(DirectionAndExpansionDepth, SingleNodeBfsTest, - testing::Combine(testing::Range(-1, kVertexCount), testing::Range(-1, kVertexCount), - testing::Values(EdgeAtom::Direction::OUT, EdgeAtom::Direction::IN, - EdgeAtom::Direction::BOTH), - testing::Values(std::vector<std::string>{}), testing::Bool(), - testing::Values(FilterLambdaType::NONE))); +INSTANTIATE_TEST_SUITE_P(DirectionAndExpansionDepth, SingleNodeBfsTest, + testing::Combine(testing::Range(-1, kVertexCount), testing::Range(-1, kVertexCount), + testing::Values(EdgeAtom::Direction::OUT, EdgeAtom::Direction::IN, + EdgeAtom::Direction::BOTH), + testing::Values(std::vector<std::string>{}), testing::Bool(), + testing::Values(FilterLambdaType::NONE))); -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( EdgeType, SingleNodeBfsTest, testing::Combine(testing::Values(-1), testing::Values(-1), testing::Values(EdgeAtom::Direction::OUT, EdgeAtom::Direction::IN, EdgeAtom::Direction::BOTH), @@ -103,11 +103,11 @@ INSTANTIATE_TEST_CASE_P( std::vector<std::string>{"b"}, std::vector<std::string>{"a", "b"}), testing::Bool(), testing::Values(FilterLambdaType::NONE))); -INSTANTIATE_TEST_CASE_P(FilterLambda, SingleNodeBfsTest, - testing::Combine(testing::Values(-1), testing::Values(-1), - testing::Values(EdgeAtom::Direction::OUT, EdgeAtom::Direction::IN, - EdgeAtom::Direction::BOTH), - testing::Values(std::vector<std::string>{}), testing::Bool(), - testing::Values(FilterLambdaType::NONE, FilterLambdaType::USE_FRAME, - FilterLambdaType::USE_FRAME_NULL, FilterLambdaType::USE_CTX, - FilterLambdaType::ERROR))); +INSTANTIATE_TEST_SUITE_P(FilterLambda, SingleNodeBfsTest, + testing::Combine(testing::Values(-1), testing::Values(-1), + testing::Values(EdgeAtom::Direction::OUT, EdgeAtom::Direction::IN, + EdgeAtom::Direction::BOTH), + testing::Values(std::vector<std::string>{}), testing::Bool(), + testing::Values(FilterLambdaType::NONE, FilterLambdaType::USE_FRAME, + FilterLambdaType::USE_FRAME_NULL, FilterLambdaType::USE_CTX, + FilterLambdaType::ERROR))); diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 40aeb0161..93f0653b4 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -294,7 +294,7 @@ std::shared_ptr<Base> gAstGeneratorTypes[] = { std::make_shared<CachedAstGenerator>(), }; -INSTANTIATE_TEST_CASE_P(AstGeneratorTypes, CypherMainVisitorTest, ::testing::ValuesIn(gAstGeneratorTypes)); +INSTANTIATE_TEST_SUITE_P(AstGeneratorTypes, CypherMainVisitorTest, ::testing::ValuesIn(gAstGeneratorTypes)); // NOTE: The above used to use *Typed Tests* functionality of gtest library. // Unfortunately, the compilation time of this test increased to full 2 minutes! @@ -308,7 +308,7 @@ INSTANTIATE_TEST_CASE_P(AstGeneratorTypes, CypherMainVisitorTest, ::testing::Val // ClonedAstGenerator, CachedAstGenerator> // AstGeneratorTypes; // -// TYPED_TEST_CASE(CypherMainVisitorTest, AstGeneratorTypes); +// TYPED_TEST_SUITE(CypherMainVisitorTest, AstGeneratorTypes); TEST_P(CypherMainVisitorTest, SyntaxException) { auto &ast_generator = *GetParam(); diff --git a/tests/unit/high_density_shard_create_scan.cpp b/tests/unit/high_density_shard_create_scan.cpp index 9fabf6ccc..4a90b98b2 100644 --- a/tests/unit/high_density_shard_create_scan.cpp +++ b/tests/unit/high_density_shard_create_scan.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/unit/machine_manager.cpp b/tests/unit/machine_manager.cpp index 74b7d3863..6cc6a3ff1 100644 --- a/tests/unit/machine_manager.cpp +++ b/tests/unit/machine_manager.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/unit/mock_helpers.hpp b/tests/unit/mock_helpers.hpp index 6c03889ba..15f264cac 100644 --- a/tests/unit/mock_helpers.hpp +++ b/tests/unit/mock_helpers.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -13,12 +13,13 @@ #include <gmock/gmock.h> #include <gtest/gtest.h> + #include "query/v2/common.hpp" #include "query/v2/context.hpp" #include "query/v2/plan/operator.hpp" #include "query/v2/request_router.hpp" -namespace memgraph::query::v2 { +namespace memgraph::query::v2::tests { class MockedRequestRouter : public RequestRouterInterface { public: MOCK_METHOD(std::vector<VertexAccessor>, ScanVertices, (std::optional<std::string> label)); @@ -41,6 +42,9 @@ class MockedRequestRouter : public RequestRouterInterface { MOCK_METHOD(std::optional<storage::v3::LabelId>, MaybeNameToLabel, (const std::string &), (const)); MOCK_METHOD(bool, IsPrimaryLabel, (storage::v3::LabelId), (const)); MOCK_METHOD(bool, IsPrimaryKey, (storage::v3::LabelId, storage::v3::PropertyId), (const)); + MOCK_METHOD((std::optional<std::pair<uint64_t, uint64_t>>), AllocateInitialEdgeIds, (io::Address)); + MOCK_METHOD(void, InstallSimulatorTicker, (std::function<bool()>)); + MOCK_METHOD(const std::vector<coordinator::SchemaProperty> &, GetSchemaForLabel, (storage::v3::LabelId), (const)); }; class MockedLogicalOperator : public plan::LogicalOperator { @@ -57,7 +61,7 @@ class MockedLogicalOperator : public plan::LogicalOperator { class MockedCursor : public plan::Cursor { public: MOCK_METHOD(bool, Pull, (Frame &, expr::ExecutionContext &)); - MOCK_METHOD(void, PullMultiple, (MultiFrame &, expr::ExecutionContext &)); + MOCK_METHOD(bool, PullMultiple, (MultiFrame &, expr::ExecutionContext &)); MOCK_METHOD(void, Reset, ()); MOCK_METHOD(void, Shutdown, ()); }; @@ -79,4 +83,4 @@ inline MockedLogicalOperator &BaseToMock(plan::LogicalOperator &op) { inline MockedCursor &BaseToMock(plan::Cursor &cursor) { return dynamic_cast<MockedCursor &>(cursor); } -} // namespace memgraph::query::v2 +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/pretty_print_ast_to_original_expression_test.cpp b/tests/unit/pretty_print_ast_to_original_expression_test.cpp index e5d77ae0e..a2144c8aa 100644 --- a/tests/unit/pretty_print_ast_to_original_expression_test.cpp +++ b/tests/unit/pretty_print_ast_to_original_expression_test.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -67,7 +67,7 @@ TEST_P(ExpressiontoStringTest, Example) { EXPECT_EQ(rewritten_expression, rewritten_expression2); } -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( PARAMETER, ExpressiontoStringTest, ::testing::Values( std::make_pair(std::string("2 / 1"), std::string("(2 / 1)")), diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 903cacf12..784df5e21 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -501,6 +501,8 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec storage.Create<memgraph::query::MapLiteral>( \ std::unordered_map<memgraph::query::PropertyIx, memgraph::query::Expression *>{__VA_ARGS__}) #define PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToProperty(property_name)) +#define PRIMARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToPrimaryProperty(property_name)) +#define SECONDARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToSecondaryProperty(property_name)) #define PROPERTY_LOOKUP(...) memgraph::query::test_common::GetPropertyLookup(storage, dba, __VA_ARGS__) #define PARAMETER_LOOKUP(token_position) storage.Create<memgraph::query::ParameterLookup>((token_position)) #define NEXPR(name, expr) storage.Create<memgraph::query::NamedExpression>((name), (expr)) diff --git a/tests/unit/query_plan.cpp b/tests/unit/query_plan.cpp index 93d2f33c7..935709cec 100644 --- a/tests/unit/query_plan.cpp +++ b/tests/unit/query_plan.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -90,7 +90,7 @@ void DeleteListContent(std::list<BaseOpChecker *> *list) { delete ptr; } } -TYPED_TEST_CASE(TestPlanner, PlannerTypes); +TYPED_TEST_SUITE(TestPlanner, PlannerTypes); TYPED_TEST(TestPlanner, MatchNodeReturn) { // Test MATCH (n) RETURN n diff --git a/tests/unit/query_plan_checker.hpp b/tests/unit/query_plan_checker.hpp index 335b6ab2b..7907f167f 100644 --- a/tests/unit/query_plan_checker.hpp +++ b/tests/unit/query_plan_checker.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -17,6 +17,7 @@ #include "query/plan/operator.hpp" #include "query/plan/planner.hpp" #include "query/plan/preprocess.hpp" +#include "query/v2/plan/operator.hpp" namespace memgraph::query::plan { @@ -90,7 +91,7 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { } PRE_VISIT(Unwind); PRE_VISIT(Distinct); - + bool PreVisit(Foreach &op) override { CheckOp(op); return false; @@ -336,6 +337,35 @@ class ExpectScanAllByLabelProperty : public OpChecker<ScanAllByLabelProperty> { memgraph::storage::PropertyId property_; }; +class ExpectScanByPrimaryKey : public OpChecker<v2::plan::ScanByPrimaryKey> { + public: + ExpectScanByPrimaryKey(memgraph::storage::v3::LabelId label, const std::vector<Expression *> &properties) + : label_(label), properties_(properties) {} + + void ExpectOp(v2::plan::ScanByPrimaryKey &scan_all, const SymbolTable &) override { + EXPECT_EQ(scan_all.label_, label_); + + bool primary_property_match = true; + for (const auto &expected_prop : properties_) { + bool has_match = false; + for (const auto &prop : scan_all.primary_key_) { + if (typeid(prop).hash_code() == typeid(expected_prop).hash_code()) { + has_match = true; + } + } + if (!has_match) { + primary_property_match = false; + } + } + + EXPECT_TRUE(primary_property_match); + } + + private: + memgraph::storage::v3::LabelId label_; + std::vector<Expression *> properties_; +}; + class ExpectCartesian : public OpChecker<Cartesian> { public: ExpectCartesian(const std::list<std::unique_ptr<BaseOpChecker>> &left, diff --git a/tests/unit/query_plan_checker_v2.hpp b/tests/unit/query_plan_checker_v2.hpp new file mode 100644 index 000000000..014f1c2bb --- /dev/null +++ b/tests/unit/query_plan_checker_v2.hpp @@ -0,0 +1,452 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include "query/frontend/semantic/symbol_generator.hpp" +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/v2/plan/operator.hpp" +#include "query/v2/plan/planner.hpp" +#include "query/v2/plan/preprocess.hpp" +#include "utils/exceptions.hpp" + +namespace memgraph::query::v2::plan { + +class BaseOpChecker { + public: + virtual ~BaseOpChecker() {} + + virtual void CheckOp(LogicalOperator &, const SymbolTable &) = 0; +}; + +class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { + public: + using HierarchicalLogicalOperatorVisitor::PostVisit; + using HierarchicalLogicalOperatorVisitor::PreVisit; + using HierarchicalLogicalOperatorVisitor::Visit; + + PlanChecker(const std::list<std::unique_ptr<BaseOpChecker>> &checkers, const SymbolTable &symbol_table) + : symbol_table_(symbol_table) { + for (const auto &checker : checkers) checkers_.emplace_back(checker.get()); + } + + PlanChecker(const std::list<BaseOpChecker *> &checkers, const SymbolTable &symbol_table) + : checkers_(checkers), symbol_table_(symbol_table) {} + +#define PRE_VISIT(TOp) \ + bool PreVisit(TOp &op) override { \ + CheckOp(op); \ + return true; \ + } + +#define VISIT(TOp) \ + bool Visit(TOp &op) override { \ + CheckOp(op); \ + return true; \ + } + + PRE_VISIT(CreateNode); + PRE_VISIT(CreateExpand); + PRE_VISIT(Delete); + PRE_VISIT(ScanAll); + PRE_VISIT(ScanAllByLabel); + PRE_VISIT(ScanAllByLabelPropertyValue); + PRE_VISIT(ScanAllByLabelPropertyRange); + PRE_VISIT(ScanAllByLabelProperty); + PRE_VISIT(ScanByPrimaryKey); + PRE_VISIT(Expand); + PRE_VISIT(ExpandVariable); + PRE_VISIT(Filter); + PRE_VISIT(ConstructNamedPath); + PRE_VISIT(Produce); + PRE_VISIT(SetProperty); + PRE_VISIT(SetProperties); + PRE_VISIT(SetLabels); + PRE_VISIT(RemoveProperty); + PRE_VISIT(RemoveLabels); + PRE_VISIT(EdgeUniquenessFilter); + PRE_VISIT(Accumulate); + PRE_VISIT(Aggregate); + PRE_VISIT(Skip); + PRE_VISIT(Limit); + PRE_VISIT(OrderBy); + bool PreVisit(Merge &op) override { + CheckOp(op); + op.input()->Accept(*this); + return false; + } + bool PreVisit(Optional &op) override { + CheckOp(op); + op.input()->Accept(*this); + return false; + } + PRE_VISIT(Unwind); + PRE_VISIT(Distinct); + + bool PreVisit(Foreach &op) override { + CheckOp(op); + return false; + } + + bool Visit(Once &) override { + // Ignore checking Once, it is implicitly at the end. + return true; + } + + bool PreVisit(Cartesian &op) override { + CheckOp(op); + return false; + } + + PRE_VISIT(CallProcedure); + +#undef PRE_VISIT +#undef VISIT + + void CheckOp(LogicalOperator &op) { + ASSERT_FALSE(checkers_.empty()); + checkers_.back()->CheckOp(op, symbol_table_); + checkers_.pop_back(); + } + + std::list<BaseOpChecker *> checkers_; + const SymbolTable &symbol_table_; +}; + +template <class TOp> +class OpChecker : public BaseOpChecker { + public: + void CheckOp(LogicalOperator &op, const SymbolTable &symbol_table) override { + auto *expected_op = dynamic_cast<TOp *>(&op); + ASSERT_TRUE(expected_op) << "op is '" << op.GetTypeInfo().name << "' expected '" << TOp::kType.name << "'!"; + ExpectOp(*expected_op, symbol_table); + } + + virtual void ExpectOp(TOp &, const SymbolTable &) {} +}; + +using ExpectCreateNode = OpChecker<CreateNode>; +using ExpectCreateExpand = OpChecker<CreateExpand>; +using ExpectDelete = OpChecker<Delete>; +using ExpectScanAll = OpChecker<ScanAll>; +using ExpectScanAllByLabel = OpChecker<ScanAllByLabel>; +using ExpectExpand = OpChecker<Expand>; +using ExpectFilter = OpChecker<Filter>; +using ExpectConstructNamedPath = OpChecker<ConstructNamedPath>; +using ExpectProduce = OpChecker<Produce>; +using ExpectSetProperty = OpChecker<SetProperty>; +using ExpectSetProperties = OpChecker<SetProperties>; +using ExpectSetLabels = OpChecker<SetLabels>; +using ExpectRemoveProperty = OpChecker<RemoveProperty>; +using ExpectRemoveLabels = OpChecker<RemoveLabels>; +using ExpectEdgeUniquenessFilter = OpChecker<EdgeUniquenessFilter>; +using ExpectSkip = OpChecker<Skip>; +using ExpectLimit = OpChecker<Limit>; +using ExpectOrderBy = OpChecker<OrderBy>; +using ExpectUnwind = OpChecker<Unwind>; +using ExpectDistinct = OpChecker<Distinct>; + +class ExpectScanAllByLabelPropertyValue : public OpChecker<ScanAllByLabelPropertyValue> { + public: + ExpectScanAllByLabelPropertyValue(memgraph::storage::v3::LabelId label, + const std::pair<std::string, memgraph::storage::v3::PropertyId> &prop_pair, + memgraph::query::v2::Expression *expression) + : label_(label), property_(prop_pair.second), expression_(expression) {} + + void ExpectOp(ScanAllByLabelPropertyValue &scan_all, const SymbolTable &) override { + EXPECT_EQ(scan_all.label_, label_); + EXPECT_EQ(scan_all.property_, property_); + // TODO: Proper expression equality + EXPECT_EQ(typeid(scan_all.expression_).hash_code(), typeid(expression_).hash_code()); + } + + private: + memgraph::storage::v3::LabelId label_; + memgraph::storage::v3::PropertyId property_; + memgraph::query::v2::Expression *expression_; +}; + +class ExpectScanByPrimaryKey : public OpChecker<v2::plan::ScanByPrimaryKey> { + public: + ExpectScanByPrimaryKey(memgraph::storage::v3::LabelId label, const std::vector<Expression *> &properties) + : label_(label), properties_(properties) {} + + void ExpectOp(v2::plan::ScanByPrimaryKey &scan_all, const SymbolTable &) override { + EXPECT_EQ(scan_all.label_, label_); + + bool primary_property_match = true; + for (const auto &expected_prop : properties_) { + bool has_match = false; + for (const auto &prop : scan_all.primary_key_) { + if (typeid(prop).hash_code() == typeid(expected_prop).hash_code()) { + has_match = true; + } + } + if (!has_match) { + primary_property_match = false; + } + } + + EXPECT_TRUE(primary_property_match); + } + + private: + memgraph::storage::v3::LabelId label_; + std::vector<Expression *> properties_; +}; + +class ExpectCartesian : public OpChecker<Cartesian> { + public: + ExpectCartesian(const std::list<std::unique_ptr<BaseOpChecker>> &left, + const std::list<std::unique_ptr<BaseOpChecker>> &right) + : left_(left), right_(right) {} + + void ExpectOp(Cartesian &op, const SymbolTable &symbol_table) override { + ASSERT_TRUE(op.left_op_); + PlanChecker left_checker(left_, symbol_table); + op.left_op_->Accept(left_checker); + ASSERT_TRUE(op.right_op_); + PlanChecker right_checker(right_, symbol_table); + op.right_op_->Accept(right_checker); + } + + private: + const std::list<std::unique_ptr<BaseOpChecker>> &left_; + const std::list<std::unique_ptr<BaseOpChecker>> &right_; +}; + +class ExpectCallProcedure : public OpChecker<CallProcedure> { + public: + ExpectCallProcedure(const std::string &name, const std::vector<memgraph::query::Expression *> &args, + const std::vector<std::string> &fields, const std::vector<Symbol> &result_syms) + : name_(name), args_(args), fields_(fields), result_syms_(result_syms) {} + + void ExpectOp(CallProcedure &op, const SymbolTable &symbol_table) override { + EXPECT_EQ(op.procedure_name_, name_); + EXPECT_EQ(op.arguments_.size(), args_.size()); + for (size_t i = 0; i < args_.size(); ++i) { + const auto *op_arg = op.arguments_[i]; + const auto *expected_arg = args_[i]; + // TODO: Proper expression equality + EXPECT_EQ(op_arg->GetTypeInfo(), expected_arg->GetTypeInfo()); + } + EXPECT_EQ(op.result_fields_, fields_); + EXPECT_EQ(op.result_symbols_, result_syms_); + } + + private: + std::string name_; + std::vector<memgraph::query::Expression *> args_; + std::vector<std::string> fields_; + std::vector<Symbol> result_syms_; +}; + +template <class T> +std::list<std::unique_ptr<BaseOpChecker>> MakeCheckers(T arg) { + std::list<std::unique_ptr<BaseOpChecker>> l; + l.emplace_back(std::make_unique<T>(arg)); + return l; +} + +template <class T, class... Rest> +std::list<std::unique_ptr<BaseOpChecker>> MakeCheckers(T arg, Rest &&...rest) { + auto l = MakeCheckers(std::forward<Rest>(rest)...); + l.emplace_front(std::make_unique<T>(arg)); + return std::move(l); +} + +template <class TPlanner, class TDbAccessor> +TPlanner MakePlanner(TDbAccessor *dba, AstStorage &storage, SymbolTable &symbol_table, CypherQuery *query) { + auto planning_context = MakePlanningContext(&storage, &symbol_table, query, dba); + auto query_parts = CollectQueryParts(symbol_table, storage, query); + auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; + return TPlanner(single_query_parts, planning_context); +} + +class FakeDistributedDbAccessor { + public: + int64_t VerticesCount(memgraph::storage::v3::LabelId label) const { + auto found = label_index_.find(label); + if (found != label_index_.end()) return found->second; + return 0; + } + + int64_t VerticesCount(memgraph::storage::v3::LabelId label, memgraph::storage::v3::PropertyId property) const { + for (auto &index : label_property_index_) { + if (std::get<0>(index) == label && std::get<1>(index) == property) { + return std::get<2>(index); + } + } + return 0; + } + + bool LabelIndexExists(memgraph::storage::v3::LabelId label) const { + throw utils::NotYetImplemented("Label indicies are yet to be implemented."); + } + + bool LabelPropertyIndexExists(memgraph::storage::v3::LabelId label, + memgraph::storage::v3::PropertyId property) const { + for (auto &index : label_property_index_) { + if (std::get<0>(index) == label && std::get<1>(index) == property) { + return true; + } + } + return false; + } + + bool PrimaryLabelExists(storage::v3::LabelId label) { return label_index_.find(label) != label_index_.end(); } + + void SetIndexCount(memgraph::storage::v3::LabelId label, int64_t count) { label_index_[label] = count; } + + void SetIndexCount(memgraph::storage::v3::LabelId label, memgraph::storage::v3::PropertyId property, int64_t count) { + for (auto &index : label_property_index_) { + if (std::get<0>(index) == label && std::get<1>(index) == property) { + std::get<2>(index) = count; + return; + } + } + label_property_index_.emplace_back(label, property, count); + } + + memgraph::storage::v3::LabelId NameToLabel(const std::string &name) { + auto found = primary_labels_.find(name); + if (found != primary_labels_.end()) return found->second; + return primary_labels_.emplace(name, memgraph::storage::v3::LabelId::FromUint(primary_labels_.size())) + .first->second; + } + + memgraph::storage::v3::LabelId Label(const std::string &name) { return NameToLabel(name); } + + memgraph::storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) { + auto found = edge_types_.find(name); + if (found != edge_types_.end()) return found->second; + return edge_types_.emplace(name, memgraph::storage::v3::EdgeTypeId::FromUint(edge_types_.size())).first->second; + } + + memgraph::storage::v3::PropertyId NameToPrimaryProperty(const std::string &name) { + auto found = primary_properties_.find(name); + if (found != primary_properties_.end()) return found->second; + return primary_properties_.emplace(name, memgraph::storage::v3::PropertyId::FromUint(primary_properties_.size())) + .first->second; + } + + memgraph::storage::v3::PropertyId NameToSecondaryProperty(const std::string &name) { + auto found = secondary_properties_.find(name); + if (found != secondary_properties_.end()) return found->second; + return secondary_properties_ + .emplace(name, memgraph::storage::v3::PropertyId::FromUint(secondary_properties_.size())) + .first->second; + } + + memgraph::storage::v3::PropertyId PrimaryProperty(const std::string &name) { return NameToPrimaryProperty(name); } + memgraph::storage::v3::PropertyId SecondaryProperty(const std::string &name) { return NameToSecondaryProperty(name); } + + std::string PrimaryPropertyToName(memgraph::storage::v3::PropertyId property) const { + for (const auto &kv : primary_properties_) { + if (kv.second == property) return kv.first; + } + LOG_FATAL("Unable to find primary property name"); + } + + std::string SecondaryPropertyToName(memgraph::storage::v3::PropertyId property) const { + for (const auto &kv : secondary_properties_) { + if (kv.second == property) return kv.first; + } + LOG_FATAL("Unable to find secondary property name"); + } + + std::string PrimaryPropertyName(memgraph::storage::v3::PropertyId property) const { + return PrimaryPropertyToName(property); + } + std::string SecondaryPropertyName(memgraph::storage::v3::PropertyId property) const { + return SecondaryPropertyToName(property); + } + + memgraph::storage::v3::PropertyId NameToProperty(const std::string &name) { + auto find_in_prim_properties = primary_properties_.find(name); + if (find_in_prim_properties != primary_properties_.end()) { + return find_in_prim_properties->second; + } + auto find_in_secondary_properties = secondary_properties_.find(name); + if (find_in_secondary_properties != secondary_properties_.end()) { + return find_in_secondary_properties->second; + } + + LOG_FATAL("The property does not exist as a primary or a secondary property."); + return memgraph::storage::v3::PropertyId::FromUint(0); + } + + std::vector<memgraph::storage::v3::SchemaProperty> GetSchemaForLabel(storage::v3::LabelId label) { + auto schema_properties = schemas_.at(label); + std::vector<memgraph::storage::v3::SchemaProperty> ret; + std::transform(schema_properties.begin(), schema_properties.end(), std::back_inserter(ret), [](const auto &prop) { + memgraph::storage::v3::SchemaProperty schema_prop = { + .property_id = prop, + // This should not be hardcoded, but for testing purposes it will suffice. + .type = memgraph::common::SchemaType::INT}; + + return schema_prop; + }); + return ret; + } + + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> ExtractPrimaryKey( + storage::v3::LabelId label, std::vector<query::v2::plan::FilterInfo> property_filters) { + MG_ASSERT(schemas_.contains(label), + "You did not specify the Schema for this label! Use FakeDistributedDbAccessor::CreateSchema(...)."); + + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> pk; + const auto schema = GetSchemaPropertiesForLabel(label); + + std::vector<storage::v3::PropertyId> schema_properties; + schema_properties.reserve(schema.size()); + + std::transform(schema.begin(), schema.end(), std::back_inserter(schema_properties), + [](const auto &schema_elem) { return schema_elem; }); + + for (const auto &property_filter : property_filters) { + const auto &property_id = NameToProperty(property_filter.property_filter->property_.name); + if (std::find(schema_properties.begin(), schema_properties.end(), property_id) != schema_properties.end()) { + pk.emplace_back(std::make_pair(property_filter.expression, property_filter)); + } + } + + return pk.size() == schema_properties.size() + ? pk + : std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>>{}; + } + + std::vector<memgraph::storage::v3::PropertyId> GetSchemaPropertiesForLabel(storage::v3::LabelId label) { + return schemas_.at(label); + } + + void CreateSchema(const memgraph::storage::v3::LabelId primary_label, + const std::vector<memgraph::storage::v3::PropertyId> &schemas_types) { + MG_ASSERT(!schemas_.contains(primary_label), "You already created the schema for this label!"); + schemas_.emplace(primary_label, schemas_types); + } + + private: + std::unordered_map<std::string, memgraph::storage::v3::LabelId> primary_labels_; + std::unordered_map<std::string, memgraph::storage::v3::LabelId> secondary_labels_; + std::unordered_map<std::string, memgraph::storage::v3::EdgeTypeId> edge_types_; + std::unordered_map<std::string, memgraph::storage::v3::PropertyId> primary_properties_; + std::unordered_map<std::string, memgraph::storage::v3::PropertyId> secondary_properties_; + + std::unordered_map<memgraph::storage::v3::LabelId, int64_t> label_index_; + std::vector<std::tuple<memgraph::storage::v3::LabelId, memgraph::storage::v3::PropertyId, int64_t>> + label_property_index_; + + std::unordered_map<memgraph::storage::v3::LabelId, std::vector<memgraph::storage::v3::PropertyId>> schemas_; +}; + +} // namespace memgraph::query::v2::plan diff --git a/tests/unit/query_v2_common.hpp b/tests/unit/query_v2_common.hpp new file mode 100644 index 000000000..1b9d6807c --- /dev/null +++ b/tests/unit/query_v2_common.hpp @@ -0,0 +1,603 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +/// This file provides macros for easier construction of openCypher query AST. +/// The usage of macros is very similar to how one would write openCypher. For +/// example: +/// +/// AstStorage storage; // Macros rely on storage being in scope. +/// // PROPERTY_LOOKUP and PROPERTY_PAIR macros +/// // rely on a DbAccessor *reference* named dba. +/// database::GraphDb db; +/// auto dba_ptr = db.Access(); +/// auto &dba = *dba_ptr; +/// +/// QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), +/// WHERE(LESS(PROPERTY_LOOKUP("e", edge_prop), LITERAL(3))), +/// RETURN(SUM(PROPERTY_LOOKUP("m", prop)), AS("sum"), +/// ORDER_BY(IDENT("sum")), +/// SKIP(ADD(LITERAL(1), LITERAL(2))))); +/// +/// Each of the macros is accompanied by a function. The functions use overload +/// resolution and template magic to provide a type safe way of constructing +/// queries. Although the functions can be used by themselves, it is more +/// convenient to use the macros. + +#pragma once + +#include <map> +#include <sstream> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "query/frontend/ast/pretty_print.hpp" // not sure if that is ok... +#include "query/v2/frontend/ast/ast.hpp" +#include "storage/v3/id_types.hpp" +#include "utils/string.hpp" + +#include "query/v2/frontend/ast/ast.hpp" + +namespace memgraph::query::v2 { + +namespace test_common { + +auto ToIntList(const TypedValue &t) { + std::vector<int64_t> list; + for (auto x : t.ValueList()) { + list.push_back(x.ValueInt()); + } + return list; +}; + +auto ToIntMap(const TypedValue &t) { + std::map<std::string, int64_t> map; + for (const auto &kv : t.ValueMap()) map.emplace(kv.first, kv.second.ValueInt()); + return map; +}; + +std::string ToString(Expression *expr) { + std::ostringstream ss; + // PrintExpression(expr, &ss); + return ss.str(); +} + +std::string ToString(NamedExpression *expr) { + std::ostringstream ss; + // PrintExpression(expr, &ss); + return ss.str(); +} + +// Custom types for ORDER BY, SKIP, LIMIT, ON MATCH and ON CREATE expressions, +// so that they can be used to resolve function calls. +struct OrderBy { + std::vector<SortItem> expressions; +}; + +struct Skip { + Expression *expression = nullptr; +}; + +struct Limit { + Expression *expression = nullptr; +}; + +struct OnMatch { + std::vector<Clause *> set; +}; +struct OnCreate { + std::vector<Clause *> set; +}; + +// Helper functions for filling the OrderBy with expressions. +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering = Ordering::ASC) { + order_by.expressions.push_back({ordering, expression}); +} +template <class... T> +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering, T... rest) { + FillOrderBy(order_by, expression, ordering); + FillOrderBy(order_by, rest...); +} +template <class... T> +auto FillOrderBy(OrderBy &order_by, Expression *expression, T... rest) { + FillOrderBy(order_by, expression); + FillOrderBy(order_by, rest...); +} + +/// Create OrderBy expressions. +/// +/// The supported combination of arguments is: (Expression, [Ordering])+ +/// Since the Ordering is optional, by default it is ascending. +template <class... T> +auto GetOrderBy(T... exprs) { + OrderBy order_by; + FillOrderBy(order_by, exprs...); + return order_by; +} + +/// Create PropertyLookup with given name and property. +/// +/// Name is used to create the Identifier which is used for property lookup. +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, const std::string &name, + memgraph::storage::v3::PropertyId property) { + return storage.Create<PropertyLookup>(storage.Create<Identifier>(name), + storage.GetPropertyIx(dba.PropertyToName(property))); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, Expression *expr, + memgraph::storage::v3::PropertyId property) { + return storage.Create<PropertyLookup>(expr, storage.GetPropertyIx(dba.PropertyToName(property))); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, Expression *expr, const std::string &property) { + return storage.Create<PropertyLookup>(expr, storage.GetPropertyIx(property)); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, const std::string &name, + const std::pair<std::string, memgraph::storage::v3::PropertyId> &prop_pair) { + return storage.Create<PropertyLookup>(storage.Create<Identifier>(name), storage.GetPropertyIx(prop_pair.first)); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, Expression *expr, + const std::pair<std::string, memgraph::storage::v3::PropertyId> &prop_pair) { + return storage.Create<PropertyLookup>(expr, storage.GetPropertyIx(prop_pair.first)); +} + +/// Create an EdgeAtom with given name, direction and edge_type. +/// +/// Name is used to create the Identifier which is assigned to the edge. +auto GetEdge(AstStorage &storage, const std::string &name, EdgeAtom::Direction dir = EdgeAtom::Direction::BOTH, + const std::vector<std::string> &edge_types = {}) { + std::vector<EdgeTypeIx> types; + types.reserve(edge_types.size()); + for (const auto &type : edge_types) { + types.push_back(storage.GetEdgeTypeIx(type)); + } + return storage.Create<EdgeAtom>(storage.Create<Identifier>(name), EdgeAtom::Type::SINGLE, dir, types); +} + +/// Create a variable length expansion EdgeAtom with given name, direction and +/// edge_type. +/// +/// Name is used to create the Identifier which is assigned to the edge. +auto GetEdgeVariable(AstStorage &storage, const std::string &name, EdgeAtom::Type type = EdgeAtom::Type::DEPTH_FIRST, + EdgeAtom::Direction dir = EdgeAtom::Direction::BOTH, + const std::vector<std::string> &edge_types = {}, Identifier *flambda_inner_edge = nullptr, + Identifier *flambda_inner_node = nullptr, Identifier *wlambda_inner_edge = nullptr, + Identifier *wlambda_inner_node = nullptr, Expression *wlambda_expression = nullptr, + Identifier *total_weight = nullptr) { + std::vector<EdgeTypeIx> types; + types.reserve(edge_types.size()); + for (const auto &type : edge_types) { + types.push_back(storage.GetEdgeTypeIx(type)); + } + auto r_val = storage.Create<EdgeAtom>(storage.Create<Identifier>(name), type, dir, types); + + r_val->filter_lambda_.inner_edge = + flambda_inner_edge ? flambda_inner_edge : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + r_val->filter_lambda_.inner_node = + flambda_inner_node ? flambda_inner_node : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + + if (type == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) { + r_val->weight_lambda_.inner_edge = + wlambda_inner_edge ? wlambda_inner_edge : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + r_val->weight_lambda_.inner_node = + wlambda_inner_node ? wlambda_inner_node : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + r_val->weight_lambda_.expression = + wlambda_expression ? wlambda_expression : storage.Create<memgraph::query::v2::PrimitiveLiteral>(1); + + r_val->total_weight_ = total_weight; + } + + return r_val; +} + +/// Create a NodeAtom with given name and label. +/// +/// Name is used to create the Identifier which is assigned to the node. +auto GetNode(AstStorage &storage, const std::string &name, std::optional<std::string> label = std::nullopt) { + auto node = storage.Create<NodeAtom>(storage.Create<Identifier>(name)); + if (label) node->labels_.emplace_back(storage.GetLabelIx(*label)); + return node; +} + +/// Create a Pattern with given atoms. +auto GetPattern(AstStorage &storage, std::vector<PatternAtom *> atoms) { + auto pattern = storage.Create<Pattern>(); + pattern->identifier_ = storage.Create<Identifier>(memgraph::utils::RandomString(20), false); + pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); + return pattern; +} + +/// Create a Pattern with given name and atoms. +auto GetPattern(AstStorage &storage, const std::string &name, std::vector<PatternAtom *> atoms) { + auto pattern = storage.Create<Pattern>(); + pattern->identifier_ = storage.Create<Identifier>(name, true); + pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); + return pattern; +} + +/// This function fills an AST node which with given patterns. +/// +/// The function is most commonly used to create Match and Create clauses. +template <class TWithPatterns> +auto GetWithPatterns(TWithPatterns *with_patterns, std::vector<Pattern *> patterns) { + with_patterns->patterns_.insert(with_patterns->patterns_.begin(), patterns.begin(), patterns.end()); + return with_patterns; +} + +/// Create a query with given clauses. + +auto GetSingleQuery(SingleQuery *single_query, Clause *clause) { + single_query->clauses_.emplace_back(clause); + return single_query; +} +auto GetSingleQuery(SingleQuery *single_query, Match *match, Where *where) { + match->where_ = where; + single_query->clauses_.emplace_back(match); + return single_query; +} +auto GetSingleQuery(SingleQuery *single_query, With *with, Where *where) { + with->where_ = where; + single_query->clauses_.emplace_back(with); + return single_query; +} +template <class... T> +auto GetSingleQuery(SingleQuery *single_query, Match *match, Where *where, T *...clauses) { + match->where_ = where; + single_query->clauses_.emplace_back(match); + return GetSingleQuery(single_query, clauses...); +} +template <class... T> +auto GetSingleQuery(SingleQuery *single_query, With *with, Where *where, T *...clauses) { + with->where_ = where; + single_query->clauses_.emplace_back(with); + return GetSingleQuery(single_query, clauses...); +} + +template <class... T> +auto GetSingleQuery(SingleQuery *single_query, Clause *clause, T *...clauses) { + single_query->clauses_.emplace_back(clause); + return GetSingleQuery(single_query, clauses...); +} + +auto GetCypherUnion(CypherUnion *cypher_union, SingleQuery *single_query) { + cypher_union->single_query_ = single_query; + return cypher_union; +} + +auto GetQuery(AstStorage &storage, SingleQuery *single_query) { + auto *query = storage.Create<CypherQuery>(); + query->single_query_ = single_query; + return query; +} + +template <class... T> +auto GetQuery(AstStorage &storage, SingleQuery *single_query, T *...cypher_unions) { + auto *query = storage.Create<CypherQuery>(); + query->single_query_ = single_query; + query->cypher_unions_ = std::vector<CypherUnion *>{cypher_unions...}; + return query; +} + +// Helper functions for constructing RETURN and WITH clauses. +void FillReturnBody(AstStorage &, ReturnBody &body, NamedExpression *named_expr) { + body.named_expressions.emplace_back(named_expr); +} +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name) { + if (name == "*") { + body.all_identifiers = true; + } else { + auto *ident = storage.Create<memgraph::query::v2::Identifier>(name); + auto *named_expr = storage.Create<memgraph::query::v2::NamedExpression>(name, ident); + body.named_expressions.emplace_back(named_expr); + } +} +void FillReturnBody(AstStorage &, ReturnBody &body, Limit limit) { body.limit = limit.expression; } +void FillReturnBody(AstStorage &, ReturnBody &body, Skip skip, Limit limit = Limit{}) { + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Skip skip, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, Expression *expr, NamedExpression *named_expr) { + // This overload supports `RETURN(expr, AS(name))` construct, since + // NamedExpression does not inherit Expression. + named_expr->expression_ = expr; + body.named_expressions.emplace_back(named_expr); +} +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, NamedExpression *named_expr) { + named_expr->expression_ = storage.Create<memgraph::query::v2::Identifier>(name); + body.named_expressions.emplace_back(named_expr); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, Expression *expr, NamedExpression *named_expr, T... rest) { + named_expr->expression_ = expr; + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, NamedExpression *named_expr, T... rest) { + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, NamedExpression *named_expr, + T... rest) { + named_expr->expression_ = storage.Create<memgraph::query::v2::Identifier>(name); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, T... rest) { + auto *ident = storage.Create<memgraph::query::v2::Identifier>(name); + auto *named_expr = storage.Create<memgraph::query::v2::NamedExpression>(name, ident); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} + +/// Create the return clause with given expressions. +/// +/// The supported expression combination of arguments is: +/// +/// (String | NamedExpression | (Expression NamedExpression))+ +/// [OrderBy] [Skip] [Limit] +/// +/// When the pair (Expression NamedExpression) is given, the Expression will be +/// moved inside the NamedExpression. This is done, so that the constructs like +/// RETURN(expr, AS("name"), ...) are supported. Taking a String is a shorthand +/// for RETURN(IDENT(string), AS(string), ....). +/// +/// @sa GetWith +template <class... T> +auto GetReturn(AstStorage &storage, bool distinct, T... exprs) { + auto ret = storage.Create<Return>(); + ret->body_.distinct = distinct; + FillReturnBody(storage, ret->body_, exprs...); + return ret; +} + +/// Create the with clause with given expressions. +/// +/// The supported expression combination is the same as for @c GetReturn. +/// +/// @sa GetReturn +template <class... T> +auto GetWith(AstStorage &storage, bool distinct, T... exprs) { + auto with = storage.Create<With>(); + with->body_.distinct = distinct; + FillReturnBody(storage, with->body_, exprs...); + return with; +} + +/// Create the UNWIND clause with given named expression. +auto GetUnwind(AstStorage &storage, NamedExpression *named_expr) { + return storage.Create<memgraph::query::v2::Unwind>(named_expr); +} +auto GetUnwind(AstStorage &storage, Expression *expr, NamedExpression *as) { + as->expression_ = expr; + return GetUnwind(storage, as); +} + +/// Create the delete clause with given named expressions. +auto GetDelete(AstStorage &storage, std::vector<Expression *> exprs, bool detach = false) { + auto del = storage.Create<Delete>(); + del->expressions_.insert(del->expressions_.begin(), exprs.begin(), exprs.end()); + del->detach_ = detach; + return del; +} + +/// Create a set property clause for given property lookup and the right hand +/// side expression. +auto GetSet(AstStorage &storage, PropertyLookup *prop_lookup, Expression *expr) { + return storage.Create<SetProperty>(prop_lookup, expr); +} + +/// Create a set properties clause for given identifier name and the right hand +/// side expression. +auto GetSet(AstStorage &storage, const std::string &name, Expression *expr, bool update = false) { + return storage.Create<SetProperties>(storage.Create<Identifier>(name), expr, update); +} + +/// Create a set labels clause for given identifier name and labels. +auto GetSet(AstStorage &storage, const std::string &name, std::vector<std::string> label_names) { + std::vector<LabelIx> labels; + labels.reserve(label_names.size()); + for (const auto &label : label_names) { + labels.push_back(storage.GetLabelIx(label)); + } + return storage.Create<SetLabels>(storage.Create<Identifier>(name), labels); +} + +/// Create a remove property clause for given property lookup +auto GetRemove(AstStorage &storage, PropertyLookup *prop_lookup) { return storage.Create<RemoveProperty>(prop_lookup); } + +/// Create a remove labels clause for given identifier name and labels. +auto GetRemove(AstStorage &storage, const std::string &name, std::vector<std::string> label_names) { + std::vector<LabelIx> labels; + labels.reserve(label_names.size()); + for (const auto &label : label_names) { + labels.push_back(storage.GetLabelIx(label)); + } + return storage.Create<RemoveLabels>(storage.Create<Identifier>(name), labels); +} + +/// Create a Merge clause for given Pattern with optional OnMatch and OnCreate +/// parts. +auto GetMerge(AstStorage &storage, Pattern *pattern, OnCreate on_create = OnCreate{}) { + auto *merge = storage.Create<memgraph::query::v2::Merge>(); + merge->pattern_ = pattern; + merge->on_create_ = on_create.set; + return merge; +} +auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match, OnCreate on_create = OnCreate{}) { + auto *merge = storage.Create<memgraph::query::v2::Merge>(); + merge->pattern_ = pattern; + merge->on_match_ = on_match.set; + merge->on_create_ = on_create.set; + return merge; +} + +auto GetCallProcedure(AstStorage &storage, std::string procedure_name, + std::vector<memgraph::query::v2::Expression *> arguments = {}) { + auto *call_procedure = storage.Create<memgraph::query::v2::CallProcedure>(); + call_procedure->procedure_name_ = std::move(procedure_name); + call_procedure->arguments_ = std::move(arguments); + return call_procedure; +} + +/// Create the FOREACH clause with given named expression. +auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vector<query::v2::Clause *> &clauses) { + return storage.Create<query::v2::Foreach>(named_expr, clauses); +} + +} // namespace test_common + +} // namespace memgraph::query::v2 + +/// All the following macros implicitly pass `storage` variable to functions. +/// You need to have `AstStorage storage;` somewhere in scope to use them. +/// Refer to function documentation to see what the macro does. +/// +/// Example usage: +/// +/// // Create MATCH (n) -[r]- (m) RETURN m AS new_name +/// AstStorage storage; +/// auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), +/// RETURN(NEXPR("new_name"), IDENT("m"))); +#define NODE(...) memgraph::expr::test_common::GetNode(storage, __VA_ARGS__) +#define EDGE(...) memgraph::expr::test_common::GetEdge(storage, __VA_ARGS__) +#define EDGE_VARIABLE(...) memgraph::expr::test_common::GetEdgeVariable(storage, __VA_ARGS__) +#define PATTERN(...) memgraph::expr::test_common::GetPattern(storage, {__VA_ARGS__}) +#define PATTERN(...) memgraph::expr::test_common::GetPattern(storage, {__VA_ARGS__}) +#define NAMED_PATTERN(name, ...) memgraph::expr::test_common::GetPattern(storage, name, {__VA_ARGS__}) +#define OPTIONAL_MATCH(...) \ + memgraph::expr::test_common::GetWithPatterns(storage.Create<memgraph::query::v2::Match>(true), {__VA_ARGS__}) +#define MATCH(...) \ + memgraph::expr::test_common::GetWithPatterns(storage.Create<memgraph::query::v2::Match>(), {__VA_ARGS__}) +#define WHERE(expr) storage.Create<memgraph::query::v2::Where>((expr)) +#define CREATE(...) \ + memgraph::expr::test_common::GetWithPatterns(storage.Create<memgraph::query::v2::Create>(), {__VA_ARGS__}) +#define IDENT(...) storage.Create<memgraph::query::v2::Identifier>(__VA_ARGS__) +#define LITERAL(val) storage.Create<memgraph::query::v2::PrimitiveLiteral>((val)) +#define LIST(...) \ + storage.Create<memgraph::query::v2::ListLiteral>(std::vector<memgraph::query::v2::Expression *>{__VA_ARGS__}) +#define MAP(...) \ + storage.Create<memgraph::query::v2::MapLiteral>( \ + std::unordered_map<memgraph::query::v2::PropertyIx, memgraph::query::v2::Expression *>{__VA_ARGS__}) +#define PROPERTY_PAIR(property_name) \ + std::make_pair(property_name, dba.NameToProperty(property_name)) // This one might not be needed at all +#define PRIMARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToPrimaryProperty(property_name)) +#define SECONDARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToSecondaryProperty(property_name)) +#define PROPERTY_LOOKUP(...) memgraph::expr::test_common::GetPropertyLookup(storage, dba, __VA_ARGS__) +#define PARAMETER_LOOKUP(token_position) storage.Create<memgraph::query::v2::ParameterLookup>((token_position)) +#define NEXPR(name, expr) storage.Create<memgraph::query::v2::NamedExpression>((name), (expr)) +// AS is alternative to NEXPR which does not initialize NamedExpression with +// Expression. It should be used with RETURN or WITH. For example: +// RETURN(IDENT("n"), AS("n")) vs. RETURN(NEXPR("n", IDENT("n"))). +#define AS(name) storage.Create<memgraph::query::v2::NamedExpression>((name)) +#define RETURN(...) memgraph::expr::test_common::GetReturn(storage, false, __VA_ARGS__) +#define WITH(...) memgraph::expr::test_common::GetWith(storage, false, __VA_ARGS__) +#define RETURN_DISTINCT(...) memgraph::expr::test_common::GetReturn(storage, true, __VA_ARGS__) +#define WITH_DISTINCT(...) memgraph::expr::test_common::GetWith(storage, true, __VA_ARGS__) +#define UNWIND(...) memgraph::expr::test_common::GetUnwind(storage, __VA_ARGS__) +#define ORDER_BY(...) memgraph::expr::test_common::GetOrderBy(__VA_ARGS__) +#define SKIP(expr) \ + memgraph::expr::test_common::Skip { (expr) } +#define LIMIT(expr) \ + memgraph::expr::test_common::Limit { (expr) } +#define DELETE(...) memgraph::expr::test_common::GetDelete(storage, {__VA_ARGS__}) +#define DETACH_DELETE(...) memgraph::expr::test_common::GetDelete(storage, {__VA_ARGS__}, true) +#define SET(...) memgraph::expr::test_common::GetSet(storage, __VA_ARGS__) +#define REMOVE(...) memgraph::expr::test_common::GetRemove(storage, __VA_ARGS__) +#define MERGE(...) memgraph::expr::test_common::GetMerge(storage, __VA_ARGS__) +#define ON_MATCH(...) \ + memgraph::expr::test_common::OnMatch { \ + std::vector<memgraph::query::v2::Clause *> { __VA_ARGS__ } \ + } +#define ON_CREATE(...) \ + memgraph::expr::test_common::OnCreate { \ + std::vector<memgraph::query::v2::Clause *> { __VA_ARGS__ } \ + } +#define CREATE_INDEX_ON(label, property) \ + storage.Create<memgraph::query::v2::IndexQuery>(memgraph::query::v2::IndexQuery::Action::CREATE, (label), \ + std::vector<memgraph::query::v2::PropertyIx>{(property)}) +#define QUERY(...) memgraph::expr::test_common::GetQuery(storage, __VA_ARGS__) +#define SINGLE_QUERY(...) memgraph::expr::test_common::GetSingleQuery(storage.Create<SingleQuery>(), __VA_ARGS__) +#define UNION(...) memgraph::expr::test_common::GetCypherUnion(storage.Create<CypherUnion>(true), __VA_ARGS__) +#define UNION_ALL(...) memgraph::expr::test_common::GetCypherUnion(storage.Create<CypherUnion>(false), __VA_ARGS__) +#define FOREACH(...) memgraph::expr::test_common::GetForeach(storage, __VA_ARGS__) +// Various operators +#define NOT(expr) storage.Create<memgraph::query::v2::NotOperator>((expr)) +#define UPLUS(expr) storage.Create<memgraph::query::v2::UnaryPlusOperator>((expr)) +#define UMINUS(expr) storage.Create<memgraph::query::v2::UnaryMinusOperator>((expr)) +#define IS_NULL(expr) storage.Create<memgraph::query::v2::IsNullOperator>((expr)) +#define ADD(expr1, expr2) storage.Create<memgraph::query::v2::AdditionOperator>((expr1), (expr2)) +#define LESS(expr1, expr2) storage.Create<memgraph::query::v2::LessOperator>((expr1), (expr2)) +#define LESS_EQ(expr1, expr2) storage.Create<memgraph::query::v2::LessEqualOperator>((expr1), (expr2)) +#define GREATER(expr1, expr2) storage.Create<memgraph::query::v2::GreaterOperator>((expr1), (expr2)) +#define GREATER_EQ(expr1, expr2) storage.Create<memgraph::query::v2::GreaterEqualOperator>((expr1), (expr2)) +#define SUM(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::SUM) +#define COUNT(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::COUNT) +#define AVG(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::AVG) +#define COLLECT_LIST(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::COLLECT_LIST) +#define EQ(expr1, expr2) storage.Create<memgraph::query::v2::EqualOperator>((expr1), (expr2)) +#define NEQ(expr1, expr2) storage.Create<memgraph::query::v2::NotEqualOperator>((expr1), (expr2)) +#define AND(expr1, expr2) storage.Create<memgraph::query::v2::AndOperator>((expr1), (expr2)) +#define OR(expr1, expr2) storage.Create<memgraph::query::v2::OrOperator>((expr1), (expr2)) +#define IN_LIST(expr1, expr2) storage.Create<memgraph::query::v2::InListOperator>((expr1), (expr2)) +#define IF(cond, then, else) storage.Create<memgraph::query::v2::IfOperator>((cond), (then), (else)) +// Function call +#define FN(function_name, ...) \ + storage.Create<memgraph::query::v2::Function>(memgraph::utils::ToUpperCase(function_name), \ + std::vector<memgraph::query::v2::Expression *>{__VA_ARGS__}) +// List slicing +#define SLICE(list, lower_bound, upper_bound) \ + storage.Create<memgraph::query::v2::ListSlicingOperator>(list, lower_bound, upper_bound) +// all(variable IN list WHERE predicate) +#define ALL(variable, list, where) \ + storage.Create<memgraph::query::v2::All>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define SINGLE(variable, list, where) \ + storage.Create<memgraph::query::v2::Single>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define ANY(variable, list, where) \ + storage.Create<memgraph::query::v2::Any>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define NONE(variable, list, where) \ + storage.Create<memgraph::query::v2::None>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define REDUCE(accumulator, initializer, variable, list, expr) \ + storage.Create<memgraph::query::v2::Reduce>(storage.Create<memgraph::query::v2::Identifier>(accumulator), \ + initializer, storage.Create<memgraph::query::v2::Identifier>(variable), \ + list, expr) +#define COALESCE(...) \ + storage.Create<memgraph::query::v2::Coalesce>(std::vector<memgraph::query::v2::Expression *>{__VA_ARGS__}) +#define EXTRACT(variable, list, expr) \ + storage.Create<memgraph::query::v2::Extract>(storage.Create<memgraph::query::v2::Identifier>(variable), list, expr) +#define AUTH_QUERY(action, user, role, user_or_role, password, privileges) \ + storage.Create<memgraph::query::v2::AuthQuery>((action), (user), (role), (user_or_role), password, (privileges)) +#define DROP_USER(usernames) storage.Create<memgraph::query::v2::DropUser>((usernames)) +#define CALL_PROCEDURE(...) memgraph::query::v2::test_common::GetCallProcedure(storage, __VA_ARGS__) diff --git a/tests/unit/query_v2_create_expand_multiframe.cpp b/tests/unit/query_v2_create_expand_multiframe.cpp index 8720656fd..77eb7f2ea 100644 --- a/tests/unit/query_v2_create_expand_multiframe.cpp +++ b/tests/unit/query_v2_create_expand_multiframe.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -23,7 +23,7 @@ #include "utils/logging.hpp" #include "utils/memory.hpp" -namespace memgraph::query::v2 { +namespace memgraph::query::v2::tests { MultiFrame CreateMultiFrame(const size_t max_pos, const Symbol &src, const Symbol &dst, MockedRequestRouter *router) { static constexpr size_t number_of_frames = 100; @@ -63,7 +63,6 @@ TEST(CreateExpandTest, Cursor) { node.symbol = symbol_table.CreateSymbol("u", true); auto once_op = std::make_shared<plan::Once>(); - auto once_cur = once_op->MakeCursor(utils::NewDeleteResource()); auto create_expand = plan::CreateExpand(node, edge, once_op, src, true); auto cursor = create_expand.MakeCursor(utils::NewDeleteResource()); @@ -91,4 +90,4 @@ TEST(CreateExpandTest, Cursor) { EXPECT_EQ(number_of_invalid_frames, 99); } -} // namespace memgraph::query::v2 +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/query_v2_create_node_multiframe.cpp b/tests/unit/query_v2_create_node_multiframe.cpp index f19082783..b298d2781 100644 --- a/tests/unit/query_v2_create_node_multiframe.cpp +++ b/tests/unit/query_v2_create_node_multiframe.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -9,6 +9,7 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. +#include "gmock/gmock.h" #include "mock_helpers.hpp" #include "query/v2/bindings/frame.hpp" @@ -22,54 +23,60 @@ #include "storage/v3/shard.hpp" #include "utils/memory.hpp" -namespace memgraph::query::v2 { +namespace memgraph::query::v2::tests { MultiFrame CreateMultiFrame(const size_t max_pos) { static constexpr size_t frame_size = 100; MultiFrame multi_frame(max_pos, frame_size, utils::NewDeleteResource()); - auto frames_populator = multi_frame.GetInvalidFramesPopulator(); - for (auto &frame : frames_populator) { - frame.MakeValid(); - } return multi_frame; } TEST(CreateNodeTest, CreateNodeCursor) { using testing::_; + using testing::IsEmpty; using testing::Return; AstStorage ast; SymbolTable symbol_table; plan::NodeCreationInfo node; - plan::EdgeCreationInfo edge; - edge.edge_type = msgs::EdgeTypeId::FromUint(1); - edge.direction = EdgeAtom::Direction::IN; auto id_alloc = IdAllocator(0, 100); node.symbol = symbol_table.CreateSymbol("n", true); - node.labels.push_back(msgs::LabelId::FromUint(2)); + const auto primary_label_id = msgs::LabelId::FromUint(2); + node.labels.push_back(primary_label_id); auto literal = PrimitiveLiteral(); literal.value_ = TypedValue(static_cast<int64_t>(200)); auto p = plan::PropertiesMapList{}; - p.push_back(std::make_pair(msgs::PropertyId::FromUint(2), &literal)); + p.push_back(std::make_pair(msgs::PropertyId::FromUint(3), &literal)); node.properties.emplace<0>(std::move(p)); - auto once_cur = plan::MakeUniqueCursorPtr<MockedCursor>(utils::NewDeleteResource()); - EXPECT_CALL(BaseToMock(once_cur.get()), PullMultiple(_, _)).Times(1); - - std::shared_ptr<plan::LogicalOperator> once_op = std::make_shared<MockedLogicalOperator>(); - EXPECT_CALL(BaseToMock(once_op.get()), MakeCursor(_)).Times(1).WillOnce(Return(std::move(once_cur))); + auto once_op = std::make_shared<plan::Once>(); auto create_expand = plan::CreateNode(once_op, node); auto cursor = create_expand.MakeCursor(utils::NewDeleteResource()); MockedRequestRouter router; - EXPECT_CALL(router, CreateVertices(testing::_)) - .Times(1) - .WillOnce(::testing::Return(std::vector<msgs::CreateVerticesResponse>{})); - EXPECT_CALL(router, IsPrimaryKey(testing::_, testing::_)).WillRepeatedly(::testing::Return(true)); + EXPECT_CALL(router, CreateVertices(_)).Times(1).WillOnce(Return(std::vector<msgs::CreateVerticesResponse>{})); + EXPECT_CALL(router, IsPrimaryLabel(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(router, IsPrimaryKey(_, _)).WillRepeatedly(Return(true)); auto context = MakeContext(ast, symbol_table, &router, &id_alloc); auto multi_frame = CreateMultiFrame(context.symbol_table.max_position()); cursor->PullMultiple(multi_frame, context); + + auto frames = multi_frame.GetValidFramesReader(); + auto number_of_valid_frames = 0; + for (auto &frame : frames) { + ++number_of_valid_frames; + EXPECT_EQ(frame[node.symbol].IsVertex(), true); + const auto &n = frame[node.symbol].ValueVertex(); + EXPECT_THAT(n.Labels(), IsEmpty()); + EXPECT_EQ(n.PrimaryLabel(), primary_label_id); + // TODO(antaljanosbenjamin): Check primary key + } + EXPECT_EQ(number_of_valid_frames, 1); + + auto invalid_frames = multi_frame.GetInvalidFramesPopulator(); + auto number_of_invalid_frames = std::distance(invalid_frames.begin(), invalid_frames.end()); + EXPECT_EQ(number_of_invalid_frames, 99); } -} // namespace memgraph::query::v2 +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/query_v2_cypher_main_visitor.cpp b/tests/unit/query_v2_cypher_main_visitor.cpp index d7e4169e2..4d5311a6e 100644 --- a/tests/unit/query_v2_cypher_main_visitor.cpp +++ b/tests/unit/query_v2_cypher_main_visitor.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -299,7 +299,7 @@ std::shared_ptr<Base> gAstGeneratorTypes[] = { std::make_shared<CachedAstGenerator>(), }; -INSTANTIATE_TEST_CASE_P(AstGeneratorTypes, CypherMainVisitorTest, ::testing::ValuesIn(gAstGeneratorTypes)); +INSTANTIATE_TEST_SUITE_P(AstGeneratorTypes, CypherMainVisitorTest, ::testing::ValuesIn(gAstGeneratorTypes)); // NOTE: The above used to use *Typed Tests* functionality of gtest library. // Unfortunately, the compilation time of this test increased to full 2 minutes! @@ -313,7 +313,7 @@ INSTANTIATE_TEST_CASE_P(AstGeneratorTypes, CypherMainVisitorTest, ::testing::Val // ClonedAstGenerator, CachedAstGenerator> // AstGeneratorTypes; // -// TYPED_TEST_CASE(CypherMainVisitorTest, AstGeneratorTypes); +// TYPED_TEST_SUITE(CypherMainVisitorTest, AstGeneratorTypes); TEST_P(CypherMainVisitorTest, SyntaxException) { auto &ast_generator = *GetParam(); diff --git a/tests/unit/query_v2_expression_evaluator.cpp b/tests/unit/query_v2_expression_evaluator.cpp index 5e91d0d5a..0000e62b2 100644 --- a/tests/unit/query_v2_expression_evaluator.cpp +++ b/tests/unit/query_v2_expression_evaluator.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -125,6 +125,16 @@ class MockedRequestRouter : public RequestRouterInterface { bool IsPrimaryKey(LabelId primary_label, PropertyId property) const override { return true; } + std::optional<std::pair<uint64_t, uint64_t>> AllocateInitialEdgeIds(io::Address coordinator_address) override { + return {}; + } + + void InstallSimulatorTicker(std::function<bool()> tick_simulator) override {} + const std::vector<coordinator::SchemaProperty> &GetSchemaForLabel(storage::v3::LabelId /*label*/) const override { + static std::vector<coordinator::SchemaProperty> schema; + return schema; + }; + private: void SetUpNameIdMappers() { std::unordered_map<uint64_t, std::string> id_to_name; diff --git a/tests/unit/query_v2_plan.cpp b/tests/unit/query_v2_plan.cpp new file mode 100644 index 000000000..8fc5e11c7 --- /dev/null +++ b/tests/unit/query_v2_plan.cpp @@ -0,0 +1,178 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query_plan_checker_v2.hpp" + +#include <iostream> +#include <list> +#include <sstream> +#include <tuple> +#include <typeinfo> +#include <unordered_set> +#include <variant> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include "expr/semantic/symbol_generator.hpp" +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/plan/operator.hpp" +#include "query/v2/plan/planner.hpp" + +#include "query_v2_common.hpp" + +namespace memgraph::query { +::std::ostream &operator<<(::std::ostream &os, const Symbol &sym) { + return os << "Symbol{\"" << sym.name() << "\" [" << sym.position() << "] " << Symbol::TypeToString(sym.type()) << "}"; +} +} // namespace memgraph::query + +// using namespace memgraph::query::v2::plan; +using namespace memgraph::expr::plan; +using memgraph::query::Symbol; +using memgraph::query::SymbolGenerator; +using memgraph::query::v2::AstStorage; +using memgraph::query::v2::SingleQuery; +using memgraph::query::v2::SymbolTable; +using Type = memgraph::query::v2::EdgeAtom::Type; +using Direction = memgraph::query::v2::EdgeAtom::Direction; +using Bound = ScanAllByLabelPropertyRange::Bound; + +namespace { + +class Planner { + public: + template <class TDbAccessor> + Planner(std::vector<SingleQueryPart> single_query_parts, PlanningContext<TDbAccessor> context) { + memgraph::expr::Parameters parameters; + PostProcessor post_processor(parameters); + plan_ = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(single_query_parts, &context); + plan_ = post_processor.Rewrite(std::move(plan_), &context); + } + + auto &plan() { return *plan_; } + + private: + std::unique_ptr<LogicalOperator> plan_; +}; + +template <class... TChecker> +auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table, TChecker... checker) { + std::list<BaseOpChecker *> checkers{&checker...}; + PlanChecker plan_checker(checkers, symbol_table); + plan.Accept(plan_checker); + EXPECT_TRUE(plan_checker.checkers_.empty()); +} + +template <class TPlanner, class... TChecker> +auto CheckPlan(memgraph::query::v2::CypherQuery *query, AstStorage &storage, TChecker... checker) { + auto symbol_table = memgraph::expr::MakeSymbolTable(query); + FakeDistributedDbAccessor dba; + auto planner = MakePlanner<TPlanner>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, checker...); +} + +template <class T> +class TestPlanner : public ::testing::Test {}; + +using PlannerTypes = ::testing::Types<Planner>; + +TYPED_TEST_CASE(TestPlanner, PlannerTypes); + +TYPED_TEST(TestPlanner, MatchFilterPropIsNotNull) { + const char *prim_label_name = "prim_label_one"; + // Exact primary key match, one elem as PK. + { + FakeDistributedDbAccessor dba; + auto label = dba.Label(prim_label_name); + auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); + + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, prim_prop_one.second, 1); + + dba.CreateSchema(label, {prim_prop_one.second}); + + memgraph::query::v2::AstStorage storage; + + memgraph::query::v2::Expression *expected_primary_key; + expected_primary_key = PROPERTY_LOOKUP("n", prim_prop_one); + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", prim_label_name))), + WHERE(EQ(PROPERTY_LOOKUP("n", prim_prop_one), LITERAL(1))), RETURN("n"))); + auto symbol_table = (memgraph::expr::MakeSymbolTable(query)); + auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanByPrimaryKey(label, {expected_primary_key}), ExpectProduce()); + } + // Exact primary key match, two elem as PK. + { + FakeDistributedDbAccessor dba; + auto label = dba.Label(prim_label_name); + auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); + + auto prim_prop_two = PRIMARY_PROPERTY_PAIR("prim_prop_two"); + auto sec_prop_one = PRIMARY_PROPERTY_PAIR("sec_prop_one"); + auto sec_prop_two = PRIMARY_PROPERTY_PAIR("sec_prop_two"); + auto sec_prop_three = PRIMARY_PROPERTY_PAIR("sec_prop_three"); + + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, prim_prop_one.second, 1); + + dba.CreateSchema(label, {prim_prop_one.second, prim_prop_two.second}); + + dba.SetIndexCount(label, prim_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_one.second, 1); + dba.SetIndexCount(label, sec_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_three.second, 1); + memgraph::query::v2::AstStorage storage; + + memgraph::query::v2::Expression *expected_primary_key; + expected_primary_key = PROPERTY_LOOKUP("n", prim_prop_one); + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", prim_label_name))), + WHERE(AND(EQ(PROPERTY_LOOKUP("n", prim_prop_one), LITERAL(1)), + EQ(PROPERTY_LOOKUP("n", prim_prop_two), LITERAL(1)))), + RETURN("n"))); + auto symbol_table = (memgraph::expr::MakeSymbolTable(query)); + auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanByPrimaryKey(label, {expected_primary_key}), ExpectProduce()); + } + // One elem is missing from PK, default to ScanAllByLabelPropertyValue. + { + FakeDistributedDbAccessor dba; + auto label = dba.Label(prim_label_name); + + auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); + auto prim_prop_two = PRIMARY_PROPERTY_PAIR("prim_prop_two"); + + auto sec_prop_one = PRIMARY_PROPERTY_PAIR("sec_prop_one"); + auto sec_prop_two = PRIMARY_PROPERTY_PAIR("sec_prop_two"); + auto sec_prop_three = PRIMARY_PROPERTY_PAIR("sec_prop_three"); + + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, prim_prop_one.second, 1); + + dba.CreateSchema(label, {prim_prop_one.second, prim_prop_two.second}); + + dba.SetIndexCount(label, prim_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_one.second, 1); + dba.SetIndexCount(label, sec_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_three.second, 1); + memgraph::query::v2::AstStorage storage; + + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", prim_label_name))), + WHERE(EQ(PROPERTY_LOOKUP("n", prim_prop_one), LITERAL(1))), RETURN("n"))); + auto symbol_table = (memgraph::expr::MakeSymbolTable(query)); + auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, prim_prop_one, IDENT("n")), + ExpectProduce()); + } +} + +} // namespace diff --git a/tests/unit/storage_v3.cpp b/tests/unit/storage_v3.cpp index 6f6902d1c..1832165f9 100644 --- a/tests/unit/storage_v3.cpp +++ b/tests/unit/storage_v3.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -82,8 +82,8 @@ class StorageV3 : public ::testing::TestWithParam<bool> { Config{.gc = {.reclamation_interval = reclamation_interval}}}; coordinator::Hlc last_hlc{0, io::Time{}}; }; -INSTANTIATE_TEST_CASE_P(WithGc, StorageV3, ::testing::Values(true)); -INSTANTIATE_TEST_CASE_P(WithoutGc, StorageV3, ::testing::Values(false)); +INSTANTIATE_TEST_SUITE_P(WithGc, StorageV3, ::testing::Values(true)); +INSTANTIATE_TEST_SUITE_P(WithoutGc, StorageV3, ::testing::Values(false)); // NOLINTNEXTLINE(hicpp-special-member-functions) TEST_P(StorageV3, Commit) { diff --git a/tests/unit/storage_v3_edge.cpp b/tests/unit/storage_v3_edge.cpp index 3d2ab8bbd..8637db11c 100644 --- a/tests/unit/storage_v3_edge.cpp +++ b/tests/unit/storage_v3_edge.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -60,8 +60,8 @@ class StorageEdgeTest : public ::testing::TestWithParam<bool> { coordinator::Hlc last_hlc{0, io::Time{}}; }; -INSTANTIATE_TEST_CASE_P(EdgesWithProperties, StorageEdgeTest, ::testing::Values(true)); -INSTANTIATE_TEST_CASE_P(EdgesWithoutProperties, StorageEdgeTest, ::testing::Values(false)); +INSTANTIATE_TEST_SUITE_P(EdgesWithProperties, StorageEdgeTest, ::testing::Values(true)); +INSTANTIATE_TEST_SUITE_P(EdgesWithoutProperties, StorageEdgeTest, ::testing::Values(false)); // NOLINTNEXTLINE(hicpp-special-member-functions) TEST_P(StorageEdgeTest, EdgeCreateFromSmallerCommit) { diff --git a/tests/unit/storage_v3_isolation_level.cpp b/tests/unit/storage_v3_isolation_level.cpp index 6b661f632..bd0f27b6b 100644 --- a/tests/unit/storage_v3_isolation_level.cpp +++ b/tests/unit/storage_v3_isolation_level.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -135,6 +135,6 @@ TEST_P(StorageIsolationLevelTest, Visibility) { } } -INSTANTIATE_TEST_CASE_P(ParameterizedStorageIsolationLevelTests, StorageIsolationLevelTest, - ::testing::ValuesIn(isolation_levels), StorageIsolationLevelTest::PrintToStringParamName()); +INSTANTIATE_TEST_SUITE_P(ParameterizedStorageIsolationLevelTests, StorageIsolationLevelTest, + ::testing::ValuesIn(isolation_levels), StorageIsolationLevelTest::PrintToStringParamName()); } // namespace memgraph::storage::v3::tests diff --git a/tests/unit/utils_csv_parsing.cpp b/tests/unit/utils_csv_parsing.cpp index 3c852b171..e8d6c2241 100644 --- a/tests/unit/utils_csv_parsing.cpp +++ b/tests/unit/utils_csv_parsing.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -330,4 +330,4 @@ TEST_P(CsvReaderTest, EmptyColumns) { } } -INSTANTIATE_TEST_CASE_P(NewlineParameterizedTest, CsvReaderTest, ::testing::Values("\n", "\r\n")); +INSTANTIATE_TEST_SUITE_P(NewlineParameterizedTest, CsvReaderTest, ::testing::Values("\n", "\r\n")); diff --git a/tests/unit/utils_file_locker.cpp b/tests/unit/utils_file_locker.cpp index 21c71b1a3..f2217953e 100644 --- a/tests/unit/utils_file_locker.cpp +++ b/tests/unit/utils_file_locker.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -190,9 +190,9 @@ TEST_P(FileLockerParameterizedTest, RemovePath) { std::filesystem::current_path(save_path); } -INSTANTIATE_TEST_CASE_P(FileLockerPathVariantTests, FileLockerParameterizedTest, - ::testing::Values(std::make_tuple(false, false), std::make_tuple(false, true), - std::make_tuple(true, false), std::make_tuple(true, true))); +INSTANTIATE_TEST_SUITE_P(FileLockerPathVariantTests, FileLockerParameterizedTest, + ::testing::Values(std::make_tuple(false, false), std::make_tuple(false, true), + std::make_tuple(true, false), std::make_tuple(true, true))); TEST_F(FileLockerTest, MultipleLockers) { CreateFiles(3); diff --git a/tests/unit/utils_memory.cpp b/tests/unit/utils_memory.cpp index 70bf85653..73f78d545 100644 --- a/tests/unit/utils_memory.cpp +++ b/tests/unit/utils_memory.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -393,7 +393,7 @@ class AllocatorTest : public ::testing::Test {}; using ContainersWithAllocators = ::testing::Types<ContainerWithAllocatorLast, ContainerWithAllocatorFirst>; -TYPED_TEST_CASE(AllocatorTest, ContainersWithAllocators); +TYPED_TEST_SUITE(AllocatorTest, ContainersWithAllocators); TYPED_TEST(AllocatorTest, PropagatesToStdUsesAllocator) { std::vector<TypeParam, memgraph::utils::Allocator<TypeParam>> vec(memgraph::utils::NewDeleteResource());