From 956705293e7019a65100cb9e218ffd27185f65b3 Mon Sep 17 00:00:00 2001
From: gvolfing <gabor.volfinger@memgraph.io>
Date: Mon, 11 Mar 2024 09:56:00 +0100
Subject: [PATCH] Add logical operator for id based edge scans

---
 src/query/plan/hint_provider.hpp              |  3 ++
 src/query/plan/operator.cpp                   | 42 +++++++++++++---
 src/query/plan/operator.hpp                   | 49 ++++++++++++++-----
 src/query/plan/operator_type_info.cpp         |  5 ++
 src/query/plan/pretty_print.cpp               | 15 ++++++
 src/query/plan/pretty_print.hpp               |  2 +
 .../plan/rewrite/edge_type_index_lookup.hpp   |  9 ++++
 src/utils/event_counter.cpp                   |  3 +-
 tests/unit/query_plan_checker.hpp             |  2 +
 9 files changed, 111 insertions(+), 19 deletions(-)

diff --git a/src/query/plan/hint_provider.hpp b/src/query/plan/hint_provider.hpp
index 3c8510561..4c981bfb0 100644
--- a/src/query/plan/hint_provider.hpp
+++ b/src/query/plan/hint_provider.hpp
@@ -117,6 +117,9 @@ class PlanHintsProvider final : public HierarchicalLogicalOperatorVisitor {
   bool PreVisit(ScanAllByEdgeType & /*unused*/) override { return true; }
   bool PostVisit(ScanAllByEdgeType & /*unused*/) override { return true; }
 
+  bool PreVisit(ScanAllByEdgeId & /*unused*/) override { return true; }
+  bool PostVisit(ScanAllByEdgeId & /*unused*/) override { return true; }
+
   bool PreVisit(ConstructNamedPath & /*unused*/) override { return true; }
   bool PostVisit(ConstructNamedPath & /*unused*/) override { return true; }
 
diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp
index 7cd506050..6cbc8de4d 100644
--- a/src/query/plan/operator.cpp
+++ b/src/query/plan/operator.cpp
@@ -106,6 +106,7 @@ extern const Event ScanAllByLabelPropertyValueOperator;
 extern const Event ScanAllByLabelPropertyOperator;
 extern const Event ScanAllByIdOperator;
 extern const Event ScanAllByEdgeTypeOperator;
+extern const Event ScanAllByEdgeIdOperator;
 extern const Event ExpandOperator;
 extern const Event ExpandVariableOperator;
 extern const Event ConstructNamedPathOperator;
@@ -521,7 +522,7 @@ class ScanAllCursor : public Cursor {
 template <typename TEdgesFun>
 class ScanAllByEdgeTypeCursor : public Cursor {
  public:
-  explicit ScanAllByEdgeTypeCursor(const ScanAllByEdgeType &self, Symbol output_symbol, UniqueCursorPtr input_cursor,
+  explicit ScanAllByEdgeTypeCursor(const ScanAll &self, Symbol output_symbol, UniqueCursorPtr input_cursor,
                                    storage::View view, TEdgesFun get_edges, const char *op_name)
       : self_(self),
         output_symbol_(std::move(output_symbol)),
@@ -561,7 +562,7 @@ class ScanAllByEdgeTypeCursor : public Cursor {
   }
 
  private:
-  const ScanAllByEdgeType &self_;
+  const ScanAll &self_;
   const Symbol output_symbol_;
   const UniqueCursorPtr input_cursor_;
   storage::View view_;
@@ -613,10 +614,7 @@ UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource *mem) const {
 
 ScanAllByEdgeType::ScanAllByEdgeType(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol,
                                      storage::EdgeTypeId edge_type, storage::View view)
-    : input_(input ? input : std::make_shared<Once>()),
-      output_symbol_(std::move(output_symbol)),
-      view_(view),
-      edge_type_(edge_type) {}
+    : ScanAll(input, output_symbol, view), edge_type_(edge_type) {}
 
 ACCEPT_WITH_INPUT(ScanAllByEdgeType)
 
@@ -782,6 +780,38 @@ UniqueCursorPtr ScanAllById::MakeCursor(utils::MemoryResource *mem) const {
                                                                 view_, std::move(vertices), "ScanAllById");
 }
 
+ScanAllByEdgeId::ScanAllByEdgeId(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol,
+                                 Expression *expression, storage::View view)
+    : ScanAll(input, output_symbol, view), expression_(expression) {
+  MG_ASSERT(expression);
+}
+
+ACCEPT_WITH_INPUT(ScanAllByEdgeId)
+
+UniqueCursorPtr ScanAllByEdgeId::MakeCursor(utils::MemoryResource *mem) const {
+  memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByEdgeIdOperator);
+
+  auto edges = [this](Frame &frame, ExecutionContext &context) -> std::optional<std::vector<EdgeAccessor>> {
+    auto *db = context.db_accessor;
+    ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_);
+    auto value = expression_->Accept(evaluator);
+    if (!value.IsNumeric()) return std::nullopt;
+    int64_t id = value.IsInt() ? value.ValueInt() : value.ValueDouble();
+    if (value.IsDouble() && id != value.ValueDouble()) return std::nullopt;
+    auto maybe_edge = db->FindEdge(storage::Gid::FromInt(id), view_);
+    if (!maybe_edge) return std::nullopt;
+    return std::vector<EdgeAccessor>{*maybe_edge};
+  };
+  return MakeUniqueCursorPtr<ScanAllByEdgeTypeCursor<decltype(edges)>>(
+      mem, *this, output_symbol_, input_->MakeCursor(mem), view_, std::move(edges), "ScanAllByEdgeId");
+}
+
+std::vector<Symbol> ScanAllByEdgeId::ModifiedSymbols(const SymbolTable &table) const {
+  auto symbols = input_->ModifiedSymbols(table);
+  symbols.emplace_back(output_symbol_);
+  return symbols;
+}
+
 namespace {
 bool CheckExistingNode(const VertexAccessor &new_node, const Symbol &existing_node_sym, Frame &frame) {
   const TypedValue &existing_node = frame[existing_node_sym];
diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp
index 6563c2bb0..2abd9744f 100644
--- a/src/query/plan/operator.hpp
+++ b/src/query/plan/operator.hpp
@@ -100,6 +100,7 @@ class ScanAllByLabelPropertyValue;
 class ScanAllByLabelProperty;
 class ScanAllById;
 class ScanAllByEdgeType;
+class ScanAllByEdgeId;
 class Expand;
 class ExpandVariable;
 class ConstructNamedPath;
@@ -133,13 +134,12 @@ class IndexedJoin;
 class HashJoin;
 class RollUpApply;
 
-using LogicalOperatorCompositeVisitor =
-    utils::CompositeVisitor<Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange,
-                            ScanAllByLabelPropertyValue, ScanAllByLabelProperty, ScanAllById, ScanAllByEdgeType, Expand,
-                            ExpandVariable, ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties,
-                            SetLabels, RemoveProperty, RemoveLabels, EdgeUniquenessFilter, Accumulate, Aggregate, Skip,
-                            Limit, OrderBy, Merge, Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv,
-                            Foreach, EmptyResult, EvaluatePatternFilter, Apply, IndexedJoin, HashJoin, RollUpApply>;
+using LogicalOperatorCompositeVisitor = utils::CompositeVisitor<
+    Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange, ScanAllByLabelPropertyValue,
+    ScanAllByLabelProperty, ScanAllById, ScanAllByEdgeType, ScanAllByEdgeId, Expand, ExpandVariable, ConstructNamedPath,
+    Filter, Produce, Delete, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, EdgeUniquenessFilter,
+    Accumulate, Aggregate, Skip, Limit, OrderBy, Merge, Optional, Unwind, Distinct, Union, Cartesian, CallProcedure,
+    LoadCsv, Foreach, EmptyResult, EvaluatePatternFilter, Apply, IndexedJoin, HashJoin, RollUpApply>;
 
 using LogicalOperatorLeafVisitor = utils::LeafVisitor<Once>;
 
@@ -593,7 +593,7 @@ class ScanAllByLabel : public memgraph::query::plan::ScanAll {
   }
 };
 
-class ScanAllByEdgeType : public memgraph::query::plan::LogicalOperator {
+class ScanAllByEdgeType : public memgraph::query::plan::ScanAll {
  public:
   static const utils::TypeInfo kType;
   const utils::TypeInfo &GetTypeInfo() const override { return kType; }
@@ -613,10 +613,6 @@ class ScanAllByEdgeType : public memgraph::query::plan::LogicalOperator {
     return fmt::format("ScanAllByEdgeType ({} :{})", output_symbol_.name(), dba_->EdgeTypeToName(edge_type_));
   }
 
-  std::shared_ptr<memgraph::query::plan::LogicalOperator> input_;
-  Symbol output_symbol_;
-  storage::View view_;
-
   storage::EdgeTypeId edge_type_;
 
   std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override {
@@ -819,6 +815,35 @@ class ScanAllById : public memgraph::query::plan::ScanAll {
     return object;
   }
 };
+class ScanAllByEdgeId : public memgraph::query::plan::ScanAll {
+ public:
+  static const utils::TypeInfo kType;
+  const utils::TypeInfo &GetTypeInfo() const override { return kType; }
+
+  ScanAllByEdgeId() = default;
+  ScanAllByEdgeId(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, Expression *expression,
+                  storage::View view = storage::View::OLD);
+  bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
+  UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
+  std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
+
+  bool HasSingleInput() const override { return true; }
+  std::shared_ptr<LogicalOperator> input() const override { return input_; }
+  void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; }
+
+  std::string ToString() const override { return fmt::format("ScanAllByEdgeId ({})", output_symbol_.name()); }
+
+  Expression *expression_;
+
+  std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override {
+    auto object = std::make_unique<ScanAllByEdgeId>();
+    object->input_ = input_ ? input_->Clone(storage) : nullptr;
+    object->output_symbol_ = output_symbol_;
+    object->view_ = view_;
+    object->expression_ = expression_ ? expression_->Clone(storage) : nullptr;
+    return object;
+  }
+};
 
 struct ExpandCommon {
   static const utils::TypeInfo kType;
diff --git a/src/query/plan/operator_type_info.cpp b/src/query/plan/operator_type_info.cpp
index 6b0a28313..ac75581d1 100644
--- a/src/query/plan/operator_type_info.cpp
+++ b/src/query/plan/operator_type_info.cpp
@@ -49,9 +49,14 @@ constexpr utils::TypeInfo query::plan::ScanAllByLabelProperty::kType{
 
 constexpr utils::TypeInfo query::plan::ScanAllById::kType{utils::TypeId::SCAN_ALL_BY_ID, "ScanAllById",
                                                           &query::plan::ScanAll::kType};
+
+// TODO check if the kType is correct here.
 constexpr utils::TypeInfo query::plan::ScanAllByEdgeType::kType{utils::TypeId::SCAN_ALL_BY_EDGE_TYPE,
                                                                 "ScanAllByEdgeType", &query::plan::ScanAll::kType};
 
+constexpr utils::TypeInfo query::plan::ScanAllByEdgeId::kType{utils::TypeId::SCAN_ALL_BY_ID, "ScanAllByEdgeId",
+                                                              &query::plan::ScanAll::kType};
+
 constexpr utils::TypeInfo query::plan::ExpandCommon::kType{utils::TypeId::EXPAND_COMMON, "ExpandCommon", nullptr};
 
 constexpr utils::TypeInfo query::plan::Expand::kType{utils::TypeId::EXPAND, "Expand",
diff --git a/src/query/plan/pretty_print.cpp b/src/query/plan/pretty_print.cpp
index eeb0c15b5..93daf9e57 100644
--- a/src/query/plan/pretty_print.cpp
+++ b/src/query/plan/pretty_print.cpp
@@ -83,6 +83,11 @@ bool PlanPrinter::PreVisit(query::plan::ScanAllByEdgeType &op) {
   return true;
 }
 
+bool PlanPrinter::PreVisit(query::plan::ScanAllByEdgeId &op) {
+  WithPrintLn([&op](auto &out) { out << "* " << op.ToString(); });
+  return true;
+}
+
 bool PlanPrinter::PreVisit(query::plan::Expand &op) {
   op.dba_ = dba_;
   WithPrintLn([&op](auto &out) { out << "* " << op.ToString(); });
@@ -484,6 +489,16 @@ bool PlanToJsonVisitor::PreVisit(ScanAllByEdgeType &op) {
   return false;
 }
 
+bool PlanToJsonVisitor::PreVisit(ScanAllByEdgeId &op) {
+  json self;
+  self["name"] = "ScanAllByEdgeId";
+  self["output_symbol"] = ToJson(op.output_symbol_);
+  op.input_->Accept(*this);
+  self["input"] = PopOutput();
+  output_ = std::move(self);
+  return false;
+}
+
 bool PlanToJsonVisitor::PreVisit(CreateNode &op) {
   json self;
   self["name"] = "CreateNode";
diff --git a/src/query/plan/pretty_print.hpp b/src/query/plan/pretty_print.hpp
index d62ae6bf2..9b418aa3e 100644
--- a/src/query/plan/pretty_print.hpp
+++ b/src/query/plan/pretty_print.hpp
@@ -68,6 +68,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor {
   bool PreVisit(ScanAllByLabelProperty &) override;
   bool PreVisit(ScanAllById &) override;
   bool PreVisit(ScanAllByEdgeType &) override;
+  bool PreVisit(ScanAllByEdgeId &) override;
 
   bool PreVisit(Expand &) override;
   bool PreVisit(ExpandVariable &) override;
@@ -206,6 +207,7 @@ class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor {
   bool PreVisit(ScanAllByLabelProperty &) override;
   bool PreVisit(ScanAllById &) override;
   bool PreVisit(ScanAllByEdgeType &) override;
+  bool PreVisit(ScanAllByEdgeId &) override;
 
   bool PreVisit(EmptyResult &) override;
   bool PreVisit(Produce &) override;
diff --git a/src/query/plan/rewrite/edge_type_index_lookup.hpp b/src/query/plan/rewrite/edge_type_index_lookup.hpp
index ed8666513..eedf36ae0 100644
--- a/src/query/plan/rewrite/edge_type_index_lookup.hpp
+++ b/src/query/plan/rewrite/edge_type_index_lookup.hpp
@@ -254,6 +254,15 @@ class EdgeTypeIndexRewriter final : public HierarchicalLogicalOperatorVisitor {
     return true;
   }
 
+  bool PreVisit(ScanAllByEdgeId &op) override {
+    prev_ops_.push_back(&op);
+    return true;
+  }
+  bool PostVisit(ScanAllByEdgeId &) override {
+    prev_ops_.pop_back();
+    return true;
+  }
+
   bool PreVisit(ConstructNamedPath &op) override {
     prev_ops_.push_back(&op);
     return true;
diff --git a/src/utils/event_counter.cpp b/src/utils/event_counter.cpp
index 54ff4ed5c..8ad705906 100644
--- a/src/utils/event_counter.cpp
+++ b/src/utils/event_counter.cpp
@@ -1,4 +1,4 @@
-// Copyright 2023 Memgraph Ltd.
+// Copyright 2024 Memgraph Ltd.
 //
 // Use of this software is governed by the Business Source License
 // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@@ -27,6 +27,7 @@
   M(ScanAllByLabelPropertyOperator, Operator, "Number of times ScanAllByLabelProperty operator was used.")           \
   M(ScanAllByIdOperator, Operator, "Number of times ScanAllById operator was used.")                                 \
   M(ScanAllByEdgeTypeOperator, Operator, "Number of times ScanAllByEdgeTypeOperator operator was used.")             \
+  M(ScanAllByEdgeIdOperator, Operator, "Number of times ScanAllByEdgeIdOperator operator was used.")                 \
   M(ExpandOperator, Operator, "Number of times Expand operator was used.")                                           \
   M(ExpandVariableOperator, Operator, "Number of times ExpandVariable operator was used.")                           \
   M(ConstructNamedPathOperator, Operator, "Number of times ConstructNamedPath operator was used.")                   \
diff --git a/tests/unit/query_plan_checker.hpp b/tests/unit/query_plan_checker.hpp
index 6eef3841a..aee241b55 100644
--- a/tests/unit/query_plan_checker.hpp
+++ b/tests/unit/query_plan_checker.hpp
@@ -66,6 +66,7 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor {
   PRE_VISIT(ScanAllByLabelPropertyRange);
   PRE_VISIT(ScanAllByLabelProperty);
   PRE_VISIT(ScanAllByEdgeType);
+  PRE_VISIT(ScanAllByEdgeId);
   PRE_VISIT(ScanAllById);
   PRE_VISIT(Expand);
   PRE_VISIT(ExpandVariable);
@@ -172,6 +173,7 @@ using ExpectDelete = OpChecker<Delete>;
 using ExpectScanAll = OpChecker<ScanAll>;
 using ExpectScanAllByLabel = OpChecker<ScanAllByLabel>;
 using ExpectScanAllByEdgeType = OpChecker<ScanAllByEdgeType>;
+using ExpectScanAllByEdgeId = OpChecker<ScanAllByEdgeId>;
 using ExpectScanAllById = OpChecker<ScanAllById>;
 using ExpectExpand = OpChecker<Expand>;
 using ExpectConstructNamedPath = OpChecker<ConstructNamedPath>;