From 87e5dc0dfb1d3f266e32f1ec753708bfd7f597c3 Mon Sep 17 00:00:00 2001 From: Teon Banek Date: Fri, 12 May 2017 10:37:22 +0100 Subject: [PATCH] Plan '*' in RETURN and WITH Summary: Make Symbol members read only. Check WITH/RETURN * in SymbolGenerator. Test semantic checks for WITH/RETURN *. Sort expanded user identifiers by name. Test planning WITH/RETURN *. Reviewers: buda, florijan, mislav.bradac Reviewed By: florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D357 --- .../frontend/semantic/symbol_generator.cpp | 41 ++++-- .../frontend/semantic/symbol_generator.hpp | 4 +- src/query/frontend/semantic/symbol_table.hpp | 53 ++++--- src/query/interpret/frame.hpp | 4 +- src/query/interpreter.hpp | 2 +- src/query/plan/planner.cpp | 136 ++++++++++++------ src/query/plan/planner.hpp | 2 +- tests/unit/query_expression_evaluator.cpp | 8 +- .../unit/query_plan_accumulate_aggregate.cpp | 24 ++-- tests/unit/query_plan_bag_semantics.cpp | 8 +- tests/unit/query_plan_common.hpp | 6 +- .../query_plan_create_set_remove_delete.cpp | 25 ++-- tests/unit/query_plan_match_filter_return.cpp | 43 +++--- tests/unit/query_planner.cpp | 59 +++++++- tests/unit/query_semantic.cpp | 80 ++++++++--- 15 files changed, 347 insertions(+), 148 deletions(-) diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index 76389cd65..6ac87515a 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -8,32 +8,46 @@ namespace query { -auto SymbolGenerator::CreateSymbol(const std::string &name, Symbol::Type type) { - auto symbol = symbol_table_.CreateSymbol(name, type); +auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared, + Symbol::Type type) { + auto symbol = symbol_table_.CreateSymbol(name, user_declared, type); scope_.symbols[name] = symbol; return symbol; } auto SymbolGenerator::GetOrCreateSymbol(const std::string &name, - Symbol::Type type) { + bool user_declared, Symbol::Type type) { auto search = scope_.symbols.find(name); if (search != scope_.symbols.end()) { auto symbol = search->second; // Unless we have `Any` type, check that types match. - if (type != Symbol::Type::Any && symbol.type_ != Symbol::Type::Any && - type != symbol.type_) { - throw TypeMismatchError(name, Symbol::TypeToString(symbol.type_), + if (type != Symbol::Type::Any && symbol.type() != Symbol::Type::Any && + type != symbol.type()) { + throw TypeMismatchError(name, Symbol::TypeToString(symbol.type()), Symbol::TypeToString(type)); } return search->second; } - return CreateSymbol(name, type); + return CreateSymbol(name, user_declared, type); } void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) { for (auto &expr : body.named_expressions) { expr->Accept(*this); } + std::vector user_symbols; + if (body.all_identifiers) { + // Carry over user symbols because '*' appeared. + for (auto sym_pair : scope_.symbols) { + if (!sym_pair.second.user_declared()) { + continue; + } + user_symbols.emplace_back(sym_pair.second); + } + if (user_symbols.empty()) { + throw SemanticException("There are no variables in scope to use for '*'"); + } + } // WITH/RETURN clause removes declarations of all the previous variables and // declares only those established through named expressions. New declarations // must not be visible inside named expressions themselves. @@ -47,6 +61,10 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) { } // Create symbols for named expressions. std::unordered_set new_names; + for (const auto &user_sym : user_symbols) { + new_names.insert(user_sym.name()); + scope_.symbols[user_sym.name()] = user_sym; + } for (auto &named_expr : body.named_expressions) { const auto &name = named_expr->name_; if (!new_names.insert(name).second) { @@ -55,7 +73,7 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) { } // An improvement would be to infer the type of the expression, so that the // new symbol would have a more specific type. - symbol_table_[*named_expr] = CreateSymbol(name); + symbol_table_[*named_expr] = CreateSymbol(name, true); } scope_.in_order_by = true; for (const auto &order_pair : body.order_by) { @@ -119,7 +137,7 @@ void SymbolGenerator::PostVisit(Unwind &unwind) { if (HasSymbol(name)) { throw RedeclareVariableError(name); } - symbol_table_[*unwind.named_expression_] = CreateSymbol(name); + symbol_table_[*unwind.named_expression_] = CreateSymbol(name, true); } void SymbolGenerator::Visit(Match &) { scope_.in_match = true; } @@ -163,7 +181,7 @@ void SymbolGenerator::Visit(Identifier &ident) { if (scope_.in_edge_atom) { type = Symbol::Type::Edge; } - symbol = GetOrCreateSymbol(ident.name_, type); + symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, type); } else if (scope_.in_pattern && scope_.in_property_map && scope_.in_match) { // Variables in property maps during MATCH can reference symbols bound later // in the same MATCH. We collect them here, so that they can be checked @@ -193,7 +211,8 @@ void SymbolGenerator::Visit(Aggregation &aggr) { } // Create a virtual symbol for aggregation result. // Currently, we only have aggregation operators which return numbers. - symbol_table_[aggr] = symbol_table_.CreateSymbol("", Symbol::Type::Number); + symbol_table_[aggr] = + symbol_table_.CreateSymbol("", false, Symbol::Type::Number); scope_.in_aggregation = true; scope_.has_aggregation = true; } diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index ef964480b..4fba06e68 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -88,12 +88,12 @@ class SymbolGenerator : public TreeVisitorBase { // Returns a freshly generated symbol. Previous mapping of the same name to a // different symbol is replaced with the new one. - auto CreateSymbol(const std::string &name, + auto CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::Any); // Returns the symbol by name. If the mapping already exists, checks if the // types match. Otherwise, returns a new symbol. - auto GetOrCreateSymbol(const std::string &name, + auto GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::Any); void VisitReturnBody(ReturnBody &body, Where *where = nullptr); diff --git a/src/query/frontend/semantic/symbol_table.hpp b/src/query/frontend/semantic/symbol_table.hpp index 76856728a..754af7af2 100644 --- a/src/query/frontend/semantic/symbol_table.hpp +++ b/src/query/frontend/semantic/symbol_table.hpp @@ -17,34 +17,38 @@ class Symbol { return enum_string[static_cast(type)]; } - // Calculates the Symbol hash based on its position. - struct Hash { - size_t operator()(const Symbol &symbol) const { - return std::hash{}(symbol.position_); - } - }; - Symbol() {} - Symbol(const std::string &name, int position, Type type = Type::Any) - : name_(name), position_(position), type_(type) {} - - std::string name_; - int position_; - Type type_{Type::Any}; + Symbol(const std::string &name, int position, bool user_declared, + Type type = Type::Any) + : name_(name), + position_(position), + user_declared_(user_declared), + type_(type) {} bool operator==(const Symbol &other) const { return position_ == other.position_ && name_ == other.name_ && type_ == other.type_; } bool operator!=(const Symbol &other) const { return !operator==(other); } + + const auto &name() const { return name_; } + int position() const { return position_; } + Type type() const { return type_; } + bool user_declared() const { return user_declared_; } + + private: + std::string name_; + int position_; + bool user_declared_ = true; + Type type_ = Type::Any; }; class SymbolTable { public: - Symbol CreateSymbol(const std::string &name, + Symbol CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::Any) { int position = position_++; - return Symbol(name, position, type); + return Symbol(name, position, user_declared, type); } auto &operator[](const Tree &tree) { return table_[tree.uid()]; } @@ -59,4 +63,21 @@ class SymbolTable { std::map table_; }; -} +} // namespace query + +namespace std { + +template <> +struct hash { + size_t operator()(const query::Symbol &symbol) const { + size_t prime = 265443599u; + size_t hash = std::hash{}(symbol.position()); + hash ^= prime * std::hash{}(symbol.name()); + hash ^= prime * std::hash{}(symbol.user_declared()); + hash ^= prime * std::hash{}(static_cast(symbol.type())); + return hash; + } +}; + +} // namespace std + diff --git a/src/query/interpret/frame.hpp b/src/query/interpret/frame.hpp index 091575216..9449eb2ca 100644 --- a/src/query/interpret/frame.hpp +++ b/src/query/interpret/frame.hpp @@ -12,10 +12,10 @@ class Frame { Frame(int size) : size_(size), elems_(size_) {} TypedValue &operator[](const Symbol &symbol) { - return elems_[symbol.position_]; + return elems_[symbol.position()]; } const TypedValue &operator[](const Symbol &symbol) const { - return elems_[symbol.position_]; + return elems_[symbol.position()]; } private: diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 3cd14c7d1..d9d4690d9 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -48,7 +48,7 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor, // clause, so stream out the results. // generate header - for (const auto &symbol : output_symbols) header.push_back(symbol.name_); + for (const auto &symbol : output_symbols) header.push_back(symbol.name()); stream.Header(header); // stream out results diff --git a/src/query/plan/planner.cpp b/src/query/plan/planner.cpp index d6f356112..39ab89e33 100644 --- a/src/query/plan/planner.cpp +++ b/src/query/plan/planner.cpp @@ -13,8 +13,9 @@ namespace { // Returns false if the symbol was already bound, otherwise binds it and // returns true. -bool BindSymbol(std::unordered_set &bound_symbols, const Symbol &symbol) { - auto insertion = bound_symbols.insert(symbol.position_); +bool BindSymbol(std::unordered_set &bound_symbols, + const Symbol &symbol) { + auto insertion = bound_symbols.insert(symbol); return insertion.second; } @@ -61,7 +62,7 @@ auto ReducePattern( auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op, const SymbolTable &symbol_table, - std::unordered_set &bound_symbols) { + std::unordered_set &bound_symbols) { auto base = [&](NodeAtom *node) -> LogicalOperator * { if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) return new CreateNode(node, std::shared_ptr(input_op)); @@ -92,7 +93,7 @@ auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op, auto GenCreate(Create &create, LogicalOperator *input_op, const SymbolTable &symbol_table, - std::unordered_set &bound_symbols) { + std::unordered_set &bound_symbols) { auto last_op = input_op; for (auto pattern : create.patterns_) { last_op = @@ -109,17 +110,16 @@ class UsedSymbolsCollector : public TreeVisitorBase { using TreeVisitorBase::Visit; void Visit(Identifier &ident) override { - const auto &symbol = symbol_table_.at(ident); - symbols_.insert(symbol.position_); + symbols_.insert(symbol_table_.at(ident)); } - std::unordered_set symbols_; + std::unordered_set symbols_; const SymbolTable &symbol_table_; }; bool HasBoundFilterSymbols( - const std::unordered_set &bound_symbols, - const std::pair> &filter) { + const std::unordered_set &bound_symbols, + const std::pair> &filter) { for (const auto &symbol : filter.second) { if (bound_symbols.find(symbol) == bound_symbols.end()) { return false; @@ -154,7 +154,7 @@ Expression *PropertiesEqual(AstTreeStorage &storage, auto &CollectPatternFilters( Pattern &pattern, const SymbolTable &symbol_table, - std::list>> &filters, + std::list>> &filters, AstTreeStorage &storage) { UsedSymbolsCollector collector(symbol_table); auto node_filter = [&](NodeAtom *node) { @@ -164,7 +164,7 @@ auto &CollectPatternFilters( node->identifier_, node->labels_); auto *props_filter = PropertiesEqual(storage, collector, node); if (labels_filter || props_filter) { - collector.symbols_.insert(symbol_table.at(*node->identifier_).position_); + collector.symbols_.insert(symbol_table.at(*node->identifier_)); filters.emplace_back( BoolJoin(storage, labels_filter, props_filter), collector.symbols_); @@ -181,7 +181,7 @@ auto &CollectPatternFilters( auto *props_filter = PropertiesEqual(storage, collector, edge); if (types_filter || props_filter) { const auto &edge_symbol = symbol_table.at(*edge->identifier_); - collector.symbols_.insert(edge_symbol.position_); + collector.symbols_.insert(edge_symbol); filters->emplace_back( BoolJoin(storage, types_filter, props_filter), collector.symbols_); @@ -190,13 +190,13 @@ auto &CollectPatternFilters( return node_filter(node); }; return *ReducePattern< - std::list>> *>( + std::list>> *>( pattern, node_filter, expand_filter); } void CollectMatchFilters( const Match &match, const SymbolTable &symbol_table, - std::list>> &filters, + std::list>> &filters, AstTreeStorage &storage) { for (auto *pattern : match.patterns_) { CollectPatternFilters(*pattern, symbol_table, filters, storage); @@ -214,22 +214,22 @@ struct MatchContext { // Already bound symbols, which are used to determine whether the operator // should reference them or establish new. This is both read from and written // to during generation. - std::unordered_set &bound_symbols; + std::unordered_set &bound_symbols; // Determines whether the match should see the new graph state or not. GraphView graph_view = GraphView::OLD; // Pairs of filter expression and symbols used in them. The list should be // filled using CollectPatternFilters function, and later modified during // GenMatchForPattern. - std::list>> filters; + std::list>> filters; // Symbols for edges established in match, used to ensure Cyphermorphism. - std::unordered_set edge_symbols; + std::unordered_set edge_symbols; // All the newly established symbols in match. std::vector new_symbols; }; auto GenFilters( - LogicalOperator *last_op, const std::unordered_set &bound_symbols, - std::list>> &filters, + LogicalOperator *last_op, const std::unordered_set &bound_symbols, + std::list>> &filters, AstTreeStorage &storage) { Expression *filter_expr = nullptr; for (auto filters_it = filters.begin(); filters_it != filters.end();) { @@ -312,7 +312,7 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op, auto GenMatches(std::vector &matches, LogicalOperator *input_op, const SymbolTable &symbol_table, - std::unordered_set &bound_symbols, + std::unordered_set &bound_symbols, AstTreeStorage &storage) { auto *last_op = input_op; MatchContext req_ctx{symbol_table, bound_symbols}; @@ -369,15 +369,27 @@ auto GenMatches(std::vector &matches, LogicalOperator *input_op, // aggregations and expressions used for group by. class ReturnBodyContext : public TreeVisitorBase { public: - ReturnBodyContext(const ReturnBody &body, const SymbolTable &symbol_table, - Where *where = nullptr) - : body_(body), symbol_table_(symbol_table), where_(where) { + ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table, + const std::unordered_set &bound_symbols, + AstTreeStorage &storage, Where *where = nullptr) + : body_(body), + symbol_table_(symbol_table), + bound_symbols_(bound_symbols), + storage_(storage), + where_(where) { // Collect symbols from named expressions. output_symbols_.reserve(body_.named_expressions.size()); + if (body.all_identifiers) { + // Expand '*' to expressions and symbols first, so that their results come + // before regular named expressions. + ExpandUserSymbols(); + } for (auto &named_expr : body_.named_expressions) { output_symbols_.emplace_back(symbol_table_.at(*named_expr)); named_expr->Accept(*this); + named_expressions_.emplace_back(named_expr); } + // Collect aggregations. if (aggregations_.empty()) { // Visit order_by and where if we do not have aggregations. This way we // prevent collecting group_by expressions from order_by and where, which @@ -456,7 +468,7 @@ class ReturnBodyContext : public TreeVisitorBase { // Aggregation contains a virtual symbol, where the result will be stored. const auto &symbol = symbol_table_.at(aggr); aggregations_.emplace_back(aggr.expression_, aggr.op_, symbol); - // aggregation expression_ is opional in COUNT(*) so it's possible the + // aggregation expression_ is optional in COUNT(*), so it's possible the // has_aggregation_ stack is empty if (aggr.expression_) has_aggregation_.back() = true; @@ -474,10 +486,41 @@ class ReturnBodyContext : public TreeVisitorBase { has_aggregation_.pop_back(); } + // Creates NamedExpression with an Identifier for each user declared symbol. + // This should be used when body.all_identifiers is true, to generate + // expressions for Produce operator. + void ExpandUserSymbols() { + debug_assert( + named_expressions_.empty(), + "ExpandUserSymbols should be first to fill named_expressions_"); + debug_assert(output_symbols_.empty(), + "ExpandUserSymbols should be first to fill output_symbols_"); + for (const auto &symbol : bound_symbols_) { + if (!symbol.user_declared()) { + continue; + } + auto *ident = storage_.Create(symbol.name()); + symbol_table_[*ident] = symbol; + auto *named_expr = storage_.Create(symbol.name(), ident); + symbol_table_[*named_expr] = symbol; + // Fill output expressions and symbols with expanded identifiers. + named_expressions_.emplace_back(named_expr); + output_symbols_.emplace_back(symbol); + used_symbols_.insert(symbol); + // Don't forget to group by expanded identifiers. + group_by_.emplace_back(ident); + } + // Cypher RETURN/WITH * expects to expand '*' sorted by name. + std::sort(output_symbols_.begin(), output_symbols_.end(), + [](const auto &a, const auto &b) { return a.name() < b.name(); }); + std::sort(named_expressions_.begin(), named_expressions_.end(), + [](const auto &a, const auto &b) { return a->name_ < b->name_; }); + } + // If true, results need to be distinct. bool distinct() const { return body_.distinct; } // Named expressions which are used to produce results. - const auto &named_expressions() const { return body_.named_expressions; } + const auto &named_expressions() const { return named_expressions_; } // Pairs of (Ordering, Expression *) for sorting results. const auto &order_by() const { return body_.order_by; } // Optional expression which determines how many results to skip. @@ -496,21 +539,23 @@ class ReturnBodyContext : public TreeVisitorBase { // expressions are used for grouping. For example, in `WITH sum(n.a) + 2 * n.b // AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`. const auto &group_by() const { return group_by_; } - // All symbols generated by named expressions. They are collected in order of // named_expressions. const auto &output_symbols() const { return output_symbols_; } private: const ReturnBody &body_; - const SymbolTable &symbol_table_; + SymbolTable &symbol_table_; + const std::unordered_set &bound_symbols_; + AstTreeStorage &storage_; const Where *const where_ = nullptr; - std::unordered_set used_symbols_; + std::unordered_set used_symbols_; std::vector output_symbols_; std::vector aggregations_; std::vector group_by_; // Flag indicating whether an expression contains an aggregation. std::list has_aggregation_; + std::vector named_expressions_; }; auto GenReturnBody(LogicalOperator *input_op, bool advance_command, @@ -561,9 +606,9 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command, return last_op; } -auto GenWith(With &with, LogicalOperator *input_op, - const SymbolTable &symbol_table, bool is_write, - std::unordered_set &bound_symbols) { +auto GenWith(With &with, LogicalOperator *input_op, SymbolTable &symbol_table, + bool is_write, std::unordered_set &bound_symbols, + AstTreeStorage &storage) { // WITH clause is Accumulate/Aggregate (advance_command) + Produce and // optional Filter. In case of update and aggregation, we want to accumulate // first, so that when aggregating, we get the latest results. Similar to @@ -571,7 +616,8 @@ auto GenWith(With &with, LogicalOperator *input_op, bool accumulate = is_write; // No need to advance the command if we only performed reads. bool advance_command = is_write; - ReturnBodyContext body(with.body_, symbol_table, with.where_); + ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage, + with.where_); LogicalOperator *last_op = GenReturnBody(input_op, advance_command, body, accumulate); // Reset bound symbols, so that only those in WITH are exposed. @@ -583,7 +629,9 @@ auto GenWith(With &with, LogicalOperator *input_op, } auto GenReturn(Return &ret, LogicalOperator *input_op, - const SymbolTable &symbol_table, bool is_write) { + SymbolTable &symbol_table, bool is_write, + const std::unordered_set &bound_symbols, + AstTreeStorage &storage) { // Similar to WITH clause, but we want to accumulate and advance command when // the query writes to the database. This way we handle the case when we want // to return expressions with the latest updated results. For example, @@ -592,7 +640,7 @@ auto GenReturn(Return &ret, LogicalOperator *input_op, // value is the same, final result of 'k' increments. bool accumulate = is_write; bool advance_command = false; - ReturnBodyContext body(ret.body_, symbol_table); + ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage); return GenReturnBody(input_op, advance_command, body, accumulate); } @@ -600,7 +648,7 @@ auto GenReturn(Return &ret, LogicalOperator *input_op, // isn't handled, returns nullptr. LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op, const SymbolTable &symbol_table, - std::unordered_set &bound_symbols) { + std::unordered_set &bound_symbols) { if (auto *create = dynamic_cast(clause)) { return GenCreate(*create, input_op, symbol_table, bound_symbols); } else if (auto *del = dynamic_cast(clause)) { @@ -632,10 +680,11 @@ LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op, auto GenMerge(query::Merge &merge, LogicalOperator *input_op, const SymbolTable &symbol_table, - std::unordered_set &bound_symbols, AstTreeStorage &storage) { + std::unordered_set &bound_symbols, + AstTreeStorage &storage) { // Copy the bound symbol set, because we don't want to use the updated version // when generating the create part. - std::unordered_set bound_symbols_copy(bound_symbols); + std::unordered_set bound_symbols_copy(bound_symbols); MatchContext context{symbol_table, bound_symbols_copy, GraphView::NEW}; CollectPatternFilters(*merge.pattern_, symbol_table, context.filters, storage); @@ -659,14 +708,14 @@ auto GenMerge(query::Merge &merge, LogicalOperator *input_op, } // namespace -std::unique_ptr MakeLogicalPlan( - AstTreeStorage &storage, const SymbolTable &symbol_table) { +std::unique_ptr MakeLogicalPlan(AstTreeStorage &storage, + SymbolTable &symbol_table) { auto query = storage.query(); // bound_symbols set is used to differentiate cycles in pattern matching, so // that the operator can be correctly initialized whether to read the symbol // or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first // `n`, but the latter `n` would only read the already written information. - std::unordered_set bound_symbols; + std::unordered_set bound_symbols; // Set to true if a query command writes to the database. bool is_write = false; LogicalOperator *input_op = nullptr; @@ -681,7 +730,8 @@ std::unique_ptr MakeLogicalPlan( GenMatches(matches, input_op, symbol_table, bound_symbols, storage); matches.clear(); if (auto *ret = dynamic_cast(clause)) { - input_op = GenReturn(*ret, input_op, symbol_table, is_write); + input_op = GenReturn(*ret, input_op, symbol_table, is_write, + bound_symbols, storage); } else if (auto *merge = dynamic_cast(clause)) { input_op = GenMerge(*merge, input_op, symbol_table, bound_symbols, storage); @@ -689,8 +739,8 @@ std::unique_ptr MakeLogicalPlan( // anything. is_write = true; } else if (auto *with = dynamic_cast(clause)) { - input_op = - GenWith(*with, input_op, symbol_table, is_write, bound_symbols); + input_op = GenWith(*with, input_op, symbol_table, is_write, + bound_symbols, storage); // WITH clause advances the command, so reset the flag. is_write = false; } else if (auto *op = HandleWriteClause(clause, input_op, symbol_table, diff --git a/src/query/plan/planner.hpp b/src/query/plan/planner.hpp index af76d6559..f311195b3 100644 --- a/src/query/plan/planner.hpp +++ b/src/query/plan/planner.hpp @@ -18,7 +18,7 @@ namespace plan { /// use in operators. @c SymbolTable is used to determine inputs and outputs of /// certain operators. std::unique_ptr MakeLogicalPlan( - AstTreeStorage &storage, const query::SymbolTable &symbol_table); + AstTreeStorage &storage, query::SymbolTable &symbol_table); } } // namespace plan diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index eb40fdcd5..c57f2a127 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -505,7 +505,7 @@ TEST(ExpressionEvaluator, PropertyLookup) { auto v1 = dba->insert_vertex(); v1.PropsSet(dba->property("age"), 10); auto *identifier = storage.Create("n"); - auto node_symbol = eval.symbol_table.CreateSymbol("n"); + auto node_symbol = eval.symbol_table.CreateSymbol("n", true); eval.symbol_table[*identifier] = node_symbol; eval.frame[node_symbol] = v1; { @@ -537,7 +537,7 @@ TEST(ExpressionEvaluator, LabelsTest) { v1.add_label(dba->label("DOG")); v1.add_label(dba->label("NICE_DOG")); auto *identifier = storage.Create("n"); - auto node_symbol = eval.symbol_table.CreateSymbol("n"); + auto node_symbol = eval.symbol_table.CreateSymbol("n", true); eval.symbol_table[*identifier] = node_symbol; eval.frame[node_symbol] = v1; { @@ -575,7 +575,7 @@ TEST(ExpressionEvaluator, EdgeTypeTest) { auto v2 = dba->insert_vertex(); auto e = dba->insert_edge(v1, v2, dba->edge_type("TYPE1")); auto *identifier = storage.Create("e"); - auto edge_symbol = eval.symbol_table.CreateSymbol("e"); + auto edge_symbol = eval.symbol_table.CreateSymbol("e", true); eval.symbol_table[*identifier] = edge_symbol; eval.frame[edge_symbol] = e; { @@ -608,7 +608,7 @@ TEST(ExpressionEvaluator, Aggregation) { auto aggr = storage.Create(storage.Create(42), Aggregation::Op::COUNT); SymbolTable symbol_table; - auto aggr_sym = symbol_table.CreateSymbol("aggr"); + auto aggr_sym = symbol_table.CreateSymbol("aggr", true); symbol_table[*aggr] = aggr_sym; Frame frame{symbol_table.max_position()}; frame[aggr_sym] = TypedValue(1); diff --git a/tests/unit/query_plan_accumulate_aggregate.cpp b/tests/unit/query_plan_accumulate_aggregate.cpp index 19e69c03e..bd3590547 100644 --- a/tests/unit/query_plan_accumulate_aggregate.cpp +++ b/tests/unit/query_plan_accumulate_aggregate.cpp @@ -65,9 +65,9 @@ TEST(QueryPlan, Accumulate) { } auto n_p_ne = NEXPR("n.p", n_p); - symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne"); + symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne", true); auto m_p_ne = NEXPR("m.p", m_p); - symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne"); + symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne", true); auto produce = MakeProduce(last_op, n_p_ne, m_p_ne); ResultStreamFaker results = CollectProduce(produce, symbol_table, *dba); std::vector results_data; @@ -95,7 +95,7 @@ TEST(QueryPlan, AccumulateAdvance) { SymbolTable symbol_table; auto node = NODE("n"); - auto sym_n = symbol_table.CreateSymbol("n"); + auto sym_n = symbol_table.CreateSymbol("n", true); symbol_table[*node->identifier_] = sym_n; auto create = std::make_shared(node, nullptr); auto accumulate = std::make_shared( @@ -126,8 +126,9 @@ std::shared_ptr MakeAggregationProduce( auto named_expr = NEXPR("", IDENT("aggregation")); named_expressions.push_back(named_expr); symbol_table[*named_expr->expression_] = - symbol_table.CreateSymbol("aggregation"); - symbol_table[*named_expr] = symbol_table.CreateSymbol("named_expression"); + symbol_table.CreateSymbol("aggregation", true); + symbol_table[*named_expr] = + symbol_table.CreateSymbol("named_expression", true); aggregates.emplace_back(*aggr_inputs_it++, aggr_op, symbol_table[*named_expr->expression_]); } @@ -137,7 +138,8 @@ std::shared_ptr MakeAggregationProduce( for (auto group_by_expr : group_by_exprs) { auto named_expr = NEXPR("", group_by_expr); named_expressions.push_back(named_expr); - symbol_table[*named_expr] = symbol_table.CreateSymbol("named_expression"); + symbol_table[*named_expr] = + symbol_table.CreateSymbol("named_expression", true); } auto aggregation = std::make_shared(input, aggregates, group_by_exprs, remember); @@ -307,7 +309,7 @@ TEST(QueryPlan, AggregateNoInput) { auto two = LITERAL(2); auto output = NEXPR("two", IDENT("two")); - symbol_table[*output->expression_] = symbol_table.CreateSymbol("two"); + symbol_table[*output->expression_] = symbol_table.CreateSymbol("two", true); auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two}, {Aggregation::Op::COUNT}, {}, {}); @@ -494,18 +496,18 @@ TEST(QueryPlan, Unwind) { std::vector{1, true, "x"}, std::vector{}, std::vector{"bla"}}); - auto x = symbol_table.CreateSymbol("x"); + auto x = symbol_table.CreateSymbol("x", true); auto unwind_0 = std::make_shared(nullptr, input_expr, x); auto x_expr = IDENT("x"); symbol_table[*x_expr] = x; - auto y = symbol_table.CreateSymbol("y"); + auto y = symbol_table.CreateSymbol("y", true); auto unwind_1 = std::make_shared(unwind_0, x_expr, y); auto x_ne = NEXPR("x", x_expr); - symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne"); + symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne", true); auto y_ne = NEXPR("y", IDENT("y")); symbol_table[*y_ne->expression_] = y; - symbol_table[*y_ne] = symbol_table.CreateSymbol("y_ne"); + symbol_table[*y_ne] = symbol_table.CreateSymbol("y_ne", true); auto produce = MakeProduce(unwind_1, x_ne, y_ne); auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); diff --git a/tests/unit/query_plan_bag_semantics.cpp b/tests/unit/query_plan_bag_semantics.cpp index d733822d0..b8e1daa10 100644 --- a/tests/unit/query_plan_bag_semantics.cpp +++ b/tests/unit/query_plan_bag_semantics.cpp @@ -95,7 +95,7 @@ TEST(QueryPlan, CreateLimit) { auto n = MakeScanAll(storage, symbol_table, "n1"); auto m = NODE("m"); - symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m"); + symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true); auto c = std::make_shared(m, n.op_); auto skip = std::make_shared(c, LITERAL(1)); @@ -157,7 +157,7 @@ TEST(QueryPlan, OrderBy) { {order_value_pair.first, n_p}}, std::vector{n.sym_}); auto n_p_ne = NEXPR("n.p", n_p); - symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n.p"); + symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n.p", true); auto produce = MakeProduce(order_by, n_p_ne); auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); ASSERT_EQ(values.size(), results.size()); @@ -208,9 +208,9 @@ TEST(QueryPlan, OrderByMultiple) { }, std::vector{n.sym_}); auto n_p1_ne = NEXPR("n.p1", n_p1); - symbol_table[*n_p1_ne] = symbol_table.CreateSymbol("n.p1"); + symbol_table[*n_p1_ne] = symbol_table.CreateSymbol("n.p1", true); auto n_p2_ne = NEXPR("n.p2", n_p2); - symbol_table[*n_p2_ne] = symbol_table.CreateSymbol("n.p2"); + symbol_table[*n_p2_ne] = symbol_table.CreateSymbol("n.p2", true); auto produce = MakeProduce(order_by, n_p1_ne, n_p2_ne); auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); ASSERT_EQ(N * N, results.size()); diff --git a/tests/unit/query_plan_common.hpp b/tests/unit/query_plan_common.hpp index 49e21edeb..24b6bbd53 100644 --- a/tests/unit/query_plan_common.hpp +++ b/tests/unit/query_plan_common.hpp @@ -93,7 +93,7 @@ ScanAllTuple MakeScanAll(AstTreeStorage &storage, SymbolTable &symbol_table, GraphView graph_view = GraphView::OLD) { auto node = NODE(identifier); auto logical_op = std::make_shared(node, input, graph_view); - auto symbol = symbol_table.CreateSymbol(identifier); + auto symbol = symbol_table.CreateSymbol(identifier, true); symbol_table[*node->identifier_] = symbol; // return std::make_tuple(node, logical_op, symbol); return ScanAllTuple{node, logical_op, symbol}; @@ -114,11 +114,11 @@ ExpandTuple MakeExpand(AstTreeStorage &storage, SymbolTable &symbol_table, const std::string &node_identifier, bool existing_node, GraphView graph_view = GraphView::AS_IS) { auto edge = EDGE(edge_identifier, direction); - auto edge_sym = symbol_table.CreateSymbol(edge_identifier); + auto edge_sym = symbol_table.CreateSymbol(edge_identifier, true); symbol_table[*edge->identifier_] = edge_sym; auto node = NODE(node_identifier); - auto node_sym = symbol_table.CreateSymbol(node_identifier); + auto node_sym = symbol_table.CreateSymbol(node_identifier, true); symbol_table[*node->identifier_] = node_sym; auto op = std::make_shared(node, edge, input, input_symbol, diff --git a/tests/unit/query_plan_create_set_remove_delete.cpp b/tests/unit/query_plan_create_set_remove_delete.cpp index 9492538d0..a9c3ae003 100644 --- a/tests/unit/query_plan_create_set_remove_delete.cpp +++ b/tests/unit/query_plan_create_set_remove_delete.cpp @@ -33,7 +33,7 @@ TEST(QueryPlan, CreateNodeWithAttributes) { SymbolTable symbol_table; auto node = NODE("n"); - symbol_table[*node->identifier_] = symbol_table.CreateSymbol("n"); + symbol_table[*node->identifier_] = symbol_table.CreateSymbol("n", true); node->labels_.emplace_back(label); node->properties_[property] = LITERAL(42); @@ -67,19 +67,20 @@ TEST(QueryPlan, CreateReturn) { SymbolTable symbol_table; auto node = NODE("n"); - auto sym_n = symbol_table.CreateSymbol("n"); + auto sym_n = symbol_table.CreateSymbol("n", true); symbol_table[*node->identifier_] = sym_n; node->labels_.emplace_back(label); node->properties_[property] = LITERAL(42); auto create = std::make_shared(node, nullptr); auto named_expr_n = NEXPR("n", IDENT("n")); - symbol_table[*named_expr_n] = symbol_table.CreateSymbol("named_expr_n"); + symbol_table[*named_expr_n] = symbol_table.CreateSymbol("named_expr_n", true); symbol_table[*named_expr_n->expression_] = sym_n; auto prop_lookup = PROPERTY_LOOKUP("n", property); symbol_table[*prop_lookup->expression_] = sym_n; auto named_expr_n_p = NEXPR("n", prop_lookup); - symbol_table[*named_expr_n_p] = symbol_table.CreateSymbol("named_expr_n_p"); + symbol_table[*named_expr_n_p] = + symbol_table.CreateSymbol("named_expr_n_p", true); symbol_table[*named_expr_n->expression_] = sym_n; auto produce = MakeProduce(create, named_expr_n, named_expr_n_p); @@ -119,7 +120,7 @@ TEST(QueryPlan, CreateExpand) { auto n = NODE("n"); n->labels_.emplace_back(label_node_1); n->properties_[property] = LITERAL(1); - auto n_sym = symbol_table.CreateSymbol("n"); + auto n_sym = symbol_table.CreateSymbol("n", true); symbol_table[*n->identifier_] = n_sym; // data for the second node @@ -129,10 +130,10 @@ TEST(QueryPlan, CreateExpand) { if (cycle) symbol_table[*m->identifier_] = n_sym; else - symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m"); + symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true); auto r = EDGE("r", EdgeAtom::Direction::RIGHT); - symbol_table[*r->identifier_] = symbol_table.CreateSymbol("r"); + symbol_table[*r->identifier_] = symbol_table.CreateSymbol("r", true); r->edge_types_.emplace_back(edge_type); r->properties_[property] = LITERAL(3); @@ -188,7 +189,7 @@ TEST(QueryPlan, MatchCreateNode) { auto n_scan_all = MakeScanAll(storage, symbol_table, "n"); // second node auto m = NODE("m"); - symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m"); + symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true); // creation op auto create_node = std::make_shared(m, n_scan_all.op_); @@ -229,10 +230,10 @@ TEST(QueryPlan, MatchCreateExpand) { if (cycle) symbol_table[*m->identifier_] = n_scan_all.sym_; else - symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m"); + symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true); auto r = EDGE("r", EdgeAtom::Direction::RIGHT); - symbol_table[*r->identifier_] = symbol_table.CreateSymbol("r"); + symbol_table[*r->identifier_] = symbol_table.CreateSymbol("r", true); r->edge_types_.emplace_back(edge_type); auto create_expand = std::make_shared(m, r, n_scan_all.op_, @@ -403,7 +404,7 @@ TEST(QueryPlan, DeleteReturn) { storage.Create(storage.Create("n"), prop); symbol_table[*prop_lookup->expression_] = n.sym_; auto n_p = storage.Create("n", prop_lookup); - symbol_table[*n_p] = symbol_table.CreateSymbol("bla"); + symbol_table[*n_p] = symbol_table.CreateSymbol("bla", true); auto produce = MakeProduce(delete_op, n_p); auto result = CollectProduce(produce, symbol_table, *dba); @@ -843,7 +844,7 @@ TEST(QueryPlan, MergeNoInput) { SymbolTable symbol_table; auto node = NODE("n"); - auto sym_n = symbol_table.CreateSymbol("n"); + auto sym_n = symbol_table.CreateSymbol("n", true); symbol_table[*node->identifier_] = sym_n; auto create = std::make_shared(node, nullptr); auto merge = std::make_shared(nullptr, create, create); diff --git a/tests/unit/query_plan_match_filter_return.cpp b/tests/unit/query_plan_match_filter_return.cpp index 5ccdab33f..aea2720f0 100644 --- a/tests/unit/query_plan_match_filter_return.cpp +++ b/tests/unit/query_plan_match_filter_return.cpp @@ -39,7 +39,8 @@ TEST(QueryPlan, MatchReturn) { auto output = NEXPR("n", IDENT("n")); auto produce = MakeProduce(scan_all.op_, output); symbol_table[*output->expression_] = scan_all.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*output] = + symbol_table.CreateSymbol("named_expression_1", true); return PullAll(produce, *dba, symbol_table); }; @@ -68,10 +69,12 @@ TEST(QueryPlan, MatchReturnCartesian) { auto m = MakeScanAll(storage, symbol_table, "m", n.op_); auto return_n = NEXPR("n", IDENT("n")); symbol_table[*return_n->expression_] = n.sym_; - symbol_table[*return_n] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*return_n] = + symbol_table.CreateSymbol("named_expression_1", true); auto return_m = NEXPR("m", IDENT("m")); symbol_table[*return_m->expression_] = m.sym_; - symbol_table[*return_m] = symbol_table.CreateSymbol("named_expression_2"); + symbol_table[*return_m] = + symbol_table.CreateSymbol("named_expression_2", true); auto produce = MakeProduce(m.op_, return_n, return_m); ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); @@ -99,7 +102,7 @@ TEST(QueryPlan, StandaloneReturn) { auto output = NEXPR("n", LITERAL(42)); auto produce = MakeProduce(std::shared_ptr(nullptr), output); - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); EXPECT_EQ(result.GetResults().size(), 1); @@ -150,7 +153,7 @@ TEST(QueryPlan, NodeFilterLabelsAndProperties) { // make a named expression and a produce auto output = NEXPR("x", IDENT("n")); symbol_table[*output->expression_] = n.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); auto produce = MakeProduce(node_filter, output); EXPECT_EQ(1, PullAll(produce, *dba, symbol_table)); @@ -206,7 +209,7 @@ TEST(QueryPlan, NodeFilterMultipleLabels) { auto produce = MakeProduce(node_filter, output); // fill up the symbol table - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); symbol_table[*output->expression_] = n.sym_; ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); @@ -240,7 +243,8 @@ TEST(QueryPlan, Expand) { // make a named expression and a produce auto output = NEXPR("m", IDENT("m")); symbol_table[*output->expression_] = r_m.node_sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*output] = + symbol_table.CreateSymbol("named_expression_1", true); auto produce = MakeProduce(r_m.op_, output); return PullAll(produce, *dba, symbol_table); @@ -298,13 +302,13 @@ TEST(QueryPlan, ExpandOptional) { // RETURN n, r, m auto n_ne = NEXPR("n", IDENT("n")); symbol_table[*n_ne->expression_] = n.sym_; - symbol_table[*n_ne] = symbol_table.CreateSymbol("n"); + symbol_table[*n_ne] = symbol_table.CreateSymbol("n", true); auto r_ne = NEXPR("r", IDENT("r")); symbol_table[*r_ne->expression_] = r_m.edge_sym_; - symbol_table[*r_ne] = symbol_table.CreateSymbol("r"); + symbol_table[*r_ne] = symbol_table.CreateSymbol("r", true); auto m_ne = NEXPR("m", IDENT("m")); symbol_table[*m_ne->expression_] = r_m.node_sym_; - symbol_table[*m_ne] = symbol_table.CreateSymbol("m"); + symbol_table[*m_ne] = symbol_table.CreateSymbol("m", true); auto produce = MakeProduce(optional, n_ne, r_ne, m_ne); auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); @@ -339,7 +343,7 @@ TEST(QueryPlan, OptionalMatchEmptyDB) { // RETURN n auto n_ne = NEXPR("n", IDENT("n")); symbol_table[*n_ne->expression_] = n.sym_; - symbol_table[*n_ne] = symbol_table.CreateSymbol("n"); + symbol_table[*n_ne] = symbol_table.CreateSymbol("n", true); auto optional = std::make_shared(nullptr, n.op_, std::vector{n.sym_}); auto produce = MakeProduce(optional, n_ne); @@ -377,7 +381,8 @@ TEST(QueryPlan, ExpandExistingNode) { // make a named expression and a produce auto output = NEXPR("n", IDENT("n")); symbol_table[*output->expression_] = n.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*output] = + symbol_table.CreateSymbol("named_expression_1", true); auto produce = MakeProduce(r_n.op_, output); ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); @@ -420,7 +425,8 @@ TEST(QueryPlan, ExpandExistingEdge) { // make a named expression and a produce auto output = NEXPR("r", IDENT("r")); symbol_table[*output->expression_] = r_j.edge_sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*output] = + symbol_table.CreateSymbol("named_expression_1", true); auto produce = MakeProduce(r_k.op_, output); ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); @@ -505,7 +511,8 @@ TEST(QueryPlan, EdgeFilter) { // make a named expression and a produce auto output = NEXPR("m", IDENT("m")); symbol_table[*output->expression_] = r_m.node_sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*output] = + symbol_table.CreateSymbol("named_expression_1", true); auto produce = MakeProduce(edge_filter, output); return PullAll(produce, *dba, symbol_table); @@ -552,7 +559,7 @@ TEST(QueryPlan, EdgeFilterMultipleTypes) { auto produce = MakeProduce(edge_filter, output); // fill up the symbol table - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); symbol_table[*output->expression_] = r_m.node_sym_; ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); @@ -582,7 +589,7 @@ TEST(QueryPlan, Filter) { auto output = storage.Create("x", storage.Create("n")); symbol_table[*output->expression_] = n.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); auto produce = MakeProduce(f, output); EXPECT_EQ(CollectProduce(produce, symbol_table, *dba).GetResults().size(), 2); @@ -647,7 +654,7 @@ TEST(QueryPlan, Distinct) { auto input_expr = LITERAL(TypedValue(input)); - auto x = symbol_table.CreateSymbol("x"); + auto x = symbol_table.CreateSymbol("x", true); auto unwind = std::make_shared(nullptr, input_expr, x); auto x_expr = IDENT("x"); symbol_table[*x_expr] = x; @@ -656,7 +663,7 @@ TEST(QueryPlan, Distinct) { std::make_shared(unwind, std::vector{x}); auto x_ne = NEXPR("x", x_expr); - symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne"); + symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne", true); auto produce = MakeProduce(distinct, x_ne); auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index 9dbc152ce..fb3c94648 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -120,17 +120,17 @@ using ExpectDistinct = OpChecker; class ExpectAccumulate : public OpChecker { public: - ExpectAccumulate(const std::unordered_set &symbols) + ExpectAccumulate(const std::unordered_set &symbols) : symbols_(symbols) {} void ExpectOp(Accumulate &op, const SymbolTable &symbol_table) override { - std::unordered_set got_symbols(op.symbols().begin(), - op.symbols().end()); + std::unordered_set got_symbols(op.symbols().begin(), + op.symbols().end()); EXPECT_EQ(symbols_, got_symbols); } private: - const std::unordered_set symbols_; + const std::unordered_set symbols_; }; class ExpectAggregate : public OpChecker { @@ -771,4 +771,55 @@ TEST(TestLogicalPlanner, MatchOptionalMatchWhere) { ExpectProduce()); } +TEST(TestLogicalPlanner, MatchReturnAsterisk) { + // Test MATCH (n) -[e]- (m) RETURN *, m.prop + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + AstTreeStorage storage; + auto ret = RETURN(PROPERTY_LOOKUP("m", prop), AS("m.prop")); + ret->body_.all_identifiers = true; + auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), ret); + auto symbol_table = MakeSymbolTable(*query); + auto plan = MakeLogicalPlan(storage, symbol_table); + CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectExpand(), + ExpectProduce()); + std::vector output_names; + for (const auto &output_symbol : plan->OutputSymbols(symbol_table)) { + output_names.emplace_back(output_symbol.name()); + } + std::vector expected_names{"e", "m", "n", "m.prop"}; + EXPECT_EQ(output_names, expected_names); +} + +TEST(TestLogicalPlanner, MatchReturnAsteriskSum) { + // Test MATCH (n) RETURN *, SUM(n.prop) AS s + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + AstTreeStorage storage; + auto sum = SUM(PROPERTY_LOOKUP("n", prop)); + auto ret = RETURN(sum, AS("s")); + ret->body_.all_identifiers = true; + auto query = QUERY(MATCH(PATTERN(NODE("n"))), ret); + auto symbol_table = MakeSymbolTable(*query); + auto plan = MakeLogicalPlan(storage, symbol_table); + auto *produce = dynamic_cast(plan.get()); + ASSERT_TRUE(produce); + const auto &named_expressions = produce->named_expressions(); + ASSERT_EQ(named_expressions.size(), 2); + auto *expanded_ident = + dynamic_cast(named_expressions[0]->expression_); + ASSERT_TRUE(expanded_ident); + auto aggr = ExpectAggregate({sum}, {expanded_ident}); + CheckPlan(*plan, symbol_table, ExpectScanAll(), aggr, ExpectProduce()); + std::vector output_names; + for (const auto &output_symbol : plan->OutputSymbols(symbol_table)) { + output_names.emplace_back(output_symbol.name()); + } + std::vector expected_names{"n", "s"}; + EXPECT_EQ(output_names, expected_names); +} + + } // namespace diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 3ea95a4f0..b1fd700ad 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -26,12 +26,12 @@ TEST(TestSymbolGenerator, MatchNodeReturn) { auto pattern = match->patterns_[0]; auto node_atom = dynamic_cast(pattern->atoms_[0]); auto node_sym = symbol_table[*node_atom->identifier_]; - EXPECT_EQ(node_sym.name_, "node_atom_1"); - EXPECT_EQ(node_sym.type_, Symbol::Type::Vertex); + EXPECT_EQ(node_sym.name(), "node_atom_1"); + EXPECT_EQ(node_sym.type(), Symbol::Type::Vertex); auto ret = dynamic_cast(query_ast->clauses_[1]); auto named_expr = ret->body_.named_expressions[0]; auto column_sym = symbol_table[*named_expr]; - EXPECT_EQ(node_sym.name_, column_sym.name_); + EXPECT_EQ(node_sym.name(), column_sym.name()); EXPECT_NE(node_sym, column_sym); auto ret_sym = symbol_table[*named_expr->expression_]; EXPECT_EQ(node_sym, ret_sym); @@ -87,12 +87,12 @@ TEST(TestSymbolGenerator, MatchSameEdge) { is_node = !is_node; } auto &node_symbol = node_symbols.front(); - EXPECT_EQ(node_symbol.type_, Symbol::Type::Vertex); + EXPECT_EQ(node_symbol.type(), Symbol::Type::Vertex); for (auto &symbol : node_symbols) { EXPECT_EQ(node_symbol, symbol); } auto &edge_symbol = edge_symbols.front(); - EXPECT_EQ(edge_symbol.type_, Symbol::Type::Edge); + EXPECT_EQ(edge_symbol.type(), Symbol::Type::Edge); for (auto &symbol : edge_symbols) { EXPECT_EQ(edge_symbol, symbol); } @@ -129,12 +129,12 @@ TEST(TestSymbolGenerator, CreateNodeReturn) { auto pattern = create->patterns_[0]; auto node_atom = dynamic_cast(pattern->atoms_[0]); auto node_sym = symbol_table[*node_atom->identifier_]; - EXPECT_EQ(node_sym.name_, "n"); - EXPECT_EQ(node_sym.type_, Symbol::Type::Vertex); + EXPECT_EQ(node_sym.name(), "n"); + EXPECT_EQ(node_sym.type(), Symbol::Type::Vertex); auto ret = dynamic_cast(query_ast->clauses_[1]); auto named_expr = ret->body_.named_expressions[0]; auto column_sym = symbol_table[*named_expr]; - EXPECT_EQ(node_sym.name_, column_sym.name_); + EXPECT_EQ(node_sym.name(), column_sym.name()); EXPECT_NE(node_sym, column_sym); auto ret_sym = symbol_table[*named_expr->expression_]; EXPECT_EQ(node_sym, ret_sym); @@ -260,7 +260,7 @@ TEST(TestSymbolGenerator, CreateDelete) { EXPECT_EQ(symbol_table.max_position(), 1); auto node_symbol = symbol_table.at(*node->identifier_); auto ident_symbol = symbol_table.at(*ident); - EXPECT_EQ(node_symbol.type_, Symbol::Type::Vertex); + EXPECT_EQ(node_symbol.type(), Symbol::Type::Vertex); EXPECT_EQ(node_symbol, ident_symbol); } @@ -369,18 +369,18 @@ TEST(TestSymbolGenerator, CreateMultiExpand) { auto n1 = symbol_table.at(*node_n1->identifier_); auto n2 = symbol_table.at(*node_n2->identifier_); EXPECT_EQ(n1, n2); - EXPECT_EQ(n1.type_, Symbol::Type::Vertex); + EXPECT_EQ(n1.type(), Symbol::Type::Vertex); auto m = symbol_table.at(*node_m->identifier_); - EXPECT_EQ(m.type_, Symbol::Type::Vertex); + EXPECT_EQ(m.type(), Symbol::Type::Vertex); EXPECT_NE(m, n1); auto l = symbol_table.at(*node_l->identifier_); - EXPECT_EQ(l.type_, Symbol::Type::Vertex); + EXPECT_EQ(l.type(), Symbol::Type::Vertex); EXPECT_NE(l, n1); EXPECT_NE(l, m); auto r = symbol_table.at(*edge_r->identifier_); auto p = symbol_table.at(*edge_p->identifier_); - EXPECT_EQ(r.type_, Symbol::Type::Edge); - EXPECT_EQ(p.type_, Symbol::Type::Edge); + EXPECT_EQ(r.type(), Symbol::Type::Edge); + EXPECT_EQ(p.type(), Symbol::Type::Edge); EXPECT_NE(r, p); } @@ -527,11 +527,11 @@ TEST(TestSymbolGenerator, MatchWithCreate) { query->Accept(symbol_generator); EXPECT_EQ(symbol_table.max_position(), 3); auto n = symbol_table.at(*node_1->identifier_); - EXPECT_EQ(n.type_, Symbol::Type::Vertex); + EXPECT_EQ(n.type(), Symbol::Type::Vertex); auto m = symbol_table.at(*node_2->identifier_); EXPECT_NE(n, m); // Currently we don't infer expression types, so we lost true type of 'm'. - EXPECT_EQ(m.type_, Symbol::Type::Any); + EXPECT_EQ(m.type(), Symbol::Type::Any); EXPECT_EQ(m, symbol_table.at(*node_3->identifier_)); } @@ -792,4 +792,52 @@ TEST(TestSymbolGenerator, MatchCrossReferenceVariable) { EXPECT_NE(m, symbol_table.at(*as_n)); } +TEST(TestSymbolGenerator, MatchWithAsteriskReturnAsterisk) { + // MATCH (n) -[e]- (m) WITH * RETURN *, n.prop + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + AstTreeStorage storage; + auto n_prop = PROPERTY_LOOKUP("n", prop); + auto ret = RETURN(n_prop, AS("n.prop")); + ret->body_.all_identifiers = true; + auto node_n = NODE("n"); + auto edge = EDGE("e"); + auto node_m = NODE("m"); + auto with = storage.Create(); + with->body_.all_identifiers = true; + auto query = QUERY(MATCH(PATTERN(node_n, edge, node_m)), with, ret); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + query->Accept(symbol_generator); + // Symbols for `n`, `e`, `m`, `AS n.prop`. + EXPECT_EQ(symbol_table.max_position(), 4); + auto n = symbol_table.at(*node_n->identifier_); + EXPECT_EQ(n, symbol_table.at(*n_prop->expression_)); +} + +TEST(TestSymbolGenerator, MatchReturnAsteriskSameResult) { + // MATCH (n) RETURN *, n AS n + AstTreeStorage storage; + auto ret = RETURN(IDENT("n"), AS("n")); + ret->body_.all_identifiers = true; + auto query = QUERY(MATCH(PATTERN(NODE("n"))), ret); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query->Accept(symbol_generator), SemanticException); +} + +TEST(TestSymbolGenerator, MatchReturnAsteriskNoUserVariables) { + // MATCH () RETURN * + AstTreeStorage storage; + auto ret = storage.Create(); + ret->body_.all_identifiers = true; + auto ident_n = storage.Create("anon", false); + auto node = storage.Create(ident_n); + auto query = QUERY(MATCH(PATTERN(node)), ret); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query->Accept(symbol_generator), SemanticException); +} + } // namespace