Plan '*' in RETURN and WITH

Summary:
Make Symbol members read only.
Check WITH/RETURN * in SymbolGenerator.
Test semantic checks for WITH/RETURN *.
Sort expanded user identifiers by name.
Test planning WITH/RETURN *.

Reviewers: buda, florijan, mislav.bradac

Reviewed By: florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D357
This commit is contained in:
Teon Banek 2017-05-12 10:37:22 +01:00
parent f82bda6c0c
commit 87e5dc0dfb
15 changed files with 347 additions and 148 deletions

View File

@ -8,32 +8,46 @@
namespace query {
auto SymbolGenerator::CreateSymbol(const std::string &name, Symbol::Type type) {
auto symbol = symbol_table_.CreateSymbol(name, type);
auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared,
Symbol::Type type) {
auto symbol = symbol_table_.CreateSymbol(name, user_declared, type);
scope_.symbols[name] = symbol;
return symbol;
}
auto SymbolGenerator::GetOrCreateSymbol(const std::string &name,
Symbol::Type type) {
bool user_declared, Symbol::Type type) {
auto search = scope_.symbols.find(name);
if (search != scope_.symbols.end()) {
auto symbol = search->second;
// Unless we have `Any` type, check that types match.
if (type != Symbol::Type::Any && symbol.type_ != Symbol::Type::Any &&
type != symbol.type_) {
throw TypeMismatchError(name, Symbol::TypeToString(symbol.type_),
if (type != Symbol::Type::Any && symbol.type() != Symbol::Type::Any &&
type != symbol.type()) {
throw TypeMismatchError(name, Symbol::TypeToString(symbol.type()),
Symbol::TypeToString(type));
}
return search->second;
}
return CreateSymbol(name, type);
return CreateSymbol(name, user_declared, type);
}
void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
for (auto &expr : body.named_expressions) {
expr->Accept(*this);
}
std::vector<Symbol> user_symbols;
if (body.all_identifiers) {
// Carry over user symbols because '*' appeared.
for (auto sym_pair : scope_.symbols) {
if (!sym_pair.second.user_declared()) {
continue;
}
user_symbols.emplace_back(sym_pair.second);
}
if (user_symbols.empty()) {
throw SemanticException("There are no variables in scope to use for '*'");
}
}
// WITH/RETURN clause removes declarations of all the previous variables and
// declares only those established through named expressions. New declarations
// must not be visible inside named expressions themselves.
@ -47,6 +61,10 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
}
// Create symbols for named expressions.
std::unordered_set<std::string> new_names;
for (const auto &user_sym : user_symbols) {
new_names.insert(user_sym.name());
scope_.symbols[user_sym.name()] = user_sym;
}
for (auto &named_expr : body.named_expressions) {
const auto &name = named_expr->name_;
if (!new_names.insert(name).second) {
@ -55,7 +73,7 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
}
// An improvement would be to infer the type of the expression, so that the
// new symbol would have a more specific type.
symbol_table_[*named_expr] = CreateSymbol(name);
symbol_table_[*named_expr] = CreateSymbol(name, true);
}
scope_.in_order_by = true;
for (const auto &order_pair : body.order_by) {
@ -119,7 +137,7 @@ void SymbolGenerator::PostVisit(Unwind &unwind) {
if (HasSymbol(name)) {
throw RedeclareVariableError(name);
}
symbol_table_[*unwind.named_expression_] = CreateSymbol(name);
symbol_table_[*unwind.named_expression_] = CreateSymbol(name, true);
}
void SymbolGenerator::Visit(Match &) { scope_.in_match = true; }
@ -163,7 +181,7 @@ void SymbolGenerator::Visit(Identifier &ident) {
if (scope_.in_edge_atom) {
type = Symbol::Type::Edge;
}
symbol = GetOrCreateSymbol(ident.name_, type);
symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, type);
} else if (scope_.in_pattern && scope_.in_property_map && scope_.in_match) {
// Variables in property maps during MATCH can reference symbols bound later
// in the same MATCH. We collect them here, so that they can be checked
@ -193,7 +211,8 @@ void SymbolGenerator::Visit(Aggregation &aggr) {
}
// Create a virtual symbol for aggregation result.
// Currently, we only have aggregation operators which return numbers.
symbol_table_[aggr] = symbol_table_.CreateSymbol("", Symbol::Type::Number);
symbol_table_[aggr] =
symbol_table_.CreateSymbol("", false, Symbol::Type::Number);
scope_.in_aggregation = true;
scope_.has_aggregation = true;
}

View File

@ -88,12 +88,12 @@ class SymbolGenerator : public TreeVisitorBase {
// Returns a freshly generated symbol. Previous mapping of the same name to a
// different symbol is replaced with the new one.
auto CreateSymbol(const std::string &name,
auto CreateSymbol(const std::string &name, bool user_declared,
Symbol::Type type = Symbol::Type::Any);
// Returns the symbol by name. If the mapping already exists, checks if the
// types match. Otherwise, returns a new symbol.
auto GetOrCreateSymbol(const std::string &name,
auto GetOrCreateSymbol(const std::string &name, bool user_declared,
Symbol::Type type = Symbol::Type::Any);
void VisitReturnBody(ReturnBody &body, Where *where = nullptr);

View File

@ -17,34 +17,38 @@ class Symbol {
return enum_string[static_cast<int>(type)];
}
// Calculates the Symbol hash based on its position.
struct Hash {
size_t operator()(const Symbol &symbol) const {
return std::hash<int>{}(symbol.position_);
}
};
Symbol() {}
Symbol(const std::string &name, int position, Type type = Type::Any)
: name_(name), position_(position), type_(type) {}
std::string name_;
int position_;
Type type_{Type::Any};
Symbol(const std::string &name, int position, bool user_declared,
Type type = Type::Any)
: name_(name),
position_(position),
user_declared_(user_declared),
type_(type) {}
bool operator==(const Symbol &other) const {
return position_ == other.position_ && name_ == other.name_ &&
type_ == other.type_;
}
bool operator!=(const Symbol &other) const { return !operator==(other); }
const auto &name() const { return name_; }
int position() const { return position_; }
Type type() const { return type_; }
bool user_declared() const { return user_declared_; }
private:
std::string name_;
int position_;
bool user_declared_ = true;
Type type_ = Type::Any;
};
class SymbolTable {
public:
Symbol CreateSymbol(const std::string &name,
Symbol CreateSymbol(const std::string &name, bool user_declared,
Symbol::Type type = Symbol::Type::Any) {
int position = position_++;
return Symbol(name, position, type);
return Symbol(name, position, user_declared, type);
}
auto &operator[](const Tree &tree) { return table_[tree.uid()]; }
@ -59,4 +63,21 @@ class SymbolTable {
std::map<int, Symbol> table_;
};
}
} // namespace query
namespace std {
template <>
struct hash<query::Symbol> {
size_t operator()(const query::Symbol &symbol) const {
size_t prime = 265443599u;
size_t hash = std::hash<int>{}(symbol.position());
hash ^= prime * std::hash<std::string>{}(symbol.name());
hash ^= prime * std::hash<bool>{}(symbol.user_declared());
hash ^= prime * std::hash<int>{}(static_cast<int>(symbol.type()));
return hash;
}
};
} // namespace std

View File

@ -12,10 +12,10 @@ class Frame {
Frame(int size) : size_(size), elems_(size_) {}
TypedValue &operator[](const Symbol &symbol) {
return elems_[symbol.position_];
return elems_[symbol.position()];
}
const TypedValue &operator[](const Symbol &symbol) const {
return elems_[symbol.position_];
return elems_[symbol.position()];
}
private:

View File

@ -48,7 +48,7 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor,
// clause, so stream out the results.
// generate header
for (const auto &symbol : output_symbols) header.push_back(symbol.name_);
for (const auto &symbol : output_symbols) header.push_back(symbol.name());
stream.Header(header);
// stream out results

View File

@ -13,8 +13,9 @@ namespace {
// Returns false if the symbol was already bound, otherwise binds it and
// returns true.
bool BindSymbol(std::unordered_set<int> &bound_symbols, const Symbol &symbol) {
auto insertion = bound_symbols.insert(symbol.position_);
bool BindSymbol(std::unordered_set<Symbol> &bound_symbols,
const Symbol &symbol) {
auto insertion = bound_symbols.insert(symbol);
return insertion.second;
}
@ -61,7 +62,7 @@ auto ReducePattern(
auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols) {
std::unordered_set<Symbol> &bound_symbols) {
auto base = [&](NodeAtom *node) -> LogicalOperator * {
if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_)))
return new CreateNode(node, std::shared_ptr<LogicalOperator>(input_op));
@ -92,7 +93,7 @@ auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op,
auto GenCreate(Create &create, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols) {
std::unordered_set<Symbol> &bound_symbols) {
auto last_op = input_op;
for (auto pattern : create.patterns_) {
last_op =
@ -109,17 +110,16 @@ class UsedSymbolsCollector : public TreeVisitorBase {
using TreeVisitorBase::Visit;
void Visit(Identifier &ident) override {
const auto &symbol = symbol_table_.at(ident);
symbols_.insert(symbol.position_);
symbols_.insert(symbol_table_.at(ident));
}
std::unordered_set<int> symbols_;
std::unordered_set<Symbol> symbols_;
const SymbolTable &symbol_table_;
};
bool HasBoundFilterSymbols(
const std::unordered_set<int> &bound_symbols,
const std::pair<Expression *, std::unordered_set<int>> &filter) {
const std::unordered_set<Symbol> &bound_symbols,
const std::pair<Expression *, std::unordered_set<Symbol>> &filter) {
for (const auto &symbol : filter.second) {
if (bound_symbols.find(symbol) == bound_symbols.end()) {
return false;
@ -154,7 +154,7 @@ Expression *PropertiesEqual(AstTreeStorage &storage,
auto &CollectPatternFilters(
Pattern &pattern, const SymbolTable &symbol_table,
std::list<std::pair<Expression *, std::unordered_set<int>>> &filters,
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
AstTreeStorage &storage) {
UsedSymbolsCollector collector(symbol_table);
auto node_filter = [&](NodeAtom *node) {
@ -164,7 +164,7 @@ auto &CollectPatternFilters(
node->identifier_, node->labels_);
auto *props_filter = PropertiesEqual(storage, collector, node);
if (labels_filter || props_filter) {
collector.symbols_.insert(symbol_table.at(*node->identifier_).position_);
collector.symbols_.insert(symbol_table.at(*node->identifier_));
filters.emplace_back(
BoolJoin<FilterAndOperator>(storage, labels_filter, props_filter),
collector.symbols_);
@ -181,7 +181,7 @@ auto &CollectPatternFilters(
auto *props_filter = PropertiesEqual(storage, collector, edge);
if (types_filter || props_filter) {
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
collector.symbols_.insert(edge_symbol.position_);
collector.symbols_.insert(edge_symbol);
filters->emplace_back(
BoolJoin<FilterAndOperator>(storage, types_filter, props_filter),
collector.symbols_);
@ -190,13 +190,13 @@ auto &CollectPatternFilters(
return node_filter(node);
};
return *ReducePattern<
std::list<std::pair<Expression *, std::unordered_set<int>>> *>(
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> *>(
pattern, node_filter, expand_filter);
}
void CollectMatchFilters(
const Match &match, const SymbolTable &symbol_table,
std::list<std::pair<Expression *, std::unordered_set<int>>> &filters,
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
AstTreeStorage &storage) {
for (auto *pattern : match.patterns_) {
CollectPatternFilters(*pattern, symbol_table, filters, storage);
@ -214,22 +214,22 @@ struct MatchContext {
// Already bound symbols, which are used to determine whether the operator
// should reference them or establish new. This is both read from and written
// to during generation.
std::unordered_set<int> &bound_symbols;
std::unordered_set<Symbol> &bound_symbols;
// Determines whether the match should see the new graph state or not.
GraphView graph_view = GraphView::OLD;
// Pairs of filter expression and symbols used in them. The list should be
// filled using CollectPatternFilters function, and later modified during
// GenMatchForPattern.
std::list<std::pair<Expression *, std::unordered_set<int>>> filters;
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> filters;
// Symbols for edges established in match, used to ensure Cyphermorphism.
std::unordered_set<Symbol, Symbol::Hash> edge_symbols;
std::unordered_set<Symbol> edge_symbols;
// All the newly established symbols in match.
std::vector<Symbol> new_symbols;
};
auto GenFilters(
LogicalOperator *last_op, const std::unordered_set<int> &bound_symbols,
std::list<std::pair<Expression *, std::unordered_set<int>>> &filters,
LogicalOperator *last_op, const std::unordered_set<Symbol> &bound_symbols,
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
AstTreeStorage &storage) {
Expression *filter_expr = nullptr;
for (auto filters_it = filters.begin(); filters_it != filters.end();) {
@ -312,7 +312,7 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
auto GenMatches(std::vector<Match *> &matches, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols,
std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
auto *last_op = input_op;
MatchContext req_ctx{symbol_table, bound_symbols};
@ -369,15 +369,27 @@ auto GenMatches(std::vector<Match *> &matches, LogicalOperator *input_op,
// aggregations and expressions used for group by.
class ReturnBodyContext : public TreeVisitorBase {
public:
ReturnBodyContext(const ReturnBody &body, const SymbolTable &symbol_table,
Where *where = nullptr)
: body_(body), symbol_table_(symbol_table), where_(where) {
ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table,
const std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage, Where *where = nullptr)
: body_(body),
symbol_table_(symbol_table),
bound_symbols_(bound_symbols),
storage_(storage),
where_(where) {
// Collect symbols from named expressions.
output_symbols_.reserve(body_.named_expressions.size());
if (body.all_identifiers) {
// Expand '*' to expressions and symbols first, so that their results come
// before regular named expressions.
ExpandUserSymbols();
}
for (auto &named_expr : body_.named_expressions) {
output_symbols_.emplace_back(symbol_table_.at(*named_expr));
named_expr->Accept(*this);
named_expressions_.emplace_back(named_expr);
}
// Collect aggregations.
if (aggregations_.empty()) {
// Visit order_by and where if we do not have aggregations. This way we
// prevent collecting group_by expressions from order_by and where, which
@ -456,7 +468,7 @@ class ReturnBodyContext : public TreeVisitorBase {
// Aggregation contains a virtual symbol, where the result will be stored.
const auto &symbol = symbol_table_.at(aggr);
aggregations_.emplace_back(aggr.expression_, aggr.op_, symbol);
// aggregation expression_ is opional in COUNT(*) so it's possible the
// aggregation expression_ is optional in COUNT(*), so it's possible the
// has_aggregation_ stack is empty
if (aggr.expression_)
has_aggregation_.back() = true;
@ -474,10 +486,41 @@ class ReturnBodyContext : public TreeVisitorBase {
has_aggregation_.pop_back();
}
// Creates NamedExpression with an Identifier for each user declared symbol.
// This should be used when body.all_identifiers is true, to generate
// expressions for Produce operator.
void ExpandUserSymbols() {
debug_assert(
named_expressions_.empty(),
"ExpandUserSymbols should be first to fill named_expressions_");
debug_assert(output_symbols_.empty(),
"ExpandUserSymbols should be first to fill output_symbols_");
for (const auto &symbol : bound_symbols_) {
if (!symbol.user_declared()) {
continue;
}
auto *ident = storage_.Create<Identifier>(symbol.name());
symbol_table_[*ident] = symbol;
auto *named_expr = storage_.Create<NamedExpression>(symbol.name(), ident);
symbol_table_[*named_expr] = symbol;
// Fill output expressions and symbols with expanded identifiers.
named_expressions_.emplace_back(named_expr);
output_symbols_.emplace_back(symbol);
used_symbols_.insert(symbol);
// Don't forget to group by expanded identifiers.
group_by_.emplace_back(ident);
}
// Cypher RETURN/WITH * expects to expand '*' sorted by name.
std::sort(output_symbols_.begin(), output_symbols_.end(),
[](const auto &a, const auto &b) { return a.name() < b.name(); });
std::sort(named_expressions_.begin(), named_expressions_.end(),
[](const auto &a, const auto &b) { return a->name_ < b->name_; });
}
// If true, results need to be distinct.
bool distinct() const { return body_.distinct; }
// Named expressions which are used to produce results.
const auto &named_expressions() const { return body_.named_expressions; }
const auto &named_expressions() const { return named_expressions_; }
// Pairs of (Ordering, Expression *) for sorting results.
const auto &order_by() const { return body_.order_by; }
// Optional expression which determines how many results to skip.
@ -496,21 +539,23 @@ class ReturnBodyContext : public TreeVisitorBase {
// expressions are used for grouping. For example, in `WITH sum(n.a) + 2 * n.b
// AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`.
const auto &group_by() const { return group_by_; }
// All symbols generated by named expressions. They are collected in order of
// named_expressions.
const auto &output_symbols() const { return output_symbols_; }
private:
const ReturnBody &body_;
const SymbolTable &symbol_table_;
SymbolTable &symbol_table_;
const std::unordered_set<Symbol> &bound_symbols_;
AstTreeStorage &storage_;
const Where *const where_ = nullptr;
std::unordered_set<Symbol, Symbol::Hash> used_symbols_;
std::unordered_set<Symbol> used_symbols_;
std::vector<Symbol> output_symbols_;
std::vector<Aggregate::Element> aggregations_;
std::vector<Expression *> group_by_;
// Flag indicating whether an expression contains an aggregation.
std::list<bool> has_aggregation_;
std::vector<NamedExpression *> named_expressions_;
};
auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
@ -561,9 +606,9 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
return last_op;
}
auto GenWith(With &with, LogicalOperator *input_op,
const SymbolTable &symbol_table, bool is_write,
std::unordered_set<int> &bound_symbols) {
auto GenWith(With &with, LogicalOperator *input_op, SymbolTable &symbol_table,
bool is_write, std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
// optional Filter. In case of update and aggregation, we want to accumulate
// first, so that when aggregating, we get the latest results. Similar to
@ -571,7 +616,8 @@ auto GenWith(With &with, LogicalOperator *input_op,
bool accumulate = is_write;
// No need to advance the command if we only performed reads.
bool advance_command = is_write;
ReturnBodyContext body(with.body_, symbol_table, with.where_);
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage,
with.where_);
LogicalOperator *last_op =
GenReturnBody(input_op, advance_command, body, accumulate);
// Reset bound symbols, so that only those in WITH are exposed.
@ -583,7 +629,9 @@ auto GenWith(With &with, LogicalOperator *input_op,
}
auto GenReturn(Return &ret, LogicalOperator *input_op,
const SymbolTable &symbol_table, bool is_write) {
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
// Similar to WITH clause, but we want to accumulate and advance command when
// the query writes to the database. This way we handle the case when we want
// to return expressions with the latest updated results. For example,
@ -592,7 +640,7 @@ auto GenReturn(Return &ret, LogicalOperator *input_op,
// value is the same, final result of 'k' increments.
bool accumulate = is_write;
bool advance_command = false;
ReturnBodyContext body(ret.body_, symbol_table);
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage);
return GenReturnBody(input_op, advance_command, body, accumulate);
}
@ -600,7 +648,7 @@ auto GenReturn(Return &ret, LogicalOperator *input_op,
// isn't handled, returns nullptr.
LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols) {
std::unordered_set<Symbol> &bound_symbols) {
if (auto *create = dynamic_cast<Create *>(clause)) {
return GenCreate(*create, input_op, symbol_table, bound_symbols);
} else if (auto *del = dynamic_cast<query::Delete *>(clause)) {
@ -632,10 +680,11 @@ LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols, AstTreeStorage &storage) {
std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
// Copy the bound symbol set, because we don't want to use the updated version
// when generating the create part.
std::unordered_set<int> bound_symbols_copy(bound_symbols);
std::unordered_set<Symbol> bound_symbols_copy(bound_symbols);
MatchContext context{symbol_table, bound_symbols_copy, GraphView::NEW};
CollectPatternFilters(*merge.pattern_, symbol_table, context.filters,
storage);
@ -659,14 +708,14 @@ auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
} // namespace
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
AstTreeStorage &storage, const SymbolTable &symbol_table) {
std::unique_ptr<LogicalOperator> MakeLogicalPlan(AstTreeStorage &storage,
SymbolTable &symbol_table) {
auto query = storage.query();
// bound_symbols set is used to differentiate cycles in pattern matching, so
// that the operator can be correctly initialized whether to read the symbol
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
// `n`, but the latter `n` would only read the already written information.
std::unordered_set<int> bound_symbols;
std::unordered_set<Symbol> bound_symbols;
// Set to true if a query command writes to the database.
bool is_write = false;
LogicalOperator *input_op = nullptr;
@ -681,7 +730,8 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
GenMatches(matches, input_op, symbol_table, bound_symbols, storage);
matches.clear();
if (auto *ret = dynamic_cast<Return *>(clause)) {
input_op = GenReturn(*ret, input_op, symbol_table, is_write);
input_op = GenReturn(*ret, input_op, symbol_table, is_write,
bound_symbols, storage);
} else if (auto *merge = dynamic_cast<query::Merge *>(clause)) {
input_op =
GenMerge(*merge, input_op, symbol_table, bound_symbols, storage);
@ -689,8 +739,8 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
// anything.
is_write = true;
} else if (auto *with = dynamic_cast<query::With *>(clause)) {
input_op =
GenWith(*with, input_op, symbol_table, is_write, bound_symbols);
input_op = GenWith(*with, input_op, symbol_table, is_write,
bound_symbols, storage);
// WITH clause advances the command, so reset the flag.
is_write = false;
} else if (auto *op = HandleWriteClause(clause, input_op, symbol_table,

View File

@ -18,7 +18,7 @@ namespace plan {
/// use in operators. @c SymbolTable is used to determine inputs and outputs of
/// certain operators.
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
AstTreeStorage &storage, const query::SymbolTable &symbol_table);
AstTreeStorage &storage, query::SymbolTable &symbol_table);
}
} // namespace plan

View File

@ -505,7 +505,7 @@ TEST(ExpressionEvaluator, PropertyLookup) {
auto v1 = dba->insert_vertex();
v1.PropsSet(dba->property("age"), 10);
auto *identifier = storage.Create<Identifier>("n");
auto node_symbol = eval.symbol_table.CreateSymbol("n");
auto node_symbol = eval.symbol_table.CreateSymbol("n", true);
eval.symbol_table[*identifier] = node_symbol;
eval.frame[node_symbol] = v1;
{
@ -537,7 +537,7 @@ TEST(ExpressionEvaluator, LabelsTest) {
v1.add_label(dba->label("DOG"));
v1.add_label(dba->label("NICE_DOG"));
auto *identifier = storage.Create<Identifier>("n");
auto node_symbol = eval.symbol_table.CreateSymbol("n");
auto node_symbol = eval.symbol_table.CreateSymbol("n", true);
eval.symbol_table[*identifier] = node_symbol;
eval.frame[node_symbol] = v1;
{
@ -575,7 +575,7 @@ TEST(ExpressionEvaluator, EdgeTypeTest) {
auto v2 = dba->insert_vertex();
auto e = dba->insert_edge(v1, v2, dba->edge_type("TYPE1"));
auto *identifier = storage.Create<Identifier>("e");
auto edge_symbol = eval.symbol_table.CreateSymbol("e");
auto edge_symbol = eval.symbol_table.CreateSymbol("e", true);
eval.symbol_table[*identifier] = edge_symbol;
eval.frame[edge_symbol] = e;
{
@ -608,7 +608,7 @@ TEST(ExpressionEvaluator, Aggregation) {
auto aggr = storage.Create<Aggregation>(storage.Create<PrimitiveLiteral>(42),
Aggregation::Op::COUNT);
SymbolTable symbol_table;
auto aggr_sym = symbol_table.CreateSymbol("aggr");
auto aggr_sym = symbol_table.CreateSymbol("aggr", true);
symbol_table[*aggr] = aggr_sym;
Frame frame{symbol_table.max_position()};
frame[aggr_sym] = TypedValue(1);

View File

@ -65,9 +65,9 @@ TEST(QueryPlan, Accumulate) {
}
auto n_p_ne = NEXPR("n.p", n_p);
symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne");
symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne", true);
auto m_p_ne = NEXPR("m.p", m_p);
symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne");
symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne", true);
auto produce = MakeProduce(last_op, n_p_ne, m_p_ne);
ResultStreamFaker results = CollectProduce(produce, symbol_table, *dba);
std::vector<int> results_data;
@ -95,7 +95,7 @@ TEST(QueryPlan, AccumulateAdvance) {
SymbolTable symbol_table;
auto node = NODE("n");
auto sym_n = symbol_table.CreateSymbol("n");
auto sym_n = symbol_table.CreateSymbol("n", true);
symbol_table[*node->identifier_] = sym_n;
auto create = std::make_shared<CreateNode>(node, nullptr);
auto accumulate = std::make_shared<Accumulate>(
@ -126,8 +126,9 @@ std::shared_ptr<Produce> MakeAggregationProduce(
auto named_expr = NEXPR("", IDENT("aggregation"));
named_expressions.push_back(named_expr);
symbol_table[*named_expr->expression_] =
symbol_table.CreateSymbol("aggregation");
symbol_table[*named_expr] = symbol_table.CreateSymbol("named_expression");
symbol_table.CreateSymbol("aggregation", true);
symbol_table[*named_expr] =
symbol_table.CreateSymbol("named_expression", true);
aggregates.emplace_back(*aggr_inputs_it++, aggr_op,
symbol_table[*named_expr->expression_]);
}
@ -137,7 +138,8 @@ std::shared_ptr<Produce> MakeAggregationProduce(
for (auto group_by_expr : group_by_exprs) {
auto named_expr = NEXPR("", group_by_expr);
named_expressions.push_back(named_expr);
symbol_table[*named_expr] = symbol_table.CreateSymbol("named_expression");
symbol_table[*named_expr] =
symbol_table.CreateSymbol("named_expression", true);
}
auto aggregation =
std::make_shared<Aggregate>(input, aggregates, group_by_exprs, remember);
@ -307,7 +309,7 @@ TEST(QueryPlan, AggregateNoInput) {
auto two = LITERAL(2);
auto output = NEXPR("two", IDENT("two"));
symbol_table[*output->expression_] = symbol_table.CreateSymbol("two");
symbol_table[*output->expression_] = symbol_table.CreateSymbol("two", true);
auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two},
{Aggregation::Op::COUNT}, {}, {});
@ -494,18 +496,18 @@ TEST(QueryPlan, Unwind) {
std::vector<TypedValue>{1, true, "x"}, std::vector<TypedValue>{},
std::vector<TypedValue>{"bla"}});
auto x = symbol_table.CreateSymbol("x");
auto x = symbol_table.CreateSymbol("x", true);
auto unwind_0 = std::make_shared<plan::Unwind>(nullptr, input_expr, x);
auto x_expr = IDENT("x");
symbol_table[*x_expr] = x;
auto y = symbol_table.CreateSymbol("y");
auto y = symbol_table.CreateSymbol("y", true);
auto unwind_1 = std::make_shared<plan::Unwind>(unwind_0, x_expr, y);
auto x_ne = NEXPR("x", x_expr);
symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne");
symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne", true);
auto y_ne = NEXPR("y", IDENT("y"));
symbol_table[*y_ne->expression_] = y;
symbol_table[*y_ne] = symbol_table.CreateSymbol("y_ne");
symbol_table[*y_ne] = symbol_table.CreateSymbol("y_ne", true);
auto produce = MakeProduce(unwind_1, x_ne, y_ne);
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();

View File

@ -95,7 +95,7 @@ TEST(QueryPlan, CreateLimit) {
auto n = MakeScanAll(storage, symbol_table, "n1");
auto m = NODE("m");
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m");
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true);
auto c = std::make_shared<CreateNode>(m, n.op_);
auto skip = std::make_shared<plan::Limit>(c, LITERAL(1));
@ -157,7 +157,7 @@ TEST(QueryPlan, OrderBy) {
{order_value_pair.first, n_p}},
std::vector<Symbol>{n.sym_});
auto n_p_ne = NEXPR("n.p", n_p);
symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n.p");
symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n.p", true);
auto produce = MakeProduce(order_by, n_p_ne);
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
ASSERT_EQ(values.size(), results.size());
@ -208,9 +208,9 @@ TEST(QueryPlan, OrderByMultiple) {
},
std::vector<Symbol>{n.sym_});
auto n_p1_ne = NEXPR("n.p1", n_p1);
symbol_table[*n_p1_ne] = symbol_table.CreateSymbol("n.p1");
symbol_table[*n_p1_ne] = symbol_table.CreateSymbol("n.p1", true);
auto n_p2_ne = NEXPR("n.p2", n_p2);
symbol_table[*n_p2_ne] = symbol_table.CreateSymbol("n.p2");
symbol_table[*n_p2_ne] = symbol_table.CreateSymbol("n.p2", true);
auto produce = MakeProduce(order_by, n_p1_ne, n_p2_ne);
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
ASSERT_EQ(N * N, results.size());

View File

@ -93,7 +93,7 @@ ScanAllTuple MakeScanAll(AstTreeStorage &storage, SymbolTable &symbol_table,
GraphView graph_view = GraphView::OLD) {
auto node = NODE(identifier);
auto logical_op = std::make_shared<ScanAll>(node, input, graph_view);
auto symbol = symbol_table.CreateSymbol(identifier);
auto symbol = symbol_table.CreateSymbol(identifier, true);
symbol_table[*node->identifier_] = symbol;
// return std::make_tuple(node, logical_op, symbol);
return ScanAllTuple{node, logical_op, symbol};
@ -114,11 +114,11 @@ ExpandTuple MakeExpand(AstTreeStorage &storage, SymbolTable &symbol_table,
const std::string &node_identifier, bool existing_node,
GraphView graph_view = GraphView::AS_IS) {
auto edge = EDGE(edge_identifier, direction);
auto edge_sym = symbol_table.CreateSymbol(edge_identifier);
auto edge_sym = symbol_table.CreateSymbol(edge_identifier, true);
symbol_table[*edge->identifier_] = edge_sym;
auto node = NODE(node_identifier);
auto node_sym = symbol_table.CreateSymbol(node_identifier);
auto node_sym = symbol_table.CreateSymbol(node_identifier, true);
symbol_table[*node->identifier_] = node_sym;
auto op = std::make_shared<Expand>(node, edge, input, input_symbol,

View File

@ -33,7 +33,7 @@ TEST(QueryPlan, CreateNodeWithAttributes) {
SymbolTable symbol_table;
auto node = NODE("n");
symbol_table[*node->identifier_] = symbol_table.CreateSymbol("n");
symbol_table[*node->identifier_] = symbol_table.CreateSymbol("n", true);
node->labels_.emplace_back(label);
node->properties_[property] = LITERAL(42);
@ -67,19 +67,20 @@ TEST(QueryPlan, CreateReturn) {
SymbolTable symbol_table;
auto node = NODE("n");
auto sym_n = symbol_table.CreateSymbol("n");
auto sym_n = symbol_table.CreateSymbol("n", true);
symbol_table[*node->identifier_] = sym_n;
node->labels_.emplace_back(label);
node->properties_[property] = LITERAL(42);
auto create = std::make_shared<CreateNode>(node, nullptr);
auto named_expr_n = NEXPR("n", IDENT("n"));
symbol_table[*named_expr_n] = symbol_table.CreateSymbol("named_expr_n");
symbol_table[*named_expr_n] = symbol_table.CreateSymbol("named_expr_n", true);
symbol_table[*named_expr_n->expression_] = sym_n;
auto prop_lookup = PROPERTY_LOOKUP("n", property);
symbol_table[*prop_lookup->expression_] = sym_n;
auto named_expr_n_p = NEXPR("n", prop_lookup);
symbol_table[*named_expr_n_p] = symbol_table.CreateSymbol("named_expr_n_p");
symbol_table[*named_expr_n_p] =
symbol_table.CreateSymbol("named_expr_n_p", true);
symbol_table[*named_expr_n->expression_] = sym_n;
auto produce = MakeProduce(create, named_expr_n, named_expr_n_p);
@ -119,7 +120,7 @@ TEST(QueryPlan, CreateExpand) {
auto n = NODE("n");
n->labels_.emplace_back(label_node_1);
n->properties_[property] = LITERAL(1);
auto n_sym = symbol_table.CreateSymbol("n");
auto n_sym = symbol_table.CreateSymbol("n", true);
symbol_table[*n->identifier_] = n_sym;
// data for the second node
@ -129,10 +130,10 @@ TEST(QueryPlan, CreateExpand) {
if (cycle)
symbol_table[*m->identifier_] = n_sym;
else
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m");
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true);
auto r = EDGE("r", EdgeAtom::Direction::RIGHT);
symbol_table[*r->identifier_] = symbol_table.CreateSymbol("r");
symbol_table[*r->identifier_] = symbol_table.CreateSymbol("r", true);
r->edge_types_.emplace_back(edge_type);
r->properties_[property] = LITERAL(3);
@ -188,7 +189,7 @@ TEST(QueryPlan, MatchCreateNode) {
auto n_scan_all = MakeScanAll(storage, symbol_table, "n");
// second node
auto m = NODE("m");
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m");
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true);
// creation op
auto create_node = std::make_shared<CreateNode>(m, n_scan_all.op_);
@ -229,10 +230,10 @@ TEST(QueryPlan, MatchCreateExpand) {
if (cycle)
symbol_table[*m->identifier_] = n_scan_all.sym_;
else
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m");
symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true);
auto r = EDGE("r", EdgeAtom::Direction::RIGHT);
symbol_table[*r->identifier_] = symbol_table.CreateSymbol("r");
symbol_table[*r->identifier_] = symbol_table.CreateSymbol("r", true);
r->edge_types_.emplace_back(edge_type);
auto create_expand = std::make_shared<CreateExpand>(m, r, n_scan_all.op_,
@ -403,7 +404,7 @@ TEST(QueryPlan, DeleteReturn) {
storage.Create<PropertyLookup>(storage.Create<Identifier>("n"), prop);
symbol_table[*prop_lookup->expression_] = n.sym_;
auto n_p = storage.Create<NamedExpression>("n", prop_lookup);
symbol_table[*n_p] = symbol_table.CreateSymbol("bla");
symbol_table[*n_p] = symbol_table.CreateSymbol("bla", true);
auto produce = MakeProduce(delete_op, n_p);
auto result = CollectProduce(produce, symbol_table, *dba);
@ -843,7 +844,7 @@ TEST(QueryPlan, MergeNoInput) {
SymbolTable symbol_table;
auto node = NODE("n");
auto sym_n = symbol_table.CreateSymbol("n");
auto sym_n = symbol_table.CreateSymbol("n", true);
symbol_table[*node->identifier_] = sym_n;
auto create = std::make_shared<CreateNode>(node, nullptr);
auto merge = std::make_shared<plan::Merge>(nullptr, create, create);

View File

@ -39,7 +39,8 @@ TEST(QueryPlan, MatchReturn) {
auto output = NEXPR("n", IDENT("n"));
auto produce = MakeProduce(scan_all.op_, output);
symbol_table[*output->expression_] = scan_all.sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output] =
symbol_table.CreateSymbol("named_expression_1", true);
return PullAll(produce, *dba, symbol_table);
};
@ -68,10 +69,12 @@ TEST(QueryPlan, MatchReturnCartesian) {
auto m = MakeScanAll(storage, symbol_table, "m", n.op_);
auto return_n = NEXPR("n", IDENT("n"));
symbol_table[*return_n->expression_] = n.sym_;
symbol_table[*return_n] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*return_n] =
symbol_table.CreateSymbol("named_expression_1", true);
auto return_m = NEXPR("m", IDENT("m"));
symbol_table[*return_m->expression_] = m.sym_;
symbol_table[*return_m] = symbol_table.CreateSymbol("named_expression_2");
symbol_table[*return_m] =
symbol_table.CreateSymbol("named_expression_2", true);
auto produce = MakeProduce(m.op_, return_n, return_m);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
@ -99,7 +102,7 @@ TEST(QueryPlan, StandaloneReturn) {
auto output = NEXPR("n", LITERAL(42));
auto produce = MakeProduce(std::shared_ptr<LogicalOperator>(nullptr), output);
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 1);
@ -150,7 +153,7 @@ TEST(QueryPlan, NodeFilterLabelsAndProperties) {
// make a named expression and a produce
auto output = NEXPR("x", IDENT("n"));
symbol_table[*output->expression_] = n.sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true);
auto produce = MakeProduce(node_filter, output);
EXPECT_EQ(1, PullAll(produce, *dba, symbol_table));
@ -206,7 +209,7 @@ TEST(QueryPlan, NodeFilterMultipleLabels) {
auto produce = MakeProduce(node_filter, output);
// fill up the symbol table
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true);
symbol_table[*output->expression_] = n.sym_;
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
@ -240,7 +243,8 @@ TEST(QueryPlan, Expand) {
// make a named expression and a produce
auto output = NEXPR("m", IDENT("m"));
symbol_table[*output->expression_] = r_m.node_sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output] =
symbol_table.CreateSymbol("named_expression_1", true);
auto produce = MakeProduce(r_m.op_, output);
return PullAll(produce, *dba, symbol_table);
@ -298,13 +302,13 @@ TEST(QueryPlan, ExpandOptional) {
// RETURN n, r, m
auto n_ne = NEXPR("n", IDENT("n"));
symbol_table[*n_ne->expression_] = n.sym_;
symbol_table[*n_ne] = symbol_table.CreateSymbol("n");
symbol_table[*n_ne] = symbol_table.CreateSymbol("n", true);
auto r_ne = NEXPR("r", IDENT("r"));
symbol_table[*r_ne->expression_] = r_m.edge_sym_;
symbol_table[*r_ne] = symbol_table.CreateSymbol("r");
symbol_table[*r_ne] = symbol_table.CreateSymbol("r", true);
auto m_ne = NEXPR("m", IDENT("m"));
symbol_table[*m_ne->expression_] = r_m.node_sym_;
symbol_table[*m_ne] = symbol_table.CreateSymbol("m");
symbol_table[*m_ne] = symbol_table.CreateSymbol("m", true);
auto produce = MakeProduce(optional, n_ne, r_ne, m_ne);
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
@ -339,7 +343,7 @@ TEST(QueryPlan, OptionalMatchEmptyDB) {
// RETURN n
auto n_ne = NEXPR("n", IDENT("n"));
symbol_table[*n_ne->expression_] = n.sym_;
symbol_table[*n_ne] = symbol_table.CreateSymbol("n");
symbol_table[*n_ne] = symbol_table.CreateSymbol("n", true);
auto optional = std::make_shared<plan::Optional>(nullptr, n.op_,
std::vector<Symbol>{n.sym_});
auto produce = MakeProduce(optional, n_ne);
@ -377,7 +381,8 @@ TEST(QueryPlan, ExpandExistingNode) {
// make a named expression and a produce
auto output = NEXPR("n", IDENT("n"));
symbol_table[*output->expression_] = n.sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output] =
symbol_table.CreateSymbol("named_expression_1", true);
auto produce = MakeProduce(r_n.op_, output);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
@ -420,7 +425,8 @@ TEST(QueryPlan, ExpandExistingEdge) {
// make a named expression and a produce
auto output = NEXPR("r", IDENT("r"));
symbol_table[*output->expression_] = r_j.edge_sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output] =
symbol_table.CreateSymbol("named_expression_1", true);
auto produce = MakeProduce(r_k.op_, output);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
@ -505,7 +511,8 @@ TEST(QueryPlan, EdgeFilter) {
// make a named expression and a produce
auto output = NEXPR("m", IDENT("m"));
symbol_table[*output->expression_] = r_m.node_sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output] =
symbol_table.CreateSymbol("named_expression_1", true);
auto produce = MakeProduce(edge_filter, output);
return PullAll(produce, *dba, symbol_table);
@ -552,7 +559,7 @@ TEST(QueryPlan, EdgeFilterMultipleTypes) {
auto produce = MakeProduce(edge_filter, output);
// fill up the symbol table
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true);
symbol_table[*output->expression_] = r_m.node_sym_;
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
@ -582,7 +589,7 @@ TEST(QueryPlan, Filter) {
auto output =
storage.Create<NamedExpression>("x", storage.Create<Identifier>("n"));
symbol_table[*output->expression_] = n.sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true);
auto produce = MakeProduce(f, output);
EXPECT_EQ(CollectProduce(produce, symbol_table, *dba).GetResults().size(), 2);
@ -647,7 +654,7 @@ TEST(QueryPlan, Distinct) {
auto input_expr = LITERAL(TypedValue(input));
auto x = symbol_table.CreateSymbol("x");
auto x = symbol_table.CreateSymbol("x", true);
auto unwind = std::make_shared<plan::Unwind>(nullptr, input_expr, x);
auto x_expr = IDENT("x");
symbol_table[*x_expr] = x;
@ -656,7 +663,7 @@ TEST(QueryPlan, Distinct) {
std::make_shared<plan::Distinct>(unwind, std::vector<Symbol>{x});
auto x_ne = NEXPR("x", x_expr);
symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne");
symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne", true);
auto produce = MakeProduce(distinct, x_ne);
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();

View File

@ -120,17 +120,17 @@ using ExpectDistinct = OpChecker<Distinct>;
class ExpectAccumulate : public OpChecker<Accumulate> {
public:
ExpectAccumulate(const std::unordered_set<Symbol, Symbol::Hash> &symbols)
ExpectAccumulate(const std::unordered_set<Symbol> &symbols)
: symbols_(symbols) {}
void ExpectOp(Accumulate &op, const SymbolTable &symbol_table) override {
std::unordered_set<Symbol, Symbol::Hash> got_symbols(op.symbols().begin(),
op.symbols().end());
std::unordered_set<Symbol> got_symbols(op.symbols().begin(),
op.symbols().end());
EXPECT_EQ(symbols_, got_symbols);
}
private:
const std::unordered_set<Symbol, Symbol::Hash> symbols_;
const std::unordered_set<Symbol> symbols_;
};
class ExpectAggregate : public OpChecker<Aggregate> {
@ -771,4 +771,55 @@ TEST(TestLogicalPlanner, MatchOptionalMatchWhere) {
ExpectProduce());
}
TEST(TestLogicalPlanner, MatchReturnAsterisk) {
// Test MATCH (n) -[e]- (m) RETURN *, m.prop
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto ret = RETURN(PROPERTY_LOOKUP("m", prop), AS("m.prop"));
ret->body_.all_identifiers = true;
auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), ret);
auto symbol_table = MakeSymbolTable(*query);
auto plan = MakeLogicalPlan(storage, symbol_table);
CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectExpand(),
ExpectProduce());
std::vector<std::string> output_names;
for (const auto &output_symbol : plan->OutputSymbols(symbol_table)) {
output_names.emplace_back(output_symbol.name());
}
std::vector<std::string> expected_names{"e", "m", "n", "m.prop"};
EXPECT_EQ(output_names, expected_names);
}
TEST(TestLogicalPlanner, MatchReturnAsteriskSum) {
// Test MATCH (n) RETURN *, SUM(n.prop) AS s
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
auto ret = RETURN(sum, AS("s"));
ret->body_.all_identifiers = true;
auto query = QUERY(MATCH(PATTERN(NODE("n"))), ret);
auto symbol_table = MakeSymbolTable(*query);
auto plan = MakeLogicalPlan(storage, symbol_table);
auto *produce = dynamic_cast<Produce *>(plan.get());
ASSERT_TRUE(produce);
const auto &named_expressions = produce->named_expressions();
ASSERT_EQ(named_expressions.size(), 2);
auto *expanded_ident =
dynamic_cast<query::Identifier *>(named_expressions[0]->expression_);
ASSERT_TRUE(expanded_ident);
auto aggr = ExpectAggregate({sum}, {expanded_ident});
CheckPlan(*plan, symbol_table, ExpectScanAll(), aggr, ExpectProduce());
std::vector<std::string> output_names;
for (const auto &output_symbol : plan->OutputSymbols(symbol_table)) {
output_names.emplace_back(output_symbol.name());
}
std::vector<std::string> expected_names{"n", "s"};
EXPECT_EQ(output_names, expected_names);
}
} // namespace

View File

@ -26,12 +26,12 @@ TEST(TestSymbolGenerator, MatchNodeReturn) {
auto pattern = match->patterns_[0];
auto node_atom = dynamic_cast<NodeAtom *>(pattern->atoms_[0]);
auto node_sym = symbol_table[*node_atom->identifier_];
EXPECT_EQ(node_sym.name_, "node_atom_1");
EXPECT_EQ(node_sym.type_, Symbol::Type::Vertex);
EXPECT_EQ(node_sym.name(), "node_atom_1");
EXPECT_EQ(node_sym.type(), Symbol::Type::Vertex);
auto ret = dynamic_cast<Return *>(query_ast->clauses_[1]);
auto named_expr = ret->body_.named_expressions[0];
auto column_sym = symbol_table[*named_expr];
EXPECT_EQ(node_sym.name_, column_sym.name_);
EXPECT_EQ(node_sym.name(), column_sym.name());
EXPECT_NE(node_sym, column_sym);
auto ret_sym = symbol_table[*named_expr->expression_];
EXPECT_EQ(node_sym, ret_sym);
@ -87,12 +87,12 @@ TEST(TestSymbolGenerator, MatchSameEdge) {
is_node = !is_node;
}
auto &node_symbol = node_symbols.front();
EXPECT_EQ(node_symbol.type_, Symbol::Type::Vertex);
EXPECT_EQ(node_symbol.type(), Symbol::Type::Vertex);
for (auto &symbol : node_symbols) {
EXPECT_EQ(node_symbol, symbol);
}
auto &edge_symbol = edge_symbols.front();
EXPECT_EQ(edge_symbol.type_, Symbol::Type::Edge);
EXPECT_EQ(edge_symbol.type(), Symbol::Type::Edge);
for (auto &symbol : edge_symbols) {
EXPECT_EQ(edge_symbol, symbol);
}
@ -129,12 +129,12 @@ TEST(TestSymbolGenerator, CreateNodeReturn) {
auto pattern = create->patterns_[0];
auto node_atom = dynamic_cast<NodeAtom *>(pattern->atoms_[0]);
auto node_sym = symbol_table[*node_atom->identifier_];
EXPECT_EQ(node_sym.name_, "n");
EXPECT_EQ(node_sym.type_, Symbol::Type::Vertex);
EXPECT_EQ(node_sym.name(), "n");
EXPECT_EQ(node_sym.type(), Symbol::Type::Vertex);
auto ret = dynamic_cast<Return *>(query_ast->clauses_[1]);
auto named_expr = ret->body_.named_expressions[0];
auto column_sym = symbol_table[*named_expr];
EXPECT_EQ(node_sym.name_, column_sym.name_);
EXPECT_EQ(node_sym.name(), column_sym.name());
EXPECT_NE(node_sym, column_sym);
auto ret_sym = symbol_table[*named_expr->expression_];
EXPECT_EQ(node_sym, ret_sym);
@ -260,7 +260,7 @@ TEST(TestSymbolGenerator, CreateDelete) {
EXPECT_EQ(symbol_table.max_position(), 1);
auto node_symbol = symbol_table.at(*node->identifier_);
auto ident_symbol = symbol_table.at(*ident);
EXPECT_EQ(node_symbol.type_, Symbol::Type::Vertex);
EXPECT_EQ(node_symbol.type(), Symbol::Type::Vertex);
EXPECT_EQ(node_symbol, ident_symbol);
}
@ -369,18 +369,18 @@ TEST(TestSymbolGenerator, CreateMultiExpand) {
auto n1 = symbol_table.at(*node_n1->identifier_);
auto n2 = symbol_table.at(*node_n2->identifier_);
EXPECT_EQ(n1, n2);
EXPECT_EQ(n1.type_, Symbol::Type::Vertex);
EXPECT_EQ(n1.type(), Symbol::Type::Vertex);
auto m = symbol_table.at(*node_m->identifier_);
EXPECT_EQ(m.type_, Symbol::Type::Vertex);
EXPECT_EQ(m.type(), Symbol::Type::Vertex);
EXPECT_NE(m, n1);
auto l = symbol_table.at(*node_l->identifier_);
EXPECT_EQ(l.type_, Symbol::Type::Vertex);
EXPECT_EQ(l.type(), Symbol::Type::Vertex);
EXPECT_NE(l, n1);
EXPECT_NE(l, m);
auto r = symbol_table.at(*edge_r->identifier_);
auto p = symbol_table.at(*edge_p->identifier_);
EXPECT_EQ(r.type_, Symbol::Type::Edge);
EXPECT_EQ(p.type_, Symbol::Type::Edge);
EXPECT_EQ(r.type(), Symbol::Type::Edge);
EXPECT_EQ(p.type(), Symbol::Type::Edge);
EXPECT_NE(r, p);
}
@ -527,11 +527,11 @@ TEST(TestSymbolGenerator, MatchWithCreate) {
query->Accept(symbol_generator);
EXPECT_EQ(symbol_table.max_position(), 3);
auto n = symbol_table.at(*node_1->identifier_);
EXPECT_EQ(n.type_, Symbol::Type::Vertex);
EXPECT_EQ(n.type(), Symbol::Type::Vertex);
auto m = symbol_table.at(*node_2->identifier_);
EXPECT_NE(n, m);
// Currently we don't infer expression types, so we lost true type of 'm'.
EXPECT_EQ(m.type_, Symbol::Type::Any);
EXPECT_EQ(m.type(), Symbol::Type::Any);
EXPECT_EQ(m, symbol_table.at(*node_3->identifier_));
}
@ -792,4 +792,52 @@ TEST(TestSymbolGenerator, MatchCrossReferenceVariable) {
EXPECT_NE(m, symbol_table.at(*as_n));
}
TEST(TestSymbolGenerator, MatchWithAsteriskReturnAsterisk) {
// MATCH (n) -[e]- (m) WITH * RETURN *, n.prop
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto n_prop = PROPERTY_LOOKUP("n", prop);
auto ret = RETURN(n_prop, AS("n.prop"));
ret->body_.all_identifiers = true;
auto node_n = NODE("n");
auto edge = EDGE("e");
auto node_m = NODE("m");
auto with = storage.Create<With>();
with->body_.all_identifiers = true;
auto query = QUERY(MATCH(PATTERN(node_n, edge, node_m)), with, ret);
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
// Symbols for `n`, `e`, `m`, `AS n.prop`.
EXPECT_EQ(symbol_table.max_position(), 4);
auto n = symbol_table.at(*node_n->identifier_);
EXPECT_EQ(n, symbol_table.at(*n_prop->expression_));
}
TEST(TestSymbolGenerator, MatchReturnAsteriskSameResult) {
// MATCH (n) RETURN *, n AS n
AstTreeStorage storage;
auto ret = RETURN(IDENT("n"), AS("n"));
ret->body_.all_identifiers = true;
auto query = QUERY(MATCH(PATTERN(NODE("n"))), ret);
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
}
TEST(TestSymbolGenerator, MatchReturnAsteriskNoUserVariables) {
// MATCH () RETURN *
AstTreeStorage storage;
auto ret = storage.Create<Return>();
ret->body_.all_identifiers = true;
auto ident_n = storage.Create<Identifier>("anon", false);
auto node = storage.Create<NodeAtom>(ident_n);
auto query = QUERY(MATCH(PATTERN(node)), ret);
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
}
} // namespace