Query::Plan - Skip and Limit added

Reviewers: mislav.bradac, teon.banek

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D292
This commit is contained in:
florijan 2017-04-18 15:19:42 +02:00
parent 157327de48
commit 541c3f0af7
5 changed files with 282 additions and 2 deletions

View File

@ -658,10 +658,14 @@ class Return : public Clause {
for (auto &expr : named_expressions_) { for (auto &expr : named_expressions_) {
expr->Accept(visitor); expr->Accept(visitor);
} }
if (skip_) skip_->Accept(visitor);
if (limit_) limit_->Accept(visitor);
visitor.PostVisit(*this); visitor.PostVisit(*this);
} }
} }
std::vector<NamedExpression *> named_expressions_; std::vector<NamedExpression *> named_expressions_;
Expression *skip_ = nullptr;
Expression *limit_ = nullptr;
protected: protected:
Return(int uid) : Clause(uid) {} Return(int uid) : Clause(uid) {}
@ -678,6 +682,8 @@ class With : public Clause {
expr->Accept(visitor); expr->Accept(visitor);
} }
if (where_) where_->Accept(visitor); if (where_) where_->Accept(visitor);
if (skip_) skip_->Accept(visitor);
if (limit_) limit_->Accept(visitor);
visitor.PostVisit(*this); visitor.PostVisit(*this);
} }
} }
@ -685,6 +691,8 @@ class With : public Clause {
bool distinct_ = false; bool distinct_ = false;
std::vector<NamedExpression *> named_expressions_; std::vector<NamedExpression *> named_expressions_;
Where *where_ = nullptr; Where *where_ = nullptr;
Expression *skip_ = nullptr;
Expression *limit_ = nullptr;
protected: protected:
With(int uid) : Clause(uid) {} With(int uid) : Clause(uid) {}

View File

@ -1198,5 +1198,92 @@ bool Aggregate::AggregateCursor::TypedValueListEqual::operator()(
TypedValue::BoolEqual{}); TypedValue::BoolEqual{});
} }
Skip::Skip(const std::shared_ptr<LogicalOperator> &input,
Expression *expression)
: input_(input), expression_(expression) {}
void Skip::Accept(LogicalOperatorVisitor &visitor) {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
input_->Accept(visitor);
visitor.PostVisit(*this);
}
}
std::unique_ptr<Cursor> Skip::MakeCursor(GraphDbAccessor &db) {
return std::make_unique<SkipCursor>(*this, db);
}
Skip::SkipCursor::SkipCursor(Skip &self, GraphDbAccessor &db)
: self_(self), input_cursor_(self_.input_->MakeCursor(db)) {}
bool Skip::SkipCursor::Pull(Frame &frame, const SymbolTable &symbol_table) {
while (input_cursor_->Pull(frame, symbol_table)) {
if (to_skip_ == -1) {
// first successful pull from the input
// evaluate the skip expression
ExpressionEvaluator evaluator(frame, symbol_table);
self_.expression_->Accept(evaluator);
TypedValue to_skip = evaluator.PopBack();
if (to_skip.type() != TypedValue::Type::Int)
throw QueryRuntimeException("Result of SKIP expression must be an int");
to_skip_ = to_skip.Value<int64_t>();
if (to_skip_ < 0)
throw QueryRuntimeException(
"Result of SKIP expression must be greater or equal to zero");
}
if (skipped_++ < to_skip_) continue;
return true;
}
return false;
}
Limit::Limit(const std::shared_ptr<LogicalOperator> &input,
Expression *expression)
: input_(input), expression_(expression) {}
void Limit::Accept(LogicalOperatorVisitor &visitor) {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
input_->Accept(visitor);
visitor.PostVisit(*this);
}
}
std::unique_ptr<Cursor> Limit::MakeCursor(GraphDbAccessor &db) {
return std::make_unique<LimitCursor>(*this, db);
}
Limit::LimitCursor::LimitCursor(Limit &self, GraphDbAccessor &db)
: self_(self), input_cursor_(self_.input_->MakeCursor(db)) {}
bool Limit::LimitCursor::Pull(Frame &frame, const SymbolTable &symbol_table) {
// we need to evaluate the limit expression before the first input Pull
// because it might be 0 and thereby we shouldn't Pull from input at all
// we can do this before Pulling from the input because the limit expression
// is not allowed to contain any identifiers
if (limit_ == -1) {
ExpressionEvaluator evaluator(frame, symbol_table);
self_.expression_->Accept(evaluator);
TypedValue limit = evaluator.PopBack();
if (limit.type() != TypedValue::Type::Int)
throw QueryRuntimeException("Result of LIMIT expression must be an int");
limit_ = limit.Value<int64_t>();
if (limit_ < 0)
throw QueryRuntimeException(
"Result of LIMIT expression must be greater or equal to zero");
}
// check we have not exceeded the limit before pulling
if (pulled_++ >= limit_)
return false;
return input_cursor_->Pull(frame, symbol_table);
}
} // namespace plan } // namespace plan
} // namespace query } // namespace query

