Query::Plan::Unwind added
Reviewers: teon.banek, mislav.bradac, buda Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D328
This commit is contained in:
parent
dc5ef33c8f
commit
ffc977dbfc
@ -1537,4 +1537,48 @@ void Optional::OptionalCursor::Reset() {
|
||||
pull_input_ = true;
|
||||
}
|
||||
|
||||
Unwind::Unwind(const std::shared_ptr<LogicalOperator> &input,
|
||||
Expression *input_expression, Symbol output_symbol)
|
||||
: input_(input ? input : std::make_shared<Once>()),
|
||||
input_expression_(input_expression),
|
||||
output_symbol_(output_symbol) {}
|
||||
|
||||
ACCEPT_WITH_INPUT(Unwind)
|
||||
|
||||
std::unique_ptr<Cursor> Unwind::MakeCursor(GraphDbAccessor &db) {
|
||||
return std::make_unique<UnwindCursor>(*this, db);
|
||||
}
|
||||
|
||||
Unwind::UnwindCursor::UnwindCursor(Unwind &self, GraphDbAccessor &db)
|
||||
: self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {}
|
||||
|
||||
bool Unwind::UnwindCursor::Pull(Frame &frame, const SymbolTable &symbol_table) {
|
||||
// if we reached the end of our list of values
|
||||
// pull from the input
|
||||
if (input_value_it_ == input_value_.end()) {
|
||||
if (!input_cursor_->Pull(frame, symbol_table)) return false;
|
||||
|
||||
// successful pull from input, initialize value and iterator
|
||||
ExpressionEvaluator evaluator(frame, symbol_table, db_);
|
||||
self_.input_expression_->Accept(evaluator);
|
||||
TypedValue input_value = evaluator.PopBack();
|
||||
if (input_value.type() != TypedValue::Type::List)
|
||||
throw QueryRuntimeException("UNWIND only accepts list values");
|
||||
input_value_ = input_value.Value<std::vector<TypedValue>>();
|
||||
input_value_it_ = input_value_.begin();
|
||||
}
|
||||
|
||||
// if we reached the end of our list of values goto back to top
|
||||
if (input_value_it_ == input_value_.end()) return Pull(frame, symbol_table);
|
||||
|
||||
frame[self_.output_symbol_] = *input_value_it_++;
|
||||
return true;
|
||||
}
|
||||
|
||||
void Unwind::UnwindCursor::Reset() {
|
||||
input_cursor_->Reset();
|
||||
input_value_.clear();
|
||||
input_value_it_ = input_value_.end();
|
||||
}
|
||||
|
||||
} // namespace query::plan
|
||||
|
@ -75,16 +75,15 @@ class Limit;
|
||||
class OrderBy;
|
||||
class Merge;
|
||||
class Optional;
|
||||
class Unwind;
|
||||
|
||||
/** @brief Base class for visitors of @c LogicalOperator class hierarchy. */
|
||||
using LogicalOperatorVisitor =
|
||||
::utils::Visitor<Once, CreateNode, CreateExpand, ScanAll, Expand,
|
||||
NodeFilter, EdgeFilter, Filter, Produce, Delete,
|
||||
SetProperty, SetProperties, SetLabels, RemoveProperty,
|
||||
RemoveLabels, ExpandUniquenessFilter<VertexAccessor>,
|
||||
ExpandUniquenessFilter<EdgeAccessor>, Accumulate,
|
||||
AdvanceCommand, Aggregate, Skip, Limit, OrderBy, Merge,
|
||||
Optional>;
|
||||
using LogicalOperatorVisitor = ::utils::Visitor<
|
||||
Once, CreateNode, CreateExpand, ScanAll, Expand, NodeFilter, EdgeFilter,
|
||||
Filter, Produce, Delete, SetProperty, SetProperties, SetLabels,
|
||||
RemoveProperty, RemoveLabels, ExpandUniquenessFilter<VertexAccessor>,
|
||||
ExpandUniquenessFilter<EdgeAccessor>, Accumulate, AdvanceCommand, Aggregate,
|
||||
Skip, Limit, OrderBy, Merge, Optional, Unwind>;
|
||||
|
||||
/** @brief Base class for logical operators.
|
||||
*
|
||||
@ -1259,5 +1258,40 @@ class Optional : public LogicalOperator {
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Takes a list TypedValue as it's input and yields each
|
||||
* element as it's output.
|
||||
*
|
||||
* Input is optional (unwind can be the first clause in a query).
|
||||
*/
|
||||
class Unwind : public LogicalOperator {
|
||||
public:
|
||||
Unwind(const std::shared_ptr<LogicalOperator> &input,
|
||||
Expression *input_expression_, Symbol output_symbol);
|
||||
void Accept(LogicalOperatorVisitor &visitor) override;
|
||||
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
|
||||
|
||||
private:
|
||||
const std::shared_ptr<LogicalOperator> input_;
|
||||
Expression *input_expression_;
|
||||
const Symbol output_symbol_;
|
||||
|
||||
class UnwindCursor : public Cursor {
|
||||
public:
|
||||
UnwindCursor(Unwind &self, GraphDbAccessor &db);
|
||||
bool Pull(Frame &frame, const SymbolTable &symbol_table) override;
|
||||
void Reset() override;
|
||||
|
||||
private:
|
||||
const Unwind &self_;
|
||||
GraphDbAccessor &db_;
|
||||
const std::unique_ptr<Cursor> input_cursor_;
|
||||
// typed values we are unwinding and yielding
|
||||
std::vector<TypedValue> input_value_;
|
||||
// current position in input_value_
|
||||
std::vector<TypedValue>::iterator input_value_it_ = input_value_.end();
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace plan
|
||||
} // namespace query
|
||||
|
@ -305,8 +305,7 @@ TEST(QueryPlan, AggregateNoInput) {
|
||||
symbol_table[*output->expression_] = symbol_table.CreateSymbol("two");
|
||||
|
||||
auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two},
|
||||
{Aggregation::Op::COUNT},
|
||||
{}, {});
|
||||
{Aggregation::Op::COUNT}, {}, {});
|
||||
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
|
||||
EXPECT_EQ(1, results.size());
|
||||
EXPECT_EQ(1, results[0].size());
|
||||
@ -412,8 +411,10 @@ TEST(QueryPlan, AggregateFirstValueTypes) {
|
||||
aggregate(n_prop_string, Aggregation::Op::COUNT);
|
||||
aggregate(n_prop_string, Aggregation::Op::MIN);
|
||||
aggregate(n_prop_string, Aggregation::Op::MAX);
|
||||
EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::AVG), TypedValueException);
|
||||
EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::SUM), TypedValueException);
|
||||
EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::AVG),
|
||||
TypedValueException);
|
||||
EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::SUM),
|
||||
TypedValueException);
|
||||
|
||||
// on ints nothing fails
|
||||
aggregate(n_prop_int, Aggregation::Op::COUNT);
|
||||
@ -476,3 +477,45 @@ TEST(QueryPlan, AggregateTypes) {
|
||||
EXPECT_THROW(aggregate(n_p2, Aggregation::Op::AVG), TypedValueException);
|
||||
EXPECT_THROW(aggregate(n_p2, Aggregation::Op::SUM), TypedValueException);
|
||||
}
|
||||
|
||||
TEST(QueryPlan, Unwind) {
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
AstTreeStorage storage;
|
||||
SymbolTable symbol_table;
|
||||
|
||||
// UNWIND [ [1, true, "x"], [], ["bla"] ] AS x UNWIND x as y RETURN x, y
|
||||
auto input_expr = storage.Create<PrimitiveLiteral>(std::vector<TypedValue>{
|
||||
std::vector<TypedValue>{1, true, "x"}, std::vector<TypedValue>{},
|
||||
std::vector<TypedValue>{"bla"}});
|
||||
|
||||
auto x = symbol_table.CreateSymbol("x");
|
||||
auto unwind_0 = std::make_shared<Unwind>(nullptr, input_expr, x);
|
||||
auto x_expr = IDENT("x");
|
||||
symbol_table[*x_expr] = x;
|
||||
auto y = symbol_table.CreateSymbol("y");
|
||||
auto unwind_1 = std::make_shared<Unwind>(unwind_0, x_expr, y);
|
||||
|
||||
auto x_ne = NEXPR("x", x_expr);
|
||||
symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne");
|
||||
auto y_ne = NEXPR("y", IDENT("y"));
|
||||
symbol_table[*y_ne->expression_] = y;
|
||||
symbol_table[*y_ne] = symbol_table.CreateSymbol("y_ne");
|
||||
auto produce = MakeProduce(unwind_1, x_ne, y_ne);
|
||||
|
||||
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
|
||||
ASSERT_EQ(4, results.size());
|
||||
const std::vector<int> expected_x_card{3, 3, 3, 1};
|
||||
auto expected_x_card_it = expected_x_card.begin();
|
||||
const std::vector<TypedValue> expected_y{1, true, "x", "bla"};
|
||||
auto expected_y_it = expected_y.begin();
|
||||
for (const auto &row : results) {
|
||||
ASSERT_EQ(2, row.size());
|
||||
ASSERT_EQ(row[0].type(), TypedValue::Type::List);
|
||||
EXPECT_EQ(row[0].Value<std::vector<TypedValue>>().size(),
|
||||
*expected_x_card_it);
|
||||
EXPECT_EQ(row[1].type(), expected_y_it->type());
|
||||
expected_x_card_it++;
|
||||
expected_y_it++;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user