From 4cb3b064c457fbc7671aeef233fdf90a7049274b Mon Sep 17 00:00:00 2001 From: Jeremy B <97525434+42jeremy@users.noreply.github.com> Date: Wed, 12 Oct 2022 11:46:59 +0200 Subject: [PATCH] Add filter to scan all (#575) Add several versions of ScanAll with filters. Add helper function to transform an expression into string that can be parsed again once on the storage. --- ...retty_print_ast_to_original_expression.hpp | 279 ++++++++++++++++++ src/query/v2/plan/operator.cpp | 172 +++++++---- src/query/v2/requests.hpp | 5 +- tests/unit/CMakeLists.txt | 3 + ..._print_ast_to_original_expression_test.cpp | 94 ++++++ 5 files changed, 487 insertions(+), 66 deletions(-) create mode 100644 src/expr/ast/pretty_print_ast_to_original_expression.hpp create mode 100644 tests/unit/pretty_print_ast_to_original_expression_test.cpp diff --git a/src/expr/ast/pretty_print_ast_to_original_expression.hpp b/src/expr/ast/pretty_print_ast_to_original_expression.hpp new file mode 100644 index 000000000..de8b5b89b --- /dev/null +++ b/src/expr/ast/pretty_print_ast_to_original_expression.hpp @@ -0,0 +1,279 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <iostream> +#include <type_traits> + +#include "expr/ast.hpp" +#include "expr/typed_value.hpp" +#include "utils/algorithm.hpp" +#include "utils/logging.hpp" +#include "utils/string.hpp" + +namespace memgraph::expr { +inline constexpr const char *identifier_node_symbol = "MG_SYMBOL_NODE"; +inline constexpr const char *identifier_edge_symbol = "MG_SYMBOL_EDGE"; + +namespace detail { +template <typename T> +void PrintObject(std::ostream *out, const T &arg) { + static_assert(!std::is_convertible<T, Expression *>::value, + "This overload shouldn't be called with pointers convertible " + "to Expression *. This means your other PrintObject overloads aren't " + "being called for certain AST nodes when they should (or perhaps such " + "overloads don't exist yet)."); + *out << arg; +} + +inline void PrintObject(std::ostream *out, const std::string &str) { *out << str; } + +inline void PrintObject(std::ostream * /*out*/, Aggregation::Op /*op*/) { + throw utils::NotYetImplemented("PrintObject: Aggregation::Op"); +} + +inline void PrintObject(std::ostream * /*out*/, Expression * /*expr*/); + +inline void PrintObject(std::ostream *out, Identifier *expr) { PrintObject(out, static_cast<Expression *>(expr)); } + +template <typename T> +void PrintObject(std::ostream * /*out*/, const std::vector<T> & /*vec*/) { + throw utils::NotYetImplemented("PrintObject: vector<T>"); +} + +template <typename T> +void PrintObject(std::ostream * /*out*/, const std::vector<T, utils::Allocator<T>> & /*vec*/) { + throw utils::NotYetImplemented("PrintObject: vector<T, utils::Allocator<T>>"); +} + +template <typename K, typename V> +void PrintObject(std::ostream * /*out*/, const std::map<K, V> & /*map*/) { + throw utils::NotYetImplemented("PrintObject: map<K, V>"); +} + +template <typename T> +void PrintObject(std::ostream * /*out*/, const utils::pmr::map<utils::pmr::string, T> & /*map*/) { + throw utils::NotYetImplemented("PrintObject: map<utils::pmr::string, V>"); +} + +template <typename T1, typename T2, typename T3> +inline void PrintObject(std::ostream *out, const TypedValueT<T1, T2, T3> &value) { + using TypedValue = TypedValueT<T1, T2, T3>; + switch (value.type()) { + case TypedValue::Type::Null: + *out << "null"; + break; + case TypedValue::Type::String: + PrintObject(out, value.ValueString()); + break; + case TypedValue::Type::Bool: + *out << (value.ValueBool() ? "true" : "false"); + break; + case TypedValue::Type::Int: + PrintObject(out, value.ValueInt()); + break; + case TypedValue::Type::Double: + PrintObject(out, value.ValueDouble()); + break; + case TypedValue::Type::List: + PrintObject(out, value.ValueList()); + break; + case TypedValue::Type::Map: + PrintObject(out, value.ValueMap()); + break; + case TypedValue::Type::Date: + PrintObject(out, value.ValueDate()); + break; + case TypedValue::Type::Duration: + PrintObject(out, value.ValueDuration()); + break; + case TypedValue::Type::LocalTime: + PrintObject(out, value.ValueLocalTime()); + break; + case TypedValue::Type::LocalDateTime: + PrintObject(out, value.ValueLocalDateTime()); + break; + default: + MG_ASSERT(false, "PrintObject(std::ostream *out, const TypedValue &value) should not reach here"); + } +} + +template <typename T> +void PrintOperatorArgs(const std::string & /*name*/, std::ostream *out, bool with_parenthesis, const T &arg) { + PrintObject(out, arg); + if (with_parenthesis) { + *out << ")"; + } +} + +template <typename T, typename... Ts> +void PrintOperatorArgs(const std::string &name, std::ostream *out, bool with_parenthesis, const T &arg, + const Ts &...args) { + PrintObject(out, arg); + *out << " " << name << " "; + PrintOperatorArgs(name, out, with_parenthesis, args...); +} + +template <typename... Ts> +void PrintOperator(const std::string &name, std::ostream *out, bool with_parenthesis, const Ts &...args) { + if (with_parenthesis) { + *out << "("; + } + PrintOperatorArgs(name, out, with_parenthesis, args...); +} + +// new +template <typename T> +void PrintOperatorArgs(std::ostream *out, const T &arg) { + PrintObject(out, arg); +} + +template <typename T, typename... Ts> +void PrintOperatorArgs(std::ostream *out, const T &arg, const Ts &...args) { + PrintObject(out, arg); + PrintOperatorArgs(out, args...); +} + +template <typename... Ts> +void PrintOperator(std::ostream *out, const Ts &...args) { + PrintOperatorArgs(out, args...); +} +} // namespace detail + +class ExpressionPrettyPrinter : public ExpressionVisitor<void> { + public: + explicit ExpressionPrettyPrinter(std::ostream *out) : out_(out) {} + + // Unary operators + // NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define UNARY_OPERATOR_VISIT(OP_NODE, OP_STR) \ + /* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \ + void Visit(OP_NODE &op) override { detail::PrintOperator(OP_STR, out_, false /*with_parenthesis*/, op.expression_); } + + UNARY_OPERATOR_VISIT(NotOperator, "Not"); + UNARY_OPERATOR_VISIT(UnaryPlusOperator, "+"); + UNARY_OPERATOR_VISIT(UnaryMinusOperator, "-"); + UNARY_OPERATOR_VISIT(IsNullOperator, "IsNull"); + +#undef UNARY_OPERATOR_VISIT + + // Binary operators +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BINARY_OPERATOR_VISIT(OP_NODE, OP_STR) \ + /* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \ + void Visit(OP_NODE &op) override { \ + detail::PrintOperator(OP_STR, out_, true /*with_parenthesis*/, op.expression1_, op.expression2_); \ + } +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define BINARY_OPERATOR_VISIT_NOT_IMPL(OP_NODE, OP_STR) \ + /* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \ + void Visit(OP_NODE & /*op*/) override { throw utils::NotYetImplemented("OP_NODE"); } + + BINARY_OPERATOR_VISIT(OrOperator, "Or"); + BINARY_OPERATOR_VISIT(XorOperator, "Xor"); + BINARY_OPERATOR_VISIT(AndOperator, "And"); + BINARY_OPERATOR_VISIT(AdditionOperator, "+"); + BINARY_OPERATOR_VISIT(SubtractionOperator, "-"); + BINARY_OPERATOR_VISIT(MultiplicationOperator, "*"); + BINARY_OPERATOR_VISIT(DivisionOperator, "/"); + BINARY_OPERATOR_VISIT(ModOperator, "%"); + BINARY_OPERATOR_VISIT(NotEqualOperator, "!="); + BINARY_OPERATOR_VISIT(EqualOperator, "="); + BINARY_OPERATOR_VISIT(LessOperator, "<"); + BINARY_OPERATOR_VISIT(GreaterOperator, ">"); + BINARY_OPERATOR_VISIT(LessEqualOperator, "<="); + BINARY_OPERATOR_VISIT(GreaterEqualOperator, ">="); + BINARY_OPERATOR_VISIT_NOT_IMPL(InListOperator, "In"); + BINARY_OPERATOR_VISIT_NOT_IMPL(SubscriptOperator, "Subscript"); + +#undef BINARY_OPERATOR_VISIT +#undef BINARY_OPERATOR_VISIT_NOT_IMPL + + // Other + void Visit(ListSlicingOperator & /*op*/) override { throw utils::NotYetImplemented("ListSlicingOperator"); } + + void Visit(IfOperator & /*op*/) override { throw utils::NotYetImplemented("IfOperator"); } + + void Visit(ListLiteral & /*op*/) override { throw utils::NotYetImplemented("ListLiteral"); } + + void Visit(MapLiteral & /*op*/) override { throw utils::NotYetImplemented("MapLiteral"); } + + void Visit(LabelsTest & /*op*/) override { throw utils::NotYetImplemented("LabelsTest"); } + + void Visit(Aggregation & /*op*/) override { throw utils::NotYetImplemented("Aggregation"); } + + void Visit(Function & /*op*/) override { throw utils::NotYetImplemented("Function"); } + + void Visit(Reduce & /*op*/) override { throw utils::NotYetImplemented("Reduce"); } + + void Visit(Coalesce & /*op*/) override { throw utils::NotYetImplemented("Coalesce"); } + + void Visit(Extract & /*op*/) override { throw utils::NotYetImplemented("Extract"); } + + void Visit(All & /*op*/) override { throw utils::NotYetImplemented("All"); } + + void Visit(Single & /*op*/) override { throw utils::NotYetImplemented("Single"); } + + void Visit(Any & /*op*/) override { throw utils::NotYetImplemented("Any"); } + + void Visit(None & /*op*/) override { throw utils::NotYetImplemented("None"); } + + void Visit(Identifier &op) override { + auto is_node = true; + auto is_edge = false; + auto is_other = false; + if (is_node) { + detail::PrintOperator(out_, identifier_node_symbol); + } else if (is_edge) { + detail::PrintOperator(out_, identifier_edge_symbol); + } else { + MG_ASSERT(is_other); + detail::PrintOperator(out_, op.name_); + } + } + + void Visit(PrimitiveLiteral &op) override { detail::PrintObject(out_, op.value_); } + + void Visit(PropertyLookup &op) override { detail::PrintOperator(out_, op.expression_, ".", op.property_.name); } + + void Visit(ParameterLookup & /*op*/) override { throw utils::NotYetImplemented("ParameterLookup"); } + + void Visit(NamedExpression & /*op*/) override { throw utils::NotYetImplemented("NamedExpression"); } + + void Visit(RegexMatch & /*op*/) override { throw utils::NotYetImplemented("RegexMatch"); } + + private: + std::ostream *out_; +}; + +namespace detail { +inline void PrintObject(std::ostream *out, Expression *expr) { + if (expr) { + ExpressionPrettyPrinter printer{out}; + expr->Accept(printer); + } else { + *out << "<null>"; + } +} +} // namespace detail + +inline void PrintExpressionToOriginalAndReplaceNodeAndEdgeSymbols(Expression *expr, std::ostream *out) { + ExpressionPrettyPrinter printer{out}; + expr->Accept(printer); +} + +inline std::string ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(Expression *expr) { + std::ostringstream ss; + expr::PrintExpressionToOriginalAndReplaceNodeAndEdgeSymbols(expr, &ss); + return ss.str(); +} +} // namespace memgraph::expr diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp index 0a67169f0..10369c34b 100644 --- a/src/query/v2/plan/operator.cpp +++ b/src/query/v2/plan/operator.cpp @@ -26,6 +26,7 @@ #include <cppitertools/chain.hpp> #include <cppitertools/imap.hpp> +#include "expr/ast/pretty_print_ast_to_original_expression.hpp" #include "expr/exceptions.hpp" #include "query/exceptions.hpp" #include "query/v2/accessors.hpp" @@ -332,17 +333,111 @@ class ScanAllCursor : public Cursor { msgs::ExecutionState<msgs::ScanVerticesRequest> request_state; }; +class DistributedScanAllAndFilterCursor : public Cursor { + public: + explicit DistributedScanAllAndFilterCursor( + Symbol output_symbol, UniqueCursorPtr input_cursor, const char *op_name, + std::optional<storage::v3::LabelId> label, + std::optional<std::pair<storage::v3::PropertyId, Expression *>> property_expression_pair, + std::optional<std::vector<Expression *>> filter_expressions) + : output_symbol_(output_symbol), + input_cursor_(std::move(input_cursor)), + op_name_(op_name), + label_(label), + property_expression_pair_(property_expression_pair), + filter_expressions_(filter_expressions) { + ResetExecutionState(); + } + + using VertexAccessor = accessors::VertexAccessor; + + bool MakeRequest(msgs::ShardRequestManagerInterface &shard_manager) { + current_batch = shard_manager.Request(request_state_); + current_vertex_it = current_batch.begin(); + return !current_batch.empty(); + } + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP(op_name_); + auto &shard_manager = *context.shard_request_manager; + if (MustAbort(context)) { + throw HintedAbortError(); + } + using State = msgs::ExecutionState<msgs::ScanVerticesRequest>; + + if (request_state_.state == State::INITIALIZING) { + if (!input_cursor_->Pull(frame, context)) { + return false; + } + } + + if (current_vertex_it == current_batch.end()) { + if (request_state_.state == State::COMPLETED || !MakeRequest(shard_manager)) { + ResetExecutionState(); + return Pull(frame, context); + } + } + + frame[output_symbol_] = TypedValue(std::move(*current_vertex_it)); + ++current_vertex_it; + return true; + } + + void Shutdown() override { input_cursor_->Shutdown(); } + + void ResetExecutionState() { + current_batch.clear(); + current_vertex_it = current_batch.end(); + request_state_ = msgs::ExecutionState<msgs::ScanVerticesRequest>{}; + + auto request = msgs::ScanVerticesRequest{}; + if (label_.has_value()) { + request.label = msgs::Label{.id = label_.value()}; + } + if (property_expression_pair_.has_value()) { + request.property_expression_pair = std::make_pair( + property_expression_pair_.value().first, + expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(property_expression_pair_.value().second)); + } + if (filter_expressions_.has_value()) { + auto res = std::vector<std::string>{}; + res.reserve(filter_expressions_->size()); + std::transform(filter_expressions_->begin(), filter_expressions_->end(), std::back_inserter(res), + [](auto &filter) { return expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(filter); }); + + request.filter_expressions = res; + } + request_state_.requests.emplace_back(request); + } + + void Reset() override { + input_cursor_->Reset(); + ResetExecutionState(); + } + + private: + const Symbol output_symbol_; + const UniqueCursorPtr input_cursor_; + const char *op_name_; + std::vector<VertexAccessor> current_batch; + std::vector<VertexAccessor>::iterator current_vertex_it; + msgs::ExecutionState<msgs::ScanVerticesRequest> request_state_; + std::optional<storage::v3::LabelId> label_; + std::optional<std::pair<storage::v3::PropertyId, Expression *>> property_expression_pair_; + std::optional<std::vector<Expression *>> filter_expressions_; +}; + ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::v3::View view) : input_(input ? input : std::make_shared<Once>()), output_symbol_(output_symbol), view_(view) {} ACCEPT_WITH_INPUT(ScanAll) -class DistributedScanAllCursor; - UniqueCursorPtr ScanAll::MakeCursor(utils::MemoryResource *mem) const { EventCounter::IncrementCounter(EventCounter::ScanAllOperator); - return MakeUniqueCursorPtr<DistributedScanAllCursor>(mem, output_symbol_, input_->MakeCursor(mem), "ScanAll"); + return MakeUniqueCursorPtr<DistributedScanAllAndFilterCursor>( + mem, output_symbol_, input_->MakeCursor(mem), "ScanAll", std::nullopt /*label*/, + std::nullopt /*property_expression_pair*/, std::nullopt /*filter_expressions*/); } std::vector<Symbol> ScanAll::ModifiedSymbols(const SymbolTable &table) const { @@ -357,10 +452,12 @@ ScanAllByLabel::ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, Sy ACCEPT_WITH_INPUT(ScanAllByLabel) -UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource * /*mem*/) const { +UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource *mem) const { EventCounter::IncrementCounter(EventCounter::ScanAllByLabelOperator); - throw QueryRuntimeException("ScanAllByLabel is not supported"); + return MakeUniqueCursorPtr<DistributedScanAllAndFilterCursor>( + mem, output_symbol_, input_->MakeCursor(mem), "ScanAllByLabel", label_, std::nullopt /*property_expression_pair*/, + std::nullopt /*filter_expressions*/); } // TODO(buda): Implement ScanAllByLabelProperty operator to iterate over @@ -404,10 +501,12 @@ ScanAllByLabelPropertyValue::ScanAllByLabelPropertyValue(const std::shared_ptr<L ACCEPT_WITH_INPUT(ScanAllByLabelPropertyValue) -UniqueCursorPtr ScanAllByLabelPropertyValue::MakeCursor(utils::MemoryResource * /*mem*/) const { +UniqueCursorPtr ScanAllByLabelPropertyValue::MakeCursor(utils::MemoryResource *mem) const { EventCounter::IncrementCounter(EventCounter::ScanAllByLabelPropertyValueOperator); - throw QueryRuntimeException("ScanAllByLabelPropertyValue is not supported"); + return MakeUniqueCursorPtr<DistributedScanAllAndFilterCursor>( + mem, output_symbol_, input_->MakeCursor(mem), "ScanAllByLabelPropertyValue", label_, + std::make_pair(property_, expression_), std::nullopt /*filter_expressions*/); } ScanAllByLabelProperty::ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, @@ -432,6 +531,7 @@ ACCEPT_WITH_INPUT(ScanAllById) UniqueCursorPtr ScanAllById::MakeCursor(utils::MemoryResource *mem) const { EventCounter::IncrementCounter(EventCounter::ScanAllByIdOperator); + // TODO Reimplement when we have reliable conversion between hash value and pk auto vertices = [](Frame & /*frame*/, ExecutionContext & /*context*/) -> std::optional<std::vector<VertexAccessor>> { return std::nullopt; }; @@ -2256,62 +2356,4 @@ bool Foreach::Accept(HierarchicalLogicalOperatorVisitor &visitor) { return visitor.PostVisit(*this); } -class DistributedScanAllCursor : public Cursor { - public: - explicit DistributedScanAllCursor(Symbol output_symbol, UniqueCursorPtr input_cursor, const char *op_name) - : output_symbol_(output_symbol), input_cursor_(std::move(input_cursor)), op_name_(op_name) {} - - using VertexAccessor = accessors::VertexAccessor; - - bool MakeRequest(msgs::ShardRequestManagerInterface &shard_manager) { - // TODO(antaljanosbenjamin) Use real label - request_state_.label = "label"; - current_batch = shard_manager.Request(request_state_); - current_vertex_it = current_batch.begin(); - return !current_batch.empty(); - } - - bool Pull(Frame &frame, ExecutionContext &context) override { - SCOPED_PROFILE_OP(op_name_); - auto &shard_manager = *context.shard_request_manager; - if (MustAbort(context)) throw HintedAbortError(); - using State = msgs::ExecutionState<msgs::ScanVerticesRequest>; - - if (request_state_.state == State::INITIALIZING) { - if (!input_cursor_->Pull(frame, context)) return false; - } - - if (current_vertex_it == current_batch.end()) { - if (request_state_.state == State::COMPLETED || !MakeRequest(shard_manager)) { - ResetExecutionState(); - return Pull(frame, context); - } - } - - frame[output_symbol_] = TypedValue(std::move(*current_vertex_it)); - ++current_vertex_it; - return true; - } - - void Shutdown() override { input_cursor_->Shutdown(); } - - void ResetExecutionState() { - current_batch.clear(); - current_vertex_it = current_batch.end(); - request_state_ = msgs::ExecutionState<msgs::ScanVerticesRequest>{}; - } - - void Reset() override { - input_cursor_->Reset(); - ResetExecutionState(); - } - - private: - const Symbol output_symbol_; - const UniqueCursorPtr input_cursor_; - const char *op_name_; - std::vector<VertexAccessor> current_batch; - decltype(std::vector<VertexAccessor>().begin()) current_vertex_it; - msgs::ExecutionState<msgs::ScanVerticesRequest> request_state_; -}; } // namespace memgraph::query::v2::plan diff --git a/src/query/v2/requests.hpp b/src/query/v2/requests.hpp index 80e688c9f..ee6446b77 100644 --- a/src/query/v2/requests.hpp +++ b/src/query/v2/requests.hpp @@ -380,9 +380,12 @@ struct ScanVerticesRequest { Hlc transaction_id; VertexId start_id; std::optional<std::vector<PropertyId>> props_to_return; - std::optional<std::vector<std::string>> filter_expressions; std::optional<size_t> batch_limit; StorageView storage_view{StorageView::NEW}; + + std::optional<Label> label; + std::optional<std::pair<PropertyId, std::string>> property_expression_pair; + std::optional<std::vector<std::string>> filter_expressions; }; struct ScanResultRow { diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 24a61a21c..d70c4d867 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -433,3 +433,6 @@ target_link_libraries(${test_prefix}local_transport mg-io) # Test MachineManager with LocalTransport add_unit_test(machine_manager.cpp) target_link_libraries(${test_prefix}machine_manager mg-io mg-coordinator mg-storage-v3 mg-query-v2) + +add_unit_test(pretty_print_ast_to_original_expression_test.cpp) +target_link_libraries(${test_prefix}pretty_print_ast_to_original_expression_test mg-io mg-expr mg-query-v2) diff --git a/tests/unit/pretty_print_ast_to_original_expression_test.cpp b/tests/unit/pretty_print_ast_to_original_expression_test.cpp new file mode 100644 index 000000000..e5d77ae0e --- /dev/null +++ b/tests/unit/pretty_print_ast_to_original_expression_test.cpp @@ -0,0 +1,94 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <chrono> +#include <limits> +#include <thread> + +#include <gtest/gtest.h> + +#include "common/types.hpp" +#include "exceptions.hpp" +#include "parser/opencypher/parser.hpp" +#include "query/v2/bindings/cypher_main_visitor.hpp" +#include "query/v2/bindings/eval.hpp" +#include "query/v2/bindings/frame.hpp" +#include "query/v2/bindings/symbol_generator.hpp" +#include "query/v2/bindings/symbol_table.hpp" +#include "query/v2/bindings/typed_value.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "utils/string.hpp" + +#include "expr/ast/pretty_print_ast_to_original_expression.hpp" + +namespace memgraph::query::v2::test { + +class ExpressiontoStringTest : public ::testing::TestWithParam<std::pair<std::string, std::string>> { + protected: + AstStorage storage; +}; + +TEST_P(ExpressiontoStringTest, Example) { + const auto [original_expression, expected_expression] = GetParam(); + + memgraph::frontend::opencypher::Parser<frontend::opencypher::ParserOpTag::EXPRESSION> parser(original_expression); + expr::ParsingContext pc; + CypherMainVisitor visitor(pc, &storage); + + auto *ast = parser.tree(); + auto expression = visitor.visit(ast); + + const auto rewritten_expression = + expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(std::any_cast<Expression *>(expression)); + + // We check that the expression is what we expect + EXPECT_EQ(rewritten_expression, expected_expression); + + // We check that the rewritten expression can be parsed again + memgraph::frontend::opencypher::Parser<frontend::opencypher::ParserOpTag::EXPRESSION> parser2(rewritten_expression); + expr::ParsingContext pc2; + CypherMainVisitor visitor2(pc2, &storage); + + auto *ast2 = parser2.tree(); + auto expression2 = visitor2.visit(ast2); + const auto rewritten_expression2 = + expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(std::any_cast<Expression *>(expression)); + + // We check that the re-written expression from the re-written expression is exactly the same + EXPECT_EQ(rewritten_expression, rewritten_expression2); +} + +INSTANTIATE_TEST_CASE_P( + PARAMETER, ExpressiontoStringTest, + ::testing::Values( + std::make_pair(std::string("2 / 1"), std::string("(2 / 1)")), + std::make_pair(std::string("2 + 1 + 5 + 2"), std::string("(((2 + 1) + 5) + 2)")), + std::make_pair(std::string("2 + 1 * 5 + 2"), std::string("((2 + (1 * 5)) + 2)")), + std::make_pair(std::string("2 + 1 * (5 + 2)"), std::string("(2 + (1 * (5 + 2)))")), + std::make_pair(std::string("n"), std::string("MG_SYMBOL_NODE")), + std::make_pair(std::string("n.property1"), std::string("MG_SYMBOL_NODE.property1")), + std::make_pair(std::string("n.property1 > 3"), std::string("(MG_SYMBOL_NODE.property1 > 3)")), + std::make_pair(std::string("n.property1 != n.property2"), + std::string("(MG_SYMBOL_NODE.property1 != MG_SYMBOL_NODE.property2)")), + std::make_pair(std::string("n And n"), std::string("(MG_SYMBOL_NODE And MG_SYMBOL_NODE)")), + std::make_pair(std::string("n.property1 > 3 And n.property + 7 < 10"), + std::string("((MG_SYMBOL_NODE.property1 > 3) And ((MG_SYMBOL_NODE.property + 7) < 10))")), + std::make_pair(std::string("MG_SYMBOL_NODE.property1 > 3 And (MG_SYMBOL_NODE.property + 7 < 10 Or " + "MG_SYMBOL_NODE.property3 = true)"), + std::string("((MG_SYMBOL_NODE.property1 > 3) And (((MG_SYMBOL_NODE.property + 7) < 10) Or " + "(MG_SYMBOL_NODE.property3 = true)))")), + std::make_pair(std::string("(MG_SYMBOL_NODE.property1 > 3 Or MG_SYMBOL_NODE.property + 7 < 10) And " + "MG_SYMBOL_NODE.property3 = true"), + std::string("(((MG_SYMBOL_NODE.property1 > 3) Or ((MG_SYMBOL_NODE.property + 7) < 10)) And " + "(MG_SYMBOL_NODE.property3 = true))")))); + +} // namespace memgraph::query::v2::test