Add filter to scan all (#575)
Add several versions of ScanAll with filters. Add helper function to transform an expression into string that can be parsed again once on the storage.
This commit is contained in:
parent
23171e76b6
commit
4cb3b064c4
279
src/expr/ast/pretty_print_ast_to_original_expression.hpp
Normal file
279
src/expr/ast/pretty_print_ast_to_original_expression.hpp
Normal file
@ -0,0 +1,279 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
|
||||
#include "expr/ast.hpp"
|
||||
#include "expr/typed_value.hpp"
|
||||
#include "utils/algorithm.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/string.hpp"
|
||||
|
||||
namespace memgraph::expr {
|
||||
inline constexpr const char *identifier_node_symbol = "MG_SYMBOL_NODE";
|
||||
inline constexpr const char *identifier_edge_symbol = "MG_SYMBOL_EDGE";
|
||||
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
void PrintObject(std::ostream *out, const T &arg) {
|
||||
static_assert(!std::is_convertible<T, Expression *>::value,
|
||||
"This overload shouldn't be called with pointers convertible "
|
||||
"to Expression *. This means your other PrintObject overloads aren't "
|
||||
"being called for certain AST nodes when they should (or perhaps such "
|
||||
"overloads don't exist yet).");
|
||||
*out << arg;
|
||||
}
|
||||
|
||||
inline void PrintObject(std::ostream *out, const std::string &str) { *out << str; }
|
||||
|
||||
inline void PrintObject(std::ostream * /*out*/, Aggregation::Op /*op*/) {
|
||||
throw utils::NotYetImplemented("PrintObject: Aggregation::Op");
|
||||
}
|
||||
|
||||
inline void PrintObject(std::ostream * /*out*/, Expression * /*expr*/);
|
||||
|
||||
inline void PrintObject(std::ostream *out, Identifier *expr) { PrintObject(out, static_cast<Expression *>(expr)); }
|
||||
|
||||
template <typename T>
|
||||
void PrintObject(std::ostream * /*out*/, const std::vector<T> & /*vec*/) {
|
||||
throw utils::NotYetImplemented("PrintObject: vector<T>");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PrintObject(std::ostream * /*out*/, const std::vector<T, utils::Allocator<T>> & /*vec*/) {
|
||||
throw utils::NotYetImplemented("PrintObject: vector<T, utils::Allocator<T>>");
|
||||
}
|
||||
|
||||
template <typename K, typename V>
|
||||
void PrintObject(std::ostream * /*out*/, const std::map<K, V> & /*map*/) {
|
||||
throw utils::NotYetImplemented("PrintObject: map<K, V>");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PrintObject(std::ostream * /*out*/, const utils::pmr::map<utils::pmr::string, T> & /*map*/) {
|
||||
throw utils::NotYetImplemented("PrintObject: map<utils::pmr::string, V>");
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
inline void PrintObject(std::ostream *out, const TypedValueT<T1, T2, T3> &value) {
|
||||
using TypedValue = TypedValueT<T1, T2, T3>;
|
||||
switch (value.type()) {
|
||||
case TypedValue::Type::Null:
|
||||
*out << "null";
|
||||
break;
|
||||
case TypedValue::Type::String:
|
||||
PrintObject(out, value.ValueString());
|
||||
break;
|
||||
case TypedValue::Type::Bool:
|
||||
*out << (value.ValueBool() ? "true" : "false");
|
||||
break;
|
||||
case TypedValue::Type::Int:
|
||||
PrintObject(out, value.ValueInt());
|
||||
break;
|
||||
case TypedValue::Type::Double:
|
||||
PrintObject(out, value.ValueDouble());
|
||||
break;
|
||||
case TypedValue::Type::List:
|
||||
PrintObject(out, value.ValueList());
|
||||
break;
|
||||
case TypedValue::Type::Map:
|
||||
PrintObject(out, value.ValueMap());
|
||||
break;
|
||||
case TypedValue::Type::Date:
|
||||
PrintObject(out, value.ValueDate());
|
||||
break;
|
||||
case TypedValue::Type::Duration:
|
||||
PrintObject(out, value.ValueDuration());
|
||||
break;
|
||||
case TypedValue::Type::LocalTime:
|
||||
PrintObject(out, value.ValueLocalTime());
|
||||
break;
|
||||
case TypedValue::Type::LocalDateTime:
|
||||
PrintObject(out, value.ValueLocalDateTime());
|
||||
break;
|
||||
default:
|
||||
MG_ASSERT(false, "PrintObject(std::ostream *out, const TypedValue &value) should not reach here");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PrintOperatorArgs(const std::string & /*name*/, std::ostream *out, bool with_parenthesis, const T &arg) {
|
||||
PrintObject(out, arg);
|
||||
if (with_parenthesis) {
|
||||
*out << ")";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
void PrintOperatorArgs(const std::string &name, std::ostream *out, bool with_parenthesis, const T &arg,
|
||||
const Ts &...args) {
|
||||
PrintObject(out, arg);
|
||||
*out << " " << name << " ";
|
||||
PrintOperatorArgs(name, out, with_parenthesis, args...);
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
void PrintOperator(const std::string &name, std::ostream *out, bool with_parenthesis, const Ts &...args) {
|
||||
if (with_parenthesis) {
|
||||
*out << "(";
|
||||
}
|
||||
PrintOperatorArgs(name, out, with_parenthesis, args...);
|
||||
}
|
||||
|
||||
// new
|
||||
template <typename T>
|
||||
void PrintOperatorArgs(std::ostream *out, const T &arg) {
|
||||
PrintObject(out, arg);
|
||||
}
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
void PrintOperatorArgs(std::ostream *out, const T &arg, const Ts &...args) {
|
||||
PrintObject(out, arg);
|
||||
PrintOperatorArgs(out, args...);
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
void PrintOperator(std::ostream *out, const Ts &...args) {
|
||||
PrintOperatorArgs(out, args...);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
class ExpressionPrettyPrinter : public ExpressionVisitor<void> {
|
||||
public:
|
||||
explicit ExpressionPrettyPrinter(std::ostream *out) : out_(out) {}
|
||||
|
||||
// Unary operators
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define UNARY_OPERATOR_VISIT(OP_NODE, OP_STR) \
|
||||
/* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
|
||||
void Visit(OP_NODE &op) override { detail::PrintOperator(OP_STR, out_, false /*with_parenthesis*/, op.expression_); }
|
||||
|
||||
UNARY_OPERATOR_VISIT(NotOperator, "Not");
|
||||
UNARY_OPERATOR_VISIT(UnaryPlusOperator, "+");
|
||||
UNARY_OPERATOR_VISIT(UnaryMinusOperator, "-");
|
||||
UNARY_OPERATOR_VISIT(IsNullOperator, "IsNull");
|
||||
|
||||
#undef UNARY_OPERATOR_VISIT
|
||||
|
||||
// Binary operators
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BINARY_OPERATOR_VISIT(OP_NODE, OP_STR) \
|
||||
/* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
|
||||
void Visit(OP_NODE &op) override { \
|
||||
detail::PrintOperator(OP_STR, out_, true /*with_parenthesis*/, op.expression1_, op.expression2_); \
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||||
#define BINARY_OPERATOR_VISIT_NOT_IMPL(OP_NODE, OP_STR) \
|
||||
/* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
|
||||
void Visit(OP_NODE & /*op*/) override { throw utils::NotYetImplemented("OP_NODE"); }
|
||||
|
||||
BINARY_OPERATOR_VISIT(OrOperator, "Or");
|
||||
BINARY_OPERATOR_VISIT(XorOperator, "Xor");
|
||||
BINARY_OPERATOR_VISIT(AndOperator, "And");
|
||||
BINARY_OPERATOR_VISIT(AdditionOperator, "+");
|
||||
BINARY_OPERATOR_VISIT(SubtractionOperator, "-");
|
||||
BINARY_OPERATOR_VISIT(MultiplicationOperator, "*");
|
||||
BINARY_OPERATOR_VISIT(DivisionOperator, "/");
|
||||
BINARY_OPERATOR_VISIT(ModOperator, "%");
|
||||
BINARY_OPERATOR_VISIT(NotEqualOperator, "!=");
|
||||
BINARY_OPERATOR_VISIT(EqualOperator, "=");
|
||||
BINARY_OPERATOR_VISIT(LessOperator, "<");
|
||||
BINARY_OPERATOR_VISIT(GreaterOperator, ">");
|
||||
BINARY_OPERATOR_VISIT(LessEqualOperator, "<=");
|
||||
BINARY_OPERATOR_VISIT(GreaterEqualOperator, ">=");
|
||||
BINARY_OPERATOR_VISIT_NOT_IMPL(InListOperator, "In");
|
||||
BINARY_OPERATOR_VISIT_NOT_IMPL(SubscriptOperator, "Subscript");
|
||||
|
||||
#undef BINARY_OPERATOR_VISIT
|
||||
#undef BINARY_OPERATOR_VISIT_NOT_IMPL
|
||||
|
||||
// Other
|
||||
void Visit(ListSlicingOperator & /*op*/) override { throw utils::NotYetImplemented("ListSlicingOperator"); }
|
||||
|
||||
void Visit(IfOperator & /*op*/) override { throw utils::NotYetImplemented("IfOperator"); }
|
||||
|
||||
void Visit(ListLiteral & /*op*/) override { throw utils::NotYetImplemented("ListLiteral"); }
|
||||
|
||||
void Visit(MapLiteral & /*op*/) override { throw utils::NotYetImplemented("MapLiteral"); }
|
||||
|
||||
void Visit(LabelsTest & /*op*/) override { throw utils::NotYetImplemented("LabelsTest"); }
|
||||
|
||||
void Visit(Aggregation & /*op*/) override { throw utils::NotYetImplemented("Aggregation"); }
|
||||
|
||||
void Visit(Function & /*op*/) override { throw utils::NotYetImplemented("Function"); }
|
||||
|
||||
void Visit(Reduce & /*op*/) override { throw utils::NotYetImplemented("Reduce"); }
|
||||
|
||||
void Visit(Coalesce & /*op*/) override { throw utils::NotYetImplemented("Coalesce"); }
|
||||
|
||||
void Visit(Extract & /*op*/) override { throw utils::NotYetImplemented("Extract"); }
|
||||
|
||||
void Visit(All & /*op*/) override { throw utils::NotYetImplemented("All"); }
|
||||
|
||||
void Visit(Single & /*op*/) override { throw utils::NotYetImplemented("Single"); }
|
||||
|
||||
void Visit(Any & /*op*/) override { throw utils::NotYetImplemented("Any"); }
|
||||
|
||||
void Visit(None & /*op*/) override { throw utils::NotYetImplemented("None"); }
|
||||
|
||||
void Visit(Identifier &op) override {
|
||||
auto is_node = true;
|
||||
auto is_edge = false;
|
||||
auto is_other = false;
|
||||
if (is_node) {
|
||||
detail::PrintOperator(out_, identifier_node_symbol);
|
||||
} else if (is_edge) {
|
||||
detail::PrintOperator(out_, identifier_edge_symbol);
|
||||
} else {
|
||||
MG_ASSERT(is_other);
|
||||
detail::PrintOperator(out_, op.name_);
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(PrimitiveLiteral &op) override { detail::PrintObject(out_, op.value_); }
|
||||
|
||||
void Visit(PropertyLookup &op) override { detail::PrintOperator(out_, op.expression_, ".", op.property_.name); }
|
||||
|
||||
void Visit(ParameterLookup & /*op*/) override { throw utils::NotYetImplemented("ParameterLookup"); }
|
||||
|
||||
void Visit(NamedExpression & /*op*/) override { throw utils::NotYetImplemented("NamedExpression"); }
|
||||
|
||||
void Visit(RegexMatch & /*op*/) override { throw utils::NotYetImplemented("RegexMatch"); }
|
||||
|
||||
private:
|
||||
std::ostream *out_;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
inline void PrintObject(std::ostream *out, Expression *expr) {
|
||||
if (expr) {
|
||||
ExpressionPrettyPrinter printer{out};
|
||||
expr->Accept(printer);
|
||||
} else {
|
||||
*out << "<null>";
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
inline void PrintExpressionToOriginalAndReplaceNodeAndEdgeSymbols(Expression *expr, std::ostream *out) {
|
||||
ExpressionPrettyPrinter printer{out};
|
||||
expr->Accept(printer);
|
||||
}
|
||||
|
||||
inline std::string ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(Expression *expr) {
|
||||
std::ostringstream ss;
|
||||
expr::PrintExpressionToOriginalAndReplaceNodeAndEdgeSymbols(expr, &ss);
|
||||
return ss.str();
|
||||
}
|
||||
} // namespace memgraph::expr
|
@ -26,6 +26,7 @@
|
||||
#include <cppitertools/chain.hpp>
|
||||
#include <cppitertools/imap.hpp>
|
||||
|
||||
#include "expr/ast/pretty_print_ast_to_original_expression.hpp"
|
||||
#include "expr/exceptions.hpp"
|
||||
#include "query/exceptions.hpp"
|
||||
#include "query/v2/accessors.hpp"
|
||||
@ -332,17 +333,111 @@ class ScanAllCursor : public Cursor {
|
||||
msgs::ExecutionState<msgs::ScanVerticesRequest> request_state;
|
||||
};
|
||||
|
||||
class DistributedScanAllAndFilterCursor : public Cursor {
|
||||
public:
|
||||
explicit DistributedScanAllAndFilterCursor(
|
||||
Symbol output_symbol, UniqueCursorPtr input_cursor, const char *op_name,
|
||||
std::optional<storage::v3::LabelId> label,
|
||||
std::optional<std::pair<storage::v3::PropertyId, Expression *>> property_expression_pair,
|
||||
std::optional<std::vector<Expression *>> filter_expressions)
|
||||
: output_symbol_(output_symbol),
|
||||
input_cursor_(std::move(input_cursor)),
|
||||
op_name_(op_name),
|
||||
label_(label),
|
||||
property_expression_pair_(property_expression_pair),
|
||||
filter_expressions_(filter_expressions) {
|
||||
ResetExecutionState();
|
||||
}
|
||||
|
||||
using VertexAccessor = accessors::VertexAccessor;
|
||||
|
||||
bool MakeRequest(msgs::ShardRequestManagerInterface &shard_manager) {
|
||||
current_batch = shard_manager.Request(request_state_);
|
||||
current_vertex_it = current_batch.begin();
|
||||
return !current_batch.empty();
|
||||
}
|
||||
|
||||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||||
SCOPED_PROFILE_OP(op_name_);
|
||||
auto &shard_manager = *context.shard_request_manager;
|
||||
if (MustAbort(context)) {
|
||||
throw HintedAbortError();
|
||||
}
|
||||
using State = msgs::ExecutionState<msgs::ScanVerticesRequest>;
|
||||
|
||||
if (request_state_.state == State::INITIALIZING) {
|
||||
if (!input_cursor_->Pull(frame, context)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (current_vertex_it == current_batch.end()) {
|
||||
if (request_state_.state == State::COMPLETED || !MakeRequest(shard_manager)) {
|
||||
ResetExecutionState();
|
||||
return Pull(frame, context);
|
||||
}
|
||||
}
|
||||
|
||||
frame[output_symbol_] = TypedValue(std::move(*current_vertex_it));
|
||||
++current_vertex_it;
|
||||
return true;
|
||||
}
|
||||
|
||||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||||
|
||||
void ResetExecutionState() {
|
||||
current_batch.clear();
|
||||
current_vertex_it = current_batch.end();
|
||||
request_state_ = msgs::ExecutionState<msgs::ScanVerticesRequest>{};
|
||||
|
||||
auto request = msgs::ScanVerticesRequest{};
|
||||
if (label_.has_value()) {
|
||||
request.label = msgs::Label{.id = label_.value()};
|
||||
}
|
||||
if (property_expression_pair_.has_value()) {
|
||||
request.property_expression_pair = std::make_pair(
|
||||
property_expression_pair_.value().first,
|
||||
expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(property_expression_pair_.value().second));
|
||||
}
|
||||
if (filter_expressions_.has_value()) {
|
||||
auto res = std::vector<std::string>{};
|
||||
res.reserve(filter_expressions_->size());
|
||||
std::transform(filter_expressions_->begin(), filter_expressions_->end(), std::back_inserter(res),
|
||||
[](auto &filter) { return expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(filter); });
|
||||
|
||||
request.filter_expressions = res;
|
||||
}
|
||||
request_state_.requests.emplace_back(request);
|
||||
}
|
||||
|
||||
void Reset() override {
|
||||
input_cursor_->Reset();
|
||||
ResetExecutionState();
|
||||
}
|
||||
|
||||
private:
|
||||
const Symbol output_symbol_;
|
||||
const UniqueCursorPtr input_cursor_;
|
||||
const char *op_name_;
|
||||
std::vector<VertexAccessor> current_batch;
|
||||
std::vector<VertexAccessor>::iterator current_vertex_it;
|
||||
msgs::ExecutionState<msgs::ScanVerticesRequest> request_state_;
|
||||
std::optional<storage::v3::LabelId> label_;
|
||||
std::optional<std::pair<storage::v3::PropertyId, Expression *>> property_expression_pair_;
|
||||
std::optional<std::vector<Expression *>> filter_expressions_;
|
||||
};
|
||||
|
||||
ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::v3::View view)
|
||||
: input_(input ? input : std::make_shared<Once>()), output_symbol_(output_symbol), view_(view) {}
|
||||
|
||||
ACCEPT_WITH_INPUT(ScanAll)
|
||||
|
||||
class DistributedScanAllCursor;
|
||||
|
||||
UniqueCursorPtr ScanAll::MakeCursor(utils::MemoryResource *mem) const {
|
||||
EventCounter::IncrementCounter(EventCounter::ScanAllOperator);
|
||||
|
||||
return MakeUniqueCursorPtr<DistributedScanAllCursor>(mem, output_symbol_, input_->MakeCursor(mem), "ScanAll");
|
||||
return MakeUniqueCursorPtr<DistributedScanAllAndFilterCursor>(
|
||||
mem, output_symbol_, input_->MakeCursor(mem), "ScanAll", std::nullopt /*label*/,
|
||||
std::nullopt /*property_expression_pair*/, std::nullopt /*filter_expressions*/);
|
||||
}
|
||||
|
||||
std::vector<Symbol> ScanAll::ModifiedSymbols(const SymbolTable &table) const {
|
||||
@ -357,10 +452,12 @@ ScanAllByLabel::ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, Sy
|
||||
|
||||
ACCEPT_WITH_INPUT(ScanAllByLabel)
|
||||
|
||||
UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource * /*mem*/) const {
|
||||
UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource *mem) const {
|
||||
EventCounter::IncrementCounter(EventCounter::ScanAllByLabelOperator);
|
||||
|
||||
throw QueryRuntimeException("ScanAllByLabel is not supported");
|
||||
return MakeUniqueCursorPtr<DistributedScanAllAndFilterCursor>(
|
||||
mem, output_symbol_, input_->MakeCursor(mem), "ScanAllByLabel", label_, std::nullopt /*property_expression_pair*/,
|
||||
std::nullopt /*filter_expressions*/);
|
||||
}
|
||||
|
||||
// TODO(buda): Implement ScanAllByLabelProperty operator to iterate over
|
||||
@ -404,10 +501,12 @@ ScanAllByLabelPropertyValue::ScanAllByLabelPropertyValue(const std::shared_ptr<L
|
||||
|
||||
ACCEPT_WITH_INPUT(ScanAllByLabelPropertyValue)
|
||||
|
||||
UniqueCursorPtr ScanAllByLabelPropertyValue::MakeCursor(utils::MemoryResource * /*mem*/) const {
|
||||
UniqueCursorPtr ScanAllByLabelPropertyValue::MakeCursor(utils::MemoryResource *mem) const {
|
||||
EventCounter::IncrementCounter(EventCounter::ScanAllByLabelPropertyValueOperator);
|
||||
|
||||
throw QueryRuntimeException("ScanAllByLabelPropertyValue is not supported");
|
||||
return MakeUniqueCursorPtr<DistributedScanAllAndFilterCursor>(
|
||||
mem, output_symbol_, input_->MakeCursor(mem), "ScanAllByLabelPropertyValue", label_,
|
||||
std::make_pair(property_, expression_), std::nullopt /*filter_expressions*/);
|
||||
}
|
||||
|
||||
ScanAllByLabelProperty::ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol,
|
||||
@ -432,6 +531,7 @@ ACCEPT_WITH_INPUT(ScanAllById)
|
||||
|
||||
UniqueCursorPtr ScanAllById::MakeCursor(utils::MemoryResource *mem) const {
|
||||
EventCounter::IncrementCounter(EventCounter::ScanAllByIdOperator);
|
||||
// TODO Reimplement when we have reliable conversion between hash value and pk
|
||||
auto vertices = [](Frame & /*frame*/, ExecutionContext & /*context*/) -> std::optional<std::vector<VertexAccessor>> {
|
||||
return std::nullopt;
|
||||
};
|
||||
@ -2256,62 +2356,4 @@ bool Foreach::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||||
return visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
class DistributedScanAllCursor : public Cursor {
|
||||
public:
|
||||
explicit DistributedScanAllCursor(Symbol output_symbol, UniqueCursorPtr input_cursor, const char *op_name)
|
||||
: output_symbol_(output_symbol), input_cursor_(std::move(input_cursor)), op_name_(op_name) {}
|
||||
|
||||
using VertexAccessor = accessors::VertexAccessor;
|
||||
|
||||
bool MakeRequest(msgs::ShardRequestManagerInterface &shard_manager) {
|
||||
// TODO(antaljanosbenjamin) Use real label
|
||||
request_state_.label = "label";
|
||||
current_batch = shard_manager.Request(request_state_);
|
||||
current_vertex_it = current_batch.begin();
|
||||
return !current_batch.empty();
|
||||
}
|
||||
|
||||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||||
SCOPED_PROFILE_OP(op_name_);
|
||||
auto &shard_manager = *context.shard_request_manager;
|
||||
if (MustAbort(context)) throw HintedAbortError();
|
||||
using State = msgs::ExecutionState<msgs::ScanVerticesRequest>;
|
||||
|
||||
if (request_state_.state == State::INITIALIZING) {
|
||||
if (!input_cursor_->Pull(frame, context)) return false;
|
||||
}
|
||||
|
||||
if (current_vertex_it == current_batch.end()) {
|
||||
if (request_state_.state == State::COMPLETED || !MakeRequest(shard_manager)) {
|
||||
ResetExecutionState();
|
||||
return Pull(frame, context);
|
||||
}
|
||||
}
|
||||
|
||||
frame[output_symbol_] = TypedValue(std::move(*current_vertex_it));
|
||||
++current_vertex_it;
|
||||
return true;
|
||||
}
|
||||
|
||||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||||
|
||||
void ResetExecutionState() {
|
||||
current_batch.clear();
|
||||
current_vertex_it = current_batch.end();
|
||||
request_state_ = msgs::ExecutionState<msgs::ScanVerticesRequest>{};
|
||||
}
|
||||
|
||||
void Reset() override {
|
||||
input_cursor_->Reset();
|
||||
ResetExecutionState();
|
||||
}
|
||||
|
||||
private:
|
||||
const Symbol output_symbol_;
|
||||
const UniqueCursorPtr input_cursor_;
|
||||
const char *op_name_;
|
||||
std::vector<VertexAccessor> current_batch;
|
||||
decltype(std::vector<VertexAccessor>().begin()) current_vertex_it;
|
||||
msgs::ExecutionState<msgs::ScanVerticesRequest> request_state_;
|
||||
};
|
||||
} // namespace memgraph::query::v2::plan
|
||||
|
@ -380,9 +380,12 @@ struct ScanVerticesRequest {
|
||||
Hlc transaction_id;
|
||||
VertexId start_id;
|
||||
std::optional<std::vector<PropertyId>> props_to_return;
|
||||
std::optional<std::vector<std::string>> filter_expressions;
|
||||
std::optional<size_t> batch_limit;
|
||||
StorageView storage_view{StorageView::NEW};
|
||||
|
||||
std::optional<Label> label;
|
||||
std::optional<std::pair<PropertyId, std::string>> property_expression_pair;
|
||||
std::optional<std::vector<std::string>> filter_expressions;
|
||||
};
|
||||
|
||||
struct ScanResultRow {
|
||||
|
@ -433,3 +433,6 @@ target_link_libraries(${test_prefix}local_transport mg-io)
|
||||
# Test MachineManager with LocalTransport
|
||||
add_unit_test(machine_manager.cpp)
|
||||
target_link_libraries(${test_prefix}machine_manager mg-io mg-coordinator mg-storage-v3 mg-query-v2)
|
||||
|
||||
add_unit_test(pretty_print_ast_to_original_expression_test.cpp)
|
||||
target_link_libraries(${test_prefix}pretty_print_ast_to_original_expression_test mg-io mg-expr mg-query-v2)
|
||||
|
94
tests/unit/pretty_print_ast_to_original_expression_test.cpp
Normal file
94
tests/unit/pretty_print_ast_to_original_expression_test.cpp
Normal file
@ -0,0 +1,94 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <chrono>
|
||||
#include <limits>
|
||||
#include <thread>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common/types.hpp"
|
||||
#include "exceptions.hpp"
|
||||
#include "parser/opencypher/parser.hpp"
|
||||
#include "query/v2/bindings/cypher_main_visitor.hpp"
|
||||
#include "query/v2/bindings/eval.hpp"
|
||||
#include "query/v2/bindings/frame.hpp"
|
||||
#include "query/v2/bindings/symbol_generator.hpp"
|
||||
#include "query/v2/bindings/symbol_table.hpp"
|
||||
#include "query/v2/bindings/typed_value.hpp"
|
||||
#include "query/v2/db_accessor.hpp"
|
||||
#include "query/v2/frontend/ast/ast.hpp"
|
||||
#include "utils/string.hpp"
|
||||
|
||||
#include "expr/ast/pretty_print_ast_to_original_expression.hpp"
|
||||
|
||||
namespace memgraph::query::v2::test {
|
||||
|
||||
class ExpressiontoStringTest : public ::testing::TestWithParam<std::pair<std::string, std::string>> {
|
||||
protected:
|
||||
AstStorage storage;
|
||||
};
|
||||
|
||||
TEST_P(ExpressiontoStringTest, Example) {
|
||||
const auto [original_expression, expected_expression] = GetParam();
|
||||
|
||||
memgraph::frontend::opencypher::Parser<frontend::opencypher::ParserOpTag::EXPRESSION> parser(original_expression);
|
||||
expr::ParsingContext pc;
|
||||
CypherMainVisitor visitor(pc, &storage);
|
||||
|
||||
auto *ast = parser.tree();
|
||||
auto expression = visitor.visit(ast);
|
||||
|
||||
const auto rewritten_expression =
|
||||
expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(std::any_cast<Expression *>(expression));
|
||||
|
||||
// We check that the expression is what we expect
|
||||
EXPECT_EQ(rewritten_expression, expected_expression);
|
||||
|
||||
// We check that the rewritten expression can be parsed again
|
||||
memgraph::frontend::opencypher::Parser<frontend::opencypher::ParserOpTag::EXPRESSION> parser2(rewritten_expression);
|
||||
expr::ParsingContext pc2;
|
||||
CypherMainVisitor visitor2(pc2, &storage);
|
||||
|
||||
auto *ast2 = parser2.tree();
|
||||
auto expression2 = visitor2.visit(ast2);
|
||||
const auto rewritten_expression2 =
|
||||
expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(std::any_cast<Expression *>(expression));
|
||||
|
||||
// We check that the re-written expression from the re-written expression is exactly the same
|
||||
EXPECT_EQ(rewritten_expression, rewritten_expression2);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
PARAMETER, ExpressiontoStringTest,
|
||||
::testing::Values(
|
||||
std::make_pair(std::string("2 / 1"), std::string("(2 / 1)")),
|
||||
std::make_pair(std::string("2 + 1 + 5 + 2"), std::string("(((2 + 1) + 5) + 2)")),
|
||||
std::make_pair(std::string("2 + 1 * 5 + 2"), std::string("((2 + (1 * 5)) + 2)")),
|
||||
std::make_pair(std::string("2 + 1 * (5 + 2)"), std::string("(2 + (1 * (5 + 2)))")),
|
||||
std::make_pair(std::string("n"), std::string("MG_SYMBOL_NODE")),
|
||||
std::make_pair(std::string("n.property1"), std::string("MG_SYMBOL_NODE.property1")),
|
||||
std::make_pair(std::string("n.property1 > 3"), std::string("(MG_SYMBOL_NODE.property1 > 3)")),
|
||||
std::make_pair(std::string("n.property1 != n.property2"),
|
||||
std::string("(MG_SYMBOL_NODE.property1 != MG_SYMBOL_NODE.property2)")),
|
||||
std::make_pair(std::string("n And n"), std::string("(MG_SYMBOL_NODE And MG_SYMBOL_NODE)")),
|
||||
std::make_pair(std::string("n.property1 > 3 And n.property + 7 < 10"),
|
||||
std::string("((MG_SYMBOL_NODE.property1 > 3) And ((MG_SYMBOL_NODE.property + 7) < 10))")),
|
||||
std::make_pair(std::string("MG_SYMBOL_NODE.property1 > 3 And (MG_SYMBOL_NODE.property + 7 < 10 Or "
|
||||
"MG_SYMBOL_NODE.property3 = true)"),
|
||||
std::string("((MG_SYMBOL_NODE.property1 > 3) And (((MG_SYMBOL_NODE.property + 7) < 10) Or "
|
||||
"(MG_SYMBOL_NODE.property3 = true)))")),
|
||||
std::make_pair(std::string("(MG_SYMBOL_NODE.property1 > 3 Or MG_SYMBOL_NODE.property + 7 < 10) And "
|
||||
"MG_SYMBOL_NODE.property3 = true"),
|
||||
std::string("(((MG_SYMBOL_NODE.property1 > 3) Or ((MG_SYMBOL_NODE.property + 7) < 10)) And "
|
||||
"(MG_SYMBOL_NODE.property3 = true))"))));
|
||||
|
||||
} // namespace memgraph::query::v2::test
|
Loading…
Reference in New Issue
Block a user