Plan WITH clause without aggregation
Summary: Generate symbols for WITH clause. Reviewers: florijan, mislav.bradac, buda Reviewed By: mislav.bradac Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D231
This commit is contained in:
parent
919258d6f6
commit
aa6cae0b16
@ -1,5 +1,6 @@
|
||||
#include "query/frontend/logical/planner.hpp"
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
@ -176,6 +177,59 @@ auto GenMatch(Match &match, LogicalOperator *input_op,
|
||||
return last_op;
|
||||
}
|
||||
|
||||
// Ast tree visitor which collects all the symbols referenced by identifiers.
|
||||
class SymbolCollector : public TreeVisitorBase {
|
||||
public:
|
||||
SymbolCollector(const SymbolTable &symbol_table)
|
||||
: symbol_table_(symbol_table) {}
|
||||
|
||||
using TreeVisitorBase::Visit;
|
||||
using TreeVisitorBase::PostVisit;
|
||||
|
||||
void Visit(Identifier &ident) override {
|
||||
symbols_.insert(symbol_table_.at(ident));
|
||||
}
|
||||
|
||||
const auto &symbols() const { return symbols_; }
|
||||
|
||||
private:
|
||||
// Calculates the Symbol hash based on its position.
|
||||
struct SymbolHash {
|
||||
size_t operator()(const Symbol &symbol) const {
|
||||
return std::hash<int>{}(symbol.position_);
|
||||
}
|
||||
};
|
||||
|
||||
const SymbolTable &symbol_table_;
|
||||
std::unordered_set<Symbol, SymbolHash> symbols_;
|
||||
};
|
||||
|
||||
auto GenWith(With &with, LogicalOperator *input_op,
|
||||
const query::SymbolTable &symbol_table) {
|
||||
if (with.distinct_) {
|
||||
// TODO: Plan disctint with, when operator available.
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
// WITH clause is Accumulate/Aggregate (advance_command) + Produce.
|
||||
SymbolCollector symbol_collector(symbol_table);
|
||||
// Collect used symbols so that accumulate doesn't copy the whole frame.
|
||||
for (auto &named_expr : with.named_expressions_) {
|
||||
named_expr->expression_->Accept(symbol_collector);
|
||||
}
|
||||
auto symbols = symbol_collector.symbols();
|
||||
// TODO: Check whether we need aggregate instead of accumulate.
|
||||
LogicalOperator *last_op =
|
||||
new Accumulate(std::shared_ptr<LogicalOperator>(input_op),
|
||||
std::vector<Symbol>(symbols.begin(), symbols.end()), true);
|
||||
last_op = new Produce(std::shared_ptr<LogicalOperator>(last_op),
|
||||
with.named_expressions_);
|
||||
if (with.where_) {
|
||||
last_op = new Filter(std::shared_ptr<LogicalOperator>(last_op),
|
||||
with.where_->expression_);
|
||||
}
|
||||
return last_op;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
|
||||
@ -223,6 +277,8 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
|
||||
input_op =
|
||||
new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op),
|
||||
input_symbol, rem->labels_);
|
||||
} else if (auto *with = dynamic_cast<query::With *>(clause_ptr)) {
|
||||
input_op = GenWith(*with, input_op, symbol_table);
|
||||
} else {
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
|
@ -39,6 +39,36 @@ void SymbolGenerator::PostVisit(Return &ret) {
|
||||
}
|
||||
}
|
||||
|
||||
void SymbolGenerator::SetWithSymbols(With &with) {
|
||||
// WITH clause removes declarations of all the previous variables and declares
|
||||
// only those established through named expressions. New declarations must not
|
||||
// be visible inside named expressions themselves.
|
||||
scope_.symbols.clear();
|
||||
for (auto &named_expr : with.named_expressions_) {
|
||||
symbol_table_[*named_expr] = CreateSymbol(named_expr->name_);
|
||||
}
|
||||
}
|
||||
|
||||
void SymbolGenerator::Visit(With &with) {
|
||||
scope_.with = &with;
|
||||
}
|
||||
|
||||
void SymbolGenerator::Visit(Where &where) {
|
||||
if (scope_.with) {
|
||||
// New symbols must be visible in WHERE clause, so this must be done here
|
||||
// and not in PostVisit(With&).
|
||||
SetWithSymbols(*scope_.with);
|
||||
}
|
||||
}
|
||||
|
||||
void SymbolGenerator::PostVisit(With &with) {
|
||||
if (!with.where_) {
|
||||
// This wasn't done when visiting Where, so do it here.
|
||||
SetWithSymbols(with);
|
||||
}
|
||||
scope_.with = nullptr;
|
||||
}
|
||||
|
||||
// Expressions
|
||||
|
||||
void SymbolGenerator::Visit(Identifier &ident) {
|
||||
|
@ -27,6 +27,9 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
void Visit(Create &create) override;
|
||||
void PostVisit(Create &create) override;
|
||||
void PostVisit(Return &ret) override;
|
||||
void Visit(With &with) override;
|
||||
void PostVisit(With &with) override;
|
||||
void Visit(Where &where) override;
|
||||
|
||||
// Expressions
|
||||
void Visit(Identifier &ident) override;
|
||||
@ -53,6 +56,8 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
bool in_node_atom{false};
|
||||
bool in_edge_atom{false};
|
||||
bool in_property_map{false};
|
||||
// Pointer to With clause if we are inside it, otherwise nullptr.
|
||||
With *with{nullptr};
|
||||
std::map<std::string, Symbol> symbols;
|
||||
};
|
||||
|
||||
@ -68,6 +73,9 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
auto GetOrCreateSymbol(const std::string &name,
|
||||
Symbol::Type type = Symbol::Type::Any);
|
||||
|
||||
// Clear old symbol bindings and establish new from WITH clause.
|
||||
void SetWithSymbols(With &with);
|
||||
|
||||
SymbolTable &symbol_table_;
|
||||
Scope scope_;
|
||||
};
|
||||
|
@ -78,10 +78,15 @@ auto GetQuery(AstTreeStorage &storage, Clause *clause) {
|
||||
storage.query()->clauses_.emplace_back(clause);
|
||||
return storage.query();
|
||||
}
|
||||
template <class... T>
|
||||
auto GetQuery(AstTreeStorage &storage, Clause *clause, T *... clauses) {
|
||||
storage.query()->clauses_.emplace_back(clause);
|
||||
return GetQuery(storage, clauses...);
|
||||
auto GetQuery(AstTreeStorage &storage, Match *match, Where *where) {
|
||||
match->where_ = where;
|
||||
storage.query()->clauses_.emplace_back(match);
|
||||
return storage.query();
|
||||
}
|
||||
auto GetQuery(AstTreeStorage &storage, With *with, Where *where) {
|
||||
with->where_ = where;
|
||||
storage.query()->clauses_.emplace_back(with);
|
||||
return storage.query();
|
||||
}
|
||||
template <class... T>
|
||||
auto GetQuery(AstTreeStorage &storage, Match *match, Where *where,
|
||||
@ -90,6 +95,18 @@ auto GetQuery(AstTreeStorage &storage, Match *match, Where *where,
|
||||
storage.query()->clauses_.emplace_back(match);
|
||||
return GetQuery(storage, clauses...);
|
||||
}
|
||||
template <class... T>
|
||||
auto GetQuery(AstTreeStorage &storage, With *with, Where *where,
|
||||
T *... clauses) {
|
||||
with->where_ = where;
|
||||
storage.query()->clauses_.emplace_back(with);
|
||||
return GetQuery(storage, clauses...);
|
||||
}
|
||||
template <class... T>
|
||||
auto GetQuery(AstTreeStorage &storage, Clause *clause, T *... clauses) {
|
||||
storage.query()->clauses_.emplace_back(clause);
|
||||
return GetQuery(storage, clauses...);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create the return clause with given named expressions.
|
||||
@ -123,6 +140,38 @@ auto GetReturn(AstTreeStorage &storage, T *... exprs) {
|
||||
return GetReturn(ret, exprs...);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create the with clause with given named expressions.
|
||||
///
|
||||
auto GetWith(With *with, NamedExpression *named_expr) {
|
||||
with->named_expressions_.emplace_back(named_expr);
|
||||
return with;
|
||||
}
|
||||
auto GetWith(With *with, Expression *expr, NamedExpression *named_expr) {
|
||||
// This overload supports `RETURN(expr, AS(name))` construct, since
|
||||
// NamedExpression does not inherit Expression.
|
||||
named_expr->expression_ = expr;
|
||||
with->named_expressions_.emplace_back(named_expr);
|
||||
return with;
|
||||
}
|
||||
template <class... T>
|
||||
auto GetWith(With *with, Expression *expr, NamedExpression *named_expr,
|
||||
T *... rest) {
|
||||
named_expr->expression_ = expr;
|
||||
with->named_expressions_.emplace_back(named_expr);
|
||||
return GetWith(with, rest...);
|
||||
}
|
||||
template <class... T>
|
||||
auto GetWith(With *with, NamedExpression *named_expr, T *... rest) {
|
||||
with->named_expressions_.emplace_back(named_expr);
|
||||
return GetWith(with, rest...);
|
||||
}
|
||||
template <class... T>
|
||||
auto GetWith(AstTreeStorage &storage, T *... exprs) {
|
||||
auto with = storage.Create<With>();
|
||||
return GetWith(with, exprs...);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create the delete clause with given named expressions.
|
||||
///
|
||||
@ -207,10 +256,11 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
|
||||
query::test_common::GetPropertyLookup(storage, __VA_ARGS__)
|
||||
#define NEXPR(name, expr) storage.Create<query::NamedExpression>((name), (expr))
|
||||
// AS is alternative to NEXPR which does not initialize NamedExpression with
|
||||
// Expression. It should be used with RETURN. For example:
|
||||
// Expression. It should be used with RETURN or WITH. For example:
|
||||
// RETURN(IDENT("n"), AS("n")) vs. RETURN(NEXPR("n", IDENT("n"))).
|
||||
#define AS(name) storage.Create<query::NamedExpression>((name))
|
||||
#define RETURN(...) query::test_common::GetReturn(storage, __VA_ARGS__)
|
||||
#define WITH(...) query::test_common::GetWith(storage, __VA_ARGS__)
|
||||
#define DELETE(...) query::test_common::GetDelete(storage, {__VA_ARGS__})
|
||||
#define DETACH_DELETE(...) \
|
||||
query::test_common::GetDelete(storage, {__VA_ARGS__}, true)
|
||||
|
@ -7,8 +7,8 @@
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/frontend/logical/operator.hpp"
|
||||
#include "query/frontend/logical/planner.hpp"
|
||||
#include "query/frontend/semantic/symbol_table.hpp"
|
||||
#include "query/frontend/semantic/symbol_generator.hpp"
|
||||
#include "query/frontend/semantic/symbol_table.hpp"
|
||||
|
||||
#include "query_common.hpp"
|
||||
|
||||
@ -47,6 +47,7 @@ class PlanChecker : public LogicalOperatorVisitor {
|
||||
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override {
|
||||
AssertType(op);
|
||||
}
|
||||
void Visit(Accumulate &op) override { AssertType(op); }
|
||||
|
||||
std::list<size_t> types_;
|
||||
|
||||
@ -258,4 +259,24 @@ TEST(TestLogicalPlanner, MatchEdgeCycle) {
|
||||
CheckPlan<ScanAll, Expand, Expand>(*query);
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchWithReturn) {
|
||||
// Test MATCH (old) WITH old AS new RETURN new AS new
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("new")),
|
||||
RETURN(IDENT("new"), AS("new")));
|
||||
CheckPlan<ScanAll, Accumulate, Produce, Produce>(*query);
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchWithWhereReturn) {
|
||||
// Test MATCH (old) WITH old AS new WHERE new.prop < 42 RETURN new AS new
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("new")),
|
||||
WHERE(LESS(PROPERTY_LOOKUP("new", prop), LITERAL(42))),
|
||||
RETURN(IDENT("new"), AS("new")));
|
||||
CheckPlan<ScanAll, Accumulate, Produce, Filter, Produce>(*query);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "dbms/dbms.hpp"
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/frontend/interpret/interpret.hpp"
|
||||
#include "query/frontend/semantic/symbol_table.hpp"
|
||||
@ -266,4 +267,77 @@ TEST(TestSymbolGenerator, CreateDeleteUnbound) {
|
||||
EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError);
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, MatchWithReturn) {
|
||||
// Test MATCH (old) WITH old AS n RETURN n AS n
|
||||
AstTreeStorage storage;
|
||||
auto node = NODE("old");
|
||||
auto old_ident = IDENT("old");
|
||||
auto with_as_n = AS("n");
|
||||
auto n_ident = IDENT("n");
|
||||
auto ret_as_n = AS("n");
|
||||
auto query =
|
||||
QUERY(MATCH(PATTERN(node)), WITH(old_ident, with_as_n), RETURN(n_ident, ret_as_n));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query->Accept(symbol_generator);
|
||||
EXPECT_EQ(symbol_table.max_position(), 3);
|
||||
auto node_symbol = symbol_table.at(*node->identifier_);
|
||||
auto old = symbol_table.at(*old_ident);
|
||||
EXPECT_EQ(node_symbol, old);
|
||||
auto with_n = symbol_table.at(*with_as_n);
|
||||
EXPECT_NE(old, with_n);
|
||||
auto n = symbol_table.at(*n_ident);
|
||||
EXPECT_EQ(n, with_n);
|
||||
auto ret_n = symbol_table.at(*ret_as_n);
|
||||
EXPECT_NE(n, ret_n);
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, MatchWithReturnUnbound) {
|
||||
// Test MATCH (old) WITH old AS n RETURN old AS old
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("n")),
|
||||
RETURN(IDENT("old"), AS("old")));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError);
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, MatchWithWhere) {
|
||||
// Test MATCH (old) WITH old AS n WHERE n.prop < 42
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto node = NODE("old");
|
||||
auto old_ident = IDENT("old");
|
||||
auto with_as_n = AS("n");
|
||||
auto n_prop = PROPERTY_LOOKUP("n", prop);
|
||||
auto query = QUERY(MATCH(PATTERN(node)), WITH(old_ident, with_as_n),
|
||||
WHERE(LESS(n_prop, LITERAL(42))));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query->Accept(symbol_generator);
|
||||
EXPECT_EQ(symbol_table.max_position(), 2);
|
||||
auto node_symbol = symbol_table.at(*node->identifier_);
|
||||
auto old = symbol_table.at(*old_ident);
|
||||
EXPECT_EQ(node_symbol, old);
|
||||
auto with_n = symbol_table.at(*with_as_n);
|
||||
EXPECT_NE(old, with_n);
|
||||
auto n = symbol_table.at(*n_prop->expression_);
|
||||
EXPECT_EQ(n, with_n);
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, MatchWithWhereUnbound) {
|
||||
// Test MATCH (old) WITH old AS n WHERE old.prop < 42
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("n")),
|
||||
WHERE(LESS(PROPERTY_LOOKUP("old", prop), LITERAL(42))));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError);
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user