From 131d7f2a744d07c9c0d03bb7c57dd251284ec9f6 Mon Sep 17 00:00:00 2001
From: jeremy <jeremy.bailleux@memgraph.io>
Date: Mon, 14 Nov 2022 18:21:03 +0100
Subject: [PATCH] OrderByElements: no longer templated over vertice/edge types.
 For edges, we always need to have access to the corresponding vertex_accessor
 (ex of sorting expr needing both : "vertex.map[edge]") ComputeExpression:
 made assert instead of if check

---
 src/storage/v3/expr.cpp           |  6 ++-
 src/storage/v3/request_helper.hpp | 87 ++++++++++++++++++++-----------
 src/storage/v3/shard_rsm.cpp      | 12 ++---
 3 files changed, 68 insertions(+), 37 deletions(-)

diff --git a/src/storage/v3/expr.cpp b/src/storage/v3/expr.cpp
index eaff472bd..53146f595 100644
--- a/src/storage/v3/expr.cpp
+++ b/src/storage/v3/expr.cpp
@@ -186,7 +186,8 @@ TypedValue ComputeExpression(DbAccessor &dba, const std::optional<memgraph::stor
   expr::SymbolGenerator symbol_generator(&symbol_table, identifiers);
   (std::any_cast<Expression *>(expr))->Accept(symbol_generator);
 
-  if (node_identifier.symbol_pos_ != -1 && v_acc.has_value()) {
+  if (node_identifier.symbol_pos_ != -1) {
+    MG_ASSERT(v_acc.has_value());
     MG_ASSERT(std::find_if(symbol_table.table().begin(), symbol_table.table().end(),
                            [&node_name](const std::pair<int32_t, Symbol> &position_symbol_pair) {
                              return position_symbol_pair.second.name() == node_name;
@@ -195,7 +196,8 @@ TypedValue ComputeExpression(DbAccessor &dba, const std::optional<memgraph::stor
     frame[symbol_table.at(node_identifier)] = *v_acc;
   }
 
-  if (edge_identifier.symbol_pos_ != -1 && e_acc.has_value()) {
+  if (edge_identifier.symbol_pos_ != -1) {
+    MG_ASSERT(e_acc.has_value());
     MG_ASSERT(std::find_if(symbol_table.table().begin(), symbol_table.table().end(),
                            [&edge_name](const std::pair<int32_t, Symbol> &position_symbol_pair) {
                              return position_symbol_pair.second.name() == edge_name;
diff --git a/src/storage/v3/request_helper.hpp b/src/storage/v3/request_helper.hpp
index e3c5b9d3f..955059f26 100644
--- a/src/storage/v3/request_helper.hpp
+++ b/src/storage/v3/request_helper.hpp
@@ -112,44 +112,73 @@ struct Element {
   TObjectAccessor object_acc;
 };
 
-template <ObjectAccessor TObjectAccessor, typename TIterable>
-std::vector<Element<TObjectAccessor>> OrderByElements(Shard::Accessor &acc, DbAccessor &dba, TIterable &iterable,
-                                                      std::vector<msgs::OrderBy> &order_bys) {
-  std::vector<Element<TObjectAccessor>> ordered;
-  ordered.reserve(acc.ApproximateVertexCount());
+template <typename TIterable>
+std::vector<Element<VertexAccessor>> OrderByVertices(Shard::Accessor &acc, DbAccessor &dba, TIterable &iterable,
+                                                     std::vector<msgs::OrderBy> &order_bys) {
+  static_assert(std::is_same_v<TIterable, VerticesIterable> || std::is_same_v<TIterable, std::vector<VertexAccessor>>);
+
   std::vector<Ordering> ordering;
   ordering.reserve(order_bys.size());
-  for (const auto &order : order_bys) {
-    switch (order.direction) {
-      case memgraph::msgs::OrderingDirection::ASCENDING: {
-        ordering.push_back(Ordering::ASC);
-        break;
-      }
-      case memgraph::msgs::OrderingDirection::DESCENDING: {
-        ordering.push_back(Ordering::DESC);
-        break;
-      }
+  std::transform(order_bys.begin(), order_bys.end(), std::back_inserter(ordering), [](const auto &order_by) {
+    if (memgraph::msgs::OrderingDirection::ASCENDING == order_by.direction) {
+      return Ordering::ASC;
     }
-  }
-  auto compare_typed_values = TypedValueVectorCompare(ordering);
-  auto it = iterable.begin();
-  for (; it != iterable.end(); ++it) {
+    MG_ASSERT(memgraph::msgs::OrderingDirection::DESCENDING == order_by.direction);
+    return Ordering::DESC;
+  });
+
+  std::vector<Element<VertexAccessor>> ordered;
+  ordered.reserve(acc.ApproximateVertexCount());
+  for (auto it = iterable.begin(); it != iterable.end(); ++it) {
     std::vector<TypedValue> properties_order_by;
     properties_order_by.reserve(order_bys.size());
 
-    for (const auto &order_by : order_bys) {
-      if constexpr (std::is_same_v<TIterable, VerticesIterable> ||
-                    std::is_same_v<TIterable, std::vector<VertexAccessor>>) {
-        properties_order_by.push_back(ComputeExpression(dba, *it, std::nullopt, order_by.expression.expression,
-                                                        expr::identifier_node_symbol, expr::identifier_edge_symbol));
-      } else {
-        properties_order_by.push_back(ComputeExpression(dba, std::nullopt, *it, order_by.expression.expression,
-                                                        expr::identifier_node_symbol, expr::identifier_edge_symbol));
-      }
-    }
+    std::transform(order_bys.begin(), order_bys.end(), std::back_inserter(properties_order_by),
+                   [&dba, &it](const auto &order_by) {
+                     return ComputeExpression(dba, *it, std::nullopt /*e_acc*/, order_by.expression.expression,
+                                              expr::identifier_node_symbol, expr::identifier_edge_symbol);
+                   });
+
     ordered.push_back({std::move(properties_order_by), *it});
   }
 
+  auto compare_typed_values = TypedValueVectorCompare(ordering);
+  std::sort(ordered.begin(), ordered.end(), [compare_typed_values](const auto &pair1, const auto &pair2) {
+    return compare_typed_values(pair1.properties_order_by, pair2.properties_order_by);
+  });
+  return ordered;
+}
+
+template <typename TIterable>
+std::vector<Element<EdgeAccessor>> OrderByEdges(Shard::Accessor &acc, DbAccessor &dba, TIterable &iterable,
+                                                std::vector<msgs::OrderBy> &order_bys,
+                                                const VertexAccessor &vertex_acc) {
+  static_assert(std::is_same_v<TIterable, std::vector<EdgeAccessor>>);  // Can be extended if needed
+
+  std::vector<Ordering> ordering;
+  ordering.reserve(order_bys.size());
+  std::transform(order_bys.begin(), order_bys.end(), std::back_inserter(ordering), [](const auto &order_by) {
+    if (memgraph::msgs::OrderingDirection::ASCENDING == order_by.direction) {
+      return Ordering::ASC;
+    }
+    MG_ASSERT(memgraph::msgs::OrderingDirection::DESCENDING == order_by.direction);
+    return Ordering::DESC;
+  });
+
+  std::vector<Element<EdgeAccessor>> ordered;
+  for (auto it = iterable.begin(); it != iterable.end(); ++it) {
+    std::vector<TypedValue> properties_order_by;
+    properties_order_by.reserve(order_bys.size());
+    std::transform(order_bys.begin(), order_bys.end(), std::back_inserter(properties_order_by),
+                   [&dba, &vertex_acc, &it](const auto &order_by) {
+                     return ComputeExpression(dba, vertex_acc, *it, order_by.expression.expression,
+                                              expr::identifier_node_symbol, expr::identifier_edge_symbol);
+                   });
+
+    ordered.push_back({std::move(properties_order_by), *it});
+  }
+
+  auto compare_typed_values = TypedValueVectorCompare(ordering);
   std::sort(ordered.begin(), ordered.end(), [compare_typed_values](const auto &pair1, const auto &pair2) {
     return compare_typed_values(pair1.properties_order_by, pair2.properties_order_by);
   });
diff --git a/src/storage/v3/shard_rsm.cpp b/src/storage/v3/shard_rsm.cpp
index 5f6cb5959..55b726c79 100644
--- a/src/storage/v3/shard_rsm.cpp
+++ b/src/storage/v3/shard_rsm.cpp
@@ -875,7 +875,7 @@ msgs::ReadResponses ShardRsm::HandleRead(msgs::ScanVerticesRequest &&req) {
   uint64_t sample_counter{0};
   auto vertex_iterable = acc.Vertices(view);
   if (!req.order_bys.empty()) {
-    const auto ordered = OrderByElements<VertexAccessor>(acc, dba, vertex_iterable, req.order_bys);
+    const auto ordered = OrderByVertices(acc, dba, vertex_iterable, req.order_bys);
     // we are traversing Elements
     auto it = GetStartOrderedElementsIterator(ordered, start_id, View(req.storage_view));
     for (; it != ordered.end(); ++it) {
@@ -956,9 +956,9 @@ msgs::ReadResponses ShardRsm::HandleRead(msgs::ExpandOneRequest &&req) {
   }
 
   if (!req.order_by.empty()) {
-    // #NoCommit can we do differently to avoid this? We need OrderByElements but currently
-    // #NoCommit it returns vector<Element>, so this workaround is here to avoid more duplication later
-    auto sorted_vertices = OrderByElements<VertexAccessor>(acc, dba, vertex_accessors, req.order_by);
+    // Can we do differently to avoid this? We need OrderByElements but currently it returns vector<Element>, so this
+    // workaround is here to avoid more duplication later
+    auto sorted_vertices = OrderByVertices(acc, dba, vertex_accessors, req.order_by);
     vertex_accessors.clear();
     std::transform(sorted_vertices.begin(), sorted_vertices.end(), std::back_inserter(vertex_accessors),
                    [](auto &vertex) { return vertex.object_acc; });
@@ -987,8 +987,8 @@ msgs::ReadResponses ShardRsm::HandleRead(msgs::ExpandOneRequest &&req) {
 
     } else {
       auto [in_edge_accessors, out_edge_accessors] = GetEdgesFromVertex(src_vertex_acc, req.direction);
-      const auto in_ordered_edges = OrderByElements<EdgeAccessor>(acc, dba, in_edge_accessors, req.order_by);
-      const auto out_ordered_edges = OrderByElements<EdgeAccessor>(acc, dba, out_edge_accessors, req.order_by);
+      const auto in_ordered_edges = OrderByEdges(acc, dba, in_edge_accessors, req.order_by, src_vertex_acc);
+      const auto out_ordered_edges = OrderByEdges(acc, dba, out_edge_accessors, req.order_by, src_vertex_acc);
 
       std::vector<EdgeAccessor> in_edge_ordered_accessors;
       std::transform(in_ordered_edges.begin(), in_ordered_edges.end(), std::back_inserter(in_edge_ordered_accessors),