From 4cb3b064c457fbc7671aeef233fdf90a7049274b Mon Sep 17 00:00:00 2001
From: Jeremy B <97525434+42jeremy@users.noreply.github.com>
Date: Wed, 12 Oct 2022 11:46:59 +0200
Subject: [PATCH] Add filter to scan all (#575)

Add several versions of ScanAll with filters.
Add helper function to transform an expression into string that can be parsed again once on the storage.
---
 ...retty_print_ast_to_original_expression.hpp | 279 ++++++++++++++++++
 src/query/v2/plan/operator.cpp                | 172 +++++++----
 src/query/v2/requests.hpp                     |   5 +-
 tests/unit/CMakeLists.txt                     |   3 +
 ..._print_ast_to_original_expression_test.cpp |  94 ++++++
 5 files changed, 487 insertions(+), 66 deletions(-)
 create mode 100644 src/expr/ast/pretty_print_ast_to_original_expression.hpp
 create mode 100644 tests/unit/pretty_print_ast_to_original_expression_test.cpp

diff --git a/src/expr/ast/pretty_print_ast_to_original_expression.hpp b/src/expr/ast/pretty_print_ast_to_original_expression.hpp
new file mode 100644
index 000000000..de8b5b89b
--- /dev/null
+++ b/src/expr/ast/pretty_print_ast_to_original_expression.hpp
@@ -0,0 +1,279 @@
+// Copyright 2022 Memgraph Ltd.
+//
+// Use of this software is governed by the Business Source License
+// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
+// License, and you may not use this file except in compliance with the Business Source License.
+//
+// As of the Change Date specified in that file, in accordance with
+// the Business Source License, use of this software will be governed
+// by the Apache License, Version 2.0, included in the file
+// licenses/APL.txt.
+
+#pragma once
+
+#include <iostream>
+#include <type_traits>
+
+#include "expr/ast.hpp"
+#include "expr/typed_value.hpp"
+#include "utils/algorithm.hpp"
+#include "utils/logging.hpp"
+#include "utils/string.hpp"
+
+namespace memgraph::expr {
+inline constexpr const char *identifier_node_symbol = "MG_SYMBOL_NODE";
+inline constexpr const char *identifier_edge_symbol = "MG_SYMBOL_EDGE";
+
+namespace detail {
+template <typename T>
+void PrintObject(std::ostream *out, const T &arg) {
+  static_assert(!std::is_convertible<T, Expression *>::value,
+                "This overload shouldn't be called with pointers convertible "
+                "to Expression *. This means your other PrintObject overloads aren't "
+                "being called for certain AST nodes when they should (or perhaps such "
+                "overloads don't exist yet).");
+  *out << arg;
+}
+
+inline void PrintObject(std::ostream *out, const std::string &str) { *out << str; }
+
+inline void PrintObject(std::ostream * /*out*/, Aggregation::Op /*op*/) {
+  throw utils::NotYetImplemented("PrintObject: Aggregation::Op");
+}
+
+inline void PrintObject(std::ostream * /*out*/, Expression * /*expr*/);
+
+inline void PrintObject(std::ostream *out, Identifier *expr) { PrintObject(out, static_cast<Expression *>(expr)); }
+
+template <typename T>
+void PrintObject(std::ostream * /*out*/, const std::vector<T> & /*vec*/) {
+  throw utils::NotYetImplemented("PrintObject: vector<T>");
+}
+
+template <typename T>
+void PrintObject(std::ostream * /*out*/, const std::vector<T, utils::Allocator<T>> & /*vec*/) {
+  throw utils::NotYetImplemented("PrintObject: vector<T, utils::Allocator<T>>");
+}
+
+template <typename K, typename V>
+void PrintObject(std::ostream * /*out*/, const std::map<K, V> & /*map*/) {
+  throw utils::NotYetImplemented("PrintObject: map<K, V>");
+}
+
+template <typename T>
+void PrintObject(std::ostream * /*out*/, const utils::pmr::map<utils::pmr::string, T> & /*map*/) {
+  throw utils::NotYetImplemented("PrintObject: map<utils::pmr::string, V>");
+}
+
+template <typename T1, typename T2, typename T3>
+inline void PrintObject(std::ostream *out, const TypedValueT<T1, T2, T3> &value) {
+  using TypedValue = TypedValueT<T1, T2, T3>;
+  switch (value.type()) {
+    case TypedValue::Type::Null:
+      *out << "null";
+      break;
+    case TypedValue::Type::String:
+      PrintObject(out, value.ValueString());
+      break;
+    case TypedValue::Type::Bool:
+      *out << (value.ValueBool() ? "true" : "false");
+      break;
+    case TypedValue::Type::Int:
+      PrintObject(out, value.ValueInt());
+      break;
+    case TypedValue::Type::Double:
+      PrintObject(out, value.ValueDouble());
+      break;
+    case TypedValue::Type::List:
+      PrintObject(out, value.ValueList());
+      break;
+    case TypedValue::Type::Map:
+      PrintObject(out, value.ValueMap());
+      break;
+    case TypedValue::Type::Date:
+      PrintObject(out, value.ValueDate());
+      break;
+    case TypedValue::Type::Duration:
+      PrintObject(out, value.ValueDuration());
+      break;
+    case TypedValue::Type::LocalTime:
+      PrintObject(out, value.ValueLocalTime());
+      break;
+    case TypedValue::Type::LocalDateTime:
+      PrintObject(out, value.ValueLocalDateTime());
+      break;
+    default:
+      MG_ASSERT(false, "PrintObject(std::ostream *out, const TypedValue &value) should not reach here");
+  }
+}
+
+template <typename T>
+void PrintOperatorArgs(const std::string & /*name*/, std::ostream *out, bool with_parenthesis, const T &arg) {
+  PrintObject(out, arg);
+  if (with_parenthesis) {
+    *out << ")";
+  }
+}
+
+template <typename T, typename... Ts>
+void PrintOperatorArgs(const std::string &name, std::ostream *out, bool with_parenthesis, const T &arg,
+                       const Ts &...args) {
+  PrintObject(out, arg);
+  *out << " " << name << " ";
+  PrintOperatorArgs(name, out, with_parenthesis, args...);
+}
+
+template <typename... Ts>
+void PrintOperator(const std::string &name, std::ostream *out, bool with_parenthesis, const Ts &...args) {
+  if (with_parenthesis) {
+    *out << "(";
+  }
+  PrintOperatorArgs(name, out, with_parenthesis, args...);
+}
+
+// new
+template <typename T>
+void PrintOperatorArgs(std::ostream *out, const T &arg) {
+  PrintObject(out, arg);
+}
+
+template <typename T, typename... Ts>
+void PrintOperatorArgs(std::ostream *out, const T &arg, const Ts &...args) {
+  PrintObject(out, arg);
+  PrintOperatorArgs(out, args...);
+}
+
+template <typename... Ts>
+void PrintOperator(std::ostream *out, const Ts &...args) {
+  PrintOperatorArgs(out, args...);
+}
+}  // namespace detail
+
+class ExpressionPrettyPrinter : public ExpressionVisitor<void> {
+ public:
+  explicit ExpressionPrettyPrinter(std::ostream *out) : out_(out) {}
+
+  // Unary operators
+  // NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
+#define UNARY_OPERATOR_VISIT(OP_NODE, OP_STR)      \
+  /* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
+  void Visit(OP_NODE &op) override { detail::PrintOperator(OP_STR, out_, false /*with_parenthesis*/, op.expression_); }
+
+  UNARY_OPERATOR_VISIT(NotOperator, "Not");
+  UNARY_OPERATOR_VISIT(UnaryPlusOperator, "+");
+  UNARY_OPERATOR_VISIT(UnaryMinusOperator, "-");
+  UNARY_OPERATOR_VISIT(IsNullOperator, "IsNull");
+
+#undef UNARY_OPERATOR_VISIT
+
+  // Binary operators
+// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
+#define BINARY_OPERATOR_VISIT(OP_NODE, OP_STR)                                                        \
+  /* NOLINTNEXTLINE(bugprone-macro-parentheses) */                                                    \
+  void Visit(OP_NODE &op) override {                                                                  \
+    detail::PrintOperator(OP_STR, out_, true /*with_parenthesis*/, op.expression1_, op.expression2_); \
+  }
+// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
+#define BINARY_OPERATOR_VISIT_NOT_IMPL(OP_NODE, OP_STR) \
+  /* NOLINTNEXTLINE(bugprone-macro-parentheses) */      \
+  void Visit(OP_NODE & /*op*/) override { throw utils::NotYetImplemented("OP_NODE"); }
+
+  BINARY_OPERATOR_VISIT(OrOperator, "Or");
+  BINARY_OPERATOR_VISIT(XorOperator, "Xor");
+  BINARY_OPERATOR_VISIT(AndOperator, "And");
+  BINARY_OPERATOR_VISIT(AdditionOperator, "+");
+  BINARY_OPERATOR_VISIT(SubtractionOperator, "-");
+  BINARY_OPERATOR_VISIT(MultiplicationOperator, "*");
+  BINARY_OPERATOR_VISIT(DivisionOperator, "/");
+  BINARY_OPERATOR_VISIT(ModOperator, "%");
+  BINARY_OPERATOR_VISIT(NotEqualOperator, "!=");
+  BINARY_OPERATOR_VISIT(EqualOperator, "=");
+  BINARY_OPERATOR_VISIT(LessOperator, "<");
+  BINARY_OPERATOR_VISIT(GreaterOperator, ">");
+  BINARY_OPERATOR_VISIT(LessEqualOperator, "<=");
+  BINARY_OPERATOR_VISIT(GreaterEqualOperator, ">=");
+  BINARY_OPERATOR_VISIT_NOT_IMPL(InListOperator, "In");
+  BINARY_OPERATOR_VISIT_NOT_IMPL(SubscriptOperator, "Subscript");
+
+#undef BINARY_OPERATOR_VISIT
+#undef BINARY_OPERATOR_VISIT_NOT_IMPL
+
+  // Other
+  void Visit(ListSlicingOperator & /*op*/) override { throw utils::NotYetImplemented("ListSlicingOperator"); }
+
+  void Visit(IfOperator & /*op*/) override { throw utils::NotYetImplemented("IfOperator"); }
+
+  void Visit(ListLiteral & /*op*/) override { throw utils::NotYetImplemented("ListLiteral"); }
+
+  void Visit(MapLiteral & /*op*/) override { throw utils::NotYetImplemented("MapLiteral"); }
+
+  void Visit(LabelsTest & /*op*/) override { throw utils::NotYetImplemented("LabelsTest"); }
+
+  void Visit(Aggregation & /*op*/) override { throw utils::NotYetImplemented("Aggregation"); }
+
+  void Visit(Function & /*op*/) override { throw utils::NotYetImplemented("Function"); }
+
+  void Visit(Reduce & /*op*/) override { throw utils::NotYetImplemented("Reduce"); }
+
+  void Visit(Coalesce & /*op*/) override { throw utils::NotYetImplemented("Coalesce"); }
+
+  void Visit(Extract & /*op*/) override { throw utils::NotYetImplemented("Extract"); }
+
+  void Visit(All & /*op*/) override { throw utils::NotYetImplemented("All"); }
+
+  void Visit(Single & /*op*/) override { throw utils::NotYetImplemented("Single"); }
+
+  void Visit(Any & /*op*/) override { throw utils::NotYetImplemented("Any"); }
+
+  void Visit(None & /*op*/) override { throw utils::NotYetImplemented("None"); }
+
+  void Visit(Identifier &op) override {
+    auto is_node = true;
+    auto is_edge = false;
+    auto is_other = false;
+    if (is_node) {
+      detail::PrintOperator(out_, identifier_node_symbol);
+    } else if (is_edge) {
+      detail::PrintOperator(out_, identifier_edge_symbol);
+    } else {
+      MG_ASSERT(is_other);
+      detail::PrintOperator(out_, op.name_);
+    }
+  }
+
+  void Visit(PrimitiveLiteral &op) override { detail::PrintObject(out_, op.value_); }
+
+  void Visit(PropertyLookup &op) override { detail::PrintOperator(out_, op.expression_, ".", op.property_.name); }
+
+  void Visit(ParameterLookup & /*op*/) override { throw utils::NotYetImplemented("ParameterLookup"); }
+
+  void Visit(NamedExpression & /*op*/) override { throw utils::NotYetImplemented("NamedExpression"); }
+
+  void Visit(RegexMatch & /*op*/) override { throw utils::NotYetImplemented("RegexMatch"); }
+
+ private:
+  std::ostream *out_;
+};
+
+namespace detail {
+inline void PrintObject(std::ostream *out, Expression *expr) {
+  if (expr) {
+    ExpressionPrettyPrinter printer{out};
+    expr->Accept(printer);
+  } else {
+    *out << "<null>";
+  }
+}
+}  // namespace detail
+
+inline void PrintExpressionToOriginalAndReplaceNodeAndEdgeSymbols(Expression *expr, std::ostream *out) {
+  ExpressionPrettyPrinter printer{out};
+  expr->Accept(printer);
+}
+
+inline std::string ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(Expression *expr) {
+  std::ostringstream ss;
+  expr::PrintExpressionToOriginalAndReplaceNodeAndEdgeSymbols(expr, &ss);
+  return ss.str();
+}
+}  // namespace memgraph::expr
diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp
index 0a67169f0..10369c34b 100644
--- a/src/query/v2/plan/operator.cpp
+++ b/src/query/v2/plan/operator.cpp
@@ -26,6 +26,7 @@
 #include <cppitertools/chain.hpp>
 #include <cppitertools/imap.hpp>
 