View File

@ -61,6 +61,8 @@ class ExpandUniquenessFilter;
class Accumulate; class Accumulate;
class AdvanceCommand; class AdvanceCommand;
class Aggregate; class Aggregate;
class Skip;
class Limit;
/** @brief Base class for visitors of @c LogicalOperator class hierarchy. */ /** @brief Base class for visitors of @c LogicalOperator class hierarchy. */
using LogicalOperatorVisitor = using LogicalOperatorVisitor =
@ -69,7 +71,7 @@ using LogicalOperatorVisitor =
SetProperties, SetLabels, RemoveProperty, RemoveLabels, SetProperties, SetLabels, RemoveProperty, RemoveLabels,
ExpandUniquenessFilter<VertexAccessor>, ExpandUniquenessFilter<VertexAccessor>,
ExpandUniquenessFilter<EdgeAccessor>, Accumulate, ExpandUniquenessFilter<EdgeAccessor>, Accumulate,
AdvanceCommand, Aggregate>; AdvanceCommand, Aggregate, Skip, Limit>;
/** @brief Base class for logical operators. /** @brief Base class for logical operators.
* *
@ -929,5 +931,82 @@ class Aggregate : public LogicalOperator {
}; };
}; };
/** @brief Skips a number of Pulls from the input op.
*
* The given expression determines how many Pulls from the input
* should be skipped (ignored).
* All other successful Pulls from the
* input are simply passed through.
*
* The given expression is evaluated after the first Pull from
* the input, and only once. Neo does not allow this expression
* to contain identifiers, and neither does Memgraph, but this
* operator's implementation does not expect this.
*/
class Skip : public LogicalOperator {
public:
Skip(const std::shared_ptr<LogicalOperator> &input, Expression *expression);
void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
const std::shared_ptr<LogicalOperator> input_;
Expression *expression_;
class SkipCursor : public Cursor {
public:
SkipCursor(Skip &self, GraphDbAccessor &db);
bool Pull(Frame &frame, const SymbolTable &symbol_table) override;
private:
Skip &self_;
std::unique_ptr<Cursor> input_cursor_;
// init to_skip_ to -1, indicating
// that it's still unknown (input has not been Pulled yet)
int to_skip_{-1};
int skipped_{0};
};
};
/** @brief Limits the number of Pulls from the input op.
*
* The given expression determines how many
* input Pulls should be passed through. The input is not
* Pulled once this limit is reached. Note that this has
* implications: the out-of-bounds input Pulls are never
* evaluated.
*
* The limit expression must NOT use anything from the
* Frame. It is evaluated before the first Pull from the
* input. This is consistent with Neo (they don't allow
* identifiers in limit expressions), and it's necessary
* when limit evaluates to 0 (because 0 Pulls from the
* input should be performed).
*/
class Limit : public LogicalOperator {
public:
Limit(const std::shared_ptr<LogicalOperator> &input, Expression *expression);
void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
const std::shared_ptr<LogicalOperator> input_;
Expression *expression_;
class LimitCursor : public Cursor {
public:
LimitCursor(Limit &self, GraphDbAccessor &db);
bool Pull(Frame &frame, const SymbolTable &symbol_table) override;
private:
Limit &self_;
std::unique_ptr<Cursor> input_cursor_;
// init limit_ to -1, indicating
// that it's still unknown (Cursor has not been Pulled yet)
int limit_{-1};
int pulled_{0};
};
};
} // namespace plan } // namespace plan
} // namespace query } // namespace query

