Plan '*' in RETURN and WITH
Summary: Make Symbol members read only. Check WITH/RETURN * in SymbolGenerator. Test semantic checks for WITH/RETURN *. Sort expanded user identifiers by name. Test planning WITH/RETURN *. Reviewers: buda, florijan, mislav.bradac Reviewed By: florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D357
This commit is contained in:
parent
f82bda6c0c
commit
87e5dc0dfb
@ -8,32 +8,46 @@
|
||||
|
||||
namespace query {
|
||||
|
||||
auto SymbolGenerator::CreateSymbol(const std::string &name, Symbol::Type type) {
|
||||
auto symbol = symbol_table_.CreateSymbol(name, type);
|
||||
auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared,
|
||||
Symbol::Type type) {
|
||||
auto symbol = symbol_table_.CreateSymbol(name, user_declared, type);
|
||||
scope_.symbols[name] = symbol;
|
||||
return symbol;
|
||||
}
|
||||
|
||||
auto SymbolGenerator::GetOrCreateSymbol(const std::string &name,
|
||||
Symbol::Type type) {
|
||||
bool user_declared, Symbol::Type type) {
|
||||
auto search = scope_.symbols.find(name);
|
||||
if (search != scope_.symbols.end()) {
|
||||
auto symbol = search->second;
|
||||
// Unless we have `Any` type, check that types match.
|
||||
if (type != Symbol::Type::Any && symbol.type_ != Symbol::Type::Any &&
|
||||
type != symbol.type_) {
|
||||
throw TypeMismatchError(name, Symbol::TypeToString(symbol.type_),
|
||||
if (type != Symbol::Type::Any && symbol.type() != Symbol::Type::Any &&
|
||||
type != symbol.type()) {
|
||||
throw TypeMismatchError(name, Symbol::TypeToString(symbol.type()),
|
||||
Symbol::TypeToString(type));
|
||||
}
|
||||
return search->second;
|
||||
}
|
||||
return CreateSymbol(name, type);
|
||||
return CreateSymbol(name, user_declared, type);
|
||||
}
|
||||
|
||||
void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
|
||||
for (auto &expr : body.named_expressions) {
|
||||
expr->Accept(*this);
|
||||
}
|
||||
std::vector<Symbol> user_symbols;
|
||||
if (body.all_identifiers) {
|
||||
// Carry over user symbols because '*' appeared.
|
||||
for (auto sym_pair : scope_.symbols) {
|
||||
if (!sym_pair.second.user_declared()) {
|
||||
continue;
|
||||
}
|
||||
user_symbols.emplace_back(sym_pair.second);
|
||||
}
|
||||
if (user_symbols.empty()) {
|
||||
throw SemanticException("There are no variables in scope to use for '*'");
|
||||
}
|
||||
}
|
||||
// WITH/RETURN clause removes declarations of all the previous variables and
|
||||
// declares only those established through named expressions. New declarations
|
||||
// must not be visible inside named expressions themselves.
|
||||
@ -47,6 +61,10 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
|
||||
}
|
||||
// Create symbols for named expressions.
|
||||
std::unordered_set<std::string> new_names;
|
||||
for (const auto &user_sym : user_symbols) {
|
||||
new_names.insert(user_sym.name());
|
||||
scope_.symbols[user_sym.name()] = user_sym;
|
||||
}
|
||||
for (auto &named_expr : body.named_expressions) {
|
||||
const auto &name = named_expr->name_;
|
||||
if (!new_names.insert(name).second) {
|
||||
@ -55,7 +73,7 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
|
||||
}
|
||||
// An improvement would be to infer the type of the expression, so that the
|
||||
// new symbol would have a more specific type.
|
||||
symbol_table_[*named_expr] = CreateSymbol(name);
|
||||
symbol_table_[*named_expr] = CreateSymbol(name, true);
|
||||
}
|
||||
scope_.in_order_by = true;
|
||||
for (const auto &order_pair : body.order_by) {
|
||||
@ -119,7 +137,7 @@ void SymbolGenerator::PostVisit(Unwind &unwind) {
|
||||
if (HasSymbol(name)) {
|
||||
throw RedeclareVariableError(name);
|
||||
}
|
||||
symbol_table_[*unwind.named_expression_] = CreateSymbol(name);
|
||||
symbol_table_[*unwind.named_expression_] = CreateSymbol(name, true);
|
||||
}
|
||||
|
||||
void SymbolGenerator::Visit(Match &) { scope_.in_match = true; }
|
||||
@ -163,7 +181,7 @@ void SymbolGenerator::Visit(Identifier &ident) {
|
||||
if (scope_.in_edge_atom) {
|
||||
type = Symbol::Type::Edge;
|
||||
}
|
||||
symbol = GetOrCreateSymbol(ident.name_, type);
|
||||
symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, type);
|
||||
} else if (scope_.in_pattern && scope_.in_property_map && scope_.in_match) {
|
||||
// Variables in property maps during MATCH can reference symbols bound later
|
||||
// in the same MATCH. We collect them here, so that they can be checked
|
||||
@ -193,7 +211,8 @@ void SymbolGenerator::Visit(Aggregation &aggr) {
|
||||
}
|
||||
// Create a virtual symbol for aggregation result.
|
||||
// Currently, we only have aggregation operators which return numbers.
|
||||
symbol_table_[aggr] = symbol_table_.CreateSymbol("", Symbol::Type::Number);
|
||||
symbol_table_[aggr] =
|
||||
symbol_table_.CreateSymbol("", false, Symbol::Type::Number);
|
||||
scope_.in_aggregation = true;
|
||||
scope_.has_aggregation = true;
|
||||
}
|
||||
|
@ -88,12 +88,12 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
|
||||
// Returns a freshly generated symbol. Previous mapping of the same name to a
|
||||
// different symbol is replaced with the new one.
|
||||
auto CreateSymbol(const std::string &name,
|
||||
auto CreateSymbol(const std::string &name, bool user_declared,
|
||||
Symbol::Type type = Symbol::Type::Any);
|
||||
|
||||
// Returns the symbol by name. If the mapping already exists, checks if the
|
||||
// types match. Otherwise, returns a new symbol.
|
||||
auto GetOrCreateSymbol(const std::string &name,
|
||||
auto GetOrCreateSymbol(const std::string &name, bool user_declared,
|
||||
Symbol::Type type = Symbol::Type::Any);
|
||||
|
||||
void VisitReturnBody(ReturnBody &body, Where *where = nullptr);
|
||||
|
@ -17,34 +17,38 @@ class Symbol {
|
||||
return enum_string[static_cast<int>(type)];
|
||||
}
|
||||
|
||||
// Calculates the Symbol hash based on its position.
|
||||
struct Hash {
|
||||
size_t operator()(const Symbol &symbol) const {
|
||||
return std::hash<int>{}(symbol.position_);
|
||||
}
|
||||
};
|
||||
|
||||
Symbol() {}
|
||||
Symbol(const std::string &name, int position, Type type = Type::Any)
|
||||
: name_(name), position_(position), type_(type) {}
|
||||
|
||||
std::string name_;
|
||||
int position_;
|
||||
Type type_{Type::Any};
|
||||
Symbol(const std::string &name, int position, bool user_declared,
|
||||
Type type = Type::Any)
|
||||
: name_(name),
|
||||
position_(position),
|
||||
user_declared_(user_declared),
|
||||
type_(type) {}
|
||||
|
||||
bool operator==(const Symbol &other) const {
|
||||
return position_ == other.position_ && name_ == other.name_ &&
|
||||
type_ == other.type_;
|
||||
}
|
||||
bool operator!=(const Symbol &other) const { return !operator==(other); }
|
||||
|
||||
const auto &name() const { return name_; }
|
||||
int position() const { return position_; }
|
||||
Type type() const { return type_; }
|
||||
bool user_declared() const { return user_declared_; }
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
int position_;
|
||||
bool user_declared_ = true;
|
||||
Type type_ = Type::Any;
|
||||
};
|
||||
|
||||
class SymbolTable {
|
||||
public:
|
||||
Symbol CreateSymbol(const std::string &name,
|
||||
Symbol CreateSymbol(const std::string &name, bool user_declared,
|
||||
Symbol::Type type = Symbol::Type::Any) {
|
||||
int position = position_++;
|
||||
return Symbol(name, position, type);
|
||||
return Symbol(name, position, user_declared, type);
|
||||
}
|
||||
|
||||
auto &operator[](const Tree &tree) { return table_[tree.uid()]; }
|
||||
@ -59,4 +63,21 @@ class SymbolTable {
|
||||
std::map<int, Symbol> table_;
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace query
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
struct hash<query::Symbol> {
|
||||
size_t operator()(const query::Symbol &symbol) const {
|
||||
size_t prime = 265443599u;
|
||||
size_t hash = std::hash<int>{}(symbol.position());
|
||||
hash ^= prime * std::hash<std::string>{}(symbol.name());
|
||||
hash ^= prime * std::hash<bool>{}(symbol.user_declared());
|
||||
hash ^= prime * std::hash<int>{}(static_cast<int>(symbol.type()));
|
||||
return hash;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
|
@ -12,10 +12,10 @@ class Frame {
|
||||
Frame(int size) : size_(size), elems_(size_) {}
|
||||
|
||||
TypedValue &operator[](const Symbol &symbol) {
|
||||
return elems_[symbol.position_];
|
||||
return elems_[symbol.position()];
|
||||
}
|
||||
const TypedValue &operator[](const Symbol &symbol) const {
|
||||
return elems_[symbol.position_];
|
||||
return elems_[symbol.position()];
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -48,7 +48,7 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor,
|
||||
// clause, so stream out the results.
|
||||
|
||||
// generate header
|
||||
for (const auto &symbol : output_symbols) header.push_back(symbol.name_);
|
||||
for (const auto &symbol : output_symbols) header.push_back(symbol.name());
|
||||
stream.Header(header);
|
||||
|
||||
// stream out results
|
||||
|
@ -13,8 +13,9 @@ namespace {
|
||||
|
||||
// Returns false if the symbol was already bound, otherwise binds it and
|
||||
// returns true.
|
||||
bool BindSymbol(std::unordered_set<int> &bound_symbols, const Symbol &symbol) {
|
||||
auto insertion = bound_symbols.insert(symbol.position_);
|
||||
bool BindSymbol(std::unordered_set<Symbol> &bound_symbols,
|
||||
const Symbol &symbol) {
|
||||
auto insertion = bound_symbols.insert(symbol);
|
||||
return insertion.second;
|
||||
}
|
||||
|
||||
@ -61,7 +62,7 @@ auto ReducePattern(
|
||||
|
||||
auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols) {
|
||||
std::unordered_set<Symbol> &bound_symbols) {
|
||||
auto base = [&](NodeAtom *node) -> LogicalOperator * {
|
||||
if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_)))
|
||||
return new CreateNode(node, std::shared_ptr<LogicalOperator>(input_op));
|
||||
@ -92,7 +93,7 @@ auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op,
|
||||
|
||||
auto GenCreate(Create &create, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols) {
|
||||
std::unordered_set<Symbol> &bound_symbols) {
|
||||
auto last_op = input_op;
|
||||
for (auto pattern : create.patterns_) {
|
||||
last_op =
|
||||
@ -109,17 +110,16 @@ class UsedSymbolsCollector : public TreeVisitorBase {
|
||||
|
||||
using TreeVisitorBase::Visit;
|
||||
void Visit(Identifier &ident) override {
|
||||
const auto &symbol = symbol_table_.at(ident);
|
||||
symbols_.insert(symbol.position_);
|
||||
symbols_.insert(symbol_table_.at(ident));
|
||||
}
|
||||
|
||||
std::unordered_set<int> symbols_;
|
||||
std::unordered_set<Symbol> symbols_;
|
||||
const SymbolTable &symbol_table_;
|
||||
};
|
||||
|
||||
bool HasBoundFilterSymbols(
|
||||
const std::unordered_set<int> &bound_symbols,
|
||||
const std::pair<Expression *, std::unordered_set<int>> &filter) {
|
||||
const std::unordered_set<Symbol> &bound_symbols,
|
||||
const std::pair<Expression *, std::unordered_set<Symbol>> &filter) {
|
||||
for (const auto &symbol : filter.second) {
|
||||
if (bound_symbols.find(symbol) == bound_symbols.end()) {
|
||||
return false;
|
||||
@ -154,7 +154,7 @@ Expression *PropertiesEqual(AstTreeStorage &storage,
|
||||
|
||||
auto &CollectPatternFilters(
|
||||
Pattern &pattern, const SymbolTable &symbol_table,
|
||||
std::list<std::pair<Expression *, std::unordered_set<int>>> &filters,
|
||||
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
|
||||
AstTreeStorage &storage) {
|
||||
UsedSymbolsCollector collector(symbol_table);
|
||||
auto node_filter = [&](NodeAtom *node) {
|
||||
@ -164,7 +164,7 @@ auto &CollectPatternFilters(
|
||||
node->identifier_, node->labels_);
|
||||
auto *props_filter = PropertiesEqual(storage, collector, node);
|
||||
if (labels_filter || props_filter) {
|
||||
collector.symbols_.insert(symbol_table.at(*node->identifier_).position_);
|
||||
collector.symbols_.insert(symbol_table.at(*node->identifier_));
|
||||
filters.emplace_back(
|
||||
BoolJoin<FilterAndOperator>(storage, labels_filter, props_filter),
|
||||
collector.symbols_);
|
||||
@ -181,7 +181,7 @@ auto &CollectPatternFilters(
|
||||
auto *props_filter = PropertiesEqual(storage, collector, edge);
|
||||
if (types_filter || props_filter) {
|
||||
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
|
||||
collector.symbols_.insert(edge_symbol.position_);
|
||||
collector.symbols_.insert(edge_symbol);
|
||||
filters->emplace_back(
|
||||
BoolJoin<FilterAndOperator>(storage, types_filter, props_filter),
|
||||
collector.symbols_);
|
||||
@ -190,13 +190,13 @@ auto &CollectPatternFilters(
|
||||
return node_filter(node);
|
||||
};
|
||||
return *ReducePattern<
|
||||
std::list<std::pair<Expression *, std::unordered_set<int>>> *>(
|
||||
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> *>(
|
||||
pattern, node_filter, expand_filter);
|
||||
}
|
||||
|
||||
void CollectMatchFilters(
|
||||
const Match &match, const SymbolTable &symbol_table,
|
||||
std::list<std::pair<Expression *, std::unordered_set<int>>> &filters,
|
||||
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
|
||||
AstTreeStorage &storage) {
|
||||
for (auto *pattern : match.patterns_) {
|
||||
CollectPatternFilters(*pattern, symbol_table, filters, storage);
|
||||
@ -214,22 +214,22 @@ struct MatchContext {
|
||||
// Already bound symbols, which are used to determine whether the operator
|
||||
// should reference them or establish new. This is both read from and written
|
||||
// to during generation.
|
||||
std::unordered_set<int> &bound_symbols;
|
||||
std::unordered_set<Symbol> &bound_symbols;
|
||||
// Determines whether the match should see the new graph state or not.
|
||||
GraphView graph_view = GraphView::OLD;
|
||||
// Pairs of filter expression and symbols used in them. The list should be
|
||||
// filled using CollectPatternFilters function, and later modified during
|
||||
// GenMatchForPattern.
|
||||
std::list<std::pair<Expression *, std::unordered_set<int>>> filters;
|
||||
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> filters;
|
||||
// Symbols for edges established in match, used to ensure Cyphermorphism.
|
||||
std::unordered_set<Symbol, Symbol::Hash> edge_symbols;
|
||||
std::unordered_set<Symbol> edge_symbols;
|
||||
// All the newly established symbols in match.
|
||||
std::vector<Symbol> new_symbols;
|
||||
};
|
||||
|
||||
auto GenFilters(
|
||||
LogicalOperator *last_op, const std::unordered_set<int> &bound_symbols,
|
||||
std::list<std::pair<Expression *, std::unordered_set<int>>> &filters,
|
||||
LogicalOperator *last_op, const std::unordered_set<Symbol> &bound_symbols,
|
||||
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
|
||||
AstTreeStorage &storage) {
|
||||
Expression *filter_expr = nullptr;
|
||||
for (auto filters_it = filters.begin(); filters_it != filters.end();) {
|
||||
@ -312,7 +312,7 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
|
||||
|
||||
auto GenMatches(std::vector<Match *> &matches, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols,
|
||||
std::unordered_set<Symbol> &bound_symbols,
|
||||
AstTreeStorage &storage) {
|
||||
auto *last_op = input_op;
|
||||
MatchContext req_ctx{symbol_table, bound_symbols};
|
||||
@ -369,15 +369,27 @@ auto GenMatches(std::vector<Match *> &matches, LogicalOperator *input_op,
|
||||
// aggregations and expressions used for group by.
|
||||
class ReturnBodyContext : public TreeVisitorBase {
|
||||
public:
|
||||
ReturnBodyContext(const ReturnBody &body, const SymbolTable &symbol_table,
|
||||
Where *where = nullptr)
|
||||
: body_(body), symbol_table_(symbol_table), where_(where) {
|
||||
ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table,
|
||||
const std::unordered_set<Symbol> &bound_symbols,
|
||||
AstTreeStorage &storage, Where *where = nullptr)
|
||||
: body_(body),
|
||||
symbol_table_(symbol_table),
|
||||
bound_symbols_(bound_symbols),
|
||||
storage_(storage),
|
||||
where_(where) {
|
||||
// Collect symbols from named expressions.
|
||||
output_symbols_.reserve(body_.named_expressions.size());
|
||||
if (body.all_identifiers) {
|
||||
// Expand '*' to expressions and symbols first, so that their results come
|
||||
// before regular named expressions.
|
||||
ExpandUserSymbols();
|
||||
}
|
||||
for (auto &named_expr : body_.named_expressions) {
|
||||
output_symbols_.emplace_back(symbol_table_.at(*named_expr));
|
||||
named_expr->Accept(*this);
|
||||
named_expressions_.emplace_back(named_expr);
|
||||
}
|
||||
// Collect aggregations.
|
||||
if (aggregations_.empty()) {
|
||||
// Visit order_by and where if we do not have aggregations. This way we
|
||||
// prevent collecting group_by expressions from order_by and where, which
|
||||
@ -456,7 +468,7 @@ class ReturnBodyContext : public TreeVisitorBase {
|
||||
// Aggregation contains a virtual symbol, where the result will be stored.
|
||||
const auto &symbol = symbol_table_.at(aggr);
|
||||
aggregations_.emplace_back(aggr.expression_, aggr.op_, symbol);
|
||||
// aggregation expression_ is opional in COUNT(*) so it's possible the
|
||||
// aggregation expression_ is optional in COUNT(*), so it's possible the
|
||||
// has_aggregation_ stack is empty
|
||||
if (aggr.expression_)
|
||||
has_aggregation_.back() = true;
|
||||
@ -474,10 +486,41 @@ class ReturnBodyContext : public TreeVisitorBase {
|
||||
has_aggregation_.pop_back();
|
||||
}
|
||||
|
||||
// Creates NamedExpression with an Identifier for each user declared symbol.
|
||||
// This should be used when body.all_identifiers is true, to generate
|
||||
// expressions for Produce operator.
|
||||
void ExpandUserSymbols() {
|
||||
debug_assert(
|
||||
named_expressions_.empty(),
|
||||
"ExpandUserSymbols should be first to fill named_expressions_");
|
||||
debug_assert(output_symbols_.empty(),
|
||||
"ExpandUserSymbols should be first to fill output_symbols_");
|
||||
for (const auto &symbol : bound_symbols_) {
|
||||
if (!symbol.user_declared()) {
|
||||
continue;
|
||||
}
|
||||
auto *ident = storage_.Create<Identifier>(symbol.name());
|
||||
symbol_table_[*ident] = symbol;
|
||||
auto *named_expr = storage_.Create<NamedExpression>(symbol.name(), ident);
|
||||
symbol_table_[*named_expr] = symbol;
|
||||
// Fill output expressions and symbols with expanded identifiers.
|
||||
named_expressions_.emplace_back(named_expr);
|
||||
output_symbols_.emplace_back(symbol);
|
||||
used_symbols_.insert(symbol);
|
||||
// Don't forget to group by expanded identifiers.
|
||||
group_by_.emplace_back(ident);
|
||||
}
|
||||
// Cypher RETURN/WITH * expects to expand '*' sorted by name.
|
||||
std::sort(output_symbols_.begin(), output_symbols_.end(),
|
||||
[](const auto &a, const auto &b) { return a.name() < b.name(); });
|
||||
std::sort(named_expressions_.begin(), named_expressions_.end(),
|
||||
[](const auto &a, const auto &b) { return a->name_ < b->name_; });
|
||||
}
|
||||
|
||||
// If true, results need to be distinct.
|
||||
bool distinct() const { return body_.distinct; }
|
||||
// Named expressions which are used to produce results.
|
||||
const auto &named_expressions() const { return body_.named_expressions; }
|
||||
const auto &named_expressions() const { return named_expressions_; }
|
||||
// Pairs of (Ordering, Expression *) for sorting results.
|
||||
const auto &order_by() const { return body_.order_by; }
|
||||
// Optional expression which determines how many results to skip.
|
||||
@ -496,21 +539,23 @@ class ReturnBodyContext : public TreeVisitorBase {
|
||||
// expressions are used for grouping. For example, in `WITH sum(n.a) + 2 * n.b
|
||||
// AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`.
|
||||
const auto &group_by() const { return group_by_; }
|
||||
|
||||
// All symbols generated by named expressions. They are collected in order of
|
||||
// named_expressions.
|
||||
const auto &output_symbols() const { return output_symbols_; }
|
||||
|
||||
private:
|
||||
const ReturnBody &body_;
|
||||
const SymbolTable &symbol_table_;
|
||||
SymbolTable &symbol_table_;
|
||||
const std::unordered_set<Symbol> &bound_symbols_;
|
||||
AstTreeStorage &storage_;
|
||||
const Where *const where_ = nullptr;
|
||||
std::unordered_set<Symbol, Symbol::Hash> used_symbols_;
|
||||
std::unordered_set<Symbol> used_symbols_;
|
||||
std::vector<Symbol> output_symbols_;
|
||||
std::vector<Aggregate::Element> aggregations_;
|
||||
std::vector<Expression *> group_by_;
|
||||
// Flag indicating whether an expression contains an aggregation.
|
||||
std::list<bool> has_aggregation_;
|
||||
std::vector<NamedExpression *> named_expressions_;
|
||||
};
|
||||
|
||||
auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
|
||||
@ -561,9 +606,9 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
|
||||
return last_op;
|
||||
}
|
||||
|
||||
auto GenWith(With &with, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table, bool is_write,
|
||||
std::unordered_set<int> &bound_symbols) {
|
||||
auto GenWith(With &with, LogicalOperator *input_op, SymbolTable &symbol_table,
|
||||
bool is_write, std::unordered_set<Symbol> &bound_symbols,
|
||||
AstTreeStorage &storage) {
|
||||
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
|
||||
// optional Filter. In case of update and aggregation, we want to accumulate
|
||||
// first, so that when aggregating, we get the latest results. Similar to
|
||||
@ -571,7 +616,8 @@ auto GenWith(With &with, LogicalOperator *input_op,
|
||||
bool accumulate = is_write;
|
||||
// No need to advance the command if we only performed reads.
|
||||
bool advance_command = is_write;
|
||||
ReturnBodyContext body(with.body_, symbol_table, with.where_);
|
||||
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage,
|
||||
with.where_);
|
||||
LogicalOperator *last_op =
|
||||
GenReturnBody(input_op, advance_command, body, accumulate);
|
||||
// Reset bound symbols, so that only those in WITH are exposed.
|
||||
@ -583,7 +629,9 @@ auto GenWith(With &with, LogicalOperator *input_op,
|
||||
}
|
||||
|
||||
auto GenReturn(Return &ret, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table, bool is_write) {
|
||||
SymbolTable &symbol_table, bool is_write,
|
||||
const std::unordered_set<Symbol> &bound_symbols,
|
||||
AstTreeStorage &storage) {
|
||||
// Similar to WITH clause, but we want to accumulate and advance command when
|
||||
// the query writes to the database. This way we handle the case when we want
|
||||
// to return expressions with the latest updated results. For example,
|
||||
@ -592,7 +640,7 @@ auto GenReturn(Return &ret, LogicalOperator *input_op,
|
||||
// value is the same, final result of 'k' increments.
|
||||
bool accumulate = is_write;
|
||||
bool advance_command = false;
|
||||
ReturnBodyContext body(ret.body_, symbol_table);
|
||||
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage);
|
||||
return GenReturnBody(input_op, advance_command, body, accumulate);
|
||||
}
|
||||
|
||||
@ -600,7 +648,7 @@ auto GenReturn(Return &ret, LogicalOperator *input_op,
|
||||
// isn't handled, returns nullptr.
|
||||
LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols) {
|
||||
std::unordered_set<Symbol> &bound_symbols) {
|
||||
if (auto *create = dynamic_cast<Create *>(clause)) {
|
||||
return GenCreate(*create, input_op, symbol_table, bound_symbols);
|
||||
} else if (auto *del = dynamic_cast<query::Delete *>(clause)) {
|
||||
@ -632,10 +680,11 @@ LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
|
||||
|
||||
auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols, AstTreeStorage &storage) {
|
||||
std::unordered_set<Symbol> &bound_symbols,
|
||||
AstTreeStorage &storage) {
|
||||
// Copy the bound symbol set, because we don't want to use the updated version
|
||||
// when generating the create part.
|
||||
std::unordered_set<int> bound_symbols_copy(bound_symbols);
|
||||
std::unordered_set<Symbol> bound_symbols_copy(bound_symbols);
|
||||
MatchContext context{symbol_table, bound_symbols_copy, GraphView::NEW};
|
||||
CollectPatternFilters(*merge.pattern_, symbol_table, context.filters,
|
||||
storage);
|
||||
@ -659,14 +708,14 @@ auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
|
||||
AstTreeStorage &storage, const SymbolTable &symbol_table) {
|
||||
std::unique_ptr<LogicalOperator> MakeLogicalPlan(AstTreeStorage &storage,
|
||||
SymbolTable &symbol_table) {
|
||||
auto query = storage.query();
|
||||
// bound_symbols set is used to differentiate cycles in pattern matching, so
|
||||
// that the operator can be correctly initialized whether to read the symbol
|
||||
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
|
||||
// `n`, but the latter `n` would only read the already written information.
|
||||
std::unordered_set<int> bound_symbols;
|
||||
std::unordered_set<Symbol> bound_symbols;
|
||||
// Set to true if a query command writes to the database.
|
||||
bool is_write = false;
|
||||
LogicalOperator *input_op = nullptr;
|
||||
@ -681,7 +730,8 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
|
||||
GenMatches(matches, input_op, symbol_table, bound_symbols, storage);
|
||||
matches.clear();
|
||||
if (auto *ret = dynamic_cast<Return *>(clause)) {
|
||||
input_op = GenReturn(*ret, input_op, symbol_table, is_write);
|
||||
input_op = GenReturn(*ret, input_op, symbol_table, is_write,
|
||||
bound_symbols, storage);
|
||||
} else if (auto *merge = dynamic_cast<query::Merge *>(clause)) {
|
||||
input_op =
|
||||
GenMerge(*merge, input_op, symbol_table, bound_symbols, storage);
|
||||
@ -689,8 +739,8 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
|
||||
// anything.
|
||||
is_write = true;
|
||||
} else if (auto *with = dynamic_cast<query::With *>(clause)) {
|
||||
input_op =
|
||||
GenWith(*with, input_op, symbol_table, is_write, bound_symbols);
|
||||
input_op = GenWith(*with, input_op, symbol_table, is_write,
|
||||
bound_symbols, storage);
|
||||
// WITH clause advances the command, so reset the flag.
|
||||
is_write = false;
|
||||
} else if (auto *op = HandleWriteClause(clause, input_op, symbol_table,
|
||||
|
@ -18,7 +18,7 @@ namespace plan {
|
||||
/// use in operators. @c SymbolTable is used to determine inputs and outputs of
|
||||
/// certain operators.
|
||||
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
|
||||
AstTreeStorage &storage, const query::SymbolTable &symbol_table);
|
||||
AstTreeStorage &storage, query::SymbolTable &symbol_table);
|
||||
}
|
||||
|
||||
} // namespace plan
|
||||
|
@ -505,7 +505,7 @@ TEST(ExpressionEvaluator, PropertyLookup) {
|
||||
auto v1 = dba->insert_vertex();
|
||||
v1.PropsSet(dba->property("age"), 10);
|
||||
auto *identifier = storage.Create<Identifier>("n");
|
||||
auto node_symbol = eval.symbol_table.CreateSymbol("n");
|
||||
auto node_symbol = eval.symbol_table.CreateSymbol("n", true);
|
||||
eval.symbol_table[*identifier] = node_symbol;
|
||||
eval.frame[node_symbol] = v1;
|
||||
{
|
||||
@ -537,7 +537,7 @@ TEST(ExpressionEvaluator, LabelsTest) {
|
||||
v1.add_label(dba->label("DOG"));
|
||||
v1.add_label(dba->label("NICE_DOG"));
|
||||
auto *identifier = storage.Create<Identifier>("n");
|
||||
auto node_symbol = eval.symbol_table.CreateSymbol("n");
|
||||
auto node_symbol = eval.symbol_table.CreateSymbol("n", true);
|
||||
eval.symbol_table[*identifier] = node_symbol;
|
||||
eval.frame[node_symbol] = v1;
|
||||
{
|
||||
@ -575,7 +575,7 @@ TEST(ExpressionEvaluator, EdgeTypeTest) {
|
||||
auto v2 = dba->insert_vertex();
|
||||
auto e = dba->insert_edge(v1, v2, dba->edge_type("TYPE1"));
|
||||
auto *identifier = storage.Create<Identifier>("e");
|
||||
auto edge_symbol = eval.symbol_table.CreateSymbol("e");
|
||||
auto edge_symbol = eval.symbol_table.CreateSymbol("e", true);
|
||||
eval.symbol_table[*identifier] = edge_symbol;
|
||||
eval.frame[edge_symbol] = e;
|
||||
{
|
||||
@ -608,7 +608,7 @@ TEST(ExpressionEvaluator, Aggregation) {
|
||||
auto aggr = storage.Create<Aggregation>(storage.Create<PrimitiveLiteral>(42),
|
||||
Aggregation::Op::COUNT);
|
||||
SymbolTable symbol_table;
|
||||
auto aggr_sym = symbol_table.CreateSymbol("aggr");
|
||||
auto aggr_sym = symbol_table.CreateSymbol("aggr", true);
|
||||
symbol_table[*aggr] = aggr_sym;
|
||||
Frame frame{symbol_table.max_position()};
|
||||
frame[aggr_sym] = TypedValue(1);
|
||||
|
@ -65,9 +65,9 @@ TEST(QueryPlan, Accumulate) {
|
||||
}
|
||||
|
||||
auto n_p_ne = NEXPR("n.p", n_p);
|
||||
symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne");
|
||||
symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne", true);
|
||||
auto m_p_ne = NEXPR("m.p", m_p);
|
||||
symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne");
|
||||
symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne", true);
|
||||
auto produce = MakeProduce(last_op, n_p_ne, m_p_ne);
|
||||
ResultStreamFaker results = CollectProduce(produce, symbol_table, *dba);
|
||||
std::vector<int> results_data;
|
||||
@ -95,7 +95,7 @@ TEST(QueryPlan, AccumulateAdvance) {
|
||||
SymbolTable symbol_table;
|
||||
|
||||
auto node = NODE("n");
|
||||
auto sym_n = symbol_table.CreateSymbol("n");
|
||||
auto sym_n = symbol_table.CreateSymbol("n", true);
|
||||
symbol_table[*node->identifier_] = sym_n;
|
||||
auto create = std::make_shared<CreateNode>(node, nullptr);
|
||||
auto accumulate = std::make_shared<Accumulate>(
|
||||
@ -126,8 +126,9 @@ std::shared_ptr<Produce> MakeAggregationProduce(
|
||||
auto named_expr = NEXPR("", IDENT("aggregation"));
|
||||
named_expressions.push_back(named_expr);
|
||||
symbol_table[*named_expr->expression_] =
|
||||
symbol_table.CreateSymbol("aggregation");
|
||||
symbol_table[*named_expr] = symbol_table.CreateSymbol("named_expression");
|
||||
symbol_table.CreateSymbol("aggregation", true);
|
||||
symbol_table[*named_expr] =
|
||||
symbol_table.CreateSymbol("named_expression", true);
|
||||
aggregates.emplace_back(*aggr_inputs_it++, aggr_op,
|
||||
symbol_table[*named_expr->expression_]);
|
||||
}
|
||||
@ -137,7 +138,8 @@ std::shared_ptr<Produce> MakeAggregationProduce(
|
||||
for (auto group_by_expr : group_by_exprs) {
|
||||
auto named_expr = NEXPR("", group_by_expr);
|
||||
named_expressions.push_back(named_expr);
|
||||
symbol_table[*named_expr] = symbol_table.CreateSymbol("named_expression");
|
||||
symbol_table[*named_expr] =
|
||||
symbol_table.CreateSymbol("named_expression", true);
|
||||
}
|
||||
auto aggregation =
|
||||
std::make_shared<Aggregate>(input, aggregates, group_by_exprs, remember);
|
||||
@ -307,7 +309,7 @@ TEST(QueryPlan, AggregateNoInput) {
|
||||
|
||||
auto two = LITERAL(2);
|
||||
auto output = NEXPR("two", IDENT("two"));
|
||||
symbol_table[*output->expression_] = symbol_table.CreateSymbol("two");
|
||||
symbol_table[*output->expression_] = symbol_table.CreateSymbol("two", true);
|
||||
|
||||
auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two},
|
||||
{Aggregation::Op::COUNT}, {}, {});
|
||||
@ -494,18 +496,18 @@ TEST(QueryPlan, Unwind) {
|
||||
std::vector<TypedValue>{1, true, "x"}, std::vector<TypedValue>{},
|
||||
std::vector<TypedValue>{"bla"}});
|
||||
|
||||
auto x = symbol_table.CreateSymbol("x");
|
||||
auto x = symbol_table.CreateSymbol("x", true);
|
||||
auto unwind_0 = std::make_shared<plan::Unwind>(nullptr, input_expr, x);
|
||||
auto x_expr = IDENT("x");
|
||||
symbol_table[*x_expr] = x;
|
||||
auto y = symbol_table.CreateSymbol("y");
|
||||
auto y = symbol_table.CreateSymbol("y", true);
|
||||
auto unwind_1 = std::make_shared<plan::Unwind>(unwind_0, x_expr, y);
|
||||
|
||||
auto x_ne = NEXPR("x", x_expr);
|
||||
symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne");
|
||||
symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne", true);
|
||||
auto y_ne = NEXPR("y", IDENT("y"));
|
||||
symbol_table[*y_ne->expression_] = y;
|
||||
symbol_table[*y_ne] = symbol_table.CreateSymbol("y_ne");
|
||||
symbol_table[*y_ne] = symbol_table.CreateSymbol("y_ne", true);
|
||||
auto produce = MakeProduce(unwind_1, x_ne, y_ne);
|
||||
|
||||
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
|
||||
|
@ -95,7 +95,7 @@ TEST(QueryPlan, CreateLimit) {
|
||||
|
||||
auto n = MakeScanAll(storage, symbol_table, "n1");
|
||||
auto m = NODE("m");
|
||||
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m");
|
||||
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true);
|
||||
auto c = std::make_shared<CreateNode>(m, n.op_);
|
||||
auto skip = std::make_shared<plan::Limit>(c, LITERAL(1));
|
||||
|
||||
@ -157,7 +157,7 @@ TEST(QueryPlan, OrderBy) {
|
||||
{order_value_pair.first, n_p}},
|
||||
std::vector<Symbol>{n.sym_});
|
||||
auto n_p_ne = NEXPR("n.p", n_p);
|
||||
symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n.p");
|
||||
symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n.p", true);
|
||||
auto produce = MakeProduce(order_by, n_p_ne);
|
||||
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
|
||||
ASSERT_EQ(values.size(), results.size());
|
||||
@ -208,9 +208,9 @@ TEST(QueryPlan, OrderByMultiple) {
|
||||
},
|
||||
std::vector<Symbol>{n.sym_});
|
||||
auto n_p1_ne = NEXPR("n.p1", n_p1);
|
||||
symbol_table[*n_p1_ne] = symbol_table.CreateSymbol("n.p1");
|
||||
symbol_table[*n_p1_ne] = symbol_table.CreateSymbol("n.p1", true);
|
||||
auto n_p2_ne = NEXPR("n.p2", n_p2);
|
||||
symbol_table[*n_p2_ne] = symbol_table.CreateSymbol("n.p2");
|
||||
symbol_table[*n_p2_ne] = symbol_table.CreateSymbol("n.p2", true);
|
||||
auto produce = MakeProduce(order_by, n_p1_ne, n_p2_ne);
|
||||
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
|
||||
ASSERT_EQ(N * N, results.size());
|
||||
|
@ -93,7 +93,7 @@ ScanAllTuple MakeScanAll(AstTreeStorage &storage, SymbolTable &symbol_table,
|
||||
GraphView graph_view = GraphView::OLD) {
|
||||
auto node = NODE(identifier);
|
||||
auto logical_op = std::make_shared<ScanAll>(node, input, graph_view);
|
||||
auto symbol = symbol_table.CreateSymbol(identifier);
|
||||
auto symbol = symbol_table.CreateSymbol(identifier, true);
|
||||
symbol_table[*node->identifier_] = symbol;
|
||||
// return std::make_tuple(node, logical_op, symbol);
|
||||
return ScanAllTuple{node, logical_op, symbol};
|
||||
@ -114,11 +114,11 @@ ExpandTuple MakeExpand(AstTreeStorage &storage, SymbolTable &symbol_table,
|
||||
const std::string &node_identifier, bool existing_node,
|
||||
GraphView graph_view = GraphView::AS_IS) {
|
||||
auto edge = EDGE(edge_identifier, direction);
|
||||
auto edge_sym = symbol_table.CreateSymbol(edge_identifier);
|
||||
auto edge_sym = symbol_table.CreateSymbol(edge_identifier, true);
|
||||
symbol_table[*edge->identifier_] = edge_sym;
|
||||
|
||||
auto node = NODE(node_identifier);
|
||||
auto node_sym = symbol_table.CreateSymbol(node_identifier);
|
||||
auto node_sym = symbol_table.CreateSymbol(node_identifier, true);
|
||||
symbol_table[*node->identifier_] = node_sym;
|
||||
|
||||
auto op = std::make_shared<Expand>(node, edge, input, input_symbol,
|
||||
|
@ -33,7 +33,7 @@ TEST(QueryPlan, CreateNodeWithAttributes) {
|
||||
SymbolTable symbol_table;
|
||||
|
||||
auto node = NODE("n");
|
||||
symbol_table[*node->identifier_] = symbol_table.CreateSymbol("n");
|
||||
symbol_table[*node->identifier_] = symbol_table.CreateSymbol("n", true);
|
||||
node->labels_.emplace_back(label);
|
||||
node->properties_[property] = LITERAL(42);
|
||||
|
||||
@ -67,19 +67,20 @@ TEST(QueryPlan, CreateReturn) {
|
||||
SymbolTable symbol_table;
|
||||
|
||||
auto node = NODE("n");
|
||||
auto sym_n = symbol_table.CreateSymbol("n");
|
||||
auto sym_n = symbol_table.CreateSymbol("n", true);
|
||||
symbol_table[*node->identifier_] = sym_n;
|
||||
node->labels_.emplace_back(label);
|
||||
node->properties_[property] = LITERAL(42);
|
||||
|
||||
auto create = std::make_shared<CreateNode>(node, nullptr);
|
||||
auto named_expr_n = NEXPR("n", IDENT("n"));
|
||||
symbol_table[*named_expr_n] = symbol_table.CreateSymbol("named_expr_n");
|
||||
symbol_table[*named_expr_n] = symbol_table.CreateSymbol("named_expr_n", true);
|
||||
symbol_table[*named_expr_n->expression_] = sym_n;
|
||||
auto prop_lookup = PROPERTY_LOOKUP("n", property);
|
||||
symbol_table[*prop_lookup->expression_] = sym_n;
|
||||
auto named_expr_n_p = NEXPR("n", prop_lookup);
|
||||
symbol_table[*named_expr_n_p] = symbol_table.CreateSymbol("named_expr_n_p");
|
||||
symbol_table[*named_expr_n_p] =
|
||||
symbol_table.CreateSymbol("named_expr_n_p", true);
|
||||
symbol_table[*named_expr_n->expression_] = sym_n;
|
||||
|
||||
auto produce = MakeProduce(create, named_expr_n, named_expr_n_p);
|
||||
@ -119,7 +120,7 @@ TEST(QueryPlan, CreateExpand) {
|
||||
auto n = NODE("n");
|
||||
n->labels_.emplace_back(label_node_1);
|
||||
n->properties_[property] = LITERAL(1);
|
||||
auto n_sym = symbol_table.CreateSymbol("n");
|
||||
auto n_sym = symbol_table.CreateSymbol("n", true);
|
||||
symbol_table[*n->identifier_] = n_sym;
|
||||
|
||||
// data for the second node
|
||||
@ -129,10 +130,10 @@ TEST(QueryPlan, CreateExpand) {
|
||||
if (cycle)
|
||||
symbol_table[*m->identifier_] = n_sym;
|
||||
else
|
||||
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m");
|
||||
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true);
|
||||
|
||||
auto r = EDGE("r", EdgeAtom::Direction::RIGHT);
|
||||
symbol_table[*r->identifier_] = symbol_table.CreateSymbol("r");
|
||||
symbol_table[*r->identifier_] = symbol_table.CreateSymbol("r", true);
|
||||
r->edge_types_.emplace_back(edge_type);
|
||||
r->properties_[property] = LITERAL(3);
|
||||
|
||||
@ -188,7 +189,7 @@ TEST(QueryPlan, MatchCreateNode) {
|
||||
auto n_scan_all = MakeScanAll(storage, symbol_table, "n");
|
||||
// second node
|
||||
auto m = NODE("m");
|
||||
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m");
|
||||
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true);
|
||||
// creation op
|
||||
auto create_node = std::make_shared<CreateNode>(m, n_scan_all.op_);
|
||||
|
||||
@ -229,10 +230,10 @@ TEST(QueryPlan, MatchCreateExpand) {
|
||||
if (cycle)
|
||||
symbol_table[*m->identifier_] = n_scan_all.sym_;
|
||||
else
|
||||
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m");
|
||||
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true);
|
||||
|
||||
auto r = EDGE("r", EdgeAtom::Direction::RIGHT);
|
||||
symbol_table[*r->identifier_] = symbol_table.CreateSymbol("r");
|
||||
symbol_table[*r->identifier_] = symbol_table.CreateSymbol("r", true);
|
||||
r->edge_types_.emplace_back(edge_type);
|
||||
|
||||
auto create_expand = std::make_shared<CreateExpand>(m, r, n_scan_all.op_,
|
||||
@ -403,7 +404,7 @@ TEST(QueryPlan, DeleteReturn) {
|
||||
storage.Create<PropertyLookup>(storage.Create<Identifier>("n"), prop);
|
||||
symbol_table[*prop_lookup->expression_] = n.sym_;
|
||||
auto n_p = storage.Create<NamedExpression>("n", prop_lookup);
|
||||
symbol_table[*n_p] = symbol_table.CreateSymbol("bla");
|
||||
symbol_table[*n_p] = symbol_table.CreateSymbol("bla", true);
|
||||
auto produce = MakeProduce(delete_op, n_p);
|
||||
|
||||
auto result = CollectProduce(produce, symbol_table, *dba);
|
||||
@ -843,7 +844,7 @@ TEST(QueryPlan, MergeNoInput) {
|
||||
SymbolTable symbol_table;
|
||||
|
||||
auto node = NODE("n");
|
||||
auto sym_n = symbol_table.CreateSymbol("n");
|
||||
auto sym_n = symbol_table.CreateSymbol("n", true);
|
||||
symbol_table[*node->identifier_] = sym_n;
|
||||
auto create = std::make_shared<CreateNode>(node, nullptr);
|
||||
auto merge = std::make_shared<plan::Merge>(nullptr, create, create);
|
||||
|
@ -39,7 +39,8 @@ TEST(QueryPlan, MatchReturn) {
|
||||
auto output = NEXPR("n", IDENT("n"));
|
||||
auto produce = MakeProduce(scan_all.op_, output);
|
||||
symbol_table[*output->expression_] = scan_all.sym_;
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
|
||||
symbol_table[*output] =
|
||||
symbol_table.CreateSymbol("named_expression_1", true);
|
||||
return PullAll(produce, *dba, symbol_table);
|
||||
};
|
||||
|
||||
@ -68,10 +69,12 @@ TEST(QueryPlan, MatchReturnCartesian) {
|
||||
auto m = MakeScanAll(storage, symbol_table, "m", n.op_);
|
||||
auto return_n = NEXPR("n", IDENT("n"));
|
||||
symbol_table[*return_n->expression_] = n.sym_;
|
||||
symbol_table[*return_n] = symbol_table.CreateSymbol("named_expression_1");
|
||||
symbol_table[*return_n] =
|
||||
symbol_table.CreateSymbol("named_expression_1", true);
|
||||
auto return_m = NEXPR("m", IDENT("m"));
|
||||
symbol_table[*return_m->expression_] = m.sym_;
|
||||
symbol_table[*return_m] = symbol_table.CreateSymbol("named_expression_2");
|
||||
symbol_table[*return_m] =
|
||||
symbol_table.CreateSymbol("named_expression_2", true);
|
||||
auto produce = MakeProduce(m.op_, return_n, return_m);
|
||||
|
||||
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
|
||||
@ -99,7 +102,7 @@ TEST(QueryPlan, StandaloneReturn) {
|
||||
|
||||
auto output = NEXPR("n", LITERAL(42));
|
||||
auto produce = MakeProduce(std::shared_ptr<LogicalOperator>(nullptr), output);
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true);
|
||||
|
||||
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
|
||||
EXPECT_EQ(result.GetResults().size(), 1);
|
||||
@ -150,7 +153,7 @@ TEST(QueryPlan, NodeFilterLabelsAndProperties) {
|
||||
// make a named expression and a produce
|
||||
auto output = NEXPR("x", IDENT("n"));
|
||||
symbol_table[*output->expression_] = n.sym_;
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true);
|
||||
auto produce = MakeProduce(node_filter, output);
|
||||
|
||||
EXPECT_EQ(1, PullAll(produce, *dba, symbol_table));
|
||||
@ -206,7 +209,7 @@ TEST(QueryPlan, NodeFilterMultipleLabels) {
|
||||
auto produce = MakeProduce(node_filter, output);
|
||||
|
||||
// fill up the symbol table
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true);
|
||||
symbol_table[*output->expression_] = n.sym_;
|
||||
|
||||
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
|
||||
@ -240,7 +243,8 @@ TEST(QueryPlan, Expand) {
|
||||
// make a named expression and a produce
|
||||
auto output = NEXPR("m", IDENT("m"));
|
||||
symbol_table[*output->expression_] = r_m.node_sym_;
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
|
||||
symbol_table[*output] =
|
||||
symbol_table.CreateSymbol("named_expression_1", true);
|
||||
auto produce = MakeProduce(r_m.op_, output);
|
||||
|
||||
return PullAll(produce, *dba, symbol_table);
|
||||
@ -298,13 +302,13 @@ TEST(QueryPlan, ExpandOptional) {
|
||||
// RETURN n, r, m
|
||||
auto n_ne = NEXPR("n", IDENT("n"));
|
||||
symbol_table[*n_ne->expression_] = n.sym_;
|
||||
symbol_table[*n_ne] = symbol_table.CreateSymbol("n");
|
||||
symbol_table[*n_ne] = symbol_table.CreateSymbol("n", true);
|
||||
auto r_ne = NEXPR("r", IDENT("r"));
|
||||
symbol_table[*r_ne->expression_] = r_m.edge_sym_;
|
||||
symbol_table[*r_ne] = symbol_table.CreateSymbol("r");
|
||||
symbol_table[*r_ne] = symbol_table.CreateSymbol("r", true);
|
||||
auto m_ne = NEXPR("m", IDENT("m"));
|
||||
symbol_table[*m_ne->expression_] = r_m.node_sym_;
|
||||
symbol_table[*m_ne] = symbol_table.CreateSymbol("m");
|
||||
symbol_table[*m_ne] = symbol_table.CreateSymbol("m", true);
|
||||
auto produce = MakeProduce(optional, n_ne, r_ne, m_ne);
|
||||
|
||||
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
|
||||
@ -339,7 +343,7 @@ TEST(QueryPlan, OptionalMatchEmptyDB) {
|
||||
// RETURN n
|
||||
auto n_ne = NEXPR("n", IDENT("n"));
|
||||
symbol_table[*n_ne->expression_] = n.sym_;
|
||||
symbol_table[*n_ne] = symbol_table.CreateSymbol("n");
|
||||
symbol_table[*n_ne] = symbol_table.CreateSymbol("n", true);
|
||||
auto optional = std::make_shared<plan::Optional>(nullptr, n.op_,
|
||||
std::vector<Symbol>{n.sym_});
|
||||
auto produce = MakeProduce(optional, n_ne);
|
||||
@ -377,7 +381,8 @@ TEST(QueryPlan, ExpandExistingNode) {
|
||||
// make a named expression and a produce
|
||||
auto output = NEXPR("n", IDENT("n"));
|
||||
symbol_table[*output->expression_] = n.sym_;
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
|
||||
symbol_table[*output] =
|
||||
symbol_table.CreateSymbol("named_expression_1", true);
|
||||
auto produce = MakeProduce(r_n.op_, output);
|
||||
|
||||
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
|
||||
@ -420,7 +425,8 @@ TEST(QueryPlan, ExpandExistingEdge) {
|
||||
// make a named expression and a produce
|
||||
auto output = NEXPR("r", IDENT("r"));
|
||||
symbol_table[*output->expression_] = r_j.edge_sym_;
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
|
||||
symbol_table[*output] =
|
||||
symbol_table.CreateSymbol("named_expression_1", true);
|
||||
auto produce = MakeProduce(r_k.op_, output);
|
||||
|
||||
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
|
||||
@ -505,7 +511,8 @@ TEST(QueryPlan, EdgeFilter) {
|
||||
// make a named expression and a produce
|
||||
auto output = NEXPR("m", IDENT("m"));
|
||||
symbol_table[*output->expression_] = r_m.node_sym_;
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
|
||||
symbol_table[*output] =
|
||||
symbol_table.CreateSymbol("named_expression_1", true);
|
||||
auto produce = MakeProduce(edge_filter, output);
|
||||
|
||||
return PullAll(produce, *dba, symbol_table);
|
||||
@ -552,7 +559,7 @@ TEST(QueryPlan, EdgeFilterMultipleTypes) {
|
||||
auto produce = MakeProduce(edge_filter, output);
|
||||
|
||||
// fill up the symbol table
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true);
|
||||
symbol_table[*output->expression_] = r_m.node_sym_;
|
||||
|
||||
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
|
||||
@ -582,7 +589,7 @@ TEST(QueryPlan, Filter) {
|
||||
auto output =
|
||||
storage.Create<NamedExpression>("x", storage.Create<Identifier>("n"));
|
||||
symbol_table[*output->expression_] = n.sym_;
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
|
||||
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true);
|
||||
auto produce = MakeProduce(f, output);
|
||||
|
||||
EXPECT_EQ(CollectProduce(produce, symbol_table, *dba).GetResults().size(), 2);
|
||||
@ -647,7 +654,7 @@ TEST(QueryPlan, Distinct) {
|
||||
|
||||
auto input_expr = LITERAL(TypedValue(input));
|
||||
|
||||
auto x = symbol_table.CreateSymbol("x");
|
||||
auto x = symbol_table.CreateSymbol("x", true);
|
||||
auto unwind = std::make_shared<plan::Unwind>(nullptr, input_expr, x);
|
||||
auto x_expr = IDENT("x");
|
||||
symbol_table[*x_expr] = x;
|
||||
@ -656,7 +663,7 @@ TEST(QueryPlan, Distinct) {
|
||||
std::make_shared<plan::Distinct>(unwind, std::vector<Symbol>{x});
|
||||
|
||||
auto x_ne = NEXPR("x", x_expr);
|
||||
symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne");
|
||||
symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne", true);
|
||||
auto produce = MakeProduce(distinct, x_ne);
|
||||
|
||||
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
|
||||
|
@ -120,17 +120,17 @@ using ExpectDistinct = OpChecker<Distinct>;
|
||||
|
||||
class ExpectAccumulate : public OpChecker<Accumulate> {
|
||||
public:
|
||||
ExpectAccumulate(const std::unordered_set<Symbol, Symbol::Hash> &symbols)
|
||||
ExpectAccumulate(const std::unordered_set<Symbol> &symbols)
|
||||
: symbols_(symbols) {}
|
||||
|
||||
void ExpectOp(Accumulate &op, const SymbolTable &symbol_table) override {
|
||||
std::unordered_set<Symbol, Symbol::Hash> got_symbols(op.symbols().begin(),
|
||||
op.symbols().end());
|
||||
std::unordered_set<Symbol> got_symbols(op.symbols().begin(),
|
||||
op.symbols().end());
|
||||
EXPECT_EQ(symbols_, got_symbols);
|
||||
}
|
||||
|
||||
private:
|
||||
const std::unordered_set<Symbol, Symbol::Hash> symbols_;
|
||||
const std::unordered_set<Symbol> symbols_;
|
||||
};
|
||||
|
||||
class ExpectAggregate : public OpChecker<Aggregate> {
|
||||
@ -771,4 +771,55 @@ TEST(TestLogicalPlanner, MatchOptionalMatchWhere) {
|
||||
ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchReturnAsterisk) {
|
||||
// Test MATCH (n) -[e]- (m) RETURN *, m.prop
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto ret = RETURN(PROPERTY_LOOKUP("m", prop), AS("m.prop"));
|
||||
ret->body_.all_identifiers = true;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), ret);
|
||||
auto symbol_table = MakeSymbolTable(*query);
|
||||
auto plan = MakeLogicalPlan(storage, symbol_table);
|
||||
CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectExpand(),
|
||||
ExpectProduce());
|
||||
std::vector<std::string> output_names;
|
||||
for (const auto &output_symbol : plan->OutputSymbols(symbol_table)) {
|
||||
output_names.emplace_back(output_symbol.name());
|
||||
}
|
||||
std::vector<std::string> expected_names{"e", "m", "n", "m.prop"};
|
||||
EXPECT_EQ(output_names, expected_names);
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchReturnAsteriskSum) {
|
||||
// Test MATCH (n) RETURN *, SUM(n.prop) AS s
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
|
||||
auto ret = RETURN(sum, AS("s"));
|
||||
ret->body_.all_identifiers = true;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("n"))), ret);
|
||||
auto symbol_table = MakeSymbolTable(*query);
|
||||
auto plan = MakeLogicalPlan(storage, symbol_table);
|
||||
auto *produce = dynamic_cast<Produce *>(plan.get());
|
||||
ASSERT_TRUE(produce);
|
||||
const auto &named_expressions = produce->named_expressions();
|
||||
ASSERT_EQ(named_expressions.size(), 2);
|
||||
auto *expanded_ident =
|
||||
dynamic_cast<query::Identifier *>(named_expressions[0]->expression_);
|
||||
ASSERT_TRUE(expanded_ident);
|
||||
auto aggr = ExpectAggregate({sum}, {expanded_ident});
|
||||
CheckPlan(*plan, symbol_table, ExpectScanAll(), aggr, ExpectProduce());
|
||||
std::vector<std::string> output_names;
|
||||
for (const auto &output_symbol : plan->OutputSymbols(symbol_table)) {
|
||||
output_names.emplace_back(output_symbol.name());
|
||||
}
|
||||
std::vector<std::string> expected_names{"n", "s"};
|
||||
EXPECT_EQ(output_names, expected_names);
|
||||
}
|
||||
|
||||
|
||||
} // namespace
|
||||
|
@ -26,12 +26,12 @@ TEST(TestSymbolGenerator, MatchNodeReturn) {
|
||||
auto pattern = match->patterns_[0];
|
||||
auto node_atom = dynamic_cast<NodeAtom *>(pattern->atoms_[0]);
|
||||
auto node_sym = symbol_table[*node_atom->identifier_];
|
||||
EXPECT_EQ(node_sym.name_, "node_atom_1");
|
||||
EXPECT_EQ(node_sym.type_, Symbol::Type::Vertex);
|
||||
EXPECT_EQ(node_sym.name(), "node_atom_1");
|
||||
EXPECT_EQ(node_sym.type(), Symbol::Type::Vertex);
|
||||
auto ret = dynamic_cast<Return *>(query_ast->clauses_[1]);
|
||||
auto named_expr = ret->body_.named_expressions[0];
|
||||
auto column_sym = symbol_table[*named_expr];
|
||||
EXPECT_EQ(node_sym.name_, column_sym.name_);
|
||||
EXPECT_EQ(node_sym.name(), column_sym.name());
|
||||
EXPECT_NE(node_sym, column_sym);
|
||||
auto ret_sym = symbol_table[*named_expr->expression_];
|
||||
EXPECT_EQ(node_sym, ret_sym);
|
||||
@ -87,12 +87,12 @@ TEST(TestSymbolGenerator, MatchSameEdge) {
|
||||
is_node = !is_node;
|
||||
}
|
||||
auto &node_symbol = node_symbols.front();
|
||||
EXPECT_EQ(node_symbol.type_, Symbol::Type::Vertex);
|
||||
EXPECT_EQ(node_symbol.type(), Symbol::Type::Vertex);
|
||||
for (auto &symbol : node_symbols) {
|
||||
EXPECT_EQ(node_symbol, symbol);
|
||||
}
|
||||
auto &edge_symbol = edge_symbols.front();
|
||||
EXPECT_EQ(edge_symbol.type_, Symbol::Type::Edge);
|
||||
EXPECT_EQ(edge_symbol.type(), Symbol::Type::Edge);
|
||||
for (auto &symbol : edge_symbols) {
|
||||
EXPECT_EQ(edge_symbol, symbol);
|
||||
}
|
||||
@ -129,12 +129,12 @@ TEST(TestSymbolGenerator, CreateNodeReturn) {
|
||||
auto pattern = create->patterns_[0];
|
||||
auto node_atom = dynamic_cast<NodeAtom *>(pattern->atoms_[0]);
|
||||
auto node_sym = symbol_table[*node_atom->identifier_];
|
||||
EXPECT_EQ(node_sym.name_, "n");
|
||||
EXPECT_EQ(node_sym.type_, Symbol::Type::Vertex);
|
||||
EXPECT_EQ(node_sym.name(), "n");
|
||||
EXPECT_EQ(node_sym.type(), Symbol::Type::Vertex);
|
||||
auto ret = dynamic_cast<Return *>(query_ast->clauses_[1]);
|
||||
auto named_expr = ret->body_.named_expressions[0];
|
||||
auto column_sym = symbol_table[*named_expr];
|
||||
EXPECT_EQ(node_sym.name_, column_sym.name_);
|
||||
EXPECT_EQ(node_sym.name(), column_sym.name());
|
||||
EXPECT_NE(node_sym, column_sym);
|
||||
auto ret_sym = symbol_table[*named_expr->expression_];
|
||||
EXPECT_EQ(node_sym, ret_sym);
|
||||
@ -260,7 +260,7 @@ TEST(TestSymbolGenerator, CreateDelete) {
|
||||
EXPECT_EQ(symbol_table.max_position(), 1);
|
||||
auto node_symbol = symbol_table.at(*node->identifier_);
|
||||
auto ident_symbol = symbol_table.at(*ident);
|
||||
EXPECT_EQ(node_symbol.type_, Symbol::Type::Vertex);
|
||||
EXPECT_EQ(node_symbol.type(), Symbol::Type::Vertex);
|
||||
EXPECT_EQ(node_symbol, ident_symbol);
|
||||
}
|
||||
|
||||
@ -369,18 +369,18 @@ TEST(TestSymbolGenerator, CreateMultiExpand) {
|
||||
auto n1 = symbol_table.at(*node_n1->identifier_);
|
||||
auto n2 = symbol_table.at(*node_n2->identifier_);
|
||||
EXPECT_EQ(n1, n2);
|
||||
EXPECT_EQ(n1.type_, Symbol::Type::Vertex);
|
||||
EXPECT_EQ(n1.type(), Symbol::Type::Vertex);
|
||||
auto m = symbol_table.at(*node_m->identifier_);
|
||||
EXPECT_EQ(m.type_, Symbol::Type::Vertex);
|
||||
EXPECT_EQ(m.type(), Symbol::Type::Vertex);
|
||||
EXPECT_NE(m, n1);
|
||||
auto l = symbol_table.at(*node_l->identifier_);
|
||||
EXPECT_EQ(l.type_, Symbol::Type::Vertex);
|
||||
EXPECT_EQ(l.type(), Symbol::Type::Vertex);
|
||||
EXPECT_NE(l, n1);
|
||||
EXPECT_NE(l, m);
|
||||
auto r = symbol_table.at(*edge_r->identifier_);
|
||||
auto p = symbol_table.at(*edge_p->identifier_);
|
||||
EXPECT_EQ(r.type_, Symbol::Type::Edge);
|
||||
EXPECT_EQ(p.type_, Symbol::Type::Edge);
|
||||
EXPECT_EQ(r.type(), Symbol::Type::Edge);
|
||||
EXPECT_EQ(p.type(), Symbol::Type::Edge);
|
||||
EXPECT_NE(r, p);
|
||||
}
|
||||
|
||||
@ -527,11 +527,11 @@ TEST(TestSymbolGenerator, MatchWithCreate) {
|
||||
query->Accept(symbol_generator);
|
||||
EXPECT_EQ(symbol_table.max_position(), 3);
|
||||
auto n = symbol_table.at(*node_1->identifier_);
|
||||
EXPECT_EQ(n.type_, Symbol::Type::Vertex);
|
||||
EXPECT_EQ(n.type(), Symbol::Type::Vertex);
|
||||
auto m = symbol_table.at(*node_2->identifier_);
|
||||
EXPECT_NE(n, m);
|
||||
// Currently we don't infer expression types, so we lost true type of 'm'.
|
||||
EXPECT_EQ(m.type_, Symbol::Type::Any);
|
||||
EXPECT_EQ(m.type(), Symbol::Type::Any);
|
||||
EXPECT_EQ(m, symbol_table.at(*node_3->identifier_));
|
||||
}
|
||||
|
||||
@ -792,4 +792,52 @@ TEST(TestSymbolGenerator, MatchCrossReferenceVariable) {
|
||||
EXPECT_NE(m, symbol_table.at(*as_n));
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, MatchWithAsteriskReturnAsterisk) {
|
||||
// MATCH (n) -[e]- (m) WITH * RETURN *, n.prop
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto n_prop = PROPERTY_LOOKUP("n", prop);
|
||||
auto ret = RETURN(n_prop, AS("n.prop"));
|
||||
ret->body_.all_identifiers = true;
|
||||
auto node_n = NODE("n");
|
||||
auto edge = EDGE("e");
|
||||
auto node_m = NODE("m");
|
||||
auto with = storage.Create<With>();
|
||||
with->body_.all_identifiers = true;
|
||||
auto query = QUERY(MATCH(PATTERN(node_n, edge, node_m)), with, ret);
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query->Accept(symbol_generator);
|
||||
// Symbols for `n`, `e`, `m`, `AS n.prop`.
|
||||
EXPECT_EQ(symbol_table.max_position(), 4);
|
||||
auto n = symbol_table.at(*node_n->identifier_);
|
||||
EXPECT_EQ(n, symbol_table.at(*n_prop->expression_));
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, MatchReturnAsteriskSameResult) {
|
||||
// MATCH (n) RETURN *, n AS n
|
||||
AstTreeStorage storage;
|
||||
auto ret = RETURN(IDENT("n"), AS("n"));
|
||||
ret->body_.all_identifiers = true;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("n"))), ret);
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, MatchReturnAsteriskNoUserVariables) {
|
||||
// MATCH () RETURN *
|
||||
AstTreeStorage storage;
|
||||
auto ret = storage.Create<Return>();
|
||||
ret->body_.all_identifiers = true;
|
||||
auto ident_n = storage.Create<Identifier>("anon", false);
|
||||
auto node = storage.Create<NodeAtom>(ident_n);
|
||||
auto query = QUERY(MATCH(PATTERN(node)), ret);
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user