From 1a78c3695d4787bf4624b03a56fe46af2c7bd976 Mon Sep 17 00:00:00 2001
From: Antonio Andelic <antonio2368@users.noreply.github.com>
Date: Tue, 19 Oct 2021 14:39:23 +0200
Subject: [PATCH] Support multiple types for weighted shortest path (#278)

---
 src/query/plan/operator.cpp                   | 75 +++++++++++++++----
 .../features/memgraph_wshortest.feature       | 46 +++++++++++-
 .../features/wsp.feature                      |  2 +-
 tests/unit/interpreter.cpp                    | 75 +++++++++++++------
 4 files changed, 155 insertions(+), 43 deletions(-)

diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp
index d6184be3b..bf415e2d4 100644
--- a/src/query/plan/operator.cpp
+++ b/src/query/plan/operator.cpp
@@ -49,6 +49,7 @@
 #include "utils/pmr/vector.hpp"
 #include "utils/readable_size.hpp"
 #include "utils/string.hpp"
+#include "utils/temporal.hpp"
 
 // macro for the default implementation of LogicalOperator::Accept
 // that accepts the visitor and visits it's input_ operator
@@ -1445,7 +1446,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
     // satisfy the "where" condition. if so, places them in the priority
     // queue.
     auto expand_pair = [this, &evaluator, &frame, &create_state](const EdgeAccessor &edge, const VertexAccessor &vertex,
-                                                                 double weight, int64_t depth) {
+                                                                 const TypedValue &total_weight, int64_t depth) {
       auto *memory = evaluator.GetMemoryResource();
       if (self_.filter_lambda_.expression) {
         frame[self_.filter_lambda_.inner_edge_symbol] = edge;
@@ -1457,27 +1458,48 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
       frame[self_.weight_lambda_->inner_edge_symbol] = edge;
       frame[self_.weight_lambda_->inner_node_symbol] = vertex;
 
-      TypedValue typed_weight = self_.weight_lambda_->expression->Accept(evaluator);
+      TypedValue current_weight = self_.weight_lambda_->expression->Accept(evaluator);
 
-      if (!typed_weight.IsNumeric()) {
-        throw QueryRuntimeException("Calculated weight must be numeric, got {}.", typed_weight.type());
+      if (!current_weight.IsNumeric() && !current_weight.IsDuration()) {
+        throw QueryRuntimeException("Calculated weight must be numeric or a Duration, got {}.", current_weight.type());
       }
-      if ((typed_weight < TypedValue(0, memory)).ValueBool()) {
+
+      const auto is_valid_numeric = [&] {
+        return current_weight.IsNumeric() && (current_weight >= TypedValue(0, memory)).ValueBool();
+      };
+
+      const auto is_valid_duration = [&] {
+        return current_weight.IsDuration() && (current_weight >= TypedValue(utils::Duration(0), memory)).ValueBool();
+      };
+
+      if (!is_valid_numeric() && !is_valid_duration()) {
         throw QueryRuntimeException("Calculated weight must be non-negative!");
       }
 
       auto next_state = create_state(vertex, depth);
-      auto next_weight = TypedValue(weight, memory) + typed_weight;
-      auto found_it = total_cost_.find(next_state);
-      if (found_it != total_cost_.end() && found_it->second.ValueDouble() <= next_weight.ValueDouble()) return;
 
-      pq_.push({next_weight.ValueDouble(), depth + 1, vertex, edge});
+      TypedValue next_weight = std::invoke([&] {
+        if (total_weight.IsNull()) {
+          return current_weight;
+        }
+
+        ValidateWeightTypes(current_weight, total_weight);
+
+        return TypedValue(current_weight, memory) + total_weight;
+      });
+
+      auto found_it = total_cost_.find(next_state);
+      if (found_it != total_cost_.end() && (found_it->second.IsNull() || (found_it->second <= next_weight).ValueBool()))
+        return;
+
+      pq_.push({next_weight, depth + 1, vertex, edge});
     };
 
     // Populates the priority queue structure with expansions
     // from the given vertex. skips expansions that don't satisfy
     // the "where" condition.
-    auto expand_from_vertex = [this, &expand_pair](const VertexAccessor &vertex, double weight, int64_t depth) {
+    auto expand_from_vertex = [this, &expand_pair](const VertexAccessor &vertex, const TypedValue &weight,
+                                                   int64_t depth) {
       if (self_.common_.direction != EdgeAtom::Direction::IN) {
         auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types));
         for (const auto &edge : out_edges) {
@@ -1522,7 +1544,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
         total_cost_.clear();
         yielded_vertices_.clear();
 
-        pq_.push({0.0, 0, vertex, std::nullopt});
+        pq_.push({TypedValue(), 0, vertex, std::nullopt});
         // We are adding the starting vertex to the set of yielded vertices
         // because we don't want to yield paths that end with the starting
         // vertex.
@@ -1622,17 +1644,38 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
   // Keeps track of vertices for which we yielded a path already.
   utils::pmr::unordered_set<VertexAccessor> yielded_vertices_;
 
+  static void ValidateWeightTypes(const TypedValue &lhs, const TypedValue &rhs) {
+    if (!((lhs.IsNumeric() && lhs.IsNumeric()) || (rhs.IsDuration() && rhs.IsDuration()))) {
+      throw QueryRuntimeException(utils::MessageWithLink(
+          "All weights should be of the same type, either numeric or a Duration. Please update the weight "
+          "expression or the filter expression.",
+          "https://memgr.ph/wsp"));
+    }
+  }
+
   // Priority queue comparator. Keep lowest weight on top of the queue.
   class PriorityQueueComparator {
    public:
-    bool operator()(const std::tuple<double, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &lhs,
-                    const std::tuple<double, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &rhs) {
-      return std::get<0>(lhs) > std::get<0>(rhs);
+    bool operator()(const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &lhs,
+                    const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &rhs) {
+      const auto &lhs_weight = std::get<0>(lhs);
+      const auto &rhs_weight = std::get<0>(rhs);
+      // Null defines minimum value for all types
+      if (lhs_weight.IsNull()) {
+        return false;
+      }
+
+      if (rhs_weight.IsNull()) {
+        return true;
+      }
+
+      ValidateWeightTypes(lhs_weight, rhs_weight);
+      return (lhs_weight > rhs_weight).ValueBool();
     }
   };
 
-  std::priority_queue<std::tuple<double, int64_t, VertexAccessor, std::optional<EdgeAccessor>>,
-                      utils::pmr::vector<std::tuple<double, int64_t, VertexAccessor, std::optional<EdgeAccessor>>>,
+  std::priority_queue<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>>,
+                      utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>>>,
                       PriorityQueueComparator>
       pq_;
 
diff --git a/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature b/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature
index 007dc273e..1c98c2830 100644
--- a/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature
+++ b/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature
@@ -40,10 +40,10 @@ Feature: Weighted Shortest Path
           MATCH (n {a:'0'})-[le *wShortest 10 (e, n | e.w ) w]->(m) RETURN m.a, size(le) as s, w
           """
       Then the result should be:
-          | m.a | s | w   |
-          | '1' | 1 | 1.0 |
-          | '2' | 2 | 3.0 |
-          | '3' | 1 | 4.0 |
+          | m.a | s | w |
+          | '1' | 1 | 1 |
+          | '2' | 2 | 3 |
+          | '3' | 1 | 4 |
 
   Scenario: Test match wShortest single edge type filtered
       Given an empty graph
@@ -116,4 +116,42 @@ Feature: Weighted Shortest Path
           """
       Then an error should be raised
 
+  Scenario: Test match wShortest weight duration
+      Given an empty graph
+      And having executed:
+          """
+          CREATE (n {a:'0'})-[:r {w: DURATION('PT1S')}]->({a:'1'})-[:r {w: DURATION('PT2S')}]->({a:'2'}), (n)-[:r {w: DURATION('PT4S')}]->({a:'3'})
+          """
+      When executing query:
+          """
+          MATCH (n {a:'0'})-[le *wShortest 10 (e, n | e.w ) w]->(m) RETURN m.a, size(le) as s, w
+          """
+      Then the result should be:
+          | m.a | s | w    |
+          | '1' | 1 | PT1S |
+          | '2' | 2 | PT3S |
+          | '3' | 1 | PT4S |
 
+  Scenario: Test match wShortest weight negative duration
+      Given an empty graph
+      And having executed:
+          """
+          CREATE (n {a:'0'})-[:r {w: DURATION({seconds: -1})}]->({a:'1'})-[:r {w: DURATION('PT2S')}]->({a:'2'}), (n)-[:r {w: DURATION('PT4S')}]->({a:'3'})
+          """
+      When executing query:
+          """
+          MATCH (n {a:'0'})-[le *wShortest 10 (e, n | e.w ) w]->(m) RETURN m.a, size(le) as s, w
+          """
+      Then an error should be raised
+
+  Scenario: Test match wShortest weight mixed numeric and duration as weights
+      Given an empty graph
+      And having executed:
+          """
+          CREATE (n {a:'0'})-[:r {w: 2}]->({a:'1'})-[:r {w: DURATION('PT2S')}]->({a:'2'}), (n)-[:r {w: DURATION('PT4S')}]->({a:'3'})
+          """
+      When executing query:
+          """
+          MATCH (n {a:'0'})-[le *wShortest 10 (e, n | e.w ) w]->(m) RETURN m.a, size(le) as s, w
+          """
+      Then an error should be raised
diff --git a/tests/gql_behave/tests/stackoverflow_answers/features/wsp.feature b/tests/gql_behave/tests/stackoverflow_answers/features/wsp.feature
index d7a0e9035..af486a50f 100644
--- a/tests/gql_behave/tests/stackoverflow_answers/features/wsp.feature
+++ b/tests/gql_behave/tests/stackoverflow_answers/features/wsp.feature
@@ -26,5 +26,5 @@ Feature: Queries related to all Stackoverflow questions related to WSP
       """
     Then the result should be:
       | hops        | total_weight |
-      |'1 -> 3 -> 4'| 17.0         |
+      |'1 -> 3 -> 4'| 17           |
 
diff --git a/tests/unit/interpreter.cpp b/tests/unit/interpreter.cpp
index abd0d9417..63d897650 100644
--- a/tests/unit/interpreter.cpp
+++ b/tests/unit/interpreter.cpp
@@ -1,3 +1,14 @@
+// Copyright 2021 Memgraph Ltd.
+//
+// Use of this software is governed by the Business Source License
+// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
+// License, and you may not use this file except in compliance with the Business Source License.
+//
+// As of the Change Date specified in that file, in accordance with
+// the Business Source License, use of this software will be governed
+// by the Apache License, Version 2.0, included in the file
+// licenses/APL.txt.
+
 #include <cstdlib>
 #include <filesystem>
 
@@ -410,38 +421,58 @@ TEST_F(InterpreterTest, Bfs) {
 
 // Test shortest path end to end.
 TEST_F(InterpreterTest, ShortestPath) {
-  Interpret(
-      "CREATE (n:A {x: 1}), (m:B {x: 2}), (l:C {x: 1}), (n)-[:r1 {w: 1 "
-      "}]->(m)-[:r2 {w: 2}]->(l), (n)-[:r3 {w: 4}]->(l)");
+  const auto test_shortest_path = [this](const bool use_duration) {
+    const auto get_weight = [use_duration](const auto value) {
+      return fmt::format(use_duration ? "DURATION('PT{}S')" : "{}", value);
+    };
 
-  auto stream = Interpret("MATCH (n)-[e *wshortest 5 (e, n | e.w) ]->(m) return e");
+    Interpret(
+        fmt::format("CREATE (n:A {{x: 1}}), (m:B {{x: 2}}), (l:C {{x: 1}}), (n)-[:r1 {{w: {} "
+                    "}}]->(m)-[:r2 {{w: {}}}]->(l), (n)-[:r3 {{w: {}}}]->(l)",
+                    get_weight(1), get_weight(2), get_weight(4)));
 
-  ASSERT_EQ(stream.GetHeader().size(), 1U);
-  EXPECT_EQ(stream.GetHeader()[0], "e");
-  ASSERT_EQ(stream.GetResults().size(), 3U);
+    auto stream = Interpret("MATCH (n)-[e *wshortest 5 (e, n | e.w) ]->(m) return e");
 
-  auto dba = db_.Access();
-  std::vector<std::vector<std::string>> expected_results{{"r1"}, {"r2"}, {"r1", "r2"}};
+    ASSERT_EQ(stream.GetHeader().size(), 1U);
+    EXPECT_EQ(stream.GetHeader()[0], "e");
+    ASSERT_EQ(stream.GetResults().size(), 3U);
 
-  for (const auto &result : stream.GetResults()) {
-    const auto &edges = ToEdgeList(result[0]);
+    auto dba = db_.Access();
+    std::vector<std::vector<std::string>> expected_results{{"r1"}, {"r2"}, {"r1", "r2"}};
 
-    std::vector<std::string> datum;
-    datum.reserve(edges.size());
+    for (const auto &result : stream.GetResults()) {
+      const auto &edges = ToEdgeList(result[0]);
 
-    for (const auto &edge : edges) {
-      datum.push_back(edge.type);
-    }
+      std::vector<std::string> datum;
+      datum.reserve(edges.size());
 
-    bool any_match = false;
-    for (const auto &expected : expected_results) {
-      if (expected == datum) {
-        any_match = true;
-        break;
+      for (const auto &edge : edges) {
+        datum.push_back(edge.type);
       }
+
+      bool any_match = false;
+      for (const auto &expected : expected_results) {
+        if (expected == datum) {
+          any_match = true;
+          break;
+        }
+      }
+
+      EXPECT_TRUE(any_match);
     }
 
-    EXPECT_TRUE(any_match);
+    Interpret("MATCH (n) DETACH DELETE n");
+  };
+
+  constexpr bool kUseNumeric{false};
+  constexpr bool kUseDuration{true};
+  {
+    SCOPED_TRACE("Test with numeric values");
+    test_shortest_path(kUseNumeric);
+  }
+  {
+    SCOPED_TRACE("Test with Duration values");
+    test_shortest_path(kUseDuration);
   }
 }