View File

@ -0,0 +1,106 @@
//
// Copyright 2017 Memgraph
// Created by Florijan Stamenkovic on 14.03.17.
//
#include <iterator>
#include <memory>
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "communication/result_stream_faker.hpp"
#include "dbms/dbms.hpp"
#include "query/context.hpp"
#include "query/exceptions.hpp"
#include "query/plan/operator.hpp"
#include "query_plan_common.hpp"
using namespace query;
using namespace query::plan;
TEST(QueryPlan, Skip) {
Dbms dbms;
auto dba = dbms.active();
AstTreeStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n1");
auto skip = std::make_shared<plan::Skip>(n.op_, LITERAL(2));
EXPECT_EQ(0, PullAll(skip, *dba, symbol_table));
dba->insert_vertex();
dba->advance_command();
EXPECT_EQ(0, PullAll(skip, *dba, symbol_table));
dba->insert_vertex();
dba->advance_command();
EXPECT_EQ(0, PullAll(skip, *dba, symbol_table));
dba->insert_vertex();
dba->advance_command();
EXPECT_EQ(1, PullAll(skip, *dba, symbol_table));
for (int i = 0; i < 10; ++i)
dba->insert_vertex();
dba->advance_command();
EXPECT_EQ(11, PullAll(skip, *dba, symbol_table));
}
TEST(QueryPlan, Limit) {
Dbms dbms;
auto dba = dbms.active();
AstTreeStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n1");
auto skip = std::make_shared<plan::Limit>(n.op_, LITERAL(2));
EXPECT_EQ(0, PullAll(skip, *dba, symbol_table));
dba->insert_vertex();
dba->advance_command();
EXPECT_EQ(1, PullAll(skip, *dba, symbol_table));
dba->insert_vertex();
dba->advance_command();
EXPECT_EQ(2, PullAll(skip, *dba, symbol_table));
dba->insert_vertex();
dba->advance_command();
EXPECT_EQ(2, PullAll(skip, *dba, symbol_table));
for (int i = 0; i < 10; ++i)
dba->insert_vertex();
dba->advance_command();
EXPECT_EQ(2, PullAll(skip, *dba, symbol_table));
}
TEST(QueryPlan, CreateLimit) {
// CREATE (n), (m)
// MATCH (n) CREATE (m) LIMIT 1
// in the end we need to have 3 vertices in the db
Dbms dbms;
auto dba = dbms.active();
dba->insert_vertex();
dba->insert_vertex();
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n1");
auto m = NODE("m");
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m");
auto c = std::make_shared<CreateNode>(m, n.op_);
auto skip = std::make_shared<plan::Limit>(c, LITERAL(1));
EXPECT_EQ(1, PullAll(skip, *dba, symbol_table));
dba->advance_command();
EXPECT_EQ(3, CountIterable(dba->vertices()));
}

View File

@ -57,7 +57,7 @@ auto CollectProduce(std::shared_ptr<Produce> produce, SymbolTable &symbol_table,
} }
int PullAll(std::shared_ptr<LogicalOperator> logical_op, GraphDbAccessor &db, int PullAll(std::shared_ptr<LogicalOperator> logical_op, GraphDbAccessor &db,
SymbolTable symbol_table) { SymbolTable &symbol_table) {
Frame frame(symbol_table.max_position()); Frame frame(symbol_table.max_position());
auto cursor = logical_op->MakeCursor(db); auto cursor = logical_op->MakeCursor(db);
int count = 0; int count = 0;