Plan Skip and Limit operators

Summary:
Support SKIP and LIMIT macros in tests.
Test planning Skip and Limit.
Prevent variables in SKIP and LIMIT.

Reviewers: mislav.bradac, florijan

Reviewed By: florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D296
This commit is contained in:
Teon Banek 2017-04-20 11:20:20 +02:00
parent 893df584f6
commit 55dc08fc30
6 changed files with 207 additions and 39 deletions

View File

@ -44,17 +44,35 @@ void SymbolGenerator::BindNamedExpressionSymbols(
}
}
void SymbolGenerator::VisitSkipAndLimit(Expression *skip, Expression *limit) {
if (skip) {
scope_.in_skip = true;
skip->Accept(*this);
scope_.in_skip = false;
}
if (limit) {
scope_.in_limit = true;
limit->Accept(*this);
scope_.in_limit = false;
}
}
// Clauses
void SymbolGenerator::Visit(Create &create) { scope_.in_create = true; }
void SymbolGenerator::PostVisit(Create &create) { scope_.in_create = false; }
void SymbolGenerator::Visit(Return &ret) { scope_.in_return = true; }
void SymbolGenerator::PostVisit(Return &ret) {
bool SymbolGenerator::PreVisit(Return &ret) {
scope_.in_return = true;
for (auto &expr : ret.body_.named_expressions) {
expr->Accept(*this);
}
// Named expressions establish bindings for expressions which come after
// return, but not for the expressions contained inside.
BindNamedExpressionSymbols(ret.body_.named_expressions);
VisitSkipAndLimit(ret.body_.skip, ret.body_.limit);
scope_.in_return = false;
return false; // We handled the traversal ourselves.
}
bool SymbolGenerator::PreVisit(With &with) {
@ -68,6 +86,7 @@ bool SymbolGenerator::PreVisit(With &with) {
// be visible inside named expressions themselves.
scope_.symbols.clear();
BindNamedExpressionSymbols(with.body_.named_expressions);
VisitSkipAndLimit(with.body_.skip, with.body_.limit);
if (with.where_) with.where_->Accept(*this);
return false; // We handled the traversal ourselves.
}
@ -75,6 +94,10 @@ bool SymbolGenerator::PreVisit(With &with) {
// Expressions
void SymbolGenerator::Visit(Identifier &ident) {
if (scope_.in_skip || scope_.in_limit) {
throw SemanticException("Variables are not allowed in {}",
scope_.in_skip ? "SKIP" : "LIMIT");
}
Symbol symbol;
if (scope_.in_pattern && !scope_.in_property_map) {
// Patterns can bind new symbols or reference already bound. But there

View File

@ -27,8 +27,7 @@ class SymbolGenerator : public TreeVisitorBase {
// Clauses
void Visit(Create &) override;
void PostVisit(Create &) override;
void Visit(Return &) override;
void PostVisit(Return &) override;
bool PreVisit(Return &) override;
bool PreVisit(With &) override;
// Expressions
@ -61,6 +60,8 @@ class SymbolGenerator : public TreeVisitorBase {
bool in_aggregation{false};
bool in_return{false};
bool in_with{false};
bool in_skip{false};
bool in_limit{false};
std::map<std::string, Symbol> symbols;
};
@ -79,6 +80,8 @@ class SymbolGenerator : public TreeVisitorBase {
void BindNamedExpressionSymbols(
const std::vector<NamedExpression *> &named_expressions);
void VisitSkipAndLimit(Expression *skip, Expression *limit);
SymbolTable &symbol_table_;
Scope scope_;
};

View File

@ -177,14 +177,25 @@ auto GenMatch(Match &match, LogicalOperator *input_op,
return last_op;
}
// Ast tree visitor which collects the context for a return body. The return
// body are the named expressions found in WITH and RETURN clauses. The
// collected context consists of used symbols, aggregations and group by named
// expressions.
// Ast tree visitor which collects the context for a return body.
// The return body of WITH and RETURN clauses consists of:
//
// * named expressions (used to produce results);
// * flag whether the results need to be DISTINCT;
// * optional SKIP expression;
// * optional LIMIT expression and
// * optional ORDER BY expression.
//
// In addition to the above, we collect information on used symbols,
// aggregations and expressions used for group by.
class ReturnBodyContext : public TreeVisitorBase {
public:
ReturnBodyContext(const SymbolTable &symbol_table)
: symbol_table_(symbol_table) {}
ReturnBodyContext(const ReturnBody &body, const SymbolTable &symbol_table)
: body_(body), symbol_table_(symbol_table) {
for (auto &named_expr : body_.named_expressions) {
named_expr->Accept(*this);
}
}
using TreeVisitorBase::PreVisit;
using TreeVisitorBase::Visit;
@ -249,6 +260,14 @@ class ReturnBodyContext : public TreeVisitorBase {
has_aggregation_.pop_back();
}
// If true, results need to be distinct.
bool distinct() const { return body_.distinct; }
// Named expressions which are used to produce results.
const auto &named_expressions() const { return body_.named_expressions; }
// Optional expression which determines how many results to skip.
auto *skip() const { return body_.skip; }
// Optional expression which determines how many results to produce.
auto *limit() const { return body_.limit; }
// Set of symbols used inside the visited expressions outside of aggregation
// expression.
const auto &symbols() const { return symbols_; }
@ -267,6 +286,7 @@ class ReturnBodyContext : public TreeVisitorBase {
}
};
const ReturnBody &body_;
const SymbolTable &symbol_table_;
std::unordered_set<Symbol, SymbolHash> symbols_;
std::vector<Aggregate::Element> aggregations_;
@ -275,17 +295,36 @@ class ReturnBodyContext : public TreeVisitorBase {
std::list<bool> has_aggregation_;
};
auto GenSkipLimit(LogicalOperator *input_op, const ReturnBodyContext &body) {
auto last_op = input_op;
// SKIP is always before LIMIT clause.
if (body.skip()) {
last_op = new Skip(std::shared_ptr<LogicalOperator>(last_op), body.skip());
}
if (body.limit()) {
last_op =
new Limit(std::shared_ptr<LogicalOperator>(last_op), body.limit());
}
return last_op;
}
auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
const std::vector<NamedExpression *> &named_expressions,
const SymbolTable &symbol_table, bool accumulate = false) {
ReturnBodyContext context(symbol_table);
// Generate context for all named expressions.
for (auto &named_expr : named_expressions) {
named_expr->Accept(context);
const ReturnBodyContext &body, bool accumulate = false) {
if (body.distinct()) {
// TODO: Plan with distinct, when operator available.
throw utils::NotYetImplemented();
}
auto symbols =
std::vector<Symbol>(context.symbols().begin(), context.symbols().end());
std::vector<Symbol>(body.symbols().begin(), body.symbols().end());
auto last_op = input_op;
if (body.aggregations().empty()) {
// In case when we have SKIP/LIMIT and we don't perform aggregations, we
// want to put them before (optional) accumulation. This way we ensure that
// write part of the query will be limited.
// For example, `MATCH (n) SET n.x = n.x + 1 RETURN n LIMIT 1` should
// increment `n.x` only once.
last_op = GenSkipLimit(last_op, body);
}
if (accumulate) {
// We only advance the command in Accumulate. This is done for WITH clause,
// when the first part updated the database. RETURN clause may only need an
@ -293,32 +332,30 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
last_op = new Accumulate(std::shared_ptr<LogicalOperator>(last_op), symbols,
advance_command);
}
if (!context.aggregations().empty()) {
last_op =
if (!body.aggregations().empty()) {
// When we have aggregation, SKIP/LIMIT should always come after it.
last_op = GenSkipLimit(
new Aggregate(std::shared_ptr<LogicalOperator>(last_op),
context.aggregations(), context.group_by(), symbols);
body.aggregations(), body.group_by(), symbols),
body);
}
return new Produce(std::shared_ptr<LogicalOperator>(last_op),
named_expressions);
body.named_expressions());
}
auto GenWith(With &with, LogicalOperator *input_op,
const SymbolTable &symbol_table, bool is_write,
std::unordered_set<int> &bound_symbols) {
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
// optional Filter.
if (with.body_.distinct) {
// TODO: Plan distinct with, when operator available.
throw utils::NotYetImplemented();
}
// In case of update and aggregation, we want to accumulate first, so that
// when aggregating, we get the latest results. Similar to RETURN clause.
// optional Filter. In case of update and aggregation, we want to accumulate
// first, so that when aggregating, we get the latest results. Similar to
// RETURN clause.
bool accumulate = is_write;
// No need to advance the command if we only performed reads.
bool advance_command = is_write;
ReturnBodyContext body(with.body_, symbol_table);
LogicalOperator *last_op =
GenReturnBody(input_op, advance_command, with.body_.named_expressions,
symbol_table, accumulate);
GenReturnBody(input_op, advance_command, body, accumulate);
// Reset bound symbols, so that only those in WITH are exposed.
bound_symbols.clear();
for (auto &named_expr : with.body_.named_expressions) {
@ -341,8 +378,8 @@ auto GenReturn(Return &ret, LogicalOperator *input_op,
// value is the same, final result of 'k' increments.
bool accumulate = is_write;
bool advance_command = false;
return GenReturnBody(input_op, advance_command, ret.body_.named_expressions,
symbol_table, accumulate);
ReturnBodyContext body(ret.body_, symbol_table);
return GenReturnBody(input_op, advance_command, body, accumulate);
}
// Generate an operator for a clause which writes to the database. If the clause

View File

@ -4,6 +4,15 @@ namespace query {
namespace test_common {
// Custom types for SKIP and LIMIT and expressions, so that they can be used to
// resolve function calls.
struct Skip {
query::Expression *expression = nullptr;
};
struct Limit {
query::Expression *expression = nullptr;
};
///
/// Create PropertyLookup with given name and property.
///
@ -115,6 +124,15 @@ auto GetReturn(Return *ret, NamedExpression *named_expr) {
ret->body_.named_expressions.emplace_back(named_expr);
return ret;
}
auto GetReturn(Return *ret, Skip skip, Limit limit = Limit{}) {
ret->body_.skip = skip.expression;
ret->body_.limit = limit.expression;
return ret;
}
auto GetReturn(Return *ret, Limit limit) {
ret->body_.limit = limit.expression;
return ret;
}
auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr) {
// This overload supports `RETURN(expr, AS(name))` construct, since
// NamedExpression does not inherit Expression.
@ -124,18 +142,18 @@ auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr) {
}
template <class... T>
auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr,
T *... rest) {
T... rest) {
named_expr->expression_ = expr;
ret->body_.named_expressions.emplace_back(named_expr);
return GetReturn(ret, rest...);
}
template <class... T>
auto GetReturn(Return *ret, NamedExpression *named_expr, T *... rest) {
auto GetReturn(Return *ret, NamedExpression *named_expr, T... rest) {
ret->body_.named_expressions.emplace_back(named_expr);
return GetReturn(ret, rest...);
}
template <class... T>
auto GetReturn(AstTreeStorage &storage, T *... exprs) {
auto GetReturn(AstTreeStorage &storage, T... exprs) {
auto ret = storage.Create<Return>();
return GetReturn(ret, exprs...);
}
@ -147,6 +165,15 @@ auto GetWith(With *with, NamedExpression *named_expr) {
with->body_.named_expressions.emplace_back(named_expr);
return with;
}
auto GetWith(With *with, Skip skip, Limit limit = {}) {
with->body_.skip = skip.expression;
with->body_.limit = limit.expression;
return with;
}
auto GetWith(With *with, Limit limit) {
with->body_.limit = limit.expression;
return with;
}
auto GetWith(With *with, Expression *expr, NamedExpression *named_expr) {
// This overload supports `RETURN(expr, AS(name))` construct, since
// NamedExpression does not inherit Expression.
@ -156,18 +183,18 @@ auto GetWith(With *with, Expression *expr, NamedExpression *named_expr) {
}
template <class... T>
auto GetWith(With *with, Expression *expr, NamedExpression *named_expr,
T *... rest) {
T... rest) {
named_expr->expression_ = expr;
with->body_.named_expressions.emplace_back(named_expr);
return GetWith(with, rest...);
}
template <class... T>
auto GetWith(With *with, NamedExpression *named_expr, T *... rest) {
auto GetWith(With *with, NamedExpression *named_expr, T... rest) {
with->body_.named_expressions.emplace_back(named_expr);
return GetWith(with, rest...);
}
template <class... T>
auto GetWith(AstTreeStorage &storage, T *... exprs) {
auto GetWith(AstTreeStorage &storage, T... exprs) {
auto with = storage.Create<With>();
return GetWith(with, exprs...);
}
@ -261,6 +288,8 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
#define AS(name) storage.Create<query::NamedExpression>((name))
#define RETURN(...) query::test_common::GetReturn(storage, __VA_ARGS__)
#define WITH(...) query::test_common::GetWith(storage, __VA_ARGS__)
#define SKIP(expr) query::test_common::Skip{(expr)}
#define LIMIT(expr) query::test_common::Limit{(expr)}
#define DELETE(...) query::test_common::GetDelete(storage, {__VA_ARGS__})
#define DETACH_DELETE(...) \
query::test_common::GetDelete(storage, {__VA_ARGS__}, true)

View File

@ -58,6 +58,8 @@ template <class TAccessor>
using ExpectExpandUniquenessFilter =
OpChecker<ExpandUniquenessFilter<TAccessor>>;
using ExpectAccumulate = OpChecker<Accumulate>;
using ExpectSkip = OpChecker<Skip>;
using ExpectLimit = OpChecker<Limit>;
class ExpectAggregate : public OpChecker<Aggregate> {
public:
@ -115,6 +117,8 @@ class PlanChecker : public LogicalOperatorVisitor {
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override { CheckOp(op); }
void Visit(Accumulate &op) override { CheckOp(op); }
void Visit(Aggregate &op) override { CheckOp(op); }
void Visit(Skip &op) override { CheckOp(op); }
void Visit(Limit &op) override { CheckOp(op); }
std::list<BaseOpChecker *> checkers_;
@ -431,4 +435,46 @@ TEST(TestLogicalPlanner, MatchWithCreate) {
CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectCreateExpand());
}
TEST(TestLogicalPlanner, MatchReturnSkipLimit) {
// Test MATCH (n) RETURN n SKIP 2 LIMIT 1
AstTreeStorage storage;
auto query =
QUERY(MATCH(PATTERN(NODE("n"))),
RETURN(IDENT("n"), AS("n"), SKIP(LITERAL(2)), LIMIT(LITERAL(1))));
// A simple Skip and Limit combo which should come before Produce.
CheckPlan(*query, ExpectScanAll(), ExpectSkip(), ExpectLimit(),
ExpectProduce());
}
TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) {
// Test CREATE (n) WITH n AS m SKIP 2 RETURN m LIMIT 1
AstTreeStorage storage;
auto query = QUERY(CREATE(PATTERN(NODE("n"))),
WITH(IDENT("n"), AS("m"), SKIP(LITERAL(2))),
RETURN(IDENT("m"), AS("m"), LIMIT(LITERAL(1))));
// Since we have a write query, we need to have Accumulate, so Skip and Limit
// need to come before it. This is a bit different than Neo4j, which optimizes
// WITH followed by RETURN as a single RETURN clause. This would cause the
// Limit operator to also appear before Accumulate, thus changing the
// behaviour. We've decided to diverge from Neo4j here, for consistency sake.
CheckPlan(*query, ExpectCreateNode(), ExpectSkip(), ExpectAccumulate(),
ExpectProduce(), ExpectLimit(), ExpectProduce());
}
TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) {
// Test CREATE (n) RETURN SUM(n.prop) AS s SKIP 2 LIMIT 1
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
auto query = QUERY(CREATE(PATTERN(NODE("n"))),
RETURN(sum, AS("s"), SKIP(LITERAL(2)), LIMIT(LITERAL(1))));
auto aggr = ExpectAggregate({sum}, {});
// We have a write query and aggregation, therefore Skip and Limit should come
// after Accumulate and Aggregate.
CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), aggr, ExpectSkip(),
ExpectLimit(), ExpectProduce());
}
} // namespace

View File

@ -555,4 +555,34 @@ TEST(TestSymbolGenerator, SameResults) {
}
}
TEST(TestSymbolGenerator, SkipLimitIdentifier) {
// Test MATCH (old) WITH old AS new SKIP old
{
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("old"))),
WITH(IDENT("old"), AS("new"), SKIP(IDENT("old"))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
}
// Test MATCH (old) WITH old AS new SKIP new
{
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("old"))),
WITH(IDENT("old"), AS("new"), SKIP(IDENT("new"))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
}
// Test MATCH (n) RETURN n AS n LIMIT n
{
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
RETURN(IDENT("n"), AS("n"), SKIP(IDENT("n"))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
}
}
}