diff --git a/libs/setup.sh b/libs/setup.sh index bb737d432..55288f535 100755 --- a/libs/setup.sh +++ b/libs/setup.sh @@ -171,7 +171,7 @@ benchmark_tag="v1.6.0" repo_clone_try_double "${primary_urls[gbenchmark]}" "${secondary_urls[gbenchmark]}" "benchmark" "$benchmark_tag" true # google test -googletest_tag="release-1.8.0" +googletest_tag="release-1.12.1" repo_clone_try_double "${primary_urls[gtest]}" "${secondary_urls[gtest]}" "googletest" "$googletest_tag" true # libbcrypt diff --git a/src/query/v2/multiframe.cpp b/src/query/v2/multiframe.cpp index 02d396fda..2cb591153 100644 --- a/src/query/v2/multiframe.cpp +++ b/src/query/v2/multiframe.cpp @@ -24,7 +24,7 @@ static_assert(std::forward_iterator<ValidFramesModifier::Iterator>); static_assert(std::forward_iterator<ValidFramesConsumer::Iterator>); static_assert(std::forward_iterator<InvalidFramesPopulator::Iterator>); -MultiFrame::MultiFrame(int64_t size_of_frame, size_t number_of_frames, utils::MemoryResource *execution_memory) +MultiFrame::MultiFrame(size_t size_of_frame, size_t number_of_frames, utils::MemoryResource *execution_memory) : frames_(utils::pmr::vector<FrameWithValidity>( number_of_frames, FrameWithValidity(size_of_frame, execution_memory), execution_memory)) { MG_ASSERT(number_of_frames > 0); diff --git a/src/query/v2/multiframe.hpp b/src/query/v2/multiframe.hpp index 8588993a7..7d9b73700 100644 --- a/src/query/v2/multiframe.hpp +++ b/src/query/v2/multiframe.hpp @@ -30,7 +30,7 @@ class MultiFrame { friend class ValidFramesReader; friend class InvalidFramesPopulator; - MultiFrame(int64_t size_of_frame, size_t number_of_frames, utils::MemoryResource *execution_memory); + MultiFrame(size_t size_of_frame, size_t number_of_frames, utils::MemoryResource *execution_memory); ~MultiFrame() = default; MultiFrame(const MultiFrame &other); @@ -168,7 +168,7 @@ class ValidFramesModifier { Iterator &operator++() { do { ptr_++; - } while (*this != iterator_wrapper_->end() && ptr_->IsValid()); + } while (*this != iterator_wrapper_->end() && !ptr_->IsValid()); return *this; } diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp index eeb5cd6b4..8681c6a2a 100644 --- a/src/query/v2/plan/operator.cpp +++ b/src/query/v2/plan/operator.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -2376,6 +2376,22 @@ class DistributedCreateExpandCursor : public Cursor { return true; } + void PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("CreateExpandMF"); + input_cursor_->PullMultiple(multi_frame, context); + auto request_vertices = ExpandCreationInfoToRequests(multi_frame, context); + { + SCOPED_REQUEST_WAIT_PROFILE; + auto &request_router = context.request_router; + auto results = request_router->CreateExpand(std::move(request_vertices)); + for (const auto &result : results) { + if (result.error) { + throw std::runtime_error("CreateExpand Request failed"); + } + } + } + } + void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { @@ -2450,6 +2466,63 @@ class DistributedCreateExpandCursor : public Cursor { return edge_requests; } + std::vector<msgs::NewExpand> ExpandCreationInfoToRequests(MultiFrame &multi_frame, ExecutionContext &context) const { + std::vector<msgs::NewExpand> edge_requests; + auto frames_modifier = multi_frame.GetValidFramesModifier(); + + for (auto &frame : frames_modifier) { + const auto &edge_info = self_.edge_info_; + msgs::NewExpand request{.id = {context.edge_ids_alloc->AllocateId()}}; + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, nullptr, + storage::v3::View::NEW); + request.type = {edge_info.edge_type}; + if (const auto *edge_info_properties = std::get_if<PropertiesMapList>(&edge_info.properties)) { + for (const auto &[property, value_expression] : *edge_info_properties) { + TypedValue val = value_expression->Accept(evaluator); + request.properties.emplace_back(property, storage::v3::TypedValueToValue(val)); + } + } else { + // handle parameter + auto property_map = evaluator.Visit(*std::get<ParameterLookup *>(edge_info.properties)).ValueMap(); + for (const auto &[property, value] : property_map) { + const auto property_id = context.request_router->NameToProperty(std::string(property)); + request.properties.emplace_back(property_id, storage::v3::TypedValueToValue(value)); + } + } + + TypedValue &v1_value = frame[self_.input_symbol_]; + const auto &v1 = v1_value.ValueVertex(); + const auto &v2 = OtherVertex(frame); + msgs::Edge edge{.src = request.src_vertex, + .dst = request.dest_vertex, + .properties = request.properties, + .id = request.id, + .type = request.type}; + frame[self_.edge_info_.symbol] = TypedValue(accessors::EdgeAccessor(std::move(edge), context.request_router)); + + // Set src and dest vertices + // TODO(jbajic) Currently we are only handling scenario where vertices + // are matched + switch (edge_info.direction) { + case EdgeAtom::Direction::IN: { + request.src_vertex = v2.Id(); + request.dest_vertex = v1.Id(); + break; + } + case EdgeAtom::Direction::OUT: { + request.src_vertex = v1.Id(); + request.dest_vertex = v2.Id(); + break; + } + case EdgeAtom::Direction::BOTH: + LOG_FATAL("Must indicate exact expansion direction here"); + } + + edge_requests.push_back(std::move(request)); + } + return edge_requests; + } + private: void ResetExecutionState() {} diff --git a/src/query/v2/request_router.hpp b/src/query/v2/request_router.hpp index d6c484e86..0515633c9 100644 --- a/src/query/v2/request_router.hpp +++ b/src/query/v2/request_router.hpp @@ -305,7 +305,8 @@ class RequestRouter : public RequestRouterInterface { MG_ASSERT(!new_edges.empty()); // create requests - std::vector<ShardRequestState<msgs::CreateExpandRequest>> requests_to_be_sent = RequestsForCreateExpand(new_edges); + std::vector<ShardRequestState<msgs::CreateExpandRequest>> requests_to_be_sent = + RequestsForCreateExpand(std::move(new_edges)); // begin all requests in parallel RunningRequests<msgs::CreateExpandRequest> running_requests = {}; @@ -436,7 +437,7 @@ class RequestRouter : public RequestRouterInterface { } std::vector<ShardRequestState<msgs::CreateExpandRequest>> RequestsForCreateExpand( - const std::vector<msgs::NewExpand> &new_expands) { + std::vector<msgs::NewExpand> new_expands) { std::map<ShardMetadata, msgs::CreateExpandRequest> per_shard_request_table; auto ensure_shard_exists_in_table = [&per_shard_request_table, transaction_id = transaction_id_](const ShardMetadata &shard) { diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 5bfa26afd..d547c58f8 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -407,3 +407,7 @@ target_link_libraries(${test_prefix}high_density_shard_create_scan mg-io mg-coor # Tests for awesome_memgraph_functions add_unit_test(query_v2_expression_evaluator.cpp) target_link_libraries(${test_prefix}query_v2_expression_evaluator mg-query-v2) + +# Tests for multiframes +add_unit_test(query_v2_create_expand_multiframe.cpp) +target_link_libraries(${test_prefix}query_v2_create_expand_multiframe mg-query-v2) diff --git a/tests/unit/mock_helpers.hpp b/tests/unit/mock_helpers.hpp new file mode 100644 index 000000000..c522b8602 --- /dev/null +++ b/tests/unit/mock_helpers.hpp @@ -0,0 +1,83 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include "query/v2/common.hpp" +#include "query/v2/context.hpp" +#include "query/v2/plan/operator.hpp" +#include "query/v2/request_router.hpp" + +namespace memgraph::query::v2::tests { +class MockedRequestRouter : public RequestRouterInterface { + public: + MOCK_METHOD(std::vector<VertexAccessor>, ScanVertices, (std::optional<std::string> label)); + MOCK_METHOD(std::vector<msgs::CreateVerticesResponse>, CreateVertices, (std::vector<msgs::NewVertex>)); + MOCK_METHOD(std::vector<msgs::ExpandOneResultRow>, ExpandOne, (msgs::ExpandOneRequest)); + MOCK_METHOD(std::vector<msgs::CreateExpandResponse>, CreateExpand, (std::vector<msgs::NewExpand>)); + MOCK_METHOD(std::vector<msgs::GetPropertiesResultRow>, GetProperties, (msgs::GetPropertiesRequest)); + MOCK_METHOD(void, StartTransaction, ()); + MOCK_METHOD(void, Commit, ()); + + MOCK_METHOD(storage::v3::EdgeTypeId, NameToEdgeType, (const std::string &), (const)); + MOCK_METHOD(storage::v3::PropertyId, NameToProperty, (const std::string &), (const)); + MOCK_METHOD(storage::v3::LabelId, NameToLabel, (const std::string &), (const)); + MOCK_METHOD(storage::v3::LabelId, LabelToName, (const std::string &), (const)); + MOCK_METHOD(const std::string &, PropertyToName, (storage::v3::PropertyId), (const)); + MOCK_METHOD(const std::string &, LabelToName, (storage::v3::LabelId label), (const)); + MOCK_METHOD(const std::string &, EdgeTypeToName, (storage::v3::EdgeTypeId type), (const)); + MOCK_METHOD(std::optional<storage::v3::PropertyId>, MaybeNameToProperty, (const std::string &), (const)); + MOCK_METHOD(std::optional<storage::v3::EdgeTypeId>, MaybeNameToEdgeType, (const std::string &), (const)); + MOCK_METHOD(std::optional<storage::v3::LabelId>, MaybeNameToLabel, (const std::string &), (const)); + MOCK_METHOD(bool, IsPrimaryLabel, (storage::v3::LabelId), (const)); + MOCK_METHOD(bool, IsPrimaryKey, (storage::v3::LabelId, storage::v3::PropertyId), (const)); +}; + +class MockedLogicalOperator : public plan::LogicalOperator { + public: + MOCK_METHOD(plan::UniqueCursorPtr, MakeCursor, (utils::MemoryResource *), (const)); + MOCK_METHOD(std::vector<expr::Symbol>, ModifiedSymbols, (const expr::SymbolTable &), (const)); + MOCK_METHOD(bool, HasSingleInput, (), (const)); + MOCK_METHOD(std::shared_ptr<LogicalOperator>, input, (), (const)); + MOCK_METHOD(void, set_input, (std::shared_ptr<LogicalOperator>)); + MOCK_METHOD(std::unique_ptr<LogicalOperator>, Clone, (AstStorage * storage), (const)); + MOCK_METHOD(bool, Accept, (plan::HierarchicalLogicalOperatorVisitor & visitor)); +}; + +class MockedCursor : public plan::Cursor { + public: + MOCK_METHOD(bool, Pull, (Frame &, expr::ExecutionContext &)); + MOCK_METHOD(void, PullMultiple, (MultiFrame &, expr::ExecutionContext &)); + MOCK_METHOD(void, Reset, ()); + MOCK_METHOD(void, Shutdown, ()); +}; + +inline expr::ExecutionContext MakeContext(const expr::AstStorage &storage, const expr::SymbolTable &symbol_table, + RequestRouterInterface *router, IdAllocator *id_alloc) { + expr::ExecutionContext context; + context.symbol_table = symbol_table; + context.evaluation_context.properties = NamesToProperties(storage.properties_, router); + context.evaluation_context.labels = NamesToLabels(storage.labels_, router); + context.edge_ids_alloc = id_alloc; + context.request_router = router; + return context; +} + +inline MockedLogicalOperator &BaseToMock(plan::LogicalOperator &op) { + return dynamic_cast<MockedLogicalOperator &>(op); +} + +inline MockedCursor &BaseToMock(plan::Cursor &cursor) { return dynamic_cast<MockedCursor &>(cursor); } + +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/query_v2_create_expand_multiframe.cpp b/tests/unit/query_v2_create_expand_multiframe.cpp new file mode 100644 index 000000000..ebdc4a9a7 --- /dev/null +++ b/tests/unit/query_v2_create_expand_multiframe.cpp @@ -0,0 +1,94 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <memory> +#include "mock_helpers.hpp" + +#include "query/v2/bindings/frame.hpp" +#include "query/v2/bindings/symbol_table.hpp" +#include "query/v2/common.hpp" +#include "query/v2/context.hpp" +#include "query/v2/plan/operator.hpp" +#include "query/v2/requests.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/shard.hpp" +#include "utils/logging.hpp" +#include "utils/memory.hpp" + +namespace memgraph::query::v2::tests { + +MultiFrame CreateMultiFrame(const size_t max_pos, const Symbol &src, const Symbol &dst, MockedRequestRouter *router) { + static constexpr size_t number_of_frames = 100; + MultiFrame multi_frame(max_pos, number_of_frames, utils::NewDeleteResource()); + auto frames_populator = multi_frame.GetInvalidFramesPopulator(); + size_t i = 0; + for (auto &frame : frames_populator) { + auto &src_acc = frame.at(src); + auto &dst_acc = frame.at(dst); + auto v1 = msgs::Vertex{.id = {{msgs::LabelId::FromUint(1)}, {msgs::Value(static_cast<int64_t>(i++))}}}; + auto v2 = msgs::Vertex{.id = {{msgs::LabelId::FromUint(1)}, {msgs::Value(static_cast<int64_t>(i++))}}}; + std::map<msgs::PropertyId, msgs::Value> mp; + src_acc = TypedValue(query::v2::accessors::VertexAccessor(v1, mp, router)); + dst_acc = TypedValue(query::v2::accessors::VertexAccessor(v2, mp, router)); + } + + multi_frame.MakeAllFramesInvalid(); + + return multi_frame; +} + +TEST(CreateExpandTest, Cursor) { + using testing::_; + using testing::Return; + + AstStorage ast; + SymbolTable symbol_table; + + plan::NodeCreationInfo node; + plan::EdgeCreationInfo edge; + edge.edge_type = msgs::EdgeTypeId::FromUint(1); + edge.direction = EdgeAtom::Direction::IN; + edge.symbol = symbol_table.CreateSymbol("e", true); + auto id_alloc = IdAllocator(0, 100); + + const auto &src = symbol_table.CreateSymbol("n", true); + node.symbol = symbol_table.CreateSymbol("u", true); + + auto once_op = std::make_shared<plan::Once>(); + auto once_cur = once_op->MakeCursor(utils::NewDeleteResource()); + + auto create_expand = plan::CreateExpand(node, edge, once_op, src, true); + auto cursor = create_expand.MakeCursor(utils::NewDeleteResource()); + + MockedRequestRouter router; + EXPECT_CALL(router, CreateExpand(_)) + .Times(1) + .WillOnce(Return(std::vector<msgs::CreateExpandResponse>{msgs::CreateExpandResponse{}})); + auto context = MakeContext(ast, symbol_table, &router, &id_alloc); + auto multi_frame = CreateMultiFrame(context.symbol_table.max_position(), src, node.symbol, &router); + cursor->PullMultiple(multi_frame, context); + + auto frames = multi_frame.GetValidFramesReader(); + auto number_of_valid_frames = 0; + for (auto &frame : frames) { + ++number_of_valid_frames; + EXPECT_EQ(frame[edge.symbol].IsEdge(), true); + const auto &e = frame[edge.symbol].ValueEdge(); + EXPECT_EQ(e.EdgeType(), edge.edge_type); + } + EXPECT_EQ(number_of_valid_frames, 1); + + auto invalid_frames = multi_frame.GetInvalidFramesPopulator(); + auto number_of_invalid_frames = std::distance(invalid_frames.begin(), invalid_frames.end()); + EXPECT_EQ(number_of_invalid_frames, 99); +} + +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/utils_settings.cpp b/tests/unit/utils_settings.cpp index 388e467ed..20262e1f5 100644 --- a/tests/unit/utils_settings.cpp +++ b/tests/unit/utils_settings.cpp @@ -11,7 +11,7 @@ #include <filesystem> -#include <gmock/gmock-generated-matchers.h> +#include <gmock/gmock-matchers.h> #include <gtest/gtest.h> #include "utils/settings.hpp"