Plan Skip and Limit operators
Summary: Support SKIP and LIMIT macros in tests. Test planning Skip and Limit. Prevent variables in SKIP and LIMIT. Reviewers: mislav.bradac, florijan Reviewed By: florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D296
This commit is contained in:
parent
893df584f6
commit
55dc08fc30
@ -44,17 +44,35 @@ void SymbolGenerator::BindNamedExpressionSymbols(
|
||||
}
|
||||
}
|
||||
|
||||
void SymbolGenerator::VisitSkipAndLimit(Expression *skip, Expression *limit) {
|
||||
if (skip) {
|
||||
scope_.in_skip = true;
|
||||
skip->Accept(*this);
|
||||
scope_.in_skip = false;
|
||||
}
|
||||
if (limit) {
|
||||
scope_.in_limit = true;
|
||||
limit->Accept(*this);
|
||||
scope_.in_limit = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Clauses
|
||||
|
||||
void SymbolGenerator::Visit(Create &create) { scope_.in_create = true; }
|
||||
void SymbolGenerator::PostVisit(Create &create) { scope_.in_create = false; }
|
||||
|
||||
void SymbolGenerator::Visit(Return &ret) { scope_.in_return = true; }
|
||||
void SymbolGenerator::PostVisit(Return &ret) {
|
||||
bool SymbolGenerator::PreVisit(Return &ret) {
|
||||
scope_.in_return = true;
|
||||
for (auto &expr : ret.body_.named_expressions) {
|
||||
expr->Accept(*this);
|
||||
}
|
||||
// Named expressions establish bindings for expressions which come after
|
||||
// return, but not for the expressions contained inside.
|
||||
BindNamedExpressionSymbols(ret.body_.named_expressions);
|
||||
VisitSkipAndLimit(ret.body_.skip, ret.body_.limit);
|
||||
scope_.in_return = false;
|
||||
return false; // We handled the traversal ourselves.
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(With &with) {
|
||||
@ -68,6 +86,7 @@ bool SymbolGenerator::PreVisit(With &with) {
|
||||
// be visible inside named expressions themselves.
|
||||
scope_.symbols.clear();
|
||||
BindNamedExpressionSymbols(with.body_.named_expressions);
|
||||
VisitSkipAndLimit(with.body_.skip, with.body_.limit);
|
||||
if (with.where_) with.where_->Accept(*this);
|
||||
return false; // We handled the traversal ourselves.
|
||||
}
|
||||
@ -75,6 +94,10 @@ bool SymbolGenerator::PreVisit(With &with) {
|
||||
// Expressions
|
||||
|
||||
void SymbolGenerator::Visit(Identifier &ident) {
|
||||
if (scope_.in_skip || scope_.in_limit) {
|
||||
throw SemanticException("Variables are not allowed in {}",
|
||||
scope_.in_skip ? "SKIP" : "LIMIT");
|
||||
}
|
||||
Symbol symbol;
|
||||
if (scope_.in_pattern && !scope_.in_property_map) {
|
||||
// Patterns can bind new symbols or reference already bound. But there
|
||||
|
@ -27,8 +27,7 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
// Clauses
|
||||
void Visit(Create &) override;
|
||||
void PostVisit(Create &) override;
|
||||
void Visit(Return &) override;
|
||||
void PostVisit(Return &) override;
|
||||
bool PreVisit(Return &) override;
|
||||
bool PreVisit(With &) override;
|
||||
|
||||
// Expressions
|
||||
@ -61,6 +60,8 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
bool in_aggregation{false};
|
||||
bool in_return{false};
|
||||
bool in_with{false};
|
||||
bool in_skip{false};
|
||||
bool in_limit{false};
|
||||
std::map<std::string, Symbol> symbols;
|
||||
};
|
||||
|
||||
@ -79,6 +80,8 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
void BindNamedExpressionSymbols(
|
||||
const std::vector<NamedExpression *> &named_expressions);
|
||||
|
||||
void VisitSkipAndLimit(Expression *skip, Expression *limit);
|
||||
|
||||
SymbolTable &symbol_table_;
|
||||
Scope scope_;
|
||||
};
|
||||
|
@ -177,14 +177,25 @@ auto GenMatch(Match &match, LogicalOperator *input_op,
|
||||
return last_op;
|
||||
}
|
||||
|
||||
// Ast tree visitor which collects the context for a return body. The return
|
||||
// body are the named expressions found in WITH and RETURN clauses. The
|
||||
// collected context consists of used symbols, aggregations and group by named
|
||||
// expressions.
|
||||
// Ast tree visitor which collects the context for a return body.
|
||||
// The return body of WITH and RETURN clauses consists of:
|
||||
//
|
||||
// * named expressions (used to produce results);
|
||||
// * flag whether the results need to be DISTINCT;
|
||||
// * optional SKIP expression;
|
||||
// * optional LIMIT expression and
|
||||
// * optional ORDER BY expression.
|
||||
//
|
||||
// In addition to the above, we collect information on used symbols,
|
||||
// aggregations and expressions used for group by.
|
||||
class ReturnBodyContext : public TreeVisitorBase {
|
||||
public:
|
||||
ReturnBodyContext(const SymbolTable &symbol_table)
|
||||
: symbol_table_(symbol_table) {}
|
||||
ReturnBodyContext(const ReturnBody &body, const SymbolTable &symbol_table)
|
||||
: body_(body), symbol_table_(symbol_table) {
|
||||
for (auto &named_expr : body_.named_expressions) {
|
||||
named_expr->Accept(*this);
|
||||
}
|
||||
}
|
||||
|
||||
using TreeVisitorBase::PreVisit;
|
||||
using TreeVisitorBase::Visit;
|
||||
@ -249,6 +260,14 @@ class ReturnBodyContext : public TreeVisitorBase {
|
||||
has_aggregation_.pop_back();
|
||||
}
|
||||
|
||||
// If true, results need to be distinct.
|
||||
bool distinct() const { return body_.distinct; }
|
||||
// Named expressions which are used to produce results.
|
||||
const auto &named_expressions() const { return body_.named_expressions; }
|
||||
// Optional expression which determines how many results to skip.
|
||||
auto *skip() const { return body_.skip; }
|
||||
// Optional expression which determines how many results to produce.
|
||||
auto *limit() const { return body_.limit; }
|
||||
// Set of symbols used inside the visited expressions outside of aggregation
|
||||
// expression.
|
||||
const auto &symbols() const { return symbols_; }
|
||||
@ -267,6 +286,7 @@ class ReturnBodyContext : public TreeVisitorBase {
|
||||
}
|
||||
};
|
||||
|
||||
const ReturnBody &body_;
|
||||
const SymbolTable &symbol_table_;
|
||||
std::unordered_set<Symbol, SymbolHash> symbols_;
|
||||
std::vector<Aggregate::Element> aggregations_;
|
||||
@ -275,17 +295,36 @@ class ReturnBodyContext : public TreeVisitorBase {
|
||||
std::list<bool> has_aggregation_;
|
||||
};
|
||||
|
||||
auto GenSkipLimit(LogicalOperator *input_op, const ReturnBodyContext &body) {
|
||||
auto last_op = input_op;
|
||||
// SKIP is always before LIMIT clause.
|
||||
if (body.skip()) {
|
||||
last_op = new Skip(std::shared_ptr<LogicalOperator>(last_op), body.skip());
|
||||
}
|
||||
if (body.limit()) {
|
||||
last_op =
|
||||
new Limit(std::shared_ptr<LogicalOperator>(last_op), body.limit());
|
||||
}
|
||||
return last_op;
|
||||
}
|
||||
|
||||
auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
|
||||
const std::vector<NamedExpression *> &named_expressions,
|
||||
const SymbolTable &symbol_table, bool accumulate = false) {
|
||||
ReturnBodyContext context(symbol_table);
|
||||
// Generate context for all named expressions.
|
||||
for (auto &named_expr : named_expressions) {
|
||||
named_expr->Accept(context);
|
||||
const ReturnBodyContext &body, bool accumulate = false) {
|
||||
if (body.distinct()) {
|
||||
// TODO: Plan with distinct, when operator available.
|
||||
throw utils::NotYetImplemented();
|
||||
}
|
||||
auto symbols =
|
||||
std::vector<Symbol>(context.symbols().begin(), context.symbols().end());
|
||||
std::vector<Symbol>(body.symbols().begin(), body.symbols().end());
|
||||
auto last_op = input_op;
|
||||
if (body.aggregations().empty()) {
|
||||
// In case when we have SKIP/LIMIT and we don't perform aggregations, we
|
||||
// want to put them before (optional) accumulation. This way we ensure that
|
||||
// write part of the query will be limited.
|
||||
// For example, `MATCH (n) SET n.x = n.x + 1 RETURN n LIMIT 1` should
|
||||
// increment `n.x` only once.
|
||||
last_op = GenSkipLimit(last_op, body);
|
||||
}
|
||||
if (accumulate) {
|
||||
// We only advance the command in Accumulate. This is done for WITH clause,
|
||||
// when the first part updated the database. RETURN clause may only need an
|
||||
@ -293,32 +332,30 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
|
||||
last_op = new Accumulate(std::shared_ptr<LogicalOperator>(last_op), symbols,
|
||||
advance_command);
|
||||
}
|
||||
if (!context.aggregations().empty()) {
|
||||
last_op =
|
||||
if (!body.aggregations().empty()) {
|
||||
// When we have aggregation, SKIP/LIMIT should always come after it.
|
||||
last_op = GenSkipLimit(
|
||||
new Aggregate(std::shared_ptr<LogicalOperator>(last_op),
|
||||
context.aggregations(), context.group_by(), symbols);
|
||||
body.aggregations(), body.group_by(), symbols),
|
||||
body);
|
||||
}
|
||||
return new Produce(std::shared_ptr<LogicalOperator>(last_op),
|
||||
named_expressions);
|
||||
body.named_expressions());
|
||||
}
|
||||
|
||||
auto GenWith(With &with, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table, bool is_write,
|
||||
std::unordered_set<int> &bound_symbols) {
|
||||
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
|
||||
// optional Filter.
|
||||
if (with.body_.distinct) {
|
||||
// TODO: Plan distinct with, when operator available.
|
||||
throw utils::NotYetImplemented();
|
||||
}
|
||||
// In case of update and aggregation, we want to accumulate first, so that
|
||||
// when aggregating, we get the latest results. Similar to RETURN clause.
|
||||
// optional Filter. In case of update and aggregation, we want to accumulate
|
||||
// first, so that when aggregating, we get the latest results. Similar to
|
||||
// RETURN clause.
|
||||
bool accumulate = is_write;
|
||||
// No need to advance the command if we only performed reads.
|
||||
bool advance_command = is_write;
|
||||
ReturnBodyContext body(with.body_, symbol_table);
|
||||
LogicalOperator *last_op =
|
||||
GenReturnBody(input_op, advance_command, with.body_.named_expressions,
|
||||
symbol_table, accumulate);
|
||||
GenReturnBody(input_op, advance_command, body, accumulate);
|
||||
// Reset bound symbols, so that only those in WITH are exposed.
|
||||
bound_symbols.clear();
|
||||
for (auto &named_expr : with.body_.named_expressions) {
|
||||
@ -341,8 +378,8 @@ auto GenReturn(Return &ret, LogicalOperator *input_op,
|
||||
// value is the same, final result of 'k' increments.
|
||||
bool accumulate = is_write;
|
||||
bool advance_command = false;
|
||||
return GenReturnBody(input_op, advance_command, ret.body_.named_expressions,
|
||||
symbol_table, accumulate);
|
||||
ReturnBodyContext body(ret.body_, symbol_table);
|
||||
return GenReturnBody(input_op, advance_command, body, accumulate);
|
||||
}
|
||||
|
||||
// Generate an operator for a clause which writes to the database. If the clause
|
||||
|
@ -4,6 +4,15 @@ namespace query {
|
||||
|
||||
namespace test_common {
|
||||
|
||||
// Custom types for SKIP and LIMIT and expressions, so that they can be used to
|
||||
// resolve function calls.
|
||||
struct Skip {
|
||||
query::Expression *expression = nullptr;
|
||||
};
|
||||
struct Limit {
|
||||
query::Expression *expression = nullptr;
|
||||
};
|
||||
|
||||
///
|
||||
/// Create PropertyLookup with given name and property.
|
||||
///
|
||||
@ -115,6 +124,15 @@ auto GetReturn(Return *ret, NamedExpression *named_expr) {
|
||||
ret->body_.named_expressions.emplace_back(named_expr);
|
||||
return ret;
|
||||
}
|
||||
auto GetReturn(Return *ret, Skip skip, Limit limit = Limit{}) {
|
||||
ret->body_.skip = skip.expression;
|
||||
ret->body_.limit = limit.expression;
|
||||
return ret;
|
||||
}
|
||||
auto GetReturn(Return *ret, Limit limit) {
|
||||
ret->body_.limit = limit.expression;
|
||||
return ret;
|
||||
}
|
||||
auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr) {
|
||||
// This overload supports `RETURN(expr, AS(name))` construct, since
|
||||
// NamedExpression does not inherit Expression.
|
||||
@ -124,18 +142,18 @@ auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr) {
|
||||
}
|
||||
template <class... T>
|
||||
auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr,
|
||||
T *... rest) {
|
||||
T... rest) {
|
||||
named_expr->expression_ = expr;
|
||||
ret->body_.named_expressions.emplace_back(named_expr);
|
||||
return GetReturn(ret, rest...);
|
||||
}
|
||||
template <class... T>
|
||||
auto GetReturn(Return *ret, NamedExpression *named_expr, T *... rest) {
|
||||
auto GetReturn(Return *ret, NamedExpression *named_expr, T... rest) {
|
||||
ret->body_.named_expressions.emplace_back(named_expr);
|
||||
return GetReturn(ret, rest...);
|
||||
}
|
||||
template <class... T>
|
||||
auto GetReturn(AstTreeStorage &storage, T *... exprs) {
|
||||
auto GetReturn(AstTreeStorage &storage, T... exprs) {
|
||||
auto ret = storage.Create<Return>();
|
||||
return GetReturn(ret, exprs...);
|
||||
}
|
||||
@ -147,6 +165,15 @@ auto GetWith(With *with, NamedExpression *named_expr) {
|
||||
with->body_.named_expressions.emplace_back(named_expr);
|
||||
return with;
|
||||
}
|
||||
auto GetWith(With *with, Skip skip, Limit limit = {}) {
|
||||
with->body_.skip = skip.expression;
|
||||
with->body_.limit = limit.expression;
|
||||
return with;
|
||||
}
|
||||
auto GetWith(With *with, Limit limit) {
|
||||
with->body_.limit = limit.expression;
|
||||
return with;
|
||||
}
|
||||
auto GetWith(With *with, Expression *expr, NamedExpression *named_expr) {
|
||||
// This overload supports `RETURN(expr, AS(name))` construct, since
|
||||
// NamedExpression does not inherit Expression.
|
||||
@ -156,18 +183,18 @@ auto GetWith(With *with, Expression *expr, NamedExpression *named_expr) {
|
||||
}
|
||||
template <class... T>
|
||||
auto GetWith(With *with, Expression *expr, NamedExpression *named_expr,
|
||||
T *... rest) {
|
||||
T... rest) {
|
||||
named_expr->expression_ = expr;
|
||||
with->body_.named_expressions.emplace_back(named_expr);
|
||||
return GetWith(with, rest...);
|
||||
}
|
||||
template <class... T>
|
||||
auto GetWith(With *with, NamedExpression *named_expr, T *... rest) {
|
||||
auto GetWith(With *with, NamedExpression *named_expr, T... rest) {
|
||||
with->body_.named_expressions.emplace_back(named_expr);
|
||||
return GetWith(with, rest...);
|
||||
}
|
||||
template <class... T>
|
||||
auto GetWith(AstTreeStorage &storage, T *... exprs) {
|
||||
auto GetWith(AstTreeStorage &storage, T... exprs) {
|
||||
auto with = storage.Create<With>();
|
||||
return GetWith(with, exprs...);
|
||||
}
|
||||
@ -261,6 +288,8 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
|
||||
#define AS(name) storage.Create<query::NamedExpression>((name))
|
||||
#define RETURN(...) query::test_common::GetReturn(storage, __VA_ARGS__)
|
||||
#define WITH(...) query::test_common::GetWith(storage, __VA_ARGS__)
|
||||
#define SKIP(expr) query::test_common::Skip{(expr)}
|
||||
#define LIMIT(expr) query::test_common::Limit{(expr)}
|
||||
#define DELETE(...) query::test_common::GetDelete(storage, {__VA_ARGS__})
|
||||
#define DETACH_DELETE(...) \
|
||||
query::test_common::GetDelete(storage, {__VA_ARGS__}, true)
|
||||
|
@ -58,6 +58,8 @@ template <class TAccessor>
|
||||
using ExpectExpandUniquenessFilter =
|
||||
OpChecker<ExpandUniquenessFilter<TAccessor>>;
|
||||
using ExpectAccumulate = OpChecker<Accumulate>;
|
||||
using ExpectSkip = OpChecker<Skip>;
|
||||
using ExpectLimit = OpChecker<Limit>;
|
||||
|
||||
class ExpectAggregate : public OpChecker<Aggregate> {
|
||||
public:
|
||||
@ -115,6 +117,8 @@ class PlanChecker : public LogicalOperatorVisitor {
|
||||
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override { CheckOp(op); }
|
||||
void Visit(Accumulate &op) override { CheckOp(op); }
|
||||
void Visit(Aggregate &op) override { CheckOp(op); }
|
||||
void Visit(Skip &op) override { CheckOp(op); }
|
||||
void Visit(Limit &op) override { CheckOp(op); }
|
||||
|
||||
std::list<BaseOpChecker *> checkers_;
|
||||
|
||||
@ -431,4 +435,46 @@ TEST(TestLogicalPlanner, MatchWithCreate) {
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectCreateExpand());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchReturnSkipLimit) {
|
||||
// Test MATCH (n) RETURN n SKIP 2 LIMIT 1
|
||||
AstTreeStorage storage;
|
||||
auto query =
|
||||
QUERY(MATCH(PATTERN(NODE("n"))),
|
||||
RETURN(IDENT("n"), AS("n"), SKIP(LITERAL(2)), LIMIT(LITERAL(1))));
|
||||
// A simple Skip and Limit combo which should come before Produce.
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectSkip(), ExpectLimit(),
|
||||
ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) {
|
||||
// Test CREATE (n) WITH n AS m SKIP 2 RETURN m LIMIT 1
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(CREATE(PATTERN(NODE("n"))),
|
||||
WITH(IDENT("n"), AS("m"), SKIP(LITERAL(2))),
|
||||
RETURN(IDENT("m"), AS("m"), LIMIT(LITERAL(1))));
|
||||
// Since we have a write query, we need to have Accumulate, so Skip and Limit
|
||||
// need to come before it. This is a bit different than Neo4j, which optimizes
|
||||
// WITH followed by RETURN as a single RETURN clause. This would cause the
|
||||
// Limit operator to also appear before Accumulate, thus changing the
|
||||
// behaviour. We've decided to diverge from Neo4j here, for consistency sake.
|
||||
CheckPlan(*query, ExpectCreateNode(), ExpectSkip(), ExpectAccumulate(),
|
||||
ExpectProduce(), ExpectLimit(), ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) {
|
||||
// Test CREATE (n) RETURN SUM(n.prop) AS s SKIP 2 LIMIT 1
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
|
||||
auto query = QUERY(CREATE(PATTERN(NODE("n"))),
|
||||
RETURN(sum, AS("s"), SKIP(LITERAL(2)), LIMIT(LITERAL(1))));
|
||||
auto aggr = ExpectAggregate({sum}, {});
|
||||
// We have a write query and aggregation, therefore Skip and Limit should come
|
||||
// after Accumulate and Aggregate.
|
||||
CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), aggr, ExpectSkip(),
|
||||
ExpectLimit(), ExpectProduce());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -555,4 +555,34 @@ TEST(TestSymbolGenerator, SameResults) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, SkipLimitIdentifier) {
|
||||
// Test MATCH (old) WITH old AS new SKIP old
|
||||
{
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("old"))),
|
||||
WITH(IDENT("old"), AS("new"), SKIP(IDENT("old"))));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
|
||||
}
|
||||
// Test MATCH (old) WITH old AS new SKIP new
|
||||
{
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("old"))),
|
||||
WITH(IDENT("old"), AS("new"), SKIP(IDENT("new"))));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
|
||||
}
|
||||
// Test MATCH (n) RETURN n AS n LIMIT n
|
||||
{
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
|
||||
RETURN(IDENT("n"), AS("n"), SKIP(IDENT("n"))));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user