+#include "expr/ast/pretty_print_ast_to_original_expression.hpp"
 #include "expr/exceptions.hpp"
 #include "query/exceptions.hpp"
 #include "query/v2/accessors.hpp"
@@ -332,17 +333,111 @@ class ScanAllCursor : public Cursor {
   msgs::ExecutionState<msgs::ScanVerticesRequest> request_state;
 };
 
+class DistributedScanAllAndFilterCursor : public Cursor {
+ public:
+  explicit DistributedScanAllAndFilterCursor(
+      Symbol output_symbol, UniqueCursorPtr input_cursor, const char *op_name,
+      std::optional<storage::v3::LabelId> label,
+      std::optional<std::pair<storage::v3::PropertyId, Expression *>> property_expression_pair,
+      std::optional<std::vector<Expression *>> filter_expressions)
+      : output_symbol_(output_symbol),
+        input_cursor_(std::move(input_cursor)),
+        op_name_(op_name),
+        label_(label),
+        property_expression_pair_(property_expression_pair),
+        filter_expressions_(filter_expressions) {
+    ResetExecutionState();
+  }
+
+  using VertexAccessor = accessors::VertexAccessor;
+
+  bool MakeRequest(msgs::ShardRequestManagerInterface &shard_manager) {
+    current_batch = shard_manager.Request(request_state_);
+    current_vertex_it = current_batch.begin();
+    return !current_batch.empty();
+  }
+
+  bool Pull(Frame &frame, ExecutionContext &context) override {
+    SCOPED_PROFILE_OP(op_name_);
+    auto &shard_manager = *context.shard_request_manager;
+    if (MustAbort(context)) {
+      throw HintedAbortError();
+    }
+    using State = msgs::ExecutionState<msgs::ScanVerticesRequest>;
+
+    if (request_state_.state == State::INITIALIZING) {
+      if (!input_cursor_->Pull(frame, context)) {
+        return false;
+      }
+    }
+
+    if (current_vertex_it == current_batch.end()) {
+      if (request_state_.state == State::COMPLETED || !MakeRequest(shard_manager)) {
+        ResetExecutionState();
+        return Pull(frame, context);
+      }
+    }
+
+    frame[output_symbol_] = TypedValue(std::move(*current_vertex_it));
+    ++current_vertex_it;
+    return true;
+  }
+
+  void Shutdown() override { input_cursor_->Shutdown(); }
+
+  void ResetExecutionState() {
+    current_batch.clear();
+    current_vertex_it = current_batch.end();
+    request_state_ = msgs::ExecutionState<msgs::ScanVerticesRequest>{};
+
+    auto request = msgs::ScanVerticesRequest{};
+    if (label_.has_value()) {
+      request.label = msgs::Label{.id = label_.value()};
+    }
+    if (property_expression_pair_.has_value()) {
+      request.property_expression_pair = std::make_pair(
+          property_expression_pair_.value().first,
+          expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(property_expression_pair_.value().second));
+    }
+    if (filter_expressions_.has_value()) {
+      auto res = std::vector<std::string>{};
+      res.reserve(filter_expressions_->size());
+      std::transform(filter_expressions_->begin(), filter_expressions_->end(), std::back_inserter(res),
+                     [](auto &filter) { return expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(filter); });
+
+      request.filter_expressions = res;
+    }
+    request_state_.requests.emplace_back(request);
+  }
+
+  void Reset() override {
+    input_cursor_->Reset();
+    ResetExecutionState();
+  }
+
+ private:
+  const Symbol output_symbol_;
+  const UniqueCursorPtr input_cursor_;
+  const char *op_name_;
+  std::vector<VertexAccessor> current_batch;
+  std::vector<VertexAccessor>::iterator current_vertex_it;
+  msgs::ExecutionState<msgs::ScanVerticesRequest> request_state_;
+  std::optional<storage::v3::LabelId> label_;
+  std::optional<std::pair<storage::v3::PropertyId, Expression *>> property_expression_pair_;
+  std::optional<std::vector<Expression *>> filter_expressions_;
+};
+
 ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::v3::View view)
     : input_(input ? input : std::make_shared<Once>()), output_symbol_(output_symbol), view_(view) {}
 
 ACCEPT_WITH_INPUT(ScanAll)
 
-class DistributedScanAllCursor;
-
 UniqueCursorPtr ScanAll::MakeCursor(utils::MemoryResource *mem) const {
   EventCounter::IncrementCounter(EventCounter::ScanAllOperator);
 
-  return MakeUniqueCursorPtr<DistributedScanAllCursor>(mem, output_symbol_, input_->MakeCursor(mem), "ScanAll");
+  return MakeUniqueCursorPtr<DistributedScanAllAndFilterCursor>(
+      mem, output_symbol_, input_->MakeCursor(mem), "ScanAll", std::nullopt /*label*/,
+      std::nullopt /*property_expression_pair*/, std::nullopt /*filter_expressions*/);
 }
 
 std::vector<Symbol> ScanAll::ModifiedSymbols(const SymbolTable &table) const {
@@ -357,10 +452,12 @@ ScanAllByLabel::ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, Sy
 
 ACCEPT_WITH_INPUT(ScanAllByLabel)
 
-UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource * /*mem*/) const {
+UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource *mem) const {
   EventCounter::IncrementCounter(EventCounter::ScanAllByLabelOperator);
 
-  throw QueryRuntimeException("ScanAllByLabel is not supported");
+  return MakeUniqueCursorPtr<DistributedScanAllAndFilterCursor>(
+      mem, output_symbol_, input_->MakeCursor(mem), "ScanAllByLabel", label_, std::nullopt /*property_expression_pair*/,
+      std::nullopt /*filter_expressions*/);
 }
 
 // TODO(buda): Implement ScanAllByLabelProperty operator to iterate over
@@ -404,10 +501,12 @@ ScanAllByLabelPropertyValue::ScanAllByLabelPropertyValue(const std::shared_ptr<L
 
 ACCEPT_WITH_INPUT(ScanAllByLabelPropertyValue)
 
-UniqueCursorPtr ScanAllByLabelPropertyValue::MakeCursor(utils::MemoryResource * /*mem*/) const {
+UniqueCursorPtr ScanAllByLabelPropertyValue::MakeCursor(utils::MemoryResource *mem) const {
   EventCounter::IncrementCounter(EventCounter::ScanAllByLabelPropertyValueOperator);
 
-  throw QueryRuntimeException("ScanAllByLabelPropertyValue is not supported");
+  return MakeUniqueCursorPtr<DistributedScanAllAndFilterCursor>(
+      mem, output_symbol_, input_->MakeCursor(mem), "ScanAllByLabelPropertyValue", label_,
+      std::make_pair(property_, expression_), std::nullopt /*filter_expressions*/);
 }
 
 ScanAllByLabelProperty::ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol,
@@ -432,6 +531,7 @@ ACCEPT_WITH_INPUT(ScanAllById)
 
 UniqueCursorPtr ScanAllById::MakeCursor(utils::MemoryResource *mem) const {
   EventCounter::IncrementCounter(EventCounter::ScanAllByIdOperator);
+  // TODO Reimplement when we have reliable conversion between hash value and pk
   auto vertices = [](Frame & /*frame*/, ExecutionContext & /*context*/) -> std::optional<std::vector<VertexAccessor>> {
     return std::nullopt;
   };
@@ -2256,62 +2356,4 @@ bool Foreach::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
   return visitor.PostVisit(*this);
 }
 
-class DistributedScanAllCursor : public Cursor {
- public:
-  explicit DistributedScanAllCursor(Symbol output_symbol, UniqueCursorPtr input_cursor, const char *op_name)
-      : output_symbol_(output_symbol), input_cursor_(std::move(input_cursor)), op_name_(op_name) {}
-
-  using VertexAccessor = accessors::VertexAccessor;
-
-  bool MakeRequest(msgs::ShardRequestManagerInterface &shard_manager) {
-    // TODO(antaljanosbenjamin) Use real label
-    request_state_.label = "label";
-    current_batch = shard_manager.Request(request_state_);
-    current_vertex_it = current_batch.begin();
-    return !current_batch.empty();
-  }
-
-  bool Pull(Frame &frame, ExecutionContext &context) override {
-    SCOPED_PROFILE_OP(op_name_);
-    auto &shard_manager = *context.shard_request_manager;
-    if (MustAbort(context)) throw HintedAbortError();
-    using State = msgs::ExecutionState<msgs::ScanVerticesRequest>;
-
-    if (request_state_.state == State::INITIALIZING) {
-      if (!input_cursor_->Pull(frame, context)) return false;
-    }
-
-    if (current_vertex_it == current_batch.end()) {
-      if (request_state_.state == State::COMPLETED || !MakeRequest(shard_manager)) {
-        ResetExecutionState();
-        return Pull(frame, context);
-      }
-    }
-
-    frame[output_symbol_] = TypedValue(std::move(*current_vertex_it));
-    ++current_vertex_it;
-    return true;
-  }
-
-  void Shutdown() override { input_cursor_->Shutdown(); }
-
-  void ResetExecutionState() {
-    current_batch.clear();
-    current_vertex_it = current_batch.end();
-    request_state_ = msgs::ExecutionState<msgs::ScanVerticesRequest>{};
-  }
-
-  void Reset() override {
-    input_cursor_->Reset();
-    ResetExecutionState();
-  }
-
- private:
-  const Symbol output_symbol_;
-  const UniqueCursorPtr input_cursor_;
-  const char *op_name_;
-  std::vector<VertexAccessor> current_batch;
-  decltype(std::vector<VertexAccessor>().begin()) current_vertex_it;
-  msgs::ExecutionState<msgs::ScanVerticesRequest> request_state_;
-};
 }  // namespace memgraph::query::v2::plan
diff --git a/src/query/v2/requests.hpp b/src/query/v2/requests.hpp
index 80e688c9f..ee6446b77 100644
--- a/src/query/v2/requests.hpp
+++ b/src/query/v2/requests.hpp
@@ -380,9 +380,12 @@ struct ScanVerticesRequest {
   Hlc transaction_id;
   VertexId start_id;
   std::optional<std::vector<PropertyId>> props_to_return;
-  std::optional<std::vector<std::string>> filter_expressions;
   std::optional<size_t> batch_limit;
   StorageView storage_view{StorageView::NEW};
+
+  std::optional<Label> label;
+  std::optional<std::pair<PropertyId, std::string>> property_expression_pair;
+  std::optional<std::vector<std::string>> filter_expressions;
 };
 
 struct ScanResultRow {
diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt
index 24a61a21c..d70c4d867 100644
--- a/tests/unit/CMakeLists.txt
+++ b/tests/unit/CMakeLists.txt
@@ -433,3 +433,6 @@ target_link_libraries(${test_prefix}local_transport mg-io)
 # Test MachineManager with LocalTransport
 add_unit_test(machine_manager.cpp)
 target_link_libraries(${test_prefix}machine_manager mg-io mg-coordinator mg-storage-v3 mg-query-v2)
+
+add_unit_test(pretty_print_ast_to_original_expression_test.cpp)
+target_link_libraries(${test_prefix}pretty_print_ast_to_original_expression_test mg-io mg-expr mg-query-v2)
diff --git a/tests/unit/pretty_print_ast_to_original_expression_test.cpp b/tests/unit/pretty_print_ast_to_original_expression_test.cpp
new file mode 100644
index 000000000..e5d77ae0e
--- /dev/null
+++ b/tests/unit/pretty_print_ast_to_original_expression_test.cpp
@@ -0,0 +1,94 @@
+// Copyright 2022 Memgraph Ltd.
+//
+// Use of this software is governed by the Business Source License
+// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
+// License, and you may not use this file except in compliance with the Business Source License.
+//
+// As of the Change Date specified in that file, in accordance with
+// the Business Source License, use of this software will be governed
+// by the Apache License, Version 2.0, included in the file
+// licenses/APL.txt.
+
+#include <chrono>
+#include <limits>
+#include <thread>
+
+#include <gtest/gtest.h>
+
+#include "common/types.hpp"
+#include "exceptions.hpp"
+#include "parser/opencypher/parser.hpp"
+#include "query/v2/bindings/cypher_main_visitor.hpp"
+#include "query/v2/bindings/eval.hpp"
+#include "query/v2/bindings/frame.hpp"
+#include "query/v2/bindings/symbol_generator.hpp"
+#include "query/v2/bindings/symbol_table.hpp"
+#include "query/v2/bindings/typed_value.hpp"
+#include "query/v2/db_accessor.hpp"
+#include "query/v2/frontend/ast/ast.hpp"
+#include "utils/string.hpp"
+
+#include "expr/ast/pretty_print_ast_to_original_expression.hpp"
+
+namespace memgraph::query::v2::test {
+
+class ExpressiontoStringTest : public ::testing::TestWithParam<std::pair<std::string, std::string>> {
+ protected:
+  AstStorage storage;
+};
+
+TEST_P(ExpressiontoStringTest, Example) {
+  const auto [original_expression, expected_expression] = GetParam();
+
+  memgraph::frontend::opencypher::Parser<frontend::opencypher::ParserOpTag::EXPRESSION> parser(original_expression);
+  expr::ParsingContext pc;
+  CypherMainVisitor visitor(pc, &storage);
+
+  auto *ast = parser.tree();
+  auto expression = visitor.visit(ast);
+
+  const auto rewritten_expression =
+      expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(std::any_cast<Expression *>(expression));
+
+  // We check that the expression is what we expect
+  EXPECT_EQ(rewritten_expression, expected_expression);
+
+  // We check that the rewritten expression can be parsed again
+  memgraph::frontend::opencypher::Parser<frontend::opencypher::ParserOpTag::EXPRESSION> parser2(rewritten_expression);
+  expr::ParsingContext pc2;
+  CypherMainVisitor visitor2(pc2, &storage);
+
+  auto *ast2 = parser2.tree();
+  auto expression2 = visitor2.visit(ast2);
+  const auto rewritten_expression2 =
+      expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(std::any_cast<Expression *>(expression));
+
+  // We check that the re-written expression from the re-written expression is exactly the same
+  EXPECT_EQ(rewritten_expression, rewritten_expression2);
+}
+
+INSTANTIATE_TEST_CASE_P(
+    PARAMETER, ExpressiontoStringTest,
+    ::testing::Values(
+        std::make_pair(std::string("2 / 1"), std::string("(2 / 1)")),
+        std::make_pair(std::string("2 + 1 + 5 + 2"), std::string("(((2 + 1) + 5) + 2)")),
+        std::make_pair(std::string("2 + 1 * 5 + 2"), std::string("((2 + (1 * 5)) + 2)")),
+        std::make_pair(std::string("2 + 1 * (5 + 2)"), std::string("(2 + (1 * (5 + 2)))")),
+        std::make_pair(std::string("n"), std::string("MG_SYMBOL_NODE")),
+        std::make_pair(std::string("n.property1"), std::string("MG_SYMBOL_NODE.property1")),
+        std::make_pair(std::string("n.property1 > 3"), std::string("(MG_SYMBOL_NODE.property1 > 3)")),
+        std::make_pair(std::string("n.property1 != n.property2"),
+                       std::string("(MG_SYMBOL_NODE.property1 != MG_SYMBOL_NODE.property2)")),
+        std::make_pair(std::string("n And n"), std::string("(MG_SYMBOL_NODE And MG_SYMBOL_NODE)")),
+        std::make_pair(std::string("n.property1 > 3 And n.property + 7 < 10"),
+                       std::string("((MG_SYMBOL_NODE.property1 > 3) And ((MG_SYMBOL_NODE.property + 7) < 10))")),
+        std::make_pair(std::string("MG_SYMBOL_NODE.property1 > 3 And (MG_SYMBOL_NODE.property + 7 < 10 Or "
+                                   "MG_SYMBOL_NODE.property3 = true)"),
+                       std::string("((MG_SYMBOL_NODE.property1 > 3) And (((MG_SYMBOL_NODE.property + 7) < 10) Or "
+                                   "(MG_SYMBOL_NODE.property3 = true)))")),
+        std::make_pair(std::string("(MG_SYMBOL_NODE.property1 > 3 Or MG_SYMBOL_NODE.property + 7 < 10) And "
+                                   "MG_SYMBOL_NODE.property3 = true"),
+                       std::string("(((MG_SYMBOL_NODE.property1 > 3) Or ((MG_SYMBOL_NODE.property + 7) < 10)) And "
+                                   "(MG_SYMBOL_NODE.property3 = true))"))));
+
+}  // namespace memgraph::query::v2::test