Query::Plan::Unwind added

Reviewers: teon.banek, mislav.bradac, buda

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D328
This commit is contained in:
florijan 2017-04-28 15:14:28 +02:00
parent dc5ef33c8f
commit ffc977dbfc
3 changed files with 133 additions and 12 deletions

View File

@ -1537,4 +1537,48 @@ void Optional::OptionalCursor::Reset() {
pull_input_ = true;
}
Unwind::Unwind(const std::shared_ptr<LogicalOperator> &input,
Expression *input_expression, Symbol output_symbol)
: input_(input ? input : std::make_shared<Once>()),
input_expression_(input_expression),
output_symbol_(output_symbol) {}
ACCEPT_WITH_INPUT(Unwind)
std::unique_ptr<Cursor> Unwind::MakeCursor(GraphDbAccessor &db) {
return std::make_unique<UnwindCursor>(*this, db);
}
Unwind::UnwindCursor::UnwindCursor(Unwind &self, GraphDbAccessor &db)
: self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {}
bool Unwind::UnwindCursor::Pull(Frame &frame, const SymbolTable &symbol_table) {
// if we reached the end of our list of values
// pull from the input
if (input_value_it_ == input_value_.end()) {
if (!input_cursor_->Pull(frame, symbol_table)) return false;
// successful pull from input, initialize value and iterator
ExpressionEvaluator evaluator(frame, symbol_table, db_);
self_.input_expression_->Accept(evaluator);
TypedValue input_value = evaluator.PopBack();
if (input_value.type() != TypedValue::Type::List)
throw QueryRuntimeException("UNWIND only accepts list values");
input_value_ = input_value.Value<std::vector<TypedValue>>();
input_value_it_ = input_value_.begin();
}
// if we reached the end of our list of values goto back to top
if (input_value_it_ == input_value_.end()) return Pull(frame, symbol_table);
frame[self_.output_symbol_] = *input_value_it_++;
return true;
}
void Unwind::UnwindCursor::Reset() {
input_cursor_->Reset();
input_value_.clear();
input_value_it_ = input_value_.end();
}
} // namespace query::plan

View File

@ -75,16 +75,15 @@ class Limit;
class OrderBy;
class Merge;
class Optional;
class Unwind;
/** @brief Base class for visitors of @c LogicalOperator class hierarchy. */
using LogicalOperatorVisitor =
::utils::Visitor<Once, CreateNode, CreateExpand, ScanAll, Expand,
NodeFilter, EdgeFilter, Filter, Produce, Delete,
SetProperty, SetProperties, SetLabels, RemoveProperty,
RemoveLabels, ExpandUniquenessFilter<VertexAccessor>,
ExpandUniquenessFilter<EdgeAccessor>, Accumulate,
AdvanceCommand, Aggregate, Skip, Limit, OrderBy, Merge,
Optional>;
using LogicalOperatorVisitor = ::utils::Visitor<
Once, CreateNode, CreateExpand, ScanAll, Expand, NodeFilter, EdgeFilter,
Filter, Produce, Delete, SetProperty, SetProperties, SetLabels,
RemoveProperty, RemoveLabels, ExpandUniquenessFilter<VertexAccessor>,
ExpandUniquenessFilter<EdgeAccessor>, Accumulate, AdvanceCommand, Aggregate,
Skip, Limit, OrderBy, Merge, Optional, Unwind>;
/** @brief Base class for logical operators.
*
@ -1259,5 +1258,40 @@ class Optional : public LogicalOperator {
};
};
/**
* Takes a list TypedValue as it's input and yields each
* element as it's output.
*
* Input is optional (unwind can be the first clause in a query).
*/
class Unwind : public LogicalOperator {
public:
Unwind(const std::shared_ptr<LogicalOperator> &input,
Expression *input_expression_, Symbol output_symbol);
void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
const std::shared_ptr<LogicalOperator> input_;
Expression *input_expression_;
const Symbol output_symbol_;
class UnwindCursor : public Cursor {
public:
UnwindCursor(Unwind &self, GraphDbAccessor &db);
bool Pull(Frame &frame, const SymbolTable &symbol_table) override;
void Reset() override;
private:
const Unwind &self_;
GraphDbAccessor &db_;
const std::unique_ptr<Cursor> input_cursor_;
// typed values we are unwinding and yielding
std::vector<TypedValue> input_value_;
// current position in input_value_
std::vector<TypedValue>::iterator input_value_it_ = input_value_.end();
};
};
} // namespace plan
} // namespace query

View File

@ -305,8 +305,7 @@ TEST(QueryPlan, AggregateNoInput) {
symbol_table[*output->expression_] = symbol_table.CreateSymbol("two");
auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two},
{Aggregation::Op::COUNT},
{}, {});
{Aggregation::Op::COUNT}, {}, {});
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
EXPECT_EQ(1, results.size());
EXPECT_EQ(1, results[0].size());
@ -412,8 +411,10 @@ TEST(QueryPlan, AggregateFirstValueTypes) {
aggregate(n_prop_string, Aggregation::Op::COUNT);
aggregate(n_prop_string, Aggregation::Op::MIN);
aggregate(n_prop_string, Aggregation::Op::MAX);
EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::AVG), TypedValueException);
EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::SUM), TypedValueException);
EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::AVG),
TypedValueException);
EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::SUM),
TypedValueException);
// on ints nothing fails
aggregate(n_prop_int, Aggregation::Op::COUNT);
@ -476,3 +477,45 @@ TEST(QueryPlan, AggregateTypes) {
EXPECT_THROW(aggregate(n_p2, Aggregation::Op::AVG), TypedValueException);
EXPECT_THROW(aggregate(n_p2, Aggregation::Op::SUM), TypedValueException);
}
TEST(QueryPlan, Unwind) {
Dbms dbms;
auto dba = dbms.active();
AstTreeStorage storage;
SymbolTable symbol_table;
// UNWIND [ [1, true, "x"], [], ["bla"] ] AS x UNWIND x as y RETURN x, y
auto input_expr = storage.Create<PrimitiveLiteral>(std::vector<TypedValue>{
std::vector<TypedValue>{1, true, "x"}, std::vector<TypedValue>{},
std::vector<TypedValue>{"bla"}});
auto x = symbol_table.CreateSymbol("x");
auto unwind_0 = std::make_shared<Unwind>(nullptr, input_expr, x);
auto x_expr = IDENT("x");
symbol_table[*x_expr] = x;
auto y = symbol_table.CreateSymbol("y");
auto unwind_1 = std::make_shared<Unwind>(unwind_0, x_expr, y);
auto x_ne = NEXPR("x", x_expr);
symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne");
auto y_ne = NEXPR("y", IDENT("y"));
symbol_table[*y_ne->expression_] = y;
symbol_table[*y_ne] = symbol_table.CreateSymbol("y_ne");
auto produce = MakeProduce(unwind_1, x_ne, y_ne);
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
ASSERT_EQ(4, results.size());
const std::vector<int> expected_x_card{3, 3, 3, 1};
auto expected_x_card_it = expected_x_card.begin();
const std::vector<TypedValue> expected_y{1, true, "x", "bla"};
auto expected_y_it = expected_y.begin();
for (const auto &row : results) {
ASSERT_EQ(2, row.size());
ASSERT_EQ(row[0].type(), TypedValue::Type::List);
EXPECT_EQ(row[0].Value<std::vector<TypedValue>>().size(),
*expected_x_card_it);
EXPECT_EQ(row[1].type(), expected_y_it->type());
expected_x_card_it++;
expected_y_it++;
}
}