From 034b54cb7222877de1b0f77e4b532e36319a2afd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Bruno=20Sa=C4=8Dari=C4=87?= <bruno.sacaric@gmail.com>
Date: Wed, 25 Jan 2023 15:32:00 +0100
Subject: [PATCH] Fix bug on all shortest paths with an upper bound  (#737)

---
 src/query/plan/operator.cpp                   | 43 ++++++++-----------
 .../features/memgraph_allshortest.feature     | 20 +++++++++
 2 files changed, 39 insertions(+), 24 deletions(-)

diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp
index debbc36bc..6583a52a3 100644
--- a/src/query/plan/operator.cpp
+++ b/src/query/plan/operator.cpp
@@ -1,4 +1,4 @@
-// Copyright 2022 Memgraph Ltd.
+// Copyright 2023 Memgraph Ltd.
 //
 // Use of this software is governed by the Business Source License
 // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@@ -2020,7 +2020,7 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor {
         next_edges_.clear();
         traversal_stack_.clear();
 
-        pq_.push({TypedValue(), 0, *start_vertex, std::nullopt});
+        expand_from_vertex(*start_vertex, TypedValue(), 0);
         visited_cost_.emplace(*start_vertex, 0);
         frame[self_.common_.edge_symbol] = TypedValue::TVector(memory);
       }
@@ -2029,33 +2029,28 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor {
       while (!pq_.empty()) {
         if (MustAbort(context)) throw HintedAbortError();
 
-        auto [current_weight, current_depth, current_vertex, maybe_directed_edge] = pq_.top();
+        const auto [current_weight, current_depth, current_vertex, directed_edge] = pq_.top();
         pq_.pop();
 
+        const auto &[current_edge, direction, weight] = directed_edge;
+        if (expanded_.contains(current_edge)) continue;
+        expanded_.emplace(current_edge);
+
         // Expand only if what we've just expanded is less than max depth.
         if (current_depth < upper_bound_) {
-          if (maybe_directed_edge) {
-            auto &[current_edge, direction, weight] = *maybe_directed_edge;
-            if (expanded_.find(current_edge) != expanded_.end()) continue;
-            expanded_.emplace(current_edge);
-          }
           expand_from_vertex(current_vertex, current_weight, current_depth);
         }
 
-        // if current vertex is not starting vertex, maybe_directed_edge will not be nullopt
-        if (maybe_directed_edge) {
-          auto &[current_edge, direction, weight] = *maybe_directed_edge;
-          // Searching for a previous vertex in the expansion
-          auto prev_vertex = direction == EdgeAtom::Direction::IN ? current_edge.To() : current_edge.From();
+        // Searching for a previous vertex in the expansion
+        auto prev_vertex = direction == EdgeAtom::Direction::IN ? current_edge.To() : current_edge.From();
 
-          // Update the parent
-          if (next_edges_.find({prev_vertex, current_depth - 1}) == next_edges_.end()) {
-            utils::pmr::list<DirectedEdge> empty(memory);
-            next_edges_[{prev_vertex, current_depth - 1}] = std::move(empty);
-          }
-
-          next_edges_.at({prev_vertex, current_depth - 1}).emplace_back(*maybe_directed_edge);
+        // Update the parent
+        if (next_edges_.find({prev_vertex, current_depth - 1}) == next_edges_.end()) {
+          utils::pmr::list<DirectedEdge> empty(memory);
+          next_edges_[{prev_vertex, current_depth - 1}] = std::move(empty);
         }
+
+        next_edges_.at({prev_vertex, current_depth - 1}).emplace_back(directed_edge);
       }
 
       if (start_vertex && next_edges_.find({*start_vertex, 0}) != next_edges_.end()) {
@@ -2112,8 +2107,8 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor {
   // Priority queue comparator. Keep lowest weight on top of the queue.
   class PriorityQueueComparator {
    public:
-    bool operator()(const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<DirectedEdge>> &lhs,
-                    const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<DirectedEdge>> &rhs) {
+    bool operator()(const std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge> &lhs,
+                    const std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge> &rhs) {
       const auto &lhs_weight = std::get<0>(lhs);
       const auto &rhs_weight = std::get<0>(rhs);
       // Null defines minimum value for all types
@@ -2132,8 +2127,8 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor {
 
   // Priority queue - core element of the algorithm.
   // Stores: {weight, depth, next vertex, edge and direction}
-  std::priority_queue<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<DirectedEdge>>,
-                      utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<DirectedEdge>>>,
+  std::priority_queue<std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge>,
+                      utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge>>,
                       PriorityQueueComparator>
       pq_;
 
diff --git a/tests/gql_behave/tests/memgraph_V1/features/memgraph_allshortest.feature b/tests/gql_behave/tests/memgraph_V1/features/memgraph_allshortest.feature
index 8a8cd76cc..7f224a016 100644
--- a/tests/gql_behave/tests/memgraph_V1/features/memgraph_allshortest.feature
+++ b/tests/gql_behave/tests/memgraph_V1/features/memgraph_allshortest.feature
@@ -15,6 +15,26 @@ Feature: All Shortest Path
           | '1' |
           | '3' |
 
+  Scenario: Test match allShortest upper bound 2
+      Given an empty graph
+      And having executed:
+          """
+          CREATE (a {a:'0'})-[:r {w: 2}]->(b {a:'1'})-[:r {w: 3}]->(c {a:'2'}),
+            (a)-[:re {w: 2}]->(b),
+            (b)-[:re {w:3}]->(c),
+            ({a: '4'})<-[:r {w: 1}]-(a),
+            ({a: '5'})<-[:r {w: 1}]-(a),
+            (c)-[:r {w: 1}]->({a: '6'}),
+            (c)-[:r {w: 1}]->({a: '7'})
+          """
+      When executing query:
+          """
+          MATCH path=(n {a:'0'})-[r *allShortest ..2 (e, n | 1 ) w]->(m {a:'2'}) RETURN COUNT(path) AS c
+          """
+      Then the result should be:
+          | c |
+          | 4 |
+
   Scenario: Test match allShortest filtered
       Given an empty graph
       And having executed: