memgraph/src/expr/ast/pretty_print.hpp
2022-09-07 18:15:32 +03:00

272 lines
9.2 KiB
C++

// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <iostream>
#include <type_traits>
#include "expr/ast.hpp"
#include "expr/typed_value.hpp"
#include "utils/algorithm.hpp"
#include "utils/logging.hpp"
#include "utils/string.hpp"
namespace memgraph::expr {
namespace detail {
template <typename T>
void PrintObject(std::ostream *out, const T &arg) {
static_assert(!std::is_convertible<T, Expression *>::value,
"This overload shouldn't be called with pointers convertible "
"to Expression *. This means your other PrintObject overloads aren't "
"being called for certain AST nodes when they should (or perhaps such "
"overloads don't exist yet).");
*out << arg;
}
inline void PrintObject(std::ostream *out, const std::string &str) { *out << utils::Escape(str); }
inline void PrintObject(std::ostream *out, Aggregation::Op op) { *out << Aggregation::OpToString(op); }
inline void PrintObject(std::ostream *out, Expression *expr);
inline void PrintObject(std::ostream *out, Identifier *expr) { PrintObject(out, static_cast<Expression *>(expr)); }
template <typename T>
void PrintObject(std::ostream *out, const std::vector<T> &vec) {
*out << "[";
utils::PrintIterable(*out, vec, ", ", [](auto &stream, const auto &item) { PrintObject(&stream, item); });
*out << "]";
}
template <typename T>
void PrintObject(std::ostream *out, const std::vector<T, utils::Allocator<T>> &vec) {
*out << "[";
utils::PrintIterable(*out, vec, ", ", [](auto &stream, const auto &item) { PrintObject(&stream, item); });
*out << "]";
}
template <typename K, typename V>
void PrintObject(std::ostream *out, const std::map<K, V> &map) {
*out << "{";
utils::PrintIterable(*out, map, ", ", [](auto &stream, const auto &item) {
PrintObject(&stream, item.first);
stream << ": ";
PrintObject(&stream, item.second);
});
*out << "}";
}
template <typename T>
void PrintObject(std::ostream *out, const utils::pmr::map<utils::pmr::string, T> &map) {
*out << "{";
utils::PrintIterable(*out, map, ", ", [](auto &stream, const auto &item) {
PrintObject(&stream, item.first);
stream << ": ";
PrintObject(&stream, item.second);
});
*out << "}";
}
template <typename T1, typename T2, typename T3>
inline void PrintObject(std::ostream *out, const TypedValueT<T1, T2, T3> &value) {
using TypedValue = TypedValueT<T1, T2, T3>;
switch (value.type()) {
case TypedValue::Type::Null:
*out << "null";
break;
case TypedValue::Type::String:
PrintObject(out, value.ValueString());
break;
case TypedValue::Type::Bool:
*out << (value.ValueBool() ? "true" : "false");
break;
case TypedValue::Type::Int:
PrintObject(out, value.ValueInt());
break;
case TypedValue::Type::Double:
PrintObject(out, value.ValueDouble());
break;
case TypedValue::Type::List:
PrintObject(out, value.ValueList());
break;
case TypedValue::Type::Map:
PrintObject(out, value.ValueMap());
break;
case TypedValue::Type::Date:
PrintObject(out, value.ValueDate());
break;
case TypedValue::Type::Duration:
PrintObject(out, value.ValueDuration());
break;
case TypedValue::Type::LocalTime:
PrintObject(out, value.ValueLocalTime());
break;
case TypedValue::Type::LocalDateTime:
PrintObject(out, value.ValueLocalDateTime());
break;
default:
MG_ASSERT(false, "PrintObject(std::ostream *out, const TypedValue &value) should not reach here");
}
}
template <typename T>
void PrintOperatorArgs(std::ostream *out, const T &arg) {
*out << " ";
PrintObject(out, arg);
*out << ")";
}
template <typename T, typename... Ts>
void PrintOperatorArgs(std::ostream *out, const T &arg, const Ts &...args) {
*out << " ";
PrintObject(out, arg);
PrintOperatorArgs(out, args...);
}
template <typename... Ts>
void PrintOperator(std::ostream *out, const std::string &name, const Ts &...args) {
*out << "(" << name;
PrintOperatorArgs(out, args...);
}
} // namespace detail
class ExpressionPrettyPrinter : public ExpressionVisitor<void> {
public:
explicit ExpressionPrettyPrinter(std::ostream *out) : out_(out) {}
// Unary operators
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define UNARY_OPERATOR_VISIT(OP_NODE, OP_STR) \
/* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
void Visit(OP_NODE &op) override { detail::PrintOperator(out_, OP_STR, op.expression_); }
UNARY_OPERATOR_VISIT(NotOperator, "Not");
UNARY_OPERATOR_VISIT(UnaryPlusOperator, "+");
UNARY_OPERATOR_VISIT(UnaryMinusOperator, "-");
UNARY_OPERATOR_VISIT(IsNullOperator, "IsNull");
#undef UNARY_OPERATOR_VISIT
// Binary operators
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define BINARY_OPERATOR_VISIT(OP_NODE, OP_STR) \
/* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
void Visit(OP_NODE &op) override { detail::PrintOperator(out_, OP_STR, op.expression1_, op.expression2_); }
BINARY_OPERATOR_VISIT(OrOperator, "Or");
BINARY_OPERATOR_VISIT(XorOperator, "Xor");
BINARY_OPERATOR_VISIT(AndOperator, "And");
BINARY_OPERATOR_VISIT(AdditionOperator, "+");
BINARY_OPERATOR_VISIT(SubtractionOperator, "-");
BINARY_OPERATOR_VISIT(MultiplicationOperator, "*");
BINARY_OPERATOR_VISIT(DivisionOperator, "/");
BINARY_OPERATOR_VISIT(ModOperator, "%");
BINARY_OPERATOR_VISIT(NotEqualOperator, "!=");
BINARY_OPERATOR_VISIT(EqualOperator, "==");
BINARY_OPERATOR_VISIT(LessOperator, "<");
BINARY_OPERATOR_VISIT(GreaterOperator, ">");
BINARY_OPERATOR_VISIT(LessEqualOperator, "<=");
BINARY_OPERATOR_VISIT(GreaterEqualOperator, ">=");
BINARY_OPERATOR_VISIT(InListOperator, "In");
BINARY_OPERATOR_VISIT(SubscriptOperator, "Subscript");
#undef BINARY_OPERATOR_VISIT
// Other
void Visit(ListSlicingOperator &op) override {
detail::PrintOperator(out_, "ListSlicing", op.list_, op.lower_bound_, op.upper_bound_);
}
void Visit(IfOperator &op) override {
detail::PrintOperator(out_, "If", op.condition_, op.then_expression_, op.else_expression_);
}
void Visit(ListLiteral &op) override { detail::PrintOperator(out_, "ListLiteral", op.elements_); }
void Visit(MapLiteral &op) override {
std::map<std::string, Expression *> map;
for (const auto &kv : op.elements_) {
map[kv.first.name] = kv.second;
}
detail::PrintObject(out_, map);
}
void Visit(LabelsTest &op) override { detail::PrintOperator(out_, "LabelsTest", op.expression_); }
void Visit(Aggregation &op) override { detail::PrintOperator(out_, "Aggregation", op.op_); }
void Visit(Function &op) override { detail::PrintOperator(out_, "Function", op.function_name_, op.arguments_); }
void Visit(Reduce &op) override {
detail::PrintOperator(out_, "Reduce", op.accumulator_, op.initializer_, op.identifier_, op.list_, op.expression_);
}
void Visit(Coalesce &op) override { detail::PrintOperator(out_, "Coalesce", op.expressions_); }
void Visit(Extract &op) override { detail::PrintOperator(out_, "Extract", op.identifier_, op.list_, op.expression_); }
void Visit(All &op) override {
detail::PrintOperator(out_, "All", op.identifier_, op.list_expression_, op.where_->expression_);
}
void Visit(Single &op) override {
detail::PrintOperator(out_, "Single", op.identifier_, op.list_expression_, op.where_->expression_);
}
void Visit(Any &op) override {
detail::PrintOperator(out_, "Any", op.identifier_, op.list_expression_, op.where_->expression_);
}
void Visit(None &op) override {
detail::PrintOperator(out_, "None", op.identifier_, op.list_expression_, op.where_->expression_);
}
void Visit(Identifier &op) override { detail::PrintOperator(out_, "Identifier", op.name_); }
void Visit(PrimitiveLiteral &op) override { detail::PrintObject(out_, op.value_); }
void Visit(PropertyLookup &op) override {
detail::PrintOperator(out_, "PropertyLookup", op.expression_, op.property_.name);
}
void Visit(ParameterLookup &op) override { detail::PrintOperator(out_, "ParameterLookup", op.token_position_); }
void Visit(NamedExpression &op) override { detail::PrintOperator(out_, "NamedExpression", op.name_, op.expression_); }
void Visit(RegexMatch &op) override { detail::PrintOperator(out_, "=~", op.string_expr_, op.regex_); }
private:
std::ostream *out_;
};
namespace detail {
inline void PrintObject(std::ostream *out, Expression *expr) {
if (expr) {
ExpressionPrettyPrinter printer{out};
expr->Accept(printer);
} else {
*out << "<null>";
}
}
} // namespace detail
inline void PrintExpression(Expression *expr, std::ostream *out) {
ExpressionPrettyPrinter printer{out};
expr->Accept(printer);
}
inline void PrintExpression(NamedExpression *expr, std::ostream *out) {
ExpressionPrettyPrinter printer{out};
expr->Accept(printer);
}
} // namespace memgraph::expr