diff --git a/src/query/db_accessor.hpp b/src/query/db_accessor.hpp index a93d50616..9873401d9 100644 --- a/src/query/db_accessor.hpp +++ b/src/query/db_accessor.hpp @@ -314,8 +314,8 @@ class DbAccessor final { return std::nullopt; } - std::optional<EdgeAccessor> FindEdge(storage::Gid gid) { - auto maybe_edge = accessor_->FindEdge(gid); + std::optional<EdgeAccessor> FindEdge(storage::Gid edge_id, storage::Gid vertex_id) { + auto maybe_edge = accessor_->FindEdge(edge_id, vertex_id); if (maybe_edge) return EdgeAccessor(*maybe_edge); return std::nullopt; } diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index 91c84cacb..dc69ed195 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -651,12 +651,13 @@ TypedValue Labels(const TypedValue *args, int64_t nargs, const FunctionContext & } TypedValue GetEdgeById(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { - FType<Or<Null, Integer>>("edge_id", args, nargs); + FType<Or<Null, Integer>, Or<Null, Integer>>("edge_id", args, nargs); auto *dba = ctx.db_accessor; - if (args[0].IsNull()) return TypedValue(ctx.memory); - auto id = args[0].ValueInt(); - auto maybe_edge = dba->FindEdge(storage::Gid::FromUint(id)); + if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory); + auto edge_id = args[0].ValueInt(); + auto vertex_id = args[1].ValueInt(); + auto maybe_edge = dba->FindEdge(storage::Gid::FromUint(edge_id), storage::Gid::FromUint(vertex_id)); if (!maybe_edge) throw query::QueryRuntimeException("Edge doesn't exist."); return TypedValue(*maybe_edge, ctx.memory); diff --git a/src/storage/v2/storage.cpp b/src/storage/v2/storage.cpp index 7f8ff2547..9535cfa2c 100644 --- a/src/storage/v2/storage.cpp +++ b/src/storage/v2/storage.cpp @@ -15,6 +15,7 @@ #include <memory> #include <mutex> #include <optional> +#include <ranges> #include <variant> #include <gflags/gflags.h> @@ -526,19 +527,28 @@ std::optional<VertexAccessor> Storage::Accessor::FindVertex(Gid gid, View view) return VertexAccessor::Create(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, view); } -std::optional<EdgeAccessor> Storage::Accessor::FindEdge(Gid gid) { - auto edge_acc = storage_->edges_.access(); +std::optional<EdgeAccessor> Storage::Accessor::FindEdge(Gid edge_id, Gid vertex_id) { auto vertex_acc = storage_->vertices_.access(); + auto vertex = &*vertex_acc.find(vertex_id); + auto it_in_edges = + std::ranges::find_if(vertex->in_edges.begin(), vertex->in_edges.end(), [edge_id](const auto &item) { + return (get<2>(item).ptr && get<2>(item).ptr->gid == edge_id) || get<2>(item).gid == edge_id; + }); - auto maybe_edge = edge_acc.find(gid); - if (maybe_edge == edge_acc.end()) return std::nullopt; + if (it_in_edges != vertex->in_edges.end()) + return EdgeAccessor{get<2>(*it_in_edges), get<0>(*it_in_edges), get<1>(*it_in_edges), vertex, + &transaction_, &storage_->indices_, &storage_->constraints_, config_}; - auto edge = &*maybe_edge; - auto vertex_from = vertex_acc.find(edge->vertex_gid_from); - auto vertex_to = vertex_acc.find(edge->vertex_gid_to); + auto it_out_edges = + std::ranges::find_if(vertex->out_edges.begin(), vertex->out_edges.end(), [edge_id](const auto &item) { + return (get<2>(item).ptr && get<2>(item).ptr->gid == edge_id) || get<2>(item).gid == edge_id; + }); - return EdgeAccessor{EdgeRef{edge}, edge->edge_type_id, &*vertex_from, &*vertex_to, - &transaction_, &storage_->indices_, &storage_->constraints_, config_}; + if (it_out_edges != vertex->out_edges.end()) + return EdgeAccessor{get<2>(*it_out_edges), get<0>(*it_out_edges), vertex, get<1>(*it_out_edges), &transaction_, + &storage_->indices_, &storage_->constraints_, config_}; + + return std::nullopt; } Result<std::optional<VertexAccessor>> Storage::Accessor::DeleteVertex(VertexAccessor *vertex) { diff --git a/src/storage/v2/storage.hpp b/src/storage/v2/storage.hpp index ec4508b49..4e8ed92b0 100644 --- a/src/storage/v2/storage.hpp +++ b/src/storage/v2/storage.hpp @@ -216,7 +216,7 @@ class Storage final { std::optional<VertexAccessor> FindVertex(Gid gid, View view); - std::optional<EdgeAccessor> FindEdge(Gid gid); + std::optional<EdgeAccessor> FindEdge(Gid edge_id, Gid vertex_id); VerticesIterable Vertices(View view) { return VerticesIterable(AllVerticesIterable(storage_->vertices_.access(), &transaction_, view,