Plan WITH clause without aggregation

Summary: Generate symbols for WITH clause.

Reviewers: florijan, mislav.bradac, buda

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D231
This commit is contained in:
Teon Banek 2017-04-05 14:19:14 +02:00
parent 919258d6f6
commit aa6cae0b16
6 changed files with 245 additions and 6 deletions

View File

@ -1,5 +1,6 @@
#include "query/frontend/logical/planner.hpp"
#include <functional>
#include <unordered_set>
#include "query/frontend/ast/ast.hpp"
@ -176,6 +177,59 @@ auto GenMatch(Match &match, LogicalOperator *input_op,
return last_op;
}
// Ast tree visitor which collects all the symbols referenced by identifiers.
class SymbolCollector : public TreeVisitorBase {
public:
SymbolCollector(const SymbolTable &symbol_table)
: symbol_table_(symbol_table) {}
using TreeVisitorBase::Visit;
using TreeVisitorBase::PostVisit;
void Visit(Identifier &ident) override {
symbols_.insert(symbol_table_.at(ident));
}
const auto &symbols() const { return symbols_; }
private:
// Calculates the Symbol hash based on its position.
struct SymbolHash {
size_t operator()(const Symbol &symbol) const {
return std::hash<int>{}(symbol.position_);
}
};
const SymbolTable &symbol_table_;
std::unordered_set<Symbol, SymbolHash> symbols_;
};
auto GenWith(With &with, LogicalOperator *input_op,
const query::SymbolTable &symbol_table) {
if (with.distinct_) {
// TODO: Plan disctint with, when operator available.
throw NotYetImplemented();
}
// WITH clause is Accumulate/Aggregate (advance_command) + Produce.
SymbolCollector symbol_collector(symbol_table);
// Collect used symbols so that accumulate doesn't copy the whole frame.
for (auto &named_expr : with.named_expressions_) {
named_expr->expression_->Accept(symbol_collector);
}
auto symbols = symbol_collector.symbols();
// TODO: Check whether we need aggregate instead of accumulate.
LogicalOperator *last_op =
new Accumulate(std::shared_ptr<LogicalOperator>(input_op),
std::vector<Symbol>(symbols.begin(), symbols.end()), true);
last_op = new Produce(std::shared_ptr<LogicalOperator>(last_op),
with.named_expressions_);
if (with.where_) {
last_op = new Filter(std::shared_ptr<LogicalOperator>(last_op),
with.where_->expression_);
}
return last_op;
}
} // namespace
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
@ -223,6 +277,8 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
input_op =
new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, rem->labels_);
} else if (auto *with = dynamic_cast<query::With *>(clause_ptr)) {
input_op = GenWith(*with, input_op, symbol_table);
} else {
throw NotYetImplemented();
}

View File

@ -39,6 +39,36 @@ void SymbolGenerator::PostVisit(Return &ret) {
}
}
void SymbolGenerator::SetWithSymbols(With &with) {
// WITH clause removes declarations of all the previous variables and declares
// only those established through named expressions. New declarations must not
// be visible inside named expressions themselves.
scope_.symbols.clear();
for (auto &named_expr : with.named_expressions_) {
symbol_table_[*named_expr] = CreateSymbol(named_expr->name_);
}
}
void SymbolGenerator::Visit(With &with) {
scope_.with = &with;
}
void SymbolGenerator::Visit(Where &where) {
if (scope_.with) {
// New symbols must be visible in WHERE clause, so this must be done here
// and not in PostVisit(With&).
SetWithSymbols(*scope_.with);
}
}
void SymbolGenerator::PostVisit(With &with) {
if (!with.where_) {
// This wasn't done when visiting Where, so do it here.
SetWithSymbols(with);
}
scope_.with = nullptr;
}
// Expressions
void SymbolGenerator::Visit(Identifier &ident) {

View File

@ -27,6 +27,9 @@ class SymbolGenerator : public TreeVisitorBase {
void Visit(Create &create) override;
void PostVisit(Create &create) override;
void PostVisit(Return &ret) override;
void Visit(With &with) override;
void PostVisit(With &with) override;
void Visit(Where &where) override;
// Expressions
void Visit(Identifier &ident) override;
@ -53,6 +56,8 @@ class SymbolGenerator : public TreeVisitorBase {
bool in_node_atom{false};
bool in_edge_atom{false};
bool in_property_map{false};
// Pointer to With clause if we are inside it, otherwise nullptr.
With *with{nullptr};
std::map<std::string, Symbol> symbols;
};
@ -68,6 +73,9 @@ class SymbolGenerator : public TreeVisitorBase {
auto GetOrCreateSymbol(const std::string &name,
Symbol::Type type = Symbol::Type::Any);
// Clear old symbol bindings and establish new from WITH clause.
void SetWithSymbols(With &with);
SymbolTable &symbol_table_;
Scope scope_;
};

View File

@ -78,10 +78,15 @@ auto GetQuery(AstTreeStorage &storage, Clause *clause) {
storage.query()->clauses_.emplace_back(clause);
return storage.query();
}
template <class... T>
auto GetQuery(AstTreeStorage &storage, Clause *clause, T *... clauses) {
storage.query()->clauses_.emplace_back(clause);
return GetQuery(storage, clauses...);
auto GetQuery(AstTreeStorage &storage, Match *match, Where *where) {
match->where_ = where;
storage.query()->clauses_.emplace_back(match);
return storage.query();
}
auto GetQuery(AstTreeStorage &storage, With *with, Where *where) {
with->where_ = where;
storage.query()->clauses_.emplace_back(with);
return storage.query();
}
template <class... T>
auto GetQuery(AstTreeStorage &storage, Match *match, Where *where,
@ -90,6 +95,18 @@ auto GetQuery(AstTreeStorage &storage, Match *match, Where *where,
storage.query()->clauses_.emplace_back(match);
return GetQuery(storage, clauses...);
}
template <class... T>
auto GetQuery(AstTreeStorage &storage, With *with, Where *where,
T *... clauses) {
with->where_ = where;
storage.query()->clauses_.emplace_back(with);
return GetQuery(storage, clauses...);
}
template <class... T>
auto GetQuery(AstTreeStorage &storage, Clause *clause, T *... clauses) {
storage.query()->clauses_.emplace_back(clause);
return GetQuery(storage, clauses...);
}
///
/// Create the return clause with given named expressions.
@ -123,6 +140,38 @@ auto GetReturn(AstTreeStorage &storage, T *... exprs) {
return GetReturn(ret, exprs...);
}
///
/// Create the with clause with given named expressions.
///
auto GetWith(With *with, NamedExpression *named_expr) {
with->named_expressions_.emplace_back(named_expr);
return with;
}
auto GetWith(With *with, Expression *expr, NamedExpression *named_expr) {
// This overload supports `RETURN(expr, AS(name))` construct, since
// NamedExpression does not inherit Expression.
named_expr->expression_ = expr;
with->named_expressions_.emplace_back(named_expr);
return with;
}
template <class... T>
auto GetWith(With *with, Expression *expr, NamedExpression *named_expr,
T *... rest) {
named_expr->expression_ = expr;
with->named_expressions_.emplace_back(named_expr);
return GetWith(with, rest...);
}
template <class... T>
auto GetWith(With *with, NamedExpression *named_expr, T *... rest) {
with->named_expressions_.emplace_back(named_expr);
return GetWith(with, rest...);
}
template <class... T>
auto GetWith(AstTreeStorage &storage, T *... exprs) {
auto with = storage.Create<With>();
return GetWith(with, exprs...);
}
///
/// Create the delete clause with given named expressions.
///
@ -207,10 +256,11 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
query::test_common::GetPropertyLookup(storage, __VA_ARGS__)
#define NEXPR(name, expr) storage.Create<query::NamedExpression>((name), (expr))
// AS is alternative to NEXPR which does not initialize NamedExpression with
// Expression. It should be used with RETURN. For example:
// Expression. It should be used with RETURN or WITH. For example:
// RETURN(IDENT("n"), AS("n")) vs. RETURN(NEXPR("n", IDENT("n"))).
#define AS(name) storage.Create<query::NamedExpression>((name))
#define RETURN(...) query::test_common::GetReturn(storage, __VA_ARGS__)
#define WITH(...) query::test_common::GetWith(storage, __VA_ARGS__)
#define DELETE(...) query::test_common::GetDelete(storage, {__VA_ARGS__})
#define DETACH_DELETE(...) \
query::test_common::GetDelete(storage, {__VA_ARGS__}, true)

View File

@ -7,8 +7,8 @@
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/logical/operator.hpp"
#include "query/frontend/logical/planner.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/frontend/semantic/symbol_generator.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "query_common.hpp"
@ -47,6 +47,7 @@ class PlanChecker : public LogicalOperatorVisitor {
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override {
AssertType(op);
}
void Visit(Accumulate &op) override { AssertType(op); }
std::list<size_t> types_;
@ -258,4 +259,24 @@ TEST(TestLogicalPlanner, MatchEdgeCycle) {
CheckPlan<ScanAll, Expand, Expand>(*query);
}
TEST(TestLogicalPlanner, MatchWithReturn) {
// Test MATCH (old) WITH old AS new RETURN new AS new
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("new")),
RETURN(IDENT("new"), AS("new")));
CheckPlan<ScanAll, Accumulate, Produce, Produce>(*query);
}
TEST(TestLogicalPlanner, MatchWithWhereReturn) {
// Test MATCH (old) WITH old AS new WHERE new.prop < 42 RETURN new AS new
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("new")),
WHERE(LESS(PROPERTY_LOOKUP("new", prop), LITERAL(42))),
RETURN(IDENT("new"), AS("new")));
CheckPlan<ScanAll, Accumulate, Produce, Filter, Produce>(*query);
}
} // namespace

View File

@ -2,6 +2,7 @@
#include "gtest/gtest.h"
#include "dbms/dbms.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/interpret/interpret.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
@ -266,4 +267,77 @@ TEST(TestSymbolGenerator, CreateDeleteUnbound) {
EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError);
}
TEST(TestSymbolGenerator, MatchWithReturn) {
// Test MATCH (old) WITH old AS n RETURN n AS n
AstTreeStorage storage;
auto node = NODE("old");
auto old_ident = IDENT("old");
auto with_as_n = AS("n");
auto n_ident = IDENT("n");
auto ret_as_n = AS("n");
auto query =
QUERY(MATCH(PATTERN(node)), WITH(old_ident, with_as_n), RETURN(n_ident, ret_as_n));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
EXPECT_EQ(symbol_table.max_position(), 3);
auto node_symbol = symbol_table.at(*node->identifier_);
auto old = symbol_table.at(*old_ident);
EXPECT_EQ(node_symbol, old);
auto with_n = symbol_table.at(*with_as_n);
EXPECT_NE(old, with_n);
auto n = symbol_table.at(*n_ident);
EXPECT_EQ(n, with_n);
auto ret_n = symbol_table.at(*ret_as_n);
EXPECT_NE(n, ret_n);
}
TEST(TestSymbolGenerator, MatchWithReturnUnbound) {
// Test MATCH (old) WITH old AS n RETURN old AS old
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("n")),
RETURN(IDENT("old"), AS("old")));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError);
}
TEST(TestSymbolGenerator, MatchWithWhere) {
// Test MATCH (old) WITH old AS n WHERE n.prop < 42
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto node = NODE("old");
auto old_ident = IDENT("old");
auto with_as_n = AS("n");
auto n_prop = PROPERTY_LOOKUP("n", prop);
auto query = QUERY(MATCH(PATTERN(node)), WITH(old_ident, with_as_n),
WHERE(LESS(n_prop, LITERAL(42))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
EXPECT_EQ(symbol_table.max_position(), 2);
auto node_symbol = symbol_table.at(*node->identifier_);
auto old = symbol_table.at(*old_ident);
EXPECT_EQ(node_symbol, old);
auto with_n = symbol_table.at(*with_as_n);
EXPECT_NE(old, with_n);
auto n = symbol_table.at(*n_prop->expression_);
EXPECT_EQ(n, with_n);
}
TEST(TestSymbolGenerator, MatchWithWhereUnbound) {
// Test MATCH (old) WITH old AS n WHERE old.prop < 42
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("n")),
WHERE(LESS(PROPERTY_LOOKUP("old", prop), LITERAL(42))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError);
}
}