diff --git a/src/expr/interpret/frame.hpp b/src/expr/interpret/frame.hpp index 72bfcf245..457806680 100644 --- a/src/expr/interpret/frame.hpp +++ b/src/expr/interpret/frame.hpp @@ -34,6 +34,7 @@ class Frame { const TypedValue &at(const Symbol &symbol) const { return elems_.at(symbol.position()); } auto &elems() { return elems_; } + const auto &elems() const { return elems_; } utils::MemoryResource *GetMemoryResource() const { return elems_.get_allocator().GetMemoryResource(); } diff --git a/src/query/v2/interpreter.cpp b/src/query/v2/interpreter.cpp index 8cec32688..bd1bcdb86 100644 --- a/src/query/v2/interpreter.cpp +++ b/src/query/v2/interpreter.cpp @@ -778,6 +778,8 @@ std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::PullMultiple(AnyStrea break; } } + } else { + multi_frame_.MakeAllFramesInvalid(); } } diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp index 243902e94..4b5999138 100644 --- a/src/query/v2/plan/operator.cpp +++ b/src/query/v2/plan/operator.cpp @@ -194,6 +194,9 @@ class DistributedCreateNodeCursor : public Cursor { SCOPED_PROFILE_OP("CreateNodeMF"); input_cursor_->PullMultiple(multi_frame, context); auto *request_router = context.request_router; + if (!multi_frame.HasValidFrame()) { + return; + } { SCOPED_REQUEST_WAIT_PROFILE; request_router->CreateVertices(NodeCreationInfoToRequests(context, multi_frame)); @@ -519,12 +522,12 @@ class DistributedScanAllAndFilterCursor : public Cursor { return true; } - void PullMultiple(MultiFrame &input_multi_frame, ExecutionContext &context) override { + void PullMultiple(MultiFrame &output_multi_frame, ExecutionContext &context) override { SCOPED_PROFILE_OP(op_name_); if (!own_multi_frame_.has_value()) { - own_multi_frame_.emplace(MultiFrame(input_multi_frame.GetFirstFrame().elems().size(), kNumberOfFramesInMultiframe, - input_multi_frame.GetMemoryResource())); + own_multi_frame_.emplace(MultiFrame(output_multi_frame.GetFirstFrame().elems().size(), + kNumberOfFramesInMultiframe, output_multi_frame.GetMemoryResource())); MakeRequest(context); PullNextFrames(context); @@ -534,7 +537,7 @@ class DistributedScanAllAndFilterCursor : public Cursor { return; } - for (auto &frame : input_multi_frame.GetInvalidFramesPopulator()) { + for (auto &frame : output_multi_frame.GetInvalidFramesPopulator()) { if (MustAbort(context)) { throw HintedAbortError(); } @@ -2566,6 +2569,9 @@ class DistributedCreateExpandCursor : public Cursor { void PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) override { SCOPED_PROFILE_OP("CreateExpandMF"); input_cursor_->PullMultiple(multi_frame, context); + if (!multi_frame.HasValidFrame()) { + return; + } auto request_vertices = ExpandCreationInfoToRequests(multi_frame, context); { SCOPED_REQUEST_WAIT_PROFILE; @@ -2719,7 +2725,7 @@ class DistributedCreateExpandCursor : public Cursor { class DistributedExpandCursor : public Cursor { public: - explicit DistributedExpandCursor(const Expand &self, utils::MemoryResource *mem) + DistributedExpandCursor(const Expand &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)), current_in_edge_it_(current_in_edges_.begin()), @@ -2756,16 +2762,15 @@ class DistributedExpandCursor : public Cursor { throw std::runtime_error("EdgeDirection Both not implemented"); } }; - msgs::ExpandOneRequest request; + + msgs::GetPropertiesRequest request; // to not fetch any properties of the edges - request.edge_properties.emplace(); - request.src_vertices.push_back(get_dst_vertex(edge, direction)); - request.direction = (direction == EdgeAtom::Direction::IN) ? msgs::EdgeDirection::OUT : msgs::EdgeDirection::IN; - auto result_rows = context.request_router->ExpandOne(std::move(request)); + request.vertex_ids.push_back(get_dst_vertex(edge, direction)); + auto result_rows = context.request_router->GetProperties(std::move(request)); MG_ASSERT(result_rows.size() == 1); auto &result_row = result_rows.front(); - frame[self_.common_.node_symbol] = accessors::VertexAccessor( - msgs::Vertex{result_row.src_vertex}, result_row.src_vertex_properties, context.request_router); + frame[self_.common_.node_symbol] = + accessors::VertexAccessor(msgs::Vertex{result_row.vertex}, result_row.props, context.request_router); } bool InitEdges(Frame &frame, ExecutionContext &context) { @@ -2874,10 +2879,173 @@ class DistributedExpandCursor : public Cursor { } } + void InitEdgesMultiple(ExecutionContext &context) { + TypedValue &vertex_value = (*own_frames_it_)[self_.input_symbol_]; + + if (vertex_value.IsNull()) { + return; + } + + ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); + auto &vertex = vertex_value.ValueVertex(); + + const auto convert_edges = [&vertex, &context]( + std::vector<msgs::ExpandOneResultRow::EdgeWithSpecificProperties> &&edge_messages, + const EdgeAtom::Direction direction) { + std::vector<EdgeAccessor> edge_accessors; + edge_accessors.reserve(edge_messages.size()); + + switch (direction) { + case EdgeAtom::Direction::IN: { + for (auto &edge : edge_messages) { + edge_accessors.emplace_back(msgs::Edge{std::move(edge.other_end), vertex.Id(), {}, {edge.gid}, edge.type}, + context.request_router); + } + break; + } + case EdgeAtom::Direction::OUT: { + for (auto &edge : edge_messages) { + edge_accessors.emplace_back(msgs::Edge{vertex.Id(), std::move(edge.other_end), {}, {edge.gid}, edge.type}, + context.request_router); + } + break; + } + case EdgeAtom::Direction::BOTH: { + LOG_FATAL("Must indicate exact expansion direction here"); + } + } + return edge_accessors; + }; + + auto *result_row = vertex_id_to_result_row[vertex.Id()]; + current_in_edges_.clear(); + current_in_edges_ = + convert_edges(std::move(result_row->in_edges_with_specific_properties), EdgeAtom::Direction::IN); + current_in_edge_it_ = current_in_edges_.begin(); + current_out_edges_ = + convert_edges(std::move(result_row->out_edges_with_specific_properties), EdgeAtom::Direction::OUT); + current_out_edge_it_ = current_out_edges_.begin(); + vertex_id_to_result_row.erase(vertex.Id()); + } + + bool PullInputFrames(ExecutionContext &context) { + input_cursor_->PullMultiple(*own_multi_frame_, context); + // These needs to be updated regardless of the result of the pull, otherwise the consumer and iterator might + // get corrupted because of the operations done on our MultiFrame. + own_frames_consumer_ = own_multi_frame_->GetValidFramesConsumer(); + own_frames_it_ = own_frames_consumer_->begin(); + if (!own_multi_frame_->HasValidFrame()) { + return false; + } + + msgs::ExpandOneRequest request; + request.direction = DirectionToMsgsDirection(self_.common_.direction); + // to not fetch any properties of the edges + request.edge_properties.emplace(); + for (const auto &frame : own_multi_frame_->GetValidFramesReader()) { + const auto &vertex_value = frame[self_.input_symbol_]; + + // Null check due to possible failed optional match. + MG_ASSERT(!vertex_value.IsNull()); + + ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); + const auto &vertex = vertex_value.ValueVertex(); + request.src_vertices.push_back(vertex.Id()); + } + + result_rows_ = std::invoke([&context, &request]() mutable { + SCOPED_REQUEST_WAIT_PROFILE; + return context.request_router->ExpandOne(std::move(request)); + }); + vertex_id_to_result_row.clear(); + for (auto &row : result_rows_) { + vertex_id_to_result_row[row.src_vertex.id] = &row; + } + + return true; + } + + void PullMultiple(MultiFrame &output_multi_frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("DistributedExpandMF"); + MG_ASSERT(!self_.common_.existing_node); + EnsureOwnMultiFrameIsGood(output_multi_frame); + // A helper function for expanding a node from an edge. + + auto output_frames_populator = output_multi_frame.GetInvalidFramesPopulator(); + + while (true) { + switch (state_) { + case State::PullInputAndEdges: { + if (!PullInputFrames(context)) { + state_ = State::Exhausted; + return; + } + state_ = State::InitInOutEdgesIt; + break; + } + case State::InitInOutEdgesIt: { + if (own_frames_it_ == own_frames_consumer_->end()) { + state_ = State::PullInputAndEdges; + } else { + InitEdges(*own_frames_it_, context); + state_ = State::PopulateOutput; + } + break; + } + case State::PopulateOutput: { + if (!output_multi_frame.HasInvalidFrame()) { + return; + } + if (current_in_edge_it_ == current_in_edges_.end() && current_out_edge_it_ == current_out_edges_.end()) { + own_frames_it_->MakeInvalid(); + ++own_frames_it_; + state_ = State::InitInOutEdgesIt; + continue; + } + auto populate_edges = [this, &context, &output_frames_populator]( + const EdgeAtom::Direction direction, std::vector<EdgeAccessor>::iterator ¤t, + const std::vector<EdgeAccessor>::iterator &end) { + for (auto output_frame_it = output_frames_populator.begin(); + output_frame_it != output_frames_populator.end() && current != end; ++output_frame_it) { + auto &edge = *current; + ++current; + auto &output_frame = *output_frame_it; + output_frame = *own_frames_it_; + output_frame[self_.common_.edge_symbol] = edge; + PullDstVertex(output_frame, context, direction); + } + }; + populate_edges(EdgeAtom::Direction::IN, current_in_edge_it_, current_in_edges_.end()); + populate_edges(EdgeAtom::Direction::OUT, current_out_edge_it_, current_out_edges_.end()); + break; + } + case State::Exhausted: { + return; + } + } + } + } + + void EnsureOwnMultiFrameIsGood(MultiFrame &output_multi_frame) { + if (!own_multi_frame_.has_value()) { + own_multi_frame_.emplace(MultiFrame(output_multi_frame.GetFirstFrame().elems().size(), + kNumberOfFramesInMultiframe, output_multi_frame.GetMemoryResource())); + own_frames_consumer_.emplace(own_multi_frame_->GetValidFramesConsumer()); + own_frames_it_ = own_frames_consumer_->begin(); + } + MG_ASSERT(output_multi_frame.GetFirstFrame().elems().size() == own_multi_frame_->GetFirstFrame().elems().size()); + } + void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); + vertex_id_to_result_row.clear(); + result_rows_.clear(); + own_frames_it_ = ValidFramesConsumer::Iterator{}; + own_frames_consumer_.reset(); + own_multi_frame_->MakeAllFramesInvalid(); + state_ = State::PullInputAndEdges; current_in_edges_.clear(); current_out_edges_.clear(); current_in_edge_it_ = current_in_edges_.end(); @@ -2885,12 +3053,21 @@ class DistributedExpandCursor : public Cursor { } private: + enum class State { PullInputAndEdges, InitInOutEdgesIt, PopulateOutput, Exhausted }; + const Expand &self_; const UniqueCursorPtr input_cursor_; std::vector<EdgeAccessor> current_in_edges_; std::vector<EdgeAccessor> current_out_edges_; std::vector<EdgeAccessor>::iterator current_in_edge_it_; std::vector<EdgeAccessor>::iterator current_out_edge_it_; + State state_{State::PullInputAndEdges}; + std::optional<MultiFrame> own_multi_frame_; + std::optional<ValidFramesConsumer> own_frames_consumer_; + ValidFramesConsumer::Iterator own_frames_it_; + std::vector<msgs::ExpandOneResultRow> result_rows_; + // This won't work if any vertex id is duplicated in the input + std::unordered_map<msgs::VertexId, msgs::ExpandOneResultRow *> vertex_id_to_result_row; }; } // namespace memgraph::query::v2::plan diff --git a/src/query/v2/requests.hpp b/src/query/v2/requests.hpp index 2335fea7d..b2d7f9123 100644 --- a/src/query/v2/requests.hpp +++ b/src/query/v2/requests.hpp @@ -25,6 +25,7 @@ #include "storage/v3/id_types.hpp" #include "storage/v3/property_value.hpp" #include "storage/v3/result.hpp" +#include "utils/fnv.hpp" namespace memgraph::msgs { @@ -579,3 +580,48 @@ using WriteResponses = std::variant<CreateVerticesResponse, DeleteVerticesRespon CreateExpandResponse, DeleteEdgesResponse, UpdateEdgesResponse, CommitResponse>; } // namespace memgraph::msgs + +namespace std { + +template <> +struct hash<memgraph::msgs::Value>; + +template <> +struct hash<memgraph::msgs::VertexId> { + size_t operator()(const memgraph::msgs::VertexId &id) const { + using LabelId = memgraph::storage::v3::LabelId; + using Value = memgraph::msgs::Value; + return memgraph::utils::HashCombine<LabelId, std::vector<Value>, std::hash<LabelId>, + memgraph::utils::FnvCollection<std::vector<Value>, Value>>{}(id.first.id, + id.second); + } +}; + +template <> +struct hash<memgraph::msgs::Value> { + size_t operator()(const memgraph::msgs::Value &value) const { + using Type = memgraph::msgs::Value::Type; + switch (value.type) { + case Type::Null: + return std::hash<size_t>{}(0U); + case Type::Bool: + return std::hash<bool>{}(value.bool_v); + case Type::Int64: + return std::hash<int64_t>{}(value.int_v); + case Type::Double: + return std::hash<double>{}(value.double_v); + case Type::String: + return std::hash<std::string>{}(value.string_v); + case Type::List: + LOG_FATAL("Add hash for lists"); + case Type::Map: + LOG_FATAL("Add hash for maps"); + case Type::Vertex: + LOG_FATAL("Add hash for vertices"); + case Type::Edge: + LOG_FATAL("Add hash for edges"); + } + } +}; + +} // namespace std