From c03ca5f8f08b3effc017dbf16fde0e6cd9540329 Mon Sep 17 00:00:00 2001 From: Marin Tomic <marin.tomic@memgraph.io> Date: Tue, 5 Feb 2019 13:16:07 +0100 Subject: [PATCH] Remove UID tracking from AstStorage Summary: All AST nodes had a member `uid_` that was used as a key in `SymbolTable`. It is renamed to `symbol_pos_` and it appears only in `Identifier`, `NamedExpression` and `Aggregation`, since only those types were used in `SymbolTable`. SymbolGenerator is now responsible for creating symbols in `SymbolTable` and assigning positions to AST nodes. Cloning and serialization code is now simpler since there is no need to track UIDs. Reviewers: teon.banek Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1836 --- src/distributed/bfs_rpc_messages.lcp | 16 +- src/query/frontend/ast/ast.lcp | 464 ++++++------------ src/query/frontend/ast/ast_serialization.lcp | 59 +-- .../frontend/semantic/symbol_generator.cpp | 38 +- src/query/frontend/semantic/symbol_table.hpp | 31 +- src/query/plan/distributed.cpp | 34 +- src/query/plan/distributed_ops.lcp | 12 +- src/query/plan/operator.lcp | 94 ++-- src/query/plan/preprocess.cpp | 17 +- src/query/plan/rule_based_planner.cpp | 7 +- src/query/serialization.capnp | 8 +- src/query/serialization.hpp | 28 +- tests/unit/ast_serialization.cpp | 12 +- tests/unit/bfs_common.hpp | 9 +- tests/unit/distributed_query_plan.cpp | 103 ++-- tests/unit/query_expression_evaluator.cpp | 50 +- tests/unit/query_plan.cpp | 28 +- .../unit/query_plan_accumulate_aggregate.cpp | 77 ++- tests/unit/query_plan_bag_semantics.cpp | 24 +- tests/unit/query_plan_common.hpp | 14 +- .../query_plan_create_set_remove_delete.cpp | 120 ++--- tests/unit/query_plan_match_filter_return.cpp | 244 ++++----- tests/unit/query_semantic.cpp | 72 +-- 23 files changed, 592 insertions(+), 969 deletions(-) diff --git a/src/distributed/bfs_rpc_messages.lcp b/src/distributed/bfs_rpc_messages.lcp index 539aba1d1..db2b0db64 100644 --- a/src/distributed/bfs_rpc_messages.lcp +++ b/src/distributed/bfs_rpc_messages.lcp @@ -45,26 +45,14 @@ cpp<# :capnp-load (lcp:capnp-load-vector "::storage::capnp::EdgeType" "storage::EdgeType")) (filter-lambda "query::plan::ExpansionLambda" - :slk-save (lambda (member) - #>cpp - std::vector<int32_t> saved_ast_uids; - slk::Save(self.${member}, builder, &saved_ast_uids); - cpp<#) :slk-load (lambda (member) #>cpp - std::vector<int32_t> loaded_ast_uids; - slk::Load(&self->${member}, reader, ast_storage, &loaded_ast_uids); + slk::Load(&self->${member}, reader, ast_storage); cpp<#) :capnp-type "DistOps.ExpansionLambda" - :capnp-save (lambda (builder member capnp-name) - #>cpp - std::vector<int> saved_ast_uids; - Save(${member}, &${builder}, &saved_ast_uids); - cpp<#) :capnp-load (lambda (reader member capnp-name) #>cpp - std::vector<int> loaded_ast_uids; - Load(&${member}, ${reader}, ast_storage, &loaded_ast_uids); + Load(&${member}, ${reader}, ast_storage); cpp<#)) (symbol-table "query::SymbolTable" :capnp-type "Query.SymbolTable") (timestamp :int64_t) diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index f56521d8c..327926d72 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -30,20 +30,20 @@ cpp<# (defun slk-save-ast-pointer (member) #>cpp - query::SaveAstPointer(self.${member}, builder, saved_uids); + query::SaveAstPointer(self.${member}, builder); cpp<#) (defun slk-load-ast-pointer (type) (lambda (member) #>cpp - self->${member} = query::LoadAstPointer<query::${type}>(storage, reader, loaded_uids); + self->${member} = query::LoadAstPointer<query::${type}>(storage, reader); cpp<#)) (defun save-ast-pointer (builder member capnp-name) #>cpp if (${member}) { auto ${capnp-name}_builder = ${builder}->init${capnp-name}(); - Save(*${member}, &${capnp-name}_builder, saved_uids); + Save(*${member}, &${capnp-name}_builder); } cpp<#) @@ -52,7 +52,7 @@ cpp<# #>cpp if (${reader}.has${capnp-name}()) { ${member} = static_cast<${type}>( - Load(storage, ${reader}.get${capnp-name}(), loaded_uids)); + Load(storage, ${reader}.get${capnp-name}())); } else { ${member} = nullptr; } @@ -63,7 +63,7 @@ cpp<# size_t size = self.${member}.size(); slk::Save(size, builder); for (const auto *val : self.${member}) { - query::SaveAstPointer(val, builder, saved_uids); + query::SaveAstPointer(val, builder); } cpp<#) @@ -76,22 +76,22 @@ cpp<# for (size_t i = 0; i < size; ++i) { - self->${member}[i] = query::LoadAstPointer<query::${type}>(storage, reader, loaded_uids); + self->${member}[i] = query::LoadAstPointer<query::${type}>(storage, reader); } cpp<#)) (defun save-ast-vector (type) (lcp:capnp-save-vector "capnp::Tree" type - "[saved_uids](auto *builder, const auto &val) { - Save(*val, builder, saved_uids); + "[](auto *builder, const auto &val) { + Save(*val, builder); }")) (defun load-ast-vector (type) (lcp:capnp-load-vector "capnp::Tree" type (format nil - "[storage, loaded_uids](const auto &reader) { - return static_cast<~A>(Load(storage, reader, loaded_uids)); + "[storage](const auto &reader) { + return static_cast<~A>(Load(storage, reader)); }" type))) @@ -101,7 +101,7 @@ cpp<# slk::Save(size, builder); for (const auto &entry : self.${member}) { slk::Save(entry.first, builder); - query::SaveAstPointer(entry.second, builder, saved_uids); + query::SaveAstPointer(entry.second, builder); } cpp<#) @@ -114,7 +114,7 @@ cpp<# ++i) { query::PropertyIx key; slk::Load(&key, reader, storage); - auto *value = query::LoadAstPointer<query::Expression>(storage, reader, loaded_uids); + auto *value = query::LoadAstPointer<query::Expression>(storage, reader); self->${member}.emplace(key, value); } cpp<#) @@ -127,7 +127,7 @@ cpp<# auto key_builder = entries_builder[i].initKey(); Save(entry.first, &key_builder); auto value_builder = entries_builder[i].initValue(); - Save(*entry.second, &value_builder, saved_uids); + Save(*entry.second, &value_builder); ++i; } cpp<#) @@ -138,7 +138,7 @@ cpp<# PropertyIx prop; Load(&prop, entry.getKey(), storage); ${member}.emplace(prop, static_cast<Expression *>( - Load(storage, entry.getValue(), loaded_uids))); + Load(storage, entry.getValue()))); } cpp<#) @@ -280,18 +280,9 @@ class AstStorage { AstStorage(AstStorage &&) = default; AstStorage &operator=(AstStorage &&) = default; - template <typename T> - T *Create() { - T *ptr = new T(); - ptr->uid_ = ++max_existing_uid_; - std::unique_ptr<T> tmp(ptr); - storage_.emplace_back(std::move(tmp)); - return ptr; - } - template <typename T, typename... Args> T *Create(Args &&... args) { - T *ptr = new T(++max_existing_uid_, std::forward<Args>(args)...); + T *ptr = new T(std::forward<Args>(args)...); std::unique_ptr<T> tmp(ptr); storage_.emplace_back(std::move(tmp)); return ptr; @@ -315,7 +306,6 @@ class AstStorage { // Public only for serialization access std::vector<std::unique_ptr<Tree>> storage_; - int max_existing_uid_ = -1; private: int64_t FindOrAddName(const std::string &name, @@ -332,14 +322,7 @@ class AstStorage { cpp<# (lcp:define-class tree () - ((uid :int32_t :scope :public - :clone (lambda (source dest) - #>cpp - ${dest} = ${source}; - // TODO(mtomic): This is a hack. Adopting existing UIDs breaks everything - // because `AstStorage` keeps a counter for generating when `Create` is called. - storage->max_existing_uid_ = std::max(storage->max_existing_uid_, ${source}); - cpp<#))) + () (:abstractp t) (:public #>cpp @@ -350,28 +333,8 @@ cpp<# #>cpp friend class AstStorage; cpp<#) - (:protected - #>cpp - explicit Tree(int uid) : uid_(uid) {} - cpp<#) - (:serialize (:slk :save-args '((saved-uids "std::vector<int32_t> *")) - :load-args '((storage "query::AstStorage *") - (loaded-uids "std::vector<int32_t> *"))) - (:capnp - :save-args '((saved-uids "std::vector<int32_t> *")) - :load-args '((storage "AstStorage *") - (loaded-uids "std::vector<int32_t> *")) - :post-save (lambda (builder) - (declare (ignore builder)) - ;; This is a bit hacky because it relies on the fact that parent class - ;; serialization is inlined so we can short-circuit and avoid serializing - ;; the derived class. - #>cpp - if (utils::Contains(*saved_uids, self.uid_)) { - return; - } - saved_uids->push_back(self.uid_); - cpp<#))) + (:serialize (:slk :load-args '((storage "query::AstStorage *"))) + (:capnp :load-args '((storage "AstStorage *")))) (:clone :return-type (lambda (typename) (format nil "~A*" typename)) :args '((storage "AstStorage *")) @@ -396,10 +359,6 @@ cpp<# Expression() = default; cpp<#) - (:protected - #>cpp - explicit Expression(int uid) : Tree(uid) {} - cpp<#) (:private #>cpp friend class AstStorage; @@ -431,8 +390,7 @@ cpp<# cpp<#) (:protected #>cpp - explicit Where(int uid) : Tree(uid) {} - Where(int uid, Expression *expression) : Tree(uid), expression_(expression) {} + explicit Where(Expression *expression) : expression_(expression) {} cpp<#) (:private #>cpp @@ -463,9 +421,8 @@ cpp<# cpp<#) (:protected #>cpp - explicit BinaryOperator(int uid) : Expression(uid) {} - BinaryOperator(int uid, Expression *expression1, Expression *expression2) - : Expression(uid), expression1_(expression1), expression2_(expression2) {} + BinaryOperator(Expression *expression1, Expression *expression2) + : expression1_(expression1), expression2_(expression2) {} cpp<#) (:private #>cpp @@ -488,9 +445,7 @@ cpp<# cpp<#) (:protected #>cpp - explicit UnaryOperator(int uid) : Expression(uid) {} - UnaryOperator(int uid, Expression *expression) - : Expression(uid), expression_(expression) {} + explicit UnaryOperator(Expression *expression) : expression_(expression) {} cpp<#) (:private #>cpp @@ -567,7 +522,9 @@ cpp<# (define-unary-operators)) (lcp:define-class aggregation (binary-operator) - ((op "Op" :scope :public)) + ((op "Op" :scope :public) + (symbol-pos :int32_t :initval -1 :scope :public + :documentation "Symbol table position of the symbol this Aggregation is mapped to.")) (:public (lcp:define-enum op (count min max sum avg collect-list collect-map) @@ -597,17 +554,21 @@ cpp<# } return visitor.PostVisit(*this); } + + Aggregation *MapTo(const Symbol &symbol) { + symbol_pos_ = symbol.position(); + return this; + } cpp<#) (:protected #>cpp // Use only for serialization. - Aggregation(int uid) : BinaryOperator(uid) {} - Aggregation(int uid, Op op) : BinaryOperator(uid), op_(op) {} + explicit Aggregation(Op op) : op_(op) {} /// Aggregation's first expression is the value being aggregated. The second /// expression is the key used only in COLLECT_MAP. - Aggregation(int uid, Expression *expression1, Expression *expression2, Op op) - : BinaryOperator(uid, expression1, expression2), op_(op) { + Aggregation(Expression *expression1, Expression *expression2, Op op) + : BinaryOperator(expression1, expression2), op_(op) { // COUNT without expression denotes COUNT(*) in cypher. DCHECK(expression1 || op == Aggregation::Op::COUNT) << "All aggregations, except COUNT require expression"; @@ -663,14 +624,9 @@ cpp<# cpp<#) (:protected #>cpp - explicit ListSlicingOperator(int uid) : Expression(uid) {} - - ListSlicingOperator(int uid, Expression *list, Expression *lower_bound, + ListSlicingOperator(Expression *list, Expression *lower_bound, Expression *upper_bound) - : Expression(uid), - list_(list), - lower_bound_(lower_bound), - upper_bound_(upper_bound) {} + : list_(list), lower_bound_(lower_bound), upper_bound_(upper_bound) {} cpp<#) (:private #>cpp @@ -715,12 +671,9 @@ cpp<# cpp<#) (:protected #>cpp - explicit IfOperator(int uid) : Expression(uid) {} - - IfOperator(int uid, Expression *condition, Expression *then_expression, + IfOperator(Expression *condition, Expression *then_expression, Expression *else_expression) - : Expression(uid), - condition_(condition), + : condition_(condition), then_expression_(then_expression), else_expression_(else_expression) {} cpp<#) @@ -738,10 +691,6 @@ cpp<# #>cpp BaseLiteral() = default; cpp<#) - (:protected - #>cpp - explicit BaseLiteral(int uid) : Expression(uid) {} - cpp<#) (:private #>cpp friend class AstStorage; @@ -775,12 +724,11 @@ cpp<# cpp<#) (:protected #>cpp - explicit PrimitiveLiteral(int uid) : BaseLiteral(uid) {} template <typename T> - PrimitiveLiteral(int uid, T value) : BaseLiteral(uid), value_(value) {} + explicit PrimitiveLiteral(T value) : value_(value) {} template <typename T> - PrimitiveLiteral(int uid, T value, int token_position) - : BaseLiteral(uid), value_(value), token_position_(token_position) {} + PrimitiveLiteral(T value, int token_position) + : value_(value), token_position_(token_position) {} cpp<#) (:serialize (:slk) (:capnp)) (:clone)) @@ -809,9 +757,8 @@ cpp<# cpp<#) (:protected #>cpp - explicit ListLiteral(int uid) : BaseLiteral(uid) {} - ListLiteral(int uid, const std::vector<Expression *> &elements) - : BaseLiteral(uid), elements_(elements) {} + explicit ListLiteral(const std::vector<Expression *> &elements) + : elements_(elements) {} cpp<#) (:private #>cpp @@ -845,10 +792,9 @@ cpp<# cpp<#) (:protected #>cpp - explicit MapLiteral(int uid) : BaseLiteral(uid) {} - MapLiteral(int uid, - const std::unordered_map<PropertyIx, Expression *> &elements) - : BaseLiteral(uid), elements_(elements) {} + explicit MapLiteral( + const std::unordered_map<PropertyIx, Expression *> &elements) + : elements_(elements) {} cpp<#) (:private #>cpp @@ -859,7 +805,9 @@ cpp<# (lcp:define-class identifier (expression) ((name "std::string" :scope :public) - (user-declared :bool :initval "true" :scope :public)) + (user-declared :bool :initval "true" :scope :public) + (symbol-pos :int32_t :initval -1 :scope :public + :documentation "Symbol table position of the symbol this Identifier is mapped to.")) (:public #>cpp Identifier() = default; @@ -867,14 +815,17 @@ cpp<# DEFVISITABLE(ExpressionVisitor<TypedValue>); DEFVISITABLE(ExpressionVisitor<void>); DEFVISITABLE(HierarchicalTreeVisitor); + + Identifier *MapTo(const Symbol &symbol) { + symbol_pos_ = symbol.position(); + return this; + } cpp<#) (:protected #>cpp - Identifier(int uid) : Expression(uid) {} - - Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {} - Identifier(int uid, const std::string &name, bool user_declared) - : Expression(uid), name_(name), user_declared_(user_declared) {} + explicit Identifier(const std::string &name) : name_(name) {} + Identifier(const std::string &name, bool user_declared) + : name_(name), user_declared_(user_declared) {} cpp<#) (:private #>cpp @@ -914,11 +865,8 @@ cpp<# cpp<#) (:protected #>cpp - PropertyLookup(int uid) : Expression(uid) {} - PropertyLookup(int uid, Expression *expression, PropertyIx property) - : Expression(uid), - expression_(expression), - property_(property) {} + PropertyLookup(Expression *expression, PropertyIx property) + : expression_(expression), property_(property) {} cpp<#) (:private #>cpp @@ -967,11 +915,8 @@ cpp<# cpp<#) (:protected #>cpp - LabelsTest(int uid) : Expression(uid) {} - - LabelsTest(int uid, Expression *expression, - const std::vector<LabelIx> &labels) - : Expression(uid), expression_(expression), labels_(labels) {} + LabelsTest(Expression *expression, const std::vector<LabelIx> &labels) + : expression_(expression), labels_(labels) {} cpp<#) (:private #>cpp @@ -1019,12 +964,9 @@ cpp<# cpp<#) (:protected #>cpp - explicit Function(int uid) : Expression(uid) {} - - Function(int uid, const std::string &function_name, + Function(const std::string &function_name, const std::vector<Expression *> &arguments) - : Expression(uid), - arguments_(arguments), + : arguments_(arguments), function_name_(function_name), function_(NameToFunction(function_name_)) { DCHECK(function_) << "Unexpected missing function: " << function_name_; @@ -1090,10 +1032,9 @@ cpp<# cpp<#) (:protected #>cpp - Reduce(int uid, Identifier *accumulator, Expression *initializer, - Identifier *identifier, Expression *list, Expression *expression) - : Expression(uid), - accumulator_(accumulator), + Reduce(Identifier *accumulator, Expression *initializer, Identifier *identifier, + Expression *list, Expression *expression) + : accumulator_(accumulator), initializer_(initializer), identifier_(identifier), list_(list), @@ -1133,10 +1074,8 @@ cpp<# ) (:private #>cpp - Coalesce(int uid) : Expression(uid_) {} - - Coalesce(int uid, const std::vector<Expression *> &expressions) - : Expression(uid), expressions_(expressions) {} + explicit Coalesce(const std::vector<Expression *> &expressions) + : expressions_(expressions) {} friend class AstStorage; cpp<#) @@ -1181,14 +1120,8 @@ cpp<# cpp<#) (:protected #>cpp - Extract(int uid) : Expression(uid) {} - - Extract(int uid, Identifier *identifier, Expression *list, - Expression *expression) - : Expression(uid), - identifier_(identifier), - list_(list), - expression_(expression) {} + Extract(Identifier *identifier, Expression *list, Expression *expression) + : identifier_(identifier), list_(list), expression_(expression) {} cpp<#) (:private #>cpp @@ -1232,12 +1165,8 @@ cpp<# cpp<#) (:protected #>cpp - All(int uid) : Expression(uid) {} - - All(int uid, Identifier *identifier, Expression *list_expression, - Where *where) - : Expression(uid), - identifier_(identifier), + All(Identifier *identifier, Expression *list_expression, Where *where) + : identifier_(identifier), list_expression_(list_expression), where_(where) {} cpp<#) @@ -1286,10 +1215,8 @@ cpp<# cpp<#) (:protected #>cpp - Single(int uid, Identifier *identifier, Expression *list_expression, - Where *where) - : Expression(uid), - identifier_(identifier), + Single(Identifier *identifier, Expression *list_expression, Where *where) + : identifier_(identifier), list_expression_(list_expression), where_(where) {} cpp<#) @@ -1313,9 +1240,8 @@ cpp<# cpp<#) (:protected #>cpp - explicit ParameterLookup(int uid) : Expression(uid) {} - ParameterLookup(int uid, int token_position) - : Expression(uid), token_position_(token_position) {} + explicit ParameterLookup(int token_position) + : token_position_(token_position) {} cpp<#) (:private #>cpp @@ -1335,7 +1261,9 @@ cpp<# :capnp-save #'save-ast-pointer :capnp-load (load-ast-pointer "Expression *")) (token-position :int32_t :initval -1 :scope :public - :documentation "This field contains token position of first token in named expression used to create name_. If NamedExpression object is not created from query or it is aliased leave this value at -1.")) + :documentation "This field contains token position of first token in named expression used to create name_. If NamedExpression object is not created from query or it is aliased leave this value at -1.") + (symbol-pos :int32_t :initval -1 :scope :public + :documentation "Symbol table position of the symbol this NamedExpression is mapped to.")) (:public #>cpp using ::utils::Visitable<ExpressionVisitor<TypedValue>>::Accept; @@ -1352,19 +1280,20 @@ cpp<# } return visitor.PostVisit(*this); } + + NamedExpression *MapTo(const Symbol &symbol) { + symbol_pos_ = symbol.position(); + return this; + } cpp<#) (:protected #>cpp - explicit NamedExpression(int uid) : Tree(uid) {} - NamedExpression(int uid, const std::string &name) : Tree(uid), name_(name) {} - NamedExpression(int uid, const std::string &name, Expression *expression) - : Tree(uid), name_(name), expression_(expression) {} - NamedExpression(int uid, const std::string &name, Expression *expression, + explicit NamedExpression(const std::string &name) : name_(name) {} + NamedExpression(const std::string &name, Expression *expression) + : name_(name), expression_(expression) {} + NamedExpression(const std::string &name, Expression *expression, int token_position) - : Tree(uid), - name_(name), - expression_(expression), - token_position_(token_position) {} + : name_(name), expression_(expression), token_position_(token_position) {} cpp<#) (:private #>cpp @@ -1395,9 +1324,7 @@ cpp<# cpp<#) (:protected #>cpp - explicit PatternAtom(int uid) : Tree(uid) {} - PatternAtom(int uid, Identifier *identifier) - : Tree(uid), identifier_(identifier) {} + explicit PatternAtom(Identifier *identifier) : identifier_(identifier) {} cpp<#) (:private #>cpp @@ -1505,23 +1432,15 @@ cpp<# :documentation "Evaluated to upper bound in variable length expands.") (filter-lambda "Lambda" :scope :public :documentation "Filter lambda for variable length expands. Can have an empty expression, but identifiers must be valid, because an optimization pass may inline other expressions into this lambda." - :slk-save (lambda (member) - #>cpp - slk::Save(self.${member}, builder, saved_uids); - cpp<#) :slk-load (lambda (member) #>cpp - slk::Load(&self->${member}, reader, storage, loaded_uids); + slk::Load(&self->${member}, reader, storage); cpp<#)) (weight-lambda "Lambda" :scope :public :documentation "Used in weighted shortest path. It must have valid expressions and identifiers. In all other expand types, it is empty." - :slk-save (lambda (member) - #>cpp - slk::Save(self.${member}, builder, saved_uids); - cpp<#) :slk-load (lambda (member) #>cpp - slk::Load(&self->${member}, reader, storage, loaded_uids); + slk::Load(&self->${member}, reader, storage); cpp<#)) (total-weight "Identifier *" :initval "nullptr" :scope :public :slk-save #'slk-save-ast-pointer @@ -1560,13 +1479,8 @@ cpp<# :capnp-load (load-ast-pointer "Expression *") :documentation "Evaluates the result of the lambda.")) (:documentation "Lambda for use in filtering or weight calculation during variable expand.") - (:serialize (:slk :save-args '((saved-uids "std::vector<int32_t> *")) - :load-args '((storage "query::AstStorage *") - (loaded-uids "std::vector<int32_t> *"))) - (:capnp - :save-args '((saved-uids "std::vector<int32_t> *")) - :load-args '((storage "AstStorage *") - (loaded-uids "std::vector<int32_t> *")))) + (:serialize (:slk :load-args '((storage "query::AstStorage *"))) + (:capnp :load-args '((storage "AstStorage *")))) (:clone :args '((storage "AstStorage *")))) #>cpp bool Accept(HierarchicalTreeVisitor &visitor) override { @@ -1604,13 +1518,13 @@ cpp<# (:protected #>cpp using PatternAtom::PatternAtom; - EdgeAtom(int uid, Identifier *identifier, Type type, Direction direction) - : PatternAtom(uid, identifier), type_(type), direction_(direction) {} + EdgeAtom(Identifier *identifier, Type type, Direction direction) + : PatternAtom(identifier), type_(type), direction_(direction) {} // Creates an edge atom for a SINGLE expansion with the given . - EdgeAtom(int uid, Identifier *identifier, Type type, Direction direction, + EdgeAtom(Identifier *identifier, Type type, Direction direction, const std::vector<EdgeTypeIx> &edge_types) - : PatternAtom(uid, identifier), + : PatternAtom(identifier), type_(type), direction_(direction), edge_types_(edge_types) {} @@ -1654,10 +1568,6 @@ cpp<# return visitor.PostVisit(*this); } cpp<#) - (:protected - #>cpp - explicit Pattern(int uid) : Tree(uid) {} - cpp<#) (:private #>cpp friend class AstStorage; @@ -1676,10 +1586,6 @@ cpp<# Clause() = default; cpp<#) - (:protected - #>cpp - explicit Clause(int uid) : Tree(uid) {} - cpp<#) (:private #>cpp friend class AstStorage; @@ -1712,10 +1618,6 @@ cpp<# return visitor.PostVisit(*this); } cpp<#) - (:protected - #>cpp - explicit SingleQuery(int uid) : Tree(uid) {} - cpp<#) (:private #>cpp friend class AstStorage; @@ -1750,12 +1652,10 @@ cpp<# cpp<#) (:protected #>cpp - explicit CypherUnion(int uid) : Tree(uid) {} - CypherUnion(int uid, bool distinct) : Tree(uid), distinct_(distinct) {} - CypherUnion(int uid, bool distinct, SingleQuery *single_query, + explicit CypherUnion(bool distinct) : distinct_(distinct) {} + CypherUnion(bool distinct, SingleQuery *single_query, std::vector<Symbol> union_symbols) - : Tree(uid), - single_query_(single_query), + : single_query_(single_query), distinct_(distinct), union_symbols_(union_symbols) {} cpp<#) @@ -1777,10 +1677,6 @@ cpp<# Query() = default; cpp<#) - (:protected - #>cpp - explicit Query(int uid) : Tree(uid) {} - cpp<#) (:private #>cpp friend class AstStorage; @@ -1812,10 +1708,6 @@ cpp<# DEFVISITABLE(QueryVisitor<void>); cpp<#) - (:protected - #>cpp - explicit CypherQuery(int uid) : Query(uid) {} - cpp<#) (:private #>cpp friend class AstStorage; @@ -1837,10 +1729,6 @@ cpp<# DEFVISITABLE(QueryVisitor<void>); cpp<#) - (:protected - #>cpp - explicit ExplainQuery(int uid) : Query(uid) {} - cpp<#) (:private #>cpp friend class AstStorage; @@ -1867,7 +1755,6 @@ cpp<# cpp<#) (:private #>cpp - explicit ProfileQuery(int uid) : Query(uid) {} friend class AstStorage; cpp<#) (:serialize (:slk) (:capnp)) @@ -1914,10 +1801,8 @@ cpp<# cpp<#) (:protected #>cpp - explicit IndexQuery(int uid) : Query(uid) {} - IndexQuery(int uid, Action action, LabelIx label, - std::vector<PropertyIx> properties) - : Query(uid), action_(action), label_(label), properties_(properties) {} + IndexQuery(Action action, LabelIx label, std::vector<PropertyIx> properties) + : action_(action), label_(label), properties_(properties) {} cpp<#) (:private #>cpp @@ -1949,9 +1834,7 @@ cpp<# cpp<#) (:protected #>cpp - explicit Create(int uid) : Clause(uid) {} - Create(int uid, std::vector<Pattern *> patterns) - : Clause(uid), patterns_(patterns) {} + explicit Create(std::vector<Pattern *> patterns) : patterns_(patterns) {} cpp<#) (:private #>cpp @@ -1997,10 +1880,9 @@ cpp<# cpp<#) (:protected #>cpp - explicit Match(int uid) : Clause(uid) {} - Match(int uid, bool optional) : Clause(uid), optional_(optional) {} - Match(int uid, bool optional, Where *where, std::vector<Pattern *> patterns) - : Clause(uid), patterns_(patterns), where_(where), optional_(optional) {} + explicit Match(bool optional) : optional_(optional) {} + Match(bool optional, Where *where, std::vector<Pattern *> patterns) + : patterns_(patterns), where_(where), optional_(optional) {} cpp<#) (:private #>cpp @@ -2022,13 +1904,8 @@ cpp<# :capnp-type "Tree" :capnp-init nil :capnp-save #'save-ast-pointer :capnp-load (load-ast-pointer "Expression *"))) - (:serialize (:slk :save-args '((saved-uids "std::vector<int32_t> *")) - :load-args '((storage "query::AstStorage *") - (loaded-uids "std::vector<int32_t> *"))) - (:capnp - :save-args '((saved-uids "std::vector<int32_t> *")) - :load-args '((storage "AstStorage *") - (loaded-uids "std::vector<int32_t> *")))) + (:serialize (:slk :load-args '((storage "query::AstStorage *"))) + (:capnp :load-args '((storage "AstStorage *")))) (:clone :args '((storage "AstStorage *")))) (lcp:define-struct return-body () @@ -2044,14 +1921,6 @@ cpp<# :capnp-load (load-ast-vector "NamedExpression *") :documentation "Expressions which are used to produce results.") (order-by "std::vector<SortItem>" - :slk-save (lambda (member) - #>cpp - size_t size = self.${member}.size(); - slk::Save(size, builder); - for (const auto &v : self.${member}) { - slk::Save(v, builder, saved_uids); - } - cpp<#) :slk-load (lambda (member) #>cpp size_t size = 0; @@ -2060,21 +1929,15 @@ cpp<# for (size_t i = 0; i < size; ++i) { - slk::Load(&self->${member}[i], reader, storage, loaded_uids); + slk::Load(&self->${member}[i], reader, storage); } cpp<#) - :capnp-save (lcp:capnp-save-vector - "capnp::SortItem" - "SortItem" - "[saved_uids](auto *builder, const auto &val) { - Save(val, builder, saved_uids); - }") :capnp-load (lcp:capnp-load-vector "capnp::SortItem" "SortItem" - "[storage, loaded_uids](const auto &reader) { + "[storage](const auto &reader) { SortItem val; - Load(&val, reader, storage, loaded_uids); + Load(&val, reader, storage); return val; }") :documentation "Expressions used for ordering the results.") @@ -2093,24 +1956,15 @@ cpp<# :capnp-load (load-ast-pointer "Expression *") :documentation "Optional expression on how many results to produce.")) (:documentation "Contents common to @c Return and @c With clauses.") - (:serialize (:slk :save-args '((saved-uids "std::vector<int32_t> *")) - :load-args '((storage "query::AstStorage *") - (loaded-uids "std::vector<int32_t> *"))) - (:capnp - :save-args '((saved-uids "std::vector<int32_t> *")) - :load-args '((storage "AstStorage *") - (loaded-uids "std::vector<int32_t> *")))) + (:serialize (:slk :load-args '((storage "query::AstStorage *"))) + (:capnp :load-args '((storage "AstStorage *")))) (:clone :args '((storage "AstStorage *")))) (lcp:define-class return (clause) ((body "ReturnBody" :scope :public - :slk-save (lambda (member) - #>cpp - slk::Save(self.${member}, builder, saved_uids); - cpp<#) :slk-load (lambda (member) #>cpp - slk::Load(&self->${member}, reader, storage, loaded_uids); + slk::Load(&self->${member}, reader, storage); cpp<#))) (:public #>cpp @@ -2141,8 +1995,7 @@ cpp<# cpp<#) (:protected #>cpp - explicit Return(int uid) : Clause(uid) {} - Return(int uid, ReturnBody &body) : Clause(uid), body_(body) {} + explicit Return(ReturnBody &body) : body_(body) {} cpp<#) (:private #>cpp @@ -2153,13 +2006,9 @@ cpp<# (lcp:define-class with (clause) ((body "ReturnBody" :scope :public - :slk-save (lambda (member) - #>cpp - slk::Save(self.${member}, builder, saved_uids); - cpp<#) :slk-load (lambda (member) #>cpp - slk::Load(&self->${member}, reader, storage, loaded_uids); + slk::Load(&self->${member}, reader, storage); cpp<#)) (where "Where *" :initval "nullptr" :scope :public :slk-save #'slk-save-ast-pointer @@ -2197,9 +2046,7 @@ cpp<# cpp<#) (:protected #>cpp - explicit With(int uid) : Clause(uid) {} - With(int uid, ReturnBody &body, Where *where) - : Clause(uid), body_(body), where_(where) {} + With(ReturnBody &body, Where *where) : body_(body), where_(where) {} cpp<#) (:private #>cpp @@ -2232,9 +2079,8 @@ cpp<# cpp<#) (:protected #>cpp - explicit Delete(int uid) : Clause(uid) {} - Delete(int uid, bool detach, std::vector<Expression *> expressions) - : Clause(uid), expressions_(expressions), detach_(detach) {} + Delete(bool detach, std::vector<Expression *> expressions) + : expressions_(expressions), detach_(detach) {} cpp<#) (:private #>cpp @@ -2269,11 +2115,8 @@ cpp<# cpp<#) (:protected #>cpp - explicit SetProperty(int uid) : Clause(uid) {} - SetProperty(int uid, PropertyLookup *property_lookup, Expression *expression) - : Clause(uid), - property_lookup_(property_lookup), - expression_(expression) {} + SetProperty(PropertyLookup *property_lookup, Expression *expression) + : property_lookup_(property_lookup), expression_(expression) {} cpp<#) (:private #>cpp @@ -2309,13 +2152,9 @@ cpp<# cpp<#) (:protected #>cpp - explicit SetProperties(int uid) : Clause(uid) {} - SetProperties(int uid, Identifier *identifier, Expression *expression, + SetProperties(Identifier *identifier, Expression *expression, bool update = false) - : Clause(uid), - identifier_(identifier), - expression_(expression), - update_(update) {} + : identifier_(identifier), expression_(expression), update_(update) {} cpp<#) (:private #>cpp @@ -2362,10 +2201,8 @@ cpp<# cpp<#) (:protected #>cpp - explicit SetLabels(int uid) : Clause(uid) {} - SetLabels(int uid, Identifier *identifier, - const std::vector<LabelIx> &labels) - : Clause(uid), identifier_(identifier), labels_(labels) {} + SetLabels(Identifier *identifier, const std::vector<LabelIx> &labels) + : identifier_(identifier), labels_(labels) {} cpp<#) (:private #>cpp @@ -2394,9 +2231,8 @@ cpp<# cpp<#) (:protected #>cpp - explicit RemoveProperty(int uid) : Clause(uid) {} - RemoveProperty(int uid, PropertyLookup *property_lookup) - : Clause(uid), property_lookup_(property_lookup) {} + explicit RemoveProperty(PropertyLookup *property_lookup) + : property_lookup_(property_lookup) {} cpp<#) (:private #>cpp @@ -2443,10 +2279,8 @@ cpp<# cpp<#) (:protected #>cpp - explicit RemoveLabels(int uid) : Clause(uid) {} - RemoveLabels(int uid, Identifier *identifier, - const std::vector<LabelIx> &labels) - : Clause(uid), identifier_(identifier), labels_(labels) {} + RemoveLabels(Identifier *identifier, const std::vector<LabelIx> &labels) + : identifier_(identifier), labels_(labels) {} cpp<#) (:private #>cpp @@ -2505,13 +2339,9 @@ cpp<# cpp<#) (:protected #>cpp - explicit Merge(int uid) : Clause(uid) {} - Merge(int uid, Pattern *pattern, std::vector<Clause *> on_match, + Merge(Pattern *pattern, std::vector<Clause *> on_match, std::vector<Clause *> on_create) - : Clause(uid), - pattern_(pattern), - on_match_(on_match), - on_create_(on_create) {} + : pattern_(pattern), on_match_(on_match), on_create_(on_create) {} cpp<#) (:private #>cpp @@ -2540,12 +2370,9 @@ cpp<# cpp<#) (:protected #>cpp - explicit Unwind(int uid) : Clause(uid) {} - - Unwind(int uid, NamedExpression *named_expression) - : Clause(uid), named_expression_(named_expression) { - DCHECK(named_expression) - << "Unwind cannot take nullptr for named_expression"; + explicit Unwind(NamedExpression *named_expression) + : named_expression_(named_expression) { + DCHECK(named_expression) << "Unwind cannot take nullptr for named_expression"; } cpp<#) (:private @@ -2584,13 +2411,10 @@ cpp<# cpp<#) (:protected #>cpp - explicit AuthQuery(int uid) : Query(uid) {} - - explicit AuthQuery(int uid, Action action, std::string user, std::string role, - std::string user_or_role, Expression *password, - std::vector<Privilege> privileges) - : Query(uid), - action_(action), + AuthQuery(Action action, std::string user, std::string role, + std::string user_or_role, Expression *password, + std::vector<Privilege> privileges) + : action_(action), user_(user), role_(role), user_or_role_(user_or_role), @@ -2667,13 +2491,11 @@ cpp<# cpp<#) (:protected #>cpp - StreamQuery(int uid) : Query(uid) {} - StreamQuery(int uid, Action action, std::string stream_name, - Expression *stream_uri, Expression *stream_topic, - Expression *transform_uri, Expression *batch_interval_in_ms, - Expression *batch_size, Expression *limit_batches) - : Query(uid), - action_(action), + StreamQuery(Action action, std::string stream_name, Expression *stream_uri, + Expression *stream_topic, Expression *transform_uri, + Expression *batch_interval_in_ms, Expression *batch_size, + Expression *limit_batches) + : action_(action), stream_name_(std::move(stream_name)), stream_uri_(stream_uri), stream_topic_(stream_topic), diff --git a/src/query/frontend/ast/ast_serialization.lcp b/src/query/frontend/ast/ast_serialization.lcp index 3de307b94..3a8d0a8e9 100644 --- a/src/query/frontend/ast/ast_serialization.lcp +++ b/src/query/frontend/ast/ast_serialization.lcp @@ -15,26 +15,22 @@ cpp<# #>cpp /// Primary function for saving Ast nodes via SLK. -void SaveAstPointer(const Tree *ast, slk::Builder *builder, - std::vector<int32_t> *saved_uids); +void SaveAstPointer(const Tree *ast, slk::Builder *builder); -Tree *Load(AstStorage *ast, const capnp::Tree::Reader &tree, - std::vector<int32_t> *loaded_uids); +Tree *Load(AstStorage *ast, const capnp::Tree::Reader &tree); -Tree *Load(AstStorage *ast, slk::Reader *reader, - std::vector<int32_t> *loaded_uids); +Tree *Load(AstStorage *ast, slk::Reader *reader); /// Primary function for loading Ast nodes via SLK. template <class TAst> -TAst *LoadAstPointer(AstStorage *ast, slk::Reader *reader, - std::vector<int32_t> *loaded_uids) { +TAst *LoadAstPointer(AstStorage *ast, slk::Reader *reader) { static_assert(std::is_base_of<query::Tree, TAst>::value); bool has_ptr = false; slk::Load(&has_ptr, reader); if (!has_ptr) { return nullptr; } - auto *ret = utils::Downcast<TAst>(Load(ast, reader, loaded_uids)); + auto *ret = utils::Downcast<TAst>(Load(ast, reader)); if (!ret) { throw slk::SlkDecodeException("Loading unknown Ast node type"); } @@ -44,58 +40,25 @@ cpp<# (lcp:in-impl #>cpp - void SaveAstPointer(const Tree *ast, slk::Builder *builder, - std::vector<int32_t> *saved_uids) { + void SaveAstPointer(const Tree *ast, slk::Builder *builder) { slk::Save(static_cast<bool>(ast), builder); if (!ast) { return; } - slk::Save(ast->uid_, builder); - if (utils::Contains(*saved_uids, ast->uid_)) { - return; - } - slk::Save(*ast, builder, saved_uids); - CHECK(!utils::Contains(*saved_uids, ast->uid_)) << "Serializing cyclic AST"; - saved_uids->push_back(ast->uid_); + slk::Save(*ast, builder); } - Tree *Load(AstStorage *ast, slk::Reader *reader, - std::vector<int32_t> *loaded_uids) { - // Check if element already deserialized and if yes, return existing - // element from storage. - int32_t uid; - slk::Load(&uid, reader); - if (utils::Contains(*loaded_uids, uid)) { - auto found = std::find_if(ast->storage_.begin(), ast->storage_.end(), - [&](const auto &n) { return n->uid_ == uid; }); - CHECK(found != ast->storage_.end()); - return found->get(); - } + Tree *Load(AstStorage *ast, slk::Reader *reader) { std::unique_ptr<Tree> root; - slk::ConstructAndLoad(&root, reader, ast, loaded_uids); - root->uid_ = uid; + slk::ConstructAndLoad(&root, reader, ast); ast->storage_.emplace_back(std::move(root)); - loaded_uids->push_back(uid); - ast->max_existing_uid_ = std::max(ast->max_existing_uid_, uid); return ast->storage_.back().get(); } - Tree *Load(AstStorage *ast, const capnp::Tree::Reader &tree, - std::vector<int> *loaded_uids) { - // Check if element already deserialized and if yes, return existing - // element from storage. - auto uid = tree.getUid(); - if (utils::Contains(*loaded_uids, uid)) { - auto found = std::find_if(ast->storage_.begin(), ast->storage_.end(), - [&](const auto &n) { return n->uid_ == uid; }); - CHECK(found != ast->storage_.end()); - return found->get(); - } + Tree *Load(AstStorage *ast, const capnp::Tree::Reader &tree) { std::unique_ptr<Tree> root; - ::query::Load(&root, tree, ast, loaded_uids); + ::query::Load(&root, tree, ast); ast->storage_.emplace_back(std::move(root)); - loaded_uids->emplace_back(uid); - ast->max_existing_uid_ = std::max(ast->max_existing_uid_, uid); return ast->storage_.back().get(); } cpp<#) diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index d67f8cdd9..4bf75c321 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -51,7 +51,8 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) { user_symbols.emplace_back(sym_pair.second); } if (user_symbols.empty()) { - throw SemanticException("There are no variables in scope to use for '*'."); + throw SemanticException( + "There are no variables in scope to use for '*'."); } } // WITH/RETURN clause removes declarations of all the previous variables and @@ -79,8 +80,8 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) { } // An improvement would be to infer the type of the expression, so that the // new symbol would have a more specific type. - symbol_table_[*named_expr] = CreateSymbol(name, true, Symbol::Type::ANY, - named_expr->token_position_); + named_expr->MapTo(CreateSymbol(name, true, Symbol::Type::ANY, + named_expr->token_position_)); } scope_.in_order_by = true; for (const auto &order_pair : body.order_by) { @@ -198,7 +199,7 @@ bool SymbolGenerator::PostVisit(Unwind &unwind) { if (HasSymbol(name)) { throw RedeclareVariableError(name); } - symbol_table_[*unwind.named_expression_] = CreateSymbol(name, true); + unwind.named_expression_->MapTo(CreateSymbol(name, true)); return true; } @@ -212,7 +213,7 @@ bool SymbolGenerator::PostVisit(Match &) { // reference symbols out of bind order. for (auto &ident : scope_.identifiers_in_match) { if (!HasSymbol(ident->name_)) throw UnboundVariableError(ident->name_); - symbol_table_[*ident] = scope_.symbols[ident->name_]; + ident->MapTo(scope_.symbols[ident->name_]); } scope_.identifiers_in_match.clear(); return true; @@ -270,7 +271,7 @@ SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) { if (!HasSymbol(ident.name_)) throw UnboundVariableError(ident.name_); symbol = scope_.symbols[ident.name_]; } - symbol_table_[ident] = symbol; + ident.MapTo(symbol); return true; } @@ -302,9 +303,8 @@ bool SymbolGenerator::PreVisit(Aggregation &aggr) { // Create a virtual symbol for aggregation result. // Currently, we only have aggregation operators which return numbers. auto aggr_name = - Aggregation::OpToString(aggr.op_) + std::to_string(aggr.uid_); - symbol_table_[aggr] = - symbol_table_.CreateSymbol(aggr_name, false, Symbol::Type::NUMBER); + Aggregation::OpToString(aggr.op_) + std::to_string(aggr.symbol_pos_); + aggr.MapTo(CreateSymbol(aggr_name, false, Symbol::Type::NUMBER)); scope_.in_aggregation = true; scope_.has_aggregation = true; return true; @@ -435,12 +435,12 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) { } else { // Create inner symbols, but don't bind them in scope, since they are to // be used in the missing filter expression. - const auto *inner_edge = edge_atom.filter_lambda_.inner_edge; - symbol_table_[*inner_edge] = symbol_table_.CreateSymbol( - inner_edge->name_, inner_edge->user_declared_, Symbol::Type::EDGE); - const auto *inner_node = edge_atom.filter_lambda_.inner_node; - symbol_table_[*inner_node] = symbol_table_.CreateSymbol( - inner_node->name_, inner_node->user_declared_, Symbol::Type::VERTEX); + auto *inner_edge = edge_atom.filter_lambda_.inner_edge; + inner_edge->MapTo(symbol_table_.CreateSymbol( + inner_edge->name_, inner_edge->user_declared_, Symbol::Type::EDGE)); + auto *inner_node = edge_atom.filter_lambda_.inner_node; + inner_node->MapTo(symbol_table_.CreateSymbol( + inner_node->name_, inner_node->user_declared_, Symbol::Type::VERTEX)); } if (edge_atom.weight_lambda_.expression) { VisitWithIdentifiers(edge_atom.weight_lambda_.expression, @@ -456,9 +456,9 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) { if (HasSymbol(edge_atom.total_weight_->name_)) { throw RedeclareVariableError(edge_atom.total_weight_->name_); } - symbol_table_[*edge_atom.total_weight_] = GetOrCreateSymbol( + edge_atom.total_weight_->MapTo(GetOrCreateSymbol( edge_atom.total_weight_->name_, edge_atom.total_weight_->user_declared_, - Symbol::Type::NUMBER); + Symbol::Type::NUMBER)); } return false; } @@ -480,8 +480,8 @@ void SymbolGenerator::VisitWithIdentifiers( if (prev_symbol_it != scope_.symbols.end()) { prev_symbol = prev_symbol_it->second; } - symbol_table_[*identifier] = - CreateSymbol(identifier->name_, identifier->user_declared_); + identifier->MapTo( + CreateSymbol(identifier->name_, identifier->user_declared_)); prev_symbols.emplace_back(prev_symbol, identifier); } // Visit the expression with the new symbols bound. diff --git a/src/query/frontend/semantic/symbol_table.hpp b/src/query/frontend/semantic/symbol_table.hpp index 7951d0d26..613a04973 100644 --- a/src/query/frontend/semantic/symbol_table.hpp +++ b/src/query/frontend/semantic/symbol_table.hpp @@ -11,25 +11,32 @@ namespace query { class SymbolTable final { public: SymbolTable() {} - Symbol CreateSymbol(const std::string &name, bool user_declared, - Symbol::Type type = Symbol::Type::ANY, - int32_t token_position = -1) { - int32_t position = position_++; - return Symbol(name, position, user_declared, type, token_position); + const Symbol &CreateSymbol(const std::string &name, bool user_declared, + Symbol::Type type = Symbol::Type::ANY, + int32_t token_position = -1) { + CHECK(table_.size() <= std::numeric_limits<int32_t>::max()) + << "SymbolTable size doesn't fit into 32-bit integer!"; + int32_t position = static_cast<int32_t>(table_.size()); + table_.emplace_back(name, position, user_declared, type, token_position); + return table_.back(); } - auto &operator[](const Tree &tree) { return table_[tree.uid_]; } - - Symbol &at(const Tree &tree) { return table_.at(tree.uid_); } - const Symbol &at(const Tree &tree) const { return table_.at(tree.uid_); } + const Symbol &at(const Identifier &ident) const { + return table_.at(ident.symbol_pos_); + } + const Symbol &at(const NamedExpression &nexpr) const { + return table_.at(nexpr.symbol_pos_); + } + const Symbol &at(const Aggregation &aggr) const { + return table_.at(aggr.symbol_pos_); + } // TODO: Remove these since members are public - int32_t max_position() const { return position_; } + int32_t max_position() const { return static_cast<int32_t>(table_.size()); } const auto &table() const { return table_; } - int32_t position_{0}; - std::map<int32_t, Symbol> table_; + std::vector<Symbol> table_; }; } // namespace query diff --git a/src/query/plan/distributed.cpp b/src/query/plan/distributed.cpp index b290e37ac..69dd0d963 100644 --- a/src/query/plan/distributed.cpp +++ b/src/query/plan/distributed.cpp @@ -127,9 +127,10 @@ class IndependentSubtreeFinder : public DistributedOperatorVisitor { // a) Extract to ScanAllByLabel + Filter x2 auto make_prop_lookup = [&]() { - auto ident = storage_->Create<Identifier>( - scan.output_symbol_.name(), scan.output_symbol_.user_declared()); - (*symbol_table_)[*ident] = scan.output_symbol_; + auto ident = storage_ + ->Create<Identifier>(scan.output_symbol_.name(), + scan.output_symbol_.user_declared()) + ->MapTo(scan.output_symbol_); // TODO: When this extraction of a filter is removed, also remove // property_name from ScanAll operators. return storage_->Create<PropertyLookup>( @@ -224,9 +225,10 @@ class IndependentSubtreeFinder : public DistributedOperatorVisitor { // Split to ScanAllByLabel + Filter on property auto subtree = std::make_shared<ScanAllByLabel>( scan.input(), scan.output_symbol_, scan.label_, scan.graph_view_); - auto ident = storage_->Create<Identifier>( - scan.output_symbol_.name(), scan.output_symbol_.user_declared()); - (*symbol_table_)[*ident] = scan.output_symbol_; + auto ident = storage_ + ->Create<Identifier>(scan.output_symbol_.name(), + scan.output_symbol_.user_declared()) + ->MapTo(scan.output_symbol_); auto prop_lookup = storage_->Create<PropertyLookup>( ident, storage_->GetPropertyIx(scan.property_name_)); auto prop_equal = @@ -1259,20 +1261,21 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor { } auto make_ident = [this](const auto &symbol) { auto *ident = - distributed_plan_.ast_storage.Create<Identifier>(symbol.name()); - distributed_plan_.symbol_table[*ident] = symbol; + distributed_plan_.ast_storage.Create<Identifier>(symbol.name()) + ->MapTo(symbol); return ident; }; auto make_named_expr = [&](const auto &in_sym, const auto &out_sym) { - auto *nexpr = distributed_plan_.ast_storage.Create<NamedExpression>( - out_sym.name(), make_ident(in_sym)); - distributed_plan_.symbol_table[*nexpr] = out_sym; + auto *nexpr = + distributed_plan_.ast_storage + .Create<NamedExpression>(out_sym.name(), make_ident(in_sym)) + ->MapTo(out_sym); return nexpr; }; auto make_merge_aggregation = [&](auto op, const auto &worker_sym) { auto *worker_ident = make_ident(worker_sym); auto merge_name = Aggregation::OpToString(op) + - std::to_string(worker_ident->uid_) + "<-" + + std::to_string(worker_ident->symbol_pos_) + "<-" + worker_sym.name(); auto merge_sym = distributed_plan_.symbol_table.CreateSymbol( merge_name, false, Symbol::Type::NUMBER); @@ -1338,9 +1341,10 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor { auto *div_expr = distributed_plan_.ast_storage.Create<DivisionOperator>( master_sum_ident, to_float); - auto *as_avg = distributed_plan_.ast_storage.Create<NamedExpression>( - aggr.output_sym.name(), div_expr); - distributed_plan_.symbol_table[*as_avg] = aggr.output_sym; + auto *as_avg = + distributed_plan_.ast_storage + .Create<NamedExpression>(aggr.output_sym.name(), div_expr) + ->MapTo(aggr.output_sym); produce_exprs.emplace_back(as_avg); break; } diff --git a/src/query/plan/distributed_ops.lcp b/src/query/plan/distributed_ops.lcp index d7a49c88e..fddc10d85 100644 --- a/src/query/plan/distributed_ops.lcp +++ b/src/query/plan/distributed_ops.lcp @@ -291,22 +291,14 @@ by having only one result from each worker.") :capnp-load (load-ast-pointer "Expression *")) (filter-lambda "ExpansionLambda" :scope :public :documentation "Filter that must be satisfied for expansion to succeed." - :slk-save (lambda (member) - #>cpp - slk::Save(self.${member}, builder, &helper->saved_ast_uids); - cpp<#) :slk-load (lambda (member) #>cpp - slk::Load(&self->${member}, reader, &helper->ast_storage, &helper->loaded_ast_uids); + slk::Load(&self->${member}, reader, &helper->ast_storage); cpp<#) :capnp-type "ExpansionLambda" - :capnp-save (lambda (builder member capnp-name) - #>cpp - Save(${member}, &${builder}, &helper->saved_ast_uids); - cpp<#) :capnp-load (lambda (reader member capnp-name) #>cpp - Load(&${member}, ${reader}, &helper->ast_storage, &helper->loaded_ast_uids); + Load(&${member}, ${reader}, &helper->ast_storage); cpp<#))) (:documentation "BFS expansion operator suited for distributed execution.") (:public diff --git a/src/query/plan/operator.lcp b/src/query/plan/operator.lcp index ce60ee373..596b5749f 100644 --- a/src/query/plan/operator.lcp +++ b/src/query/plan/operator.lcp @@ -206,20 +206,17 @@ can serve as inputs to others and thus a sequence of operations is formed.") virtual void set_input(std::shared_ptr<LogicalOperator>) = 0; struct SaveHelper { - std::vector<int32_t> saved_ast_uids; std::vector<LogicalOperator *> saved_ops; }; struct LoadHelper { AstStorage ast_storage; - std::vector<int32_t> loaded_ast_uids; std::vector<std::pair<uint64_t, std::shared_ptr<LogicalOperator>>> loaded_ops; }; struct SlkLoadHelper { AstStorage ast_storage; - std::vector<int32_t> loaded_ast_uids; std::vector<std::shared_ptr<LogicalOperator>> loaded_ops; }; cpp<#) @@ -236,21 +233,21 @@ can serve as inputs to others and thus a sequence of operations is formed.") (defun slk-save-ast-pointer (member) #>cpp - query::SaveAstPointer(self.${member}, builder, &helper->saved_ast_uids); + query::SaveAstPointer(self.${member}, builder); cpp<#) (defun slk-load-ast-pointer (type) (lambda (member) #>cpp self->${member} = query::LoadAstPointer<query::${type}>( - &helper->ast_storage, reader, &helper->loaded_ast_uids); + &helper->ast_storage, reader); cpp<#)) (defun save-ast-pointer (builder member capnp-name) #>cpp if (${member}) { auto ${capnp-name}_builder = ${builder}->init${capnp-name}(); - Save(*${member}, &${capnp-name}_builder, &helper->saved_ast_uids); + Save(*${member}, &${capnp-name}_builder); } cpp<#) @@ -259,8 +256,7 @@ can serve as inputs to others and thus a sequence of operations is formed.") #>cpp if (${reader}.has${capnp-name}()) { ${member} = static_cast<${ast-type}>(Load(&helper->ast_storage, - ${reader}.get${capnp-name}(), - &helper->loaded_ast_uids)); + ${reader}.get${capnp-name}())); } else { ${member} = nullptr; } @@ -271,7 +267,7 @@ can serve as inputs to others and thus a sequence of operations is formed.") size_t size = self.${member}.size(); slk::Save(size, builder); for (const auto *val : self.${member}) { - query::SaveAstPointer(val, builder, &helper->saved_ast_uids); + query::SaveAstPointer(val, builder); } cpp<#) @@ -285,14 +281,14 @@ can serve as inputs to others and thus a sequence of operations is formed.") i < size; ++i) { self->${member}[i] = query::LoadAstPointer<query::${type}>( - &helper->ast_storage, reader, &helper->loaded_ast_uids); + &helper->ast_storage, reader); } cpp<#)) (defun save-ast-vector (ast-type) (lcp:capnp-save-vector "::query::capnp::Tree" ast-type "[helper](auto *builder, const auto &val) { - Save(*val, builder, &helper->saved_ast_uids); + Save(*val, builder); }")) (defun load-ast-vector (ast-type) @@ -300,8 +296,7 @@ can serve as inputs to others and thus a sequence of operations is formed.") (format nil "[helper](const auto &reader) { - return static_cast<~A>(Load(&helper->ast_storage, reader, - &helper->loaded_ast_uids)); + return static_cast<~A>(Load(&helper->ast_storage, reader)); }" ast-type))) @@ -383,7 +378,7 @@ and false on every following Pull.") slk::Save(size, builder); for (const auto &kv : self.${member}) { slk::Save(kv.first, builder); - query::SaveAstPointer(kv.second, builder, &helper->saved_ast_uids); + query::SaveAstPointer(kv.second, builder); } cpp<#) @@ -396,7 +391,7 @@ and false on every following Pull.") storage::Property prop; slk::Load(&prop, reader); auto *expr = query::LoadAstPointer<query::Expression>( - &helper->ast_storage, reader, &helper->loaded_ast_uids); + &helper->ast_storage, reader); self->${member}[i] = {prop, expr}; } cpp<#) @@ -407,7 +402,7 @@ and false on every following Pull.") auto prop_builder = ${builder}[i].initFirst(); storage::Save(${member}[i].first, &prop_builder); auto expr_builder = ${builder}[i].initSecond(); - Save(*${member}[i].second, &expr_builder, &helper->saved_ast_uids); + Save(*${member}[i].second, &expr_builder); } cpp<#) @@ -418,7 +413,7 @@ and false on every following Pull.") storage::Property prop; storage::Load(&prop, prop_reader); auto *expr = static_cast<Expression *>(Load( - &helper->ast_storage, pair_reader.getSecond(), &helper->loaded_ast_uids)); + &helper->ast_storage, pair_reader.getSecond())); ${member}.emplace_back(prop, expr); } cpp<#) @@ -725,7 +720,7 @@ given label. break; } slk::Save(bound_type, builder); - query::SaveAstPointer(bound.value(), builder, &helper->saved_ast_uids); + query::SaveAstPointer(bound.value(), builder); cpp<#) (defun slk-load-optional-bound (member) @@ -750,7 +745,7 @@ given label. throw slk::SlkDecodeException("Loading unknown BoundType"); } auto *value = query::LoadAstPointer<query::Expression>( - &helper->ast_storage, reader, &helper->loaded_ast_uids); + &helper->ast_storage, reader); self->${member}.emplace(utils::Bound<query::Expression *>(value, bound_type)); cpp<#) @@ -761,7 +756,7 @@ given label. ::utils::capnp::Bound<::query::capnp::Tree>::Type::INCLUSIVE : ::utils::capnp::Bound<::query::capnp::Tree>::Type::EXCLUSIVE); auto value_builder = builder->initValue(); - Save(*bound.value(), &value_builder, &helper->saved_ast_uids); + Save(*bound.value(), &value_builder); }")) (funcall (lcp:capnp-save-optional "::utils::capnp::Bound<::query::capnp::Tree>" "utils::Bound<Expression *>" @@ -776,7 +771,7 @@ given label. ? utils::BoundType::INCLUSIVE : utils::BoundType::EXCLUSIVE; auto *value = static_cast<Expression *>( - Load(&helper->ast_storage, reader.getValue(), &helper->loaded_ast_uids)); + Load(&helper->ast_storage, reader.getValue())); return utils::Bound<Expression *>(value, type); }")) (funcall (lcp:capnp-load-optional "::utils::capnp::Bound<::query::capnp::Tree>" @@ -1023,40 +1018,25 @@ pulled.") ((inner-edge-symbol "Symbol" :documentation "Currently expanded edge symbol.") (inner-node-symbol "Symbol" :documentation "Currently expanded node symbol.") (expression "Expression *" :documentation "Expression used in lambda during expansion." - :slk-save (lambda (member) - #>cpp - query::SaveAstPointer(self.${member}, builder, saved_ast_uids); - cpp<#) + :slk-save #'slk-save-ast-pointer :slk-load (lambda (member) #>cpp self->${member} = query::LoadAstPointer<query::Expression>( - ast_storage, reader, loaded_ast_uids); + ast_storage, reader); cpp<#) :capnp-type "Ast.Tree" :capnp-init nil - :capnp-save (lambda (builder member capnp-name) - #>cpp - if (${member}) { - auto ${capnp-name}_builder = ${builder}->init${capnp-name}(); - Save(*${member}, &${capnp-name}_builder, saved_ast_uids); - } - cpp<#) + :capnp-save #'save-ast-pointer :capnp-load (lambda (reader member capnp-name) #>cpp if (${reader}.has${capnp-name}()) { ${member} = static_cast<Expression *>(Load(ast_storage, - ${reader}.get${capnp-name}(), - loaded_ast_uids)); + ${reader}.get${capnp-name}())); } else { ${member} = nullptr; } cpp<#))) - (:serialize (:slk :save-args '((saved-ast-uids "std::vector<int32_t> *")) - :load-args '((ast-storage "query::AstStorage *") - (loaded-ast-uids "std::vector<int32_t> *"))) - (:capnp - :save-args '((saved-ast-uids "std::vector<int32_t> *")) - :load-args '((ast-storage "AstStorage *") - (loaded-ast-uids "std::vector<int32_t> *")))) + (:serialize (:slk :load-args '((ast-storage "query::AstStorage *"))) + (:capnp :load-args '((ast-storage "AstStorage *")))) (:clone :args '((storage "AstStorage *")))) (lcp:define-class expand-variable (logical-operator) @@ -1089,31 +1069,15 @@ pulled.") :documentation "Optional upper bound of the variable length expansion, defaults are (1, inf)") (filter-lambda "ExpansionLambda" :scope :public - :slk-save (lambda (member) - #>cpp - slk::Save(self.${member}, builder, &helper->saved_ast_uids); - cpp<#) :slk-load (lambda (member) #>cpp - slk::Load(&self->${member}, reader, &helper->ast_storage, &helper->loaded_ast_uids); + slk::Load(&self->${member}, reader, &helper->ast_storage); cpp<#) - :capnp-save (lambda (builder member capnp-name) - #>cpp - Save(${member}, &${builder}, &helper->saved_ast_uids); - cpp<#) :capnp-load (lambda (reader member capnp-name) #>cpp - Load(&${member}, ${reader}, &helper->ast_storage, &helper->loaded_ast_uids); + Load(&${member}, ${reader}, &helper->ast_storage); cpp<#)) (weight-lambda "std::experimental::optional<ExpansionLambda>" :scope :public - :slk-save (lambda (member) - #>cpp - slk::Save(static_cast<bool>(self.${member}), builder); - if (!self.${member}) { - return; - } - slk::Save(*self.${member}, builder, &helper->saved_ast_uids); - cpp<#) :slk-load (lambda (member) #>cpp bool has_value; @@ -1123,19 +1087,17 @@ pulled.") return; } query::plan::ExpansionLambda lambda; - slk::Load(&lambda, reader, &helper->ast_storage, &helper->loaded_ast_uids); + slk::Load(&lambda, reader, &helper->ast_storage); self->${member}.emplace(lambda); cpp<#) :capnp-save (lcp:capnp-save-optional - "capnp::ExpansionLambda" "ExpansionLambda" - "[helper](auto *builder, const auto &val) { - Save(val, builder, &helper->saved_ast_uids); - }") + "capnp::ExpansionLambda" + "ExpansionLambda") :capnp-load (lcp:capnp-load-optional "capnp::ExpansionLambda" "ExpansionLambda" "[helper](const auto &reader) { ExpansionLambda val; - Load(&val, reader, &helper->ast_storage, &helper->loaded_ast_uids); + Load(&val, reader, &helper->ast_storage); return val; }")) (total-weight "std::experimental::optional<Symbol>" :scope :public diff --git a/src/query/plan/preprocess.cpp b/src/query/plan/preprocess.cpp index f44ebd866..d1cd0da99 100644 --- a/src/query/plan/preprocess.cpp +++ b/src/query/plan/preprocess.cpp @@ -161,8 +161,7 @@ PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, } PropertyFilter::PropertyFilter( - const SymbolTable &symbol_table, const Symbol &symbol, - PropertyIx property, + const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property, const std::experimental::optional<PropertyFilter::Bound> &lower_bound, const std::experimental::optional<PropertyFilter::Bound> &upper_bound) : symbol_(symbol), @@ -251,10 +250,12 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, collector.symbols_.insert(symbol); // PropertyLookup uses the symbol. // Now handle the post-expansion filter. // Create a new identifier and a symbol which will be filled in All. - auto *identifier = storage.Create<Identifier>( - atom->identifier_->name_, atom->identifier_->user_declared_); - symbol_table[*identifier] = - symbol_table.CreateSymbol(identifier->name_, false); + auto *identifier = + storage + .Create<Identifier>(atom->identifier_->name_, + atom->identifier_->user_declared_) + ->MapTo( + symbol_table.CreateSymbol(atom->identifier_->name_, false)); // Create an equality expression and store it in all_filters_. auto *property_lookup = storage.Create<PropertyLookup>(identifier, prop_pair.first); @@ -282,8 +283,8 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, FilterInfo filter_info{FilterInfo::Type::Property, prop_equal, collector.symbols_}; // Store a PropertyFilter on the value of the property. - filter_info.property_filter.emplace( - symbol_table, symbol, prop_pair.first, prop_pair.second); + filter_info.property_filter.emplace(symbol_table, symbol, prop_pair.first, + prop_pair.second); all_filters_.emplace_back(filter_info); } }; diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index c9bc1d20c..bb29a1d75 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -370,10 +370,9 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { if (!symbol.user_declared()) { continue; } - auto *ident = storage_.Create<Identifier>(symbol.name()); - symbol_table_[*ident] = symbol; - auto *named_expr = storage_.Create<NamedExpression>(symbol.name(), ident); - symbol_table_[*named_expr] = symbol; + auto *ident = storage_.Create<Identifier>(symbol.name())->MapTo(symbol); + auto *named_expr = + storage_.Create<NamedExpression>(symbol.name(), ident)->MapTo(symbol); // Fill output expressions and symbols with expanded identifiers. named_expressions_.emplace_back(named_expr); output_symbols_.emplace_back(symbol); diff --git a/src/query/serialization.capnp b/src/query/serialization.capnp index 02ebc22ef..bf761a755 100644 --- a/src/query/serialization.capnp +++ b/src/query/serialization.capnp @@ -43,11 +43,5 @@ struct TypedValue { } struct SymbolTable { - position @0 :Int32; - table @1 :List(Entry); - - struct Entry { - key @0 :Int32; - val @1 :Sem.Symbol; - } + table @0 :List(Sem.Symbol); } diff --git a/src/query/serialization.hpp b/src/query/serialization.hpp index d7f7ef23c..09483d863 100644 --- a/src/query/serialization.hpp +++ b/src/query/serialization.hpp @@ -7,6 +7,7 @@ #include "query/serialization.capnp.h" #include "query/typed_value.hpp" #include "storage/distributed/rpc/serialization.hpp" +#include "utils/serialization.hpp" namespace distributed { class DataManager; @@ -31,27 +32,20 @@ void Load(TypedValueVectorCompare *comparator, inline void Save(const SymbolTable &symbol_table, capnp::SymbolTable::Builder *builder) { - builder->setPosition(symbol_table.max_position()); auto list_builder = builder->initTable(symbol_table.table().size()); - size_t i = 0; - for (const auto &entry : symbol_table.table()) { - auto entry_builder = list_builder[i++]; - entry_builder.setKey(entry.first); - auto sym_builder = entry_builder.initVal(); - Save(entry.second, &sym_builder); - } + utils::SaveVector<capnp::Symbol, Symbol>( + symbol_table.table(), &list_builder, + [](auto *builder, const auto &symbol) { Save(symbol, builder); }); } inline void Load(SymbolTable *symbol_table, const capnp::SymbolTable::Reader &reader) { - symbol_table->position_ = reader.getPosition(); - symbol_table->table_.clear(); - for (const auto &entry_reader : reader.getTable()) { - int key = entry_reader.getKey(); - Symbol val; - Load(&val, entry_reader.getVal()); - symbol_table->table_[key] = val; - } + utils::LoadVector<capnp::Symbol, Symbol>( + &symbol_table->table_, reader.getTable(), [](const auto &reader) { + Symbol val; + Load(&val, reader); + return val; + }); } void Save(const Parameters ¶meters, @@ -69,12 +63,10 @@ namespace slk { inline void Save(const query::SymbolTable &symbol_table, slk::Builder *builder) { - slk::Save(symbol_table.position_, builder); slk::Save(symbol_table.table_, builder); } inline void Load(query::SymbolTable *symbol_table, slk::Reader *reader) { - slk::Load(&symbol_table->position_, reader); slk::Load(&symbol_table->table_, reader); } diff --git a/tests/unit/ast_serialization.cpp b/tests/unit/ast_serialization.cpp index 1e5c51d89..d453bd454 100644 --- a/tests/unit/ast_serialization.cpp +++ b/tests/unit/ast_serialization.cpp @@ -86,15 +86,13 @@ class CapnpAstGenerator : public Base { { query::capnp::Tree::Builder builder = message.initRoot<query::capnp::Tree>(); - std::vector<int> saved_uids; - Save(*visitor.query(), &builder, &saved_uids); + Save(*visitor.query(), &builder); } { const query::capnp::Tree::Reader reader = message.getRoot<query::capnp::Tree>(); - std::vector<int> loaded_uids; - query_ = dynamic_cast<Query *>(Load(&storage_, reader, &loaded_uids)); + query_ = dynamic_cast<Query *>(Load(&storage_, reader)); } } @@ -124,14 +122,12 @@ class SlkAstGenerator : public Base { slk::Builder builder; { - std::vector<int32_t> saved_uids; - SaveAstPointer(visitor.query(), &builder, &saved_uids); + SaveAstPointer(visitor.query(), &builder); } { slk::Reader reader(builder.data(), builder.size()); - std::vector<int32_t> loaded_uids; - query_ = LoadAstPointer<Query>(&storage_, &reader, &loaded_uids); + query_ = LoadAstPointer<Query>(&storage_, &reader); } } diff --git a/tests/unit/bfs_common.hpp b/tests/unit/bfs_common.hpp index 34956d0a0..b50ea55db 100644 --- a/tests/unit/bfs_common.hpp +++ b/tests/unit/bfs_common.hpp @@ -322,12 +322,9 @@ void BfsTest(Database *db, int lower_bound, int upper_bound, context.symbol_table.CreateSymbol("inner_node", true); query::Symbol inner_edge_sym = context.symbol_table.CreateSymbol("inner_edge", true); - query::Identifier *blocked = IDENT("blocked"); - query::Identifier *inner_node = IDENT("inner_node"); - query::Identifier *inner_edge = IDENT("inner_edge"); - context.symbol_table[*blocked] = blocked_sym; - context.symbol_table[*inner_node] = inner_node_sym; - context.symbol_table[*inner_edge] = inner_edge_sym; + query::Identifier *blocked = IDENT("blocked")->MapTo(blocked_sym); + query::Identifier *inner_node = IDENT("inner_node")->MapTo(inner_node_sym); + query::Identifier *inner_edge = IDENT("inner_edge")->MapTo(inner_edge_sym); std::vector<VertexAddress> vertices; std::vector<EdgeAddress> edges; diff --git a/tests/unit/distributed_query_plan.cpp b/tests/unit/distributed_query_plan.cpp index e5e91772d..d7003d32a 100644 --- a/tests/unit/distributed_query_plan.cpp +++ b/tests/unit/distributed_query_plan.cpp @@ -45,11 +45,11 @@ ExpandTuple MakeDistributedExpand( GraphView graph_view) { auto edge = EDGE(edge_identifier, direction); auto edge_sym = symbol_table.CreateSymbol(edge_identifier, true); - symbol_table[*edge->identifier_] = edge_sym; + edge->identifier_->MapTo(edge_sym); auto node = NODE(node_identifier); auto node_sym = symbol_table.CreateSymbol(node_identifier, true); - symbol_table[*node->identifier_] = node_sym; + node->identifier_->MapTo(node_sym); auto op = std::make_shared<DistributedExpand>(input, input_symbol, node_sym, edge_sym, direction, edge_types, @@ -75,10 +75,9 @@ TEST_F(DistributedQueryPlan, PullProduceRpc) { LIST(LITERAL(42), LITERAL(true), LITERAL("bla"), LITERAL(1), LITERAL(2)); auto x = symbol_table.CreateSymbol("x", true); auto unwind = std::make_shared<plan::Unwind>(nullptr, list, x); - auto x_expr = IDENT("x"); - symbol_table[*x_expr] = x; - auto x_ne = NEXPR("x", x_expr); - symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne", true); + auto x_expr = IDENT("x")->MapTo(x); + auto x_ne = + NEXPR("x", x_expr)->MapTo(symbol_table.CreateSymbol("x_ne", true)); auto produce = MakeProduce(unwind, x_ne); // Test that the plan works locally. @@ -91,7 +90,7 @@ TEST_F(DistributedQueryPlan, PullProduceRpc) { tx::CommandId command_id = dba->transaction().cid(); auto &evaluation_context = ctx.evaluation_context; - std::vector<query::Symbol> symbols{ctx.symbol_table[*x_ne]}; + std::vector<query::Symbol> symbols{ctx.symbol_table.at(*x_ne)}; auto remote_pull = [this, &command_id, &evaluation_context, &symbols]( GraphDbAccessor &dba, int worker_id) { return master().pull_clients().Pull(&dba, worker_id, plan_id, command_id, @@ -168,18 +167,14 @@ TEST_F(DistributedQueryPlan, PullProduceRpcWithGraphElements) { auto p = std::make_shared<query::plan::ConstructNamedPath>( r_m.op_, p_sym, std::vector<Symbol>{n.sym_, r_m.edge_sym_, r_m.node_sym_}); - auto return_n = IDENT("n"); - symbol_table[*return_n] = n.sym_; - auto return_r = IDENT("r"); - symbol_table[*return_r] = r_m.edge_sym_; - auto return_n_r = NEXPR("[n, r]", LIST(return_n, return_r)); - symbol_table[*return_n_r] = symbol_table.CreateSymbol("", true); - auto return_m = NEXPR("m", IDENT("m")); - symbol_table[*return_m->expression_] = r_m.node_sym_; - symbol_table[*return_m] = symbol_table.CreateSymbol("", true); - auto return_p = NEXPR("p", IDENT("p")); - symbol_table[*return_p->expression_] = p_sym; - symbol_table[*return_p] = symbol_table.CreateSymbol("", true); + auto return_n = IDENT("n")->MapTo(n.sym_); + auto return_r = IDENT("r")->MapTo(r_m.edge_sym_); + auto return_n_r = NEXPR("[n, r]", LIST(return_n, return_r)) + ->MapTo(symbol_table.CreateSymbol("", true)); + auto return_m = NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_)) + ->MapTo(symbol_table.CreateSymbol("", true)); + auto return_p = NEXPR("p", IDENT("p")->MapTo(p_sym)) + ->MapTo(symbol_table.CreateSymbol("", true)); auto produce = MakeProduce(p, return_n_r, return_m, return_p); auto check_result = [prop](int worker_id, @@ -211,8 +206,8 @@ TEST_F(DistributedQueryPlan, PullProduceRpcWithGraphElements) { tx::CommandId command_id = dba->transaction().cid(); auto &evaluation_context = ctx.evaluation_context; - std::vector<query::Symbol> symbols{ctx.symbol_table[*return_n_r], - ctx.symbol_table[*return_m], p_sym}; + std::vector<query::Symbol> symbols{ctx.symbol_table.at(*return_n_r), + ctx.symbol_table.at(*return_m), p_sym}; auto remote_pull = [this, &command_id, &evaluation_context, &symbols]( GraphDbAccessor &dba, int worker_id) { return master().pull_clients().Pull(&dba, worker_id, plan_id, command_id, @@ -246,8 +241,7 @@ TEST_F(DistributedQueryPlan, Synchronize) { // SET auto literal = LITERAL(42); auto prop = PROPERTY_PAIR("prop"); - auto m_p = PROPERTY_LOOKUP("m", prop); - symbol_table[*m_p->expression_] = r_m.node_sym_; + auto m_p = PROPERTY_LOOKUP(IDENT("m")->MapTo(r_m.node_sym_), prop); auto set_m_p = std::make_shared<plan::SetProperty>(r_m.op_, prop.second, m_p, literal); @@ -261,11 +255,9 @@ TEST_F(DistributedQueryPlan, Synchronize) { std::make_shared<query::plan::Synchronize>(set_m_p, pull_remote, true); // RETURN - auto n_p = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_p->expression_] = n.sym_; - auto return_n_p = NEXPR("n.prop", n_p); - auto return_n_p_sym = symbol_table.CreateSymbol("n.p", true); - symbol_table[*return_n_p] = return_n_p_sym; + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto return_n_p = + NEXPR("n.prop", n_p)->MapTo(symbol_table.CreateSymbol("n.p", true)); auto produce = MakeProduce(synchronize, return_n_p); auto ctx = MakeContext(storage, symbol_table, &dba); auto results = CollectProduce(*produce, &ctx); @@ -331,8 +323,7 @@ TEST_F(DistributedQueryPlan, PullRemoteOrderBy) { // Query plan for: MATCH (n) RETURN n.prop ORDER BY n.prop; auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_p = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_p->expression_] = n.sym_; + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); auto order_by = std::make_shared<plan::OrderBy>( n.op_, std::vector<SortItem>{{Ordering::ASC, n_p}}, std::vector<Symbol>{n.sym_}); @@ -344,8 +335,8 @@ TEST_F(DistributedQueryPlan, PullRemoteOrderBy) { order_by, plan_id, std::vector<SortItem>{{Ordering::ASC, n_p}}, std::vector<Symbol>{n.sym_}); - auto n_p_ne = NEXPR("n.prop", n_p); - symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n.prop", true); + auto n_p_ne = + NEXPR("n.prop", n_p)->MapTo(symbol_table.CreateSymbol("n.prop", true)); auto produce = MakeProduce(pull_remote_order_by, n_p_ne); auto ctx = MakeContext(storage, symbol_table, &dba); auto results = CollectProduce(*produce, &ctx); @@ -374,10 +365,10 @@ TEST_F(DistributedTransactionTimeout, Timeout) { // Make distributed plan for MATCH (n) RETURN n auto scan_all = MakeScanAll(storage, symbol_table, "n"); - auto output = NEXPR("n", IDENT("n")); + auto output = + NEXPR("n", IDENT("n")->MapTo(scan_all.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(scan_all.op_, output); - symbol_table[*output->expression_] = scan_all.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); const int plan_id = 42; master().plan_dispatcher().DispatchPlan(plan_id, produce, symbol_table); @@ -387,7 +378,7 @@ TEST_F(DistributedTransactionTimeout, Timeout) { evaluation_context.properties = NamesToProperties(storage.properties_, dba.get()); evaluation_context.labels = NamesToLabels(storage.labels_, dba.get()); - std::vector<query::Symbol> symbols{symbol_table[*output]}; + std::vector<query::Symbol> symbols{symbol_table.at(*output)}; auto remote_pull = [this, &command_id, &evaluation_context, &symbols, &dba]() { return master() @@ -1205,7 +1196,8 @@ TYPED_TEST(TestPlanner, MatchReturnSum) { auto merge_sum = SUM(IDENT("worker_sum")); auto master_aggr = ExpectMasterAggregate({merge_sum}, {n_prop2}); ExpectPullRemote pull( - {symbol_table.at(*sum), symbol_table.at(*n_prop2->expression_)}); + {symbol_table.at(*sum), + symbol_table.at(*dynamic_cast<Identifier *>(n_prop2->expression_))}); auto expected = ExpectDistributed(MakeCheckers(ExpectScanAll(), aggr, pull, master_aggr, ExpectProduce(), ExpectProduce()), @@ -1314,9 +1306,11 @@ TYPED_TEST(TestPlanner, CreateWithOrderByWhere) { auto r_type = "r"; AstStorage storage; auto ident_n = IDENT("n"); + auto ident_r = IDENT("r"); + auto ident_m = IDENT("m"); auto new_prop = PROPERTY_LOOKUP("new", prop); - auto r_prop = PROPERTY_LOOKUP("r", prop); - auto m_prop = PROPERTY_LOOKUP("m", prop); + auto r_prop = PROPERTY_LOOKUP(ident_r, prop); + auto m_prop = PROPERTY_LOOKUP(ident_m, prop); auto query = QUERY(SINGLE_QUERY( CREATE( PATTERN(NODE("n"), EDGE("r", Direction::OUT, {r_type}), NODE("m"))), @@ -1325,9 +1319,9 @@ TYPED_TEST(TestPlanner, CreateWithOrderByWhere) { auto symbol_table = query::MakeSymbolTable(query); // Since this is a write query, we expect to accumulate to old used symbols. auto acc = ExpectAccumulate({ - symbol_table.at(*ident_n), // `n` in WITH - symbol_table.at(*r_prop->expression_), // `r` in ORDER BY - symbol_table.at(*m_prop->expression_), // `m` in WHERE + symbol_table.at(*ident_n), // `n` in WITH + symbol_table.at(*ident_r), // `r` in ORDER BY + symbol_table.at(*ident_m), // `m` in WHERE }); auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); auto expected = ExpectDistributed( @@ -1480,8 +1474,8 @@ TYPED_TEST(TestPlanner, DistributedAvg) { auto worker_aggr_op = std::dynamic_pointer_cast<Aggregate>(worker_plan); ASSERT_TRUE(worker_aggr_op); ASSERT_EQ(worker_aggr_op->aggregations_.size(), 2U); - symbol_table[*worker_sum] = worker_aggr_op->aggregations_[0].output_sym; - symbol_table[*worker_count] = worker_aggr_op->aggregations_[1].output_sym; + worker_sum->MapTo(worker_aggr_op->aggregations_[0].output_sym); + worker_count->MapTo(worker_aggr_op->aggregations_[1].output_sym); } auto worker_aggr = ExpectAggregate({worker_sum, worker_count}, {}); auto merge_sum = SUM(IDENT("worker_sum")); @@ -1817,16 +1811,13 @@ TEST(TestPlanner, DistributedCartesianIndexedScanByBothBounds) { auto sym_a = symbol_table.CreateSymbol("a", true); auto scan_a = std::make_shared<ScanAll>(nullptr, sym_a); auto sym_b = symbol_table.CreateSymbol("b", true); - query::Expression *lower_expr = IDENT("a"); - symbol_table[*lower_expr] = sym_a; + query::Expression *lower_expr = IDENT("a")->MapTo(sym_a); auto lower_bound = utils::MakeBoundExclusive(lower_expr); - query::Expression *upper_expr = IDENT("a"); - symbol_table[*upper_expr] = sym_a; + query::Expression *upper_expr = IDENT("a")->MapTo(sym_a); auto upper_bound = utils::MakeBoundExclusive(upper_expr); auto scan_b = std::make_shared<ScanAllByLabelPropertyRange>( scan_a, sym_b, label, prop, "prop", lower_bound, upper_bound); - auto ident_b = IDENT("b"); - symbol_table[*ident_b] = sym_b; + auto ident_b = IDENT("b")->MapTo(sym_b); auto as_b = NEXPR("b", ident_b); auto produce = std::make_shared<Produce>( scan_b, std::vector<query::NamedExpression *>{as_b}); @@ -1863,13 +1854,11 @@ TEST(TestPlanner, DistributedCartesianIndexedScanByLowerWithBothBounds) { auto sym_b = symbol_table.CreateSymbol("b", true); query::Expression *lower_expr = LITERAL(42); auto lower_bound = utils::MakeBoundExclusive(lower_expr); - query::Expression *upper_expr = IDENT("a"); - symbol_table[*upper_expr] = sym_a; + query::Expression *upper_expr = IDENT("a")->MapTo(sym_a); auto upper_bound = utils::MakeBoundExclusive(upper_expr); auto scan_b = std::make_shared<ScanAllByLabelPropertyRange>( scan_a, sym_b, label, prop, "prop", lower_bound, upper_bound); - auto ident_b = IDENT("b"); - symbol_table[*ident_b] = sym_b; + auto ident_b = IDENT("b")->MapTo(sym_b); auto as_b = NEXPR("b", ident_b); auto produce = std::make_shared<Produce>( scan_b, std::vector<query::NamedExpression *>{as_b}); @@ -1908,15 +1897,13 @@ TEST(TestPlanner, DistributedCartesianIndexedScanByUpperWithBothBounds) { auto sym_a = symbol_table.CreateSymbol("a", true); auto scan_a = std::make_shared<ScanAll>(nullptr, sym_a); auto sym_b = symbol_table.CreateSymbol("b", true); - query::Expression *lower_expr = IDENT("a"); - symbol_table[*lower_expr] = sym_a; + query::Expression *lower_expr = IDENT("a")->MapTo(sym_a); auto lower_bound = utils::MakeBoundExclusive(lower_expr); query::Expression *upper_expr = LITERAL(42); auto upper_bound = utils::MakeBoundExclusive(upper_expr); auto scan_b = std::make_shared<ScanAllByLabelPropertyRange>( scan_a, sym_b, label, prop, "prop", lower_bound, upper_bound); - auto ident_b = IDENT("b"); - symbol_table[*ident_b] = sym_b; + auto ident_b = IDENT("b")->MapTo(sym_b); auto as_b = NEXPR("b", ident_b); auto produce = std::make_shared<Produce>( scan_b, std::vector<query::NamedExpression *>{as_b}); diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 423f52fbf..e3bcbe0d0 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -44,7 +44,7 @@ class ExpressionEvaluatorTest : public ::testing::Test { const TypedValue &value) { auto id = storage.Create<Identifier>(name, true); auto symbol = symbol_table.CreateSymbol(name, true); - symbol_table[*id] = symbol; + id->MapTo(symbol); frame[symbol] = value; return id; } @@ -630,7 +630,7 @@ TEST_F(ExpressionEvaluatorTest, LabelsTest) { v1.add_label(dba->Label("NICE_DOG")); auto *identifier = storage.Create<Identifier>("n"); auto node_symbol = symbol_table.CreateSymbol("n", true); - symbol_table[*identifier] = node_symbol; + identifier->MapTo(node_symbol); frame[node_symbol] = v1; { auto *op = storage.Create<LabelsTest>( @@ -662,7 +662,7 @@ TEST_F(ExpressionEvaluatorTest, Aggregation) { auto aggr = storage.Create<Aggregation>(storage.Create<PrimitiveLiteral>(42), nullptr, Aggregation::Op::COUNT); auto aggr_sym = symbol_table.CreateSymbol("aggr", true); - symbol_table[*aggr] = aggr_sym; + aggr->MapTo(aggr_sym); frame[aggr_sym] = TypedValue(1); auto value = Eval(aggr); EXPECT_EQ(value.ValueInt(), 1); @@ -699,8 +699,8 @@ TEST_F(ExpressionEvaluatorTest, All) { auto *all = ALL("x", LIST(LITERAL(1), LITERAL(2)), WHERE(EQ(ident_x, LITERAL(1)))); const auto x_sym = symbol_table.CreateSymbol("x", true); - symbol_table[*all->identifier_] = x_sym; - symbol_table[*ident_x] = x_sym; + all->identifier_->MapTo(x_sym); + ident_x->MapTo(x_sym); auto value = Eval(all); ASSERT_TRUE(value.IsBool()); EXPECT_FALSE(value.ValueBool()); @@ -710,7 +710,7 @@ TEST_F(ExpressionEvaluatorTest, FunctionAllNullList) { AstStorage storage; auto *all = ALL("x", LITERAL(PropertyValue::Null), WHERE(LITERAL(true))); const auto x_sym = symbol_table.CreateSymbol("x", true); - symbol_table[*all->identifier_] = x_sym; + all->identifier_->MapTo(x_sym); auto value = Eval(all); EXPECT_TRUE(value.IsNull()); } @@ -719,7 +719,7 @@ TEST_F(ExpressionEvaluatorTest, FunctionAllWhereWrongType) { AstStorage storage; auto *all = ALL("x", LIST(LITERAL(1)), WHERE(LITERAL(2))); const auto x_sym = symbol_table.CreateSymbol("x", true); - symbol_table[*all->identifier_] = x_sym; + all->identifier_->MapTo(x_sym); EXPECT_THROW(Eval(all), QueryRuntimeException); } @@ -729,8 +729,8 @@ TEST_F(ExpressionEvaluatorTest, FunctionSingle) { auto *single = SINGLE("x", LIST(LITERAL(1), LITERAL(2)), WHERE(EQ(ident_x, LITERAL(1)))); const auto x_sym = symbol_table.CreateSymbol("x", true); - symbol_table[*single->identifier_] = x_sym; - symbol_table[*ident_x] = x_sym; + single->identifier_->MapTo(x_sym); + ident_x->MapTo(x_sym); auto value = Eval(single); ASSERT_TRUE(value.IsBool()); EXPECT_TRUE(value.ValueBool()); @@ -742,8 +742,8 @@ TEST_F(ExpressionEvaluatorTest, FunctionSingle2) { auto *single = SINGLE("x", LIST(LITERAL(1), LITERAL(2)), WHERE(GREATER(ident_x, LITERAL(0)))); const auto x_sym = symbol_table.CreateSymbol("x", true); - symbol_table[*single->identifier_] = x_sym; - symbol_table[*ident_x] = x_sym; + single->identifier_->MapTo(x_sym); + ident_x->MapTo(x_sym); auto value = Eval(single); ASSERT_TRUE(value.IsBool()); EXPECT_FALSE(value.ValueBool()); @@ -754,7 +754,7 @@ TEST_F(ExpressionEvaluatorTest, FunctionSingleNullList) { auto *single = SINGLE("x", LITERAL(PropertyValue::Null), WHERE(LITERAL(true))); const auto x_sym = symbol_table.CreateSymbol("x", true); - symbol_table[*single->identifier_] = x_sym; + single->identifier_->MapTo(x_sym); auto value = Eval(single); EXPECT_TRUE(value.IsNull()); } @@ -766,11 +766,11 @@ TEST_F(ExpressionEvaluatorTest, FunctionReduce) { auto *reduce = REDUCE("sum", LITERAL(0), "x", LIST(LITERAL(1), LITERAL(2)), ADD(ident_sum, ident_x)); const auto sum_sym = symbol_table.CreateSymbol("sum", true); - symbol_table[*reduce->accumulator_] = sum_sym; - symbol_table[*ident_sum] = sum_sym; + reduce->accumulator_->MapTo(sum_sym); + ident_sum->MapTo(sum_sym); const auto x_sym = symbol_table.CreateSymbol("x", true); - symbol_table[*reduce->identifier_] = x_sym; - symbol_table[*ident_x] = x_sym; + reduce->identifier_->MapTo(x_sym); + ident_x->MapTo(x_sym); auto value = Eval(reduce); ASSERT_TRUE(value.IsInt()); EXPECT_EQ(value.ValueInt(), 3); @@ -783,8 +783,8 @@ TEST_F(ExpressionEvaluatorTest, FunctionExtract) { EXTRACT("x", LIST(LITERAL(1), LITERAL(2), LITERAL(PropertyValue::Null)), ADD(ident_x, LITERAL(1))); const auto x_sym = symbol_table.CreateSymbol("x", true); - symbol_table[*extract->identifier_] = x_sym; - symbol_table[*ident_x] = x_sym; + extract->identifier_->MapTo(x_sym); + ident_x->MapTo(x_sym); auto value = Eval(extract); EXPECT_TRUE(value.IsList()); ; @@ -800,8 +800,8 @@ TEST_F(ExpressionEvaluatorTest, FunctionExtractNull) { auto *extract = EXTRACT("x", LITERAL(PropertyValue::Null), ADD(ident_x, LITERAL(1))); const auto x_sym = symbol_table.CreateSymbol("x", true); - symbol_table[*extract->identifier_] = x_sym; - symbol_table[*ident_x] = x_sym; + extract->identifier_->MapTo(x_sym); + ident_x->MapTo(x_sym); auto value = Eval(extract); EXPECT_TRUE(value.IsNull()); } @@ -811,8 +811,8 @@ TEST_F(ExpressionEvaluatorTest, FunctionExtractExceptions) { auto *ident_x = IDENT("x"); auto *extract = EXTRACT("x", LITERAL("bla"), ADD(ident_x, LITERAL(1))); const auto x_sym = symbol_table.CreateSymbol("x", true); - symbol_table[*extract->identifier_] = x_sym; - symbol_table[*ident_x] = x_sym; + extract->identifier_->MapTo(x_sym); + ident_x->MapTo(x_sym); EXPECT_THROW(Eval(extract), QueryRuntimeException); } @@ -853,10 +853,10 @@ class ExpressionEvaluatorPropertyLookup : public ExpressionEvaluatorTest { std::make_pair("age", dba->Property("age")); std::pair<std::string, storage::Property> prop_height = std::make_pair("height", dba->Property("height")); - Expression *identifier = storage.Create<Identifier>("element"); + Identifier *identifier = storage.Create<Identifier>("element"); Symbol symbol = symbol_table.CreateSymbol("element", true); - void SetUp() { symbol_table[*identifier] = symbol; } + void SetUp() { identifier->MapTo(symbol); } auto Value(std::pair<std::string, storage::Property> property) { auto *op = storage.Create<PropertyLookup>( @@ -905,7 +905,7 @@ class FunctionTest : public ExpressionEvaluatorTest { auto *ident = storage.Create<Identifier>("arg_" + std::to_string(i), true); auto sym = symbol_table.CreateSymbol("arg_" + std::to_string(i), true); - symbol_table[*ident] = sym; + ident->MapTo(sym); frame[sym] = tvs[i]; expressions.push_back(ident); } diff --git a/tests/unit/query_plan.cpp b/tests/unit/query_plan.cpp index b92de69e2..c77fe5f2b 100644 --- a/tests/unit/query_plan.cpp +++ b/tests/unit/query_plan.cpp @@ -330,9 +330,9 @@ TYPED_TEST(TestPlanner, MatchMultiPatternSameExpandStart) { // We expect the second pattern to generate only an Expand. Another // ScanAll would be redundant, as it would generate the nodes obtained from // expansion. Additionally, a uniqueness filter is expected. - CheckPlan<TypeParam>( - query, storage, ExpectScanAll(), ExpectExpand(), ExpectExpand(), - ExpectEdgeUniquenessFilter(), ExpectProduce()); + CheckPlan<TypeParam>(query, storage, ExpectScanAll(), ExpectExpand(), + ExpectExpand(), ExpectEdgeUniquenessFilter(), + ExpectProduce()); } TYPED_TEST(TestPlanner, MultiMatch) { @@ -456,12 +456,13 @@ TYPED_TEST(TestPlanner, CreateWithSum) { FakeDbAccessor dba; auto prop = dba.Property("prop"); AstStorage storage; - auto n_prop = PROPERTY_LOOKUP("n", prop); + auto ident_n = IDENT("n"); + auto n_prop = PROPERTY_LOOKUP(ident_n, prop); auto sum = SUM(n_prop); auto query = QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n"))), WITH(sum, AS("sum")))); auto symbol_table = query::MakeSymbolTable(query); - auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)}); + auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); auto aggr = ExpectAggregate({sum}, {}); auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); // We expect both the accumulation and aggregation because the part before @@ -522,13 +523,14 @@ TYPED_TEST(TestPlanner, CreateReturnSumSkipLimit) { FakeDbAccessor dba; auto prop = dba.Property("prop"); AstStorage storage; - auto n_prop = PROPERTY_LOOKUP("n", prop); + auto ident_n = IDENT("n"); + auto n_prop = PROPERTY_LOOKUP(ident_n, prop); auto sum = SUM(n_prop); auto query = QUERY( SINGLE_QUERY(CREATE(PATTERN(NODE("n"))), RETURN(sum, AS("s"), SKIP(LITERAL(2)), LIMIT(LITERAL(1))))); auto symbol_table = query::MakeSymbolTable(query); - auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)}); + auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); auto aggr = ExpectAggregate({sum}, {}); auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, @@ -558,9 +560,11 @@ TYPED_TEST(TestPlanner, CreateWithOrderByWhere) { auto r_type = "r"; AstStorage storage; auto ident_n = IDENT("n"); + auto ident_r = IDENT("r"); + auto ident_m = IDENT("m"); auto new_prop = PROPERTY_LOOKUP("new", prop); - auto r_prop = PROPERTY_LOOKUP("r", prop); - auto m_prop = PROPERTY_LOOKUP("m", prop); + auto r_prop = PROPERTY_LOOKUP(ident_r, prop); + auto m_prop = PROPERTY_LOOKUP(ident_m, prop); auto query = QUERY(SINGLE_QUERY( CREATE( PATTERN(NODE("n"), EDGE("r", Direction::OUT, {r_type}), NODE("m"))), @@ -569,9 +573,9 @@ TYPED_TEST(TestPlanner, CreateWithOrderByWhere) { auto symbol_table = query::MakeSymbolTable(query); // Since this is a write query, we expect to accumulate to old used symbols. auto acc = ExpectAccumulate({ - symbol_table.at(*ident_n), // `n` in WITH - symbol_table.at(*r_prop->expression_), // `r` in ORDER BY - symbol_table.at(*m_prop->expression_), // `m` in WHERE + symbol_table.at(*ident_n), // `n` in WITH + symbol_table.at(*ident_r), // `r` in ORDER BY + symbol_table.at(*ident_m), // `m` in WHERE }); auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), diff --git a/tests/unit/query_plan_accumulate_aggregate.cpp b/tests/unit/query_plan_accumulate_aggregate.cpp index f85131ef6..8981d8bb6 100644 --- a/tests/unit/query_plan_accumulate_aggregate.cpp +++ b/tests/unit/query_plan_accumulate_aggregate.cpp @@ -48,12 +48,10 @@ TEST(QueryPlan, Accumulate) { EdgeAtom::Direction::BOTH, {}, "m", false, GraphView::OLD); auto one = LITERAL(1); - auto n_p = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_p->expression_] = n.sym_; + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); auto set_n_p = std::make_shared<plan::SetProperty>(r_m.op_, prop, n_p, ADD(n_p, one)); - auto m_p = PROPERTY_LOOKUP("m", prop); - symbol_table[*m_p->expression_] = r_m.node_sym_; + auto m_p = PROPERTY_LOOKUP(IDENT("m")->MapTo(r_m.node_sym_), prop); auto set_m_p = std::make_shared<plan::SetProperty>(set_n_p, prop, m_p, ADD(m_p, one)); @@ -63,10 +61,10 @@ TEST(QueryPlan, Accumulate) { last_op, std::vector<Symbol>{n.sym_, r_m.node_sym_}); } - auto n_p_ne = NEXPR("n.p", n_p); - symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne", true); - auto m_p_ne = NEXPR("m.p", m_p); - symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne", true); + auto n_p_ne = + NEXPR("n.p", n_p)->MapTo(symbol_table.CreateSymbol("n_p_ne", true)); + auto m_p_ne = + NEXPR("m.p", m_p)->MapTo(symbol_table.CreateSymbol("m_p_ne", true)); auto produce = MakeProduce(last_op, n_p_ne, m_p_ne); auto context = MakeContext(storage, symbol_table, &dba); auto results = CollectProduce(*produce, &context); @@ -119,27 +117,25 @@ std::shared_ptr<Produce> MakeAggregationProduce( for (auto aggr_op : aggr_ops) { // TODO change this from using IDENT to using AGGREGATION // once AGGREGATION is handled properly in ExpressionEvaluation - auto named_expr = NEXPR("", IDENT("aggregation")); + auto aggr_sym = symbol_table.CreateSymbol("aggregation", true); + auto named_expr = + NEXPR("", IDENT("aggregation")->MapTo(aggr_sym)) + ->MapTo(symbol_table.CreateSymbol("named_expression", true)); named_expressions.push_back(named_expr); - symbol_table[*named_expr->expression_] = - symbol_table.CreateSymbol("aggregation", true); - symbol_table[*named_expr] = - symbol_table.CreateSymbol("named_expression", true); // the key expression is only used in COLLECT_MAP Expression *key_expr_ptr = aggr_op == Aggregation::Op::COLLECT_MAP ? LITERAL("key") : nullptr; aggregates.emplace_back( - Aggregate::Element{*aggr_inputs_it++, key_expr_ptr, aggr_op, - symbol_table[*named_expr->expression_]}); + Aggregate::Element{*aggr_inputs_it++, key_expr_ptr, aggr_op, aggr_sym}); } // Produce will also evaluate group_by expressions and return them after the // aggregations. for (auto group_by_expr : group_by_exprs) { - auto named_expr = NEXPR("", group_by_expr); + auto named_expr = + NEXPR("", group_by_expr) + ->MapTo(symbol_table.CreateSymbol("named_expression", true)); named_expressions.push_back(named_expr); - symbol_table[*named_expr] = - symbol_table.CreateSymbol("named_expression", true); } auto aggregation = std::make_shared<Aggregate>(input, aggregates, group_by_exprs, remember); @@ -179,8 +175,7 @@ class QueryPlanAggregateOps : public ::testing::Test { Aggregation::Op::COLLECT_MAP}) { // match all nodes and perform aggregations auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_p = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_p->expression_] = n.sym_; + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); std::vector<Expression *> aggregation_expressions(ops.size(), n_p); std::vector<Expression *> group_bys; @@ -326,8 +321,7 @@ TEST(QueryPlan, AggregateGroupByValues) { // match all nodes and perform aggregations auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_p = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_p->expression_] = n.sym_; + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, @@ -371,12 +365,9 @@ TEST(QueryPlan, AggregateMultipleGroupBy) { // match all nodes and perform aggregations auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_p1 = PROPERTY_LOOKUP("n", prop1); - auto n_p2 = PROPERTY_LOOKUP("n", prop2); - auto n_p3 = PROPERTY_LOOKUP("n", prop3); - symbol_table[*n_p1->expression_] = n.sym_; - symbol_table[*n_p2->expression_] = n.sym_; - symbol_table[*n_p3->expression_] = n.sym_; + auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop1); + auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop2); + auto n_p3 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop3); auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p1}, {Aggregation::Op::COUNT}, @@ -394,9 +385,6 @@ TEST(QueryPlan, AggregateNoInput) { SymbolTable symbol_table; auto two = LITERAL(2); - auto output = NEXPR("two", IDENT("two")); - symbol_table[*output->expression_] = symbol_table.CreateSymbol("two", true); - auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two}, {Aggregation::Op::COUNT}, {}, {}); auto context = MakeContext(storage, symbol_table, dba.get()); @@ -425,8 +413,7 @@ TEST(QueryPlan, AggregateCountEdgeCases) { SymbolTable symbol_table; auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_p = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_p->expression_] = n.sym_; + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); // returns -1 when there are no results // otherwise returns MATCH (n) RETURN count(n.prop) @@ -485,10 +472,8 @@ TEST(QueryPlan, AggregateFirstValueTypes) { SymbolTable symbol_table; auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_prop_string = PROPERTY_LOOKUP("n", prop_string); - symbol_table[*n_prop_string->expression_] = n.sym_; - auto n_prop_int = PROPERTY_LOOKUP("n", prop_int); - symbol_table[*n_prop_int->expression_] = n.sym_; + auto n_prop_string = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop_string); + auto n_prop_int = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop_int); auto n_id = n_prop_string->expression_; auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) { @@ -545,10 +530,8 @@ TEST(QueryPlan, AggregateTypes) { SymbolTable symbol_table; auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_p1 = PROPERTY_LOOKUP("n", p1); - symbol_table[*n_p1->expression_] = n.sym_; - auto n_p2 = PROPERTY_LOOKUP("n", p2); - symbol_table[*n_p2->expression_] = n.sym_; + auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p1); + auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p2); auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) { auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, @@ -599,16 +582,14 @@ TEST(QueryPlan, Unwind) { auto x = symbol_table.CreateSymbol("x", true); auto unwind_0 = std::make_shared<plan::Unwind>(nullptr, input_expr, x); - auto x_expr = IDENT("x"); - symbol_table[*x_expr] = x; + auto x_expr = IDENT("x")->MapTo(x); auto y = symbol_table.CreateSymbol("y", true); auto unwind_1 = std::make_shared<plan::Unwind>(unwind_0, x_expr, y); - auto x_ne = NEXPR("x", x_expr); - symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne", true); - auto y_ne = NEXPR("y", IDENT("y")); - symbol_table[*y_ne->expression_] = y; - symbol_table[*y_ne] = symbol_table.CreateSymbol("y_ne", true); + auto x_ne = + NEXPR("x", x_expr)->MapTo(symbol_table.CreateSymbol("x_ne", true)); + auto y_ne = NEXPR("y", IDENT("y")->MapTo(y)) + ->MapTo(symbol_table.CreateSymbol("y_ne", true)); auto produce = MakeProduce(unwind_1, x_ne, y_ne); auto context = MakeContext(storage, symbol_table, dba.get()); diff --git a/tests/unit/query_plan_bag_semantics.cpp b/tests/unit/query_plan_bag_semantics.cpp index 9d555a39c..6e630e8dc 100644 --- a/tests/unit/query_plan_bag_semantics.cpp +++ b/tests/unit/query_plan_bag_semantics.cpp @@ -151,13 +151,12 @@ TEST(QueryPlan, OrderBy) { // order by and collect results auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_p = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_p->expression_] = n.sym_; + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); auto order_by = std::make_shared<plan::OrderBy>( n.op_, std::vector<SortItem>{{order_value_pair.first, n_p}}, std::vector<Symbol>{n.sym_}); - auto n_p_ne = NEXPR("n.p", n_p); - symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n.p", true); + auto n_p_ne = + NEXPR("n.p", n_p)->MapTo(symbol_table.CreateSymbol("n.p", true)); auto produce = MakeProduce(order_by, n_p_ne); auto context = MakeContext(storage, symbol_table, &dba); auto results = CollectProduce(*produce, &context); @@ -194,10 +193,8 @@ TEST(QueryPlan, OrderByMultiple) { // order by and collect results auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_p1 = PROPERTY_LOOKUP("n", p1); - symbol_table[*n_p1->expression_] = n.sym_; - auto n_p2 = PROPERTY_LOOKUP("n", p2); - symbol_table[*n_p2->expression_] = n.sym_; + auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p1); + auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p2); // order the results so we get // (p1: 0, p2: N-1) // (p1: 0, p2: N-2) @@ -209,10 +206,10 @@ TEST(QueryPlan, OrderByMultiple) { {Ordering::DESC, n_p2}, }, std::vector<Symbol>{n.sym_}); - auto n_p1_ne = NEXPR("n.p1", n_p1); - symbol_table[*n_p1_ne] = symbol_table.CreateSymbol("n.p1", true); - auto n_p2_ne = NEXPR("n.p2", n_p2); - symbol_table[*n_p2_ne] = symbol_table.CreateSymbol("n.p2", true); + auto n_p1_ne = + NEXPR("n.p1", n_p1)->MapTo(symbol_table.CreateSymbol("n.p1", true)); + auto n_p2_ne = + NEXPR("n.p2", n_p2)->MapTo(symbol_table.CreateSymbol("n.p2", true)); auto produce = MakeProduce(order_by, n_p1_ne, n_p2_ne); auto context = MakeContext(storage, symbol_table, &dba); auto results = CollectProduce(*produce, &context); @@ -261,8 +258,7 @@ TEST(QueryPlan, OrderByExceptions) { // order by and expect an exception auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_p = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_p->expression_] = n.sym_; + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); auto order_by = std::make_shared<plan::OrderBy>( n.op_, std::vector<SortItem>{{Ordering::ASC, n_p}}, std::vector<Symbol>{}); diff --git a/tests/unit/query_plan_common.hpp b/tests/unit/query_plan_common.hpp index 67e493edf..4a096e15e 100644 --- a/tests/unit/query_plan_common.hpp +++ b/tests/unit/query_plan_common.hpp @@ -39,7 +39,7 @@ std::vector<std::vector<TypedValue>> CollectProduce(const Produce &produce, // collect the symbols from the return clause std::vector<Symbol> symbols; for (auto named_expression : produce.named_expressions_) - symbols.emplace_back(context->symbol_table[*named_expression]); + symbols.emplace_back(context->symbol_table.at(*named_expression)); // stream out results auto cursor = produce.MakeCursor(*context->db_accessor); @@ -86,7 +86,7 @@ ScanAllTuple MakeScanAll(AstStorage &storage, SymbolTable &symbol_table, GraphView graph_view = GraphView::OLD) { auto node = NODE(identifier); auto symbol = symbol_table.CreateSymbol(identifier, true); - symbol_table[*node->identifier_] = symbol; + node->identifier_->MapTo(symbol); auto logical_op = std::make_shared<ScanAll>(input, symbol, graph_view); return ScanAllTuple{node, logical_op, symbol}; } @@ -104,7 +104,7 @@ ScanAllTuple MakeScanAllByLabel( GraphView graph_view = GraphView::OLD) { auto node = NODE(identifier); auto symbol = symbol_table.CreateSymbol(identifier, true); - symbol_table[*node->identifier_] = symbol; + node->identifier_->MapTo(symbol); auto logical_op = std::make_shared<ScanAllByLabel>(input, symbol, label, graph_view); return ScanAllTuple{node, logical_op, symbol}; @@ -126,7 +126,7 @@ ScanAllTuple MakeScanAllByLabelPropertyRange( GraphView graph_view = GraphView::OLD) { auto node = NODE(identifier); auto symbol = symbol_table.CreateSymbol(identifier, true); - symbol_table[*node->identifier_] = symbol; + node->identifier_->MapTo(symbol); auto logical_op = std::make_shared<ScanAllByLabelPropertyRange>( input, symbol, label, property, property_name, lower_bound, upper_bound, graph_view); @@ -147,7 +147,7 @@ ScanAllTuple MakeScanAllByLabelPropertyValue( GraphView graph_view = GraphView::OLD) { auto node = NODE(identifier); auto symbol = symbol_table.CreateSymbol(identifier, true); - symbol_table[*node->identifier_] = symbol; + node->identifier_->MapTo(symbol); auto logical_op = std::make_shared<ScanAllByLabelPropertyValue>( input, symbol, label, property, property_name, value, graph_view); return ScanAllTuple{node, logical_op, symbol}; @@ -170,11 +170,11 @@ ExpandTuple MakeExpand(AstStorage &storage, SymbolTable &symbol_table, GraphView graph_view) { auto edge = EDGE(edge_identifier, direction); auto edge_sym = symbol_table.CreateSymbol(edge_identifier, true); - symbol_table[*edge->identifier_] = edge_sym; + edge->identifier_->MapTo(edge_sym); auto node = NODE(node_identifier); auto node_sym = symbol_table.CreateSymbol(node_identifier, true); - symbol_table[*node->identifier_] = node_sym; + node->identifier_->MapTo(node_sym); auto op = std::make_shared<Expand>(input, input_symbol, node_sym, edge_sym, direction, edge_types, existing_node, diff --git a/tests/unit/query_plan_create_set_remove_delete.cpp b/tests/unit/query_plan_create_set_remove_delete.cpp index 7498b2f4a..39cd034b2 100644 --- a/tests/unit/query_plan_create_set_remove_delete.cpp +++ b/tests/unit/query_plan_create_set_remove_delete.cpp @@ -69,15 +69,13 @@ TEST(QueryPlan, CreateReturn) { node.properties.emplace_back(property.second, LITERAL(42)); auto create = std::make_shared<CreateNode>(nullptr, node); - auto named_expr_n = NEXPR("n", IDENT("n")); - symbol_table[*named_expr_n] = symbol_table.CreateSymbol("named_expr_n", true); - symbol_table[*named_expr_n->expression_] = node.symbol; - auto prop_lookup = PROPERTY_LOOKUP("n", property); - symbol_table[*prop_lookup->expression_] = node.symbol; - auto named_expr_n_p = NEXPR("n", prop_lookup); - symbol_table[*named_expr_n_p] = - symbol_table.CreateSymbol("named_expr_n_p", true); - symbol_table[*named_expr_n->expression_] = node.symbol; + auto named_expr_n = + NEXPR("n", IDENT("n")->MapTo(node.symbol)) + ->MapTo(symbol_table.CreateSymbol("named_expr_n", true)); + auto prop_lookup = PROPERTY_LOOKUP(IDENT("n")->MapTo(node.symbol), property); + auto named_expr_n_p = + NEXPR("n", prop_lookup) + ->MapTo(symbol_table.CreateSymbol("named_expr_n_p", true)); auto produce = MakeProduce(create, named_expr_n, named_expr_n_p); auto context = MakeContext(storage, symbol_table, &dba); @@ -267,8 +265,7 @@ TEST(QueryPlan, Delete) { // attempt to delete a vertex, and fail { auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_get = storage.Create<Identifier>("n"); - symbol_table[*n_get] = n.sym_; + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); auto delete_op = std::make_shared<plan::Delete>( n.op_, std::vector<Expression *>{n_get}, false); auto context = MakeContext(storage, symbol_table, dba.get()); @@ -281,8 +278,7 @@ TEST(QueryPlan, Delete) { // detach delete a single vertex { auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_get = storage.Create<Identifier>("n"); - symbol_table[*n_get] = n.sym_; + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); auto delete_op = std::make_shared<plan::Delete>( n.op_, std::vector<Expression *>{n_get}, true); Frame frame(symbol_table.max_position()); @@ -299,8 +295,7 @@ TEST(QueryPlan, Delete) { auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, GraphView::NEW); - auto r_get = storage.Create<Identifier>("r"); - symbol_table[*r_get] = r_m.edge_sym_; + auto r_get = storage.Create<Identifier>("r")->MapTo(r_m.edge_sym_); auto delete_op = std::make_shared<plan::Delete>( r_m.op_, std::vector<Expression *>{r_get}, false); auto context = MakeContext(storage, symbol_table, dba.get()); @@ -313,8 +308,7 @@ TEST(QueryPlan, Delete) { // delete all remaining vertices { auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_get = storage.Create<Identifier>("n"); - symbol_table[*n_get] = n.sym_; + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); auto delete_op = std::make_shared<plan::Delete>( n.op_, std::vector<Expression *>{n_get}, false); auto context = MakeContext(storage, symbol_table, dba.get()); @@ -357,12 +351,9 @@ TEST(QueryPlan, DeleteTwiceDeleteBlockingEdge) { EdgeAtom::Direction::BOTH, {}, "m", false, GraphView::OLD); // getter expressions for deletion - auto n_get = storage.Create<Identifier>("n"); - symbol_table[*n_get] = n.sym_; - auto r_get = storage.Create<Identifier>("r"); - symbol_table[*r_get] = r_m.edge_sym_; - auto m_get = storage.Create<Identifier>("m"); - symbol_table[*m_get] = r_m.node_sym_; + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); + auto r_get = storage.Create<Identifier>("r")->MapTo(r_m.edge_sym_); + auto m_get = storage.Create<Identifier>("m")->MapTo(r_m.node_sym_); auto delete_op = std::make_shared<plan::Delete>( r_m.op_, std::vector<Expression *>{n_get, r_get, m_get}, detach); @@ -398,15 +389,13 @@ TEST(QueryPlan, DeleteReturn) { auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_get = storage.Create<Identifier>("n"); - symbol_table[*n_get] = n.sym_; + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); auto delete_op = std::make_shared<plan::Delete>( n.op_, std::vector<Expression *>{n_get}, true); - auto prop_lookup = PROPERTY_LOOKUP("n", prop); - symbol_table[*prop_lookup->expression_] = n.sym_; - auto n_p = storage.Create<NamedExpression>("n", prop_lookup); - symbol_table[*n_p] = symbol_table.CreateSymbol("bla", true); + auto prop_lookup = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto n_p = storage.Create<NamedExpression>("n", prop_lookup) + ->MapTo(symbol_table.CreateSymbol("bla", true)); auto produce = MakeProduce(delete_op, n_p); auto context = MakeContext(storage, symbol_table, &dba); @@ -449,8 +438,7 @@ TEST(QueryPlan, DeleteAdvance) { SymbolTable symbol_table; auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_get = storage.Create<Identifier>("n"); - symbol_table[*n_get] = n.sym_; + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); auto delete_op = std::make_shared<plan::Delete>( n.op_, std::vector<Expression *>{n_get}, false); auto advance = std::make_shared<Accumulate>( @@ -489,13 +477,11 @@ TEST(QueryPlan, SetProperty) { auto prop1 = dba.Property("prop1"); auto literal = LITERAL(42); - auto n_p = PROPERTY_LOOKUP("n", prop1); - symbol_table[*n_p->expression_] = n.sym_; + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop1); auto set_n_p = std::make_shared<plan::SetProperty>(r_m.op_, prop1, n_p, literal); - auto r_p = PROPERTY_LOOKUP("r", prop1); - symbol_table[*r_p->expression_] = r_m.edge_sym_; + auto r_p = PROPERTY_LOOKUP(IDENT("r")->MapTo(r_m.edge_sym_), prop1); auto set_r_p = std::make_shared<plan::SetProperty>(set_n_p, prop1, r_p, literal); auto context = MakeContext(storage, symbol_table, &dba); @@ -544,10 +530,8 @@ TEST(QueryPlan, SetProperties) { : plan::SetProperties::Op::REPLACE; // set properties on r to n, and on r to m - auto r_ident = IDENT("r"); - symbol_table[*r_ident] = r_m.edge_sym_; - auto m_ident = IDENT("m"); - symbol_table[*m_ident] = r_m.node_sym_; + auto r_ident = IDENT("r")->MapTo(r_m.edge_sym_); + auto m_ident = IDENT("m")->MapTo(r_m.node_sym_); auto set_r_to_n = std::make_shared<plan::SetProperties>(r_m.op_, n.sym_, r_ident, op); auto set_m_to_r = std::make_shared<plan::SetProperties>( @@ -647,12 +631,10 @@ TEST(QueryPlan, RemoveProperty) { MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, GraphView::OLD); - auto n_p = PROPERTY_LOOKUP("n", prop1); - symbol_table[*n_p->expression_] = n.sym_; + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop1); auto set_n_p = std::make_shared<plan::RemoveProperty>(r_m.op_, prop1, n_p); - auto r_p = PROPERTY_LOOKUP("r", prop1); - symbol_table[*r_p->expression_] = r_m.edge_sym_; + auto r_p = PROPERTY_LOOKUP(IDENT("r")->MapTo(r_m.edge_sym_), prop1); auto set_r_p = std::make_shared<plan::RemoveProperty>(set_n_p, prop1, r_p); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(2, PullAll(*set_r_p, &context)); @@ -733,8 +715,7 @@ TEST(QueryPlan, NodeFilterSet) { LITERAL(42)); auto node_filter = std::make_shared<Filter>(expand.op_, filter_expr); // SET n.prop = n.prop + 1 - auto set_prop = PROPERTY_LOOKUP("n", prop); - symbol_table[*set_prop->expression_] = scan_all.sym_; + auto set_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), prop); auto add = ADD(set_prop, LITERAL(1)); auto set = std::make_shared<plan::SetProperty>(node_filter, prop.second, set_prop, add); @@ -771,13 +752,11 @@ TEST(QueryPlan, FilterRemove) { auto expand = MakeExpand(storage, symbol_table, scan_all.op_, scan_all.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", false, GraphView::OLD); - auto filter_prop = PROPERTY_LOOKUP("n", prop); - symbol_table[*filter_prop->expression_] = scan_all.sym_; + auto filter_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), prop); auto filter = std::make_shared<Filter>(expand.op_, LESS(filter_prop, LITERAL(43))); // REMOVE n.prop - auto rem_prop = PROPERTY_LOOKUP("n", prop); - symbol_table[*rem_prop->expression_] = scan_all.sym_; + auto rem_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), prop); auto rem = std::make_shared<plan::RemoveProperty>(filter, prop.second, rem_prop); auto context = MakeContext(storage, symbol_table, &dba); @@ -838,14 +817,12 @@ TEST(QueryPlan, Merge) { auto r_m = MakeExpand(storage, symbol_table, std::make_shared<Once>(), n.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", false, GraphView::OLD); - auto m_p = PROPERTY_LOOKUP("m", prop); - symbol_table[*m_p->expression_] = r_m.node_sym_; + auto m_p = PROPERTY_LOOKUP(IDENT("m")->MapTo(r_m.node_sym_), prop); auto m_set = std::make_shared<plan::SetProperty>(r_m.op_, prop.second, m_p, LITERAL(1)); // merge_create branch - auto n_p = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_p->expression_] = n.sym_; + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); auto n_set = std::make_shared<plan::SetProperty>( std::make_shared<Once>(), prop.second, n_p, LITERAL(2)); @@ -910,8 +887,7 @@ TEST(QueryPlan, SetPropertiesOnNull) { AstStorage storage; SymbolTable symbol_table; auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_ident = IDENT("n"); - symbol_table[*n_ident] = n.sym_; + auto n_ident = IDENT("n")->MapTo(n.sym_); auto optional = std::make_shared<plan::Optional>(nullptr, n.op_, std::vector<Symbol>{n.sym_}); auto set_op = std::make_shared<plan::SetProperties>( @@ -929,8 +905,6 @@ TEST(QueryPlan, SetLabelsOnNull) { AstStorage storage; SymbolTable symbol_table; auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_ident = IDENT("n"); - symbol_table[*n_ident] = n.sym_; auto optional = std::make_shared<plan::Optional>(nullptr, n.op_, std::vector<Symbol>{n.sym_}); auto set_op = std::make_shared<plan::SetLabels>( @@ -965,8 +939,6 @@ TEST(QueryPlan, RemoveLabelsOnNull) { AstStorage storage; SymbolTable symbol_table; auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_ident = IDENT("n"); - symbol_table[*n_ident] = n.sym_; auto optional = std::make_shared<plan::Optional>(nullptr, n.op_, std::vector<Symbol>{n.sym_}); auto remove_op = std::make_shared<plan::RemoveLabels>( @@ -988,13 +960,11 @@ TEST(QueryPlan, DeleteSetProperty) { SymbolTable symbol_table; // MATCH (n) DELETE n SET n.property = 42 auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_get = storage.Create<Identifier>("n"); - symbol_table[*n_get] = n.sym_; + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); auto delete_op = std::make_shared<plan::Delete>( n.op_, std::vector<Expression *>{n_get}, false); auto prop = PROPERTY_PAIR("property"); - auto n_prop = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_prop->expression_] = n.sym_; + auto n_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); auto set_op = std::make_shared<plan::SetProperty>(delete_op, prop.second, n_prop, LITERAL(42)); auto context = MakeContext(storage, symbol_table, &dba); @@ -1013,17 +983,13 @@ TEST(QueryPlan, DeleteSetPropertiesFromMap) { SymbolTable symbol_table; // MATCH (n) DELETE n SET n = {property: 42} auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_get = storage.Create<Identifier>("n"); - symbol_table[*n_get] = n.sym_; + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); auto delete_op = std::make_shared<plan::Delete>( n.op_, std::vector<Expression *>{n_get}, false); auto prop = PROPERTY_PAIR("property"); - auto n_prop = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_prop->expression_] = n.sym_; std::unordered_map<PropertyIx, Expression *> prop_map; prop_map.emplace(storage.GetPropertyIx(prop.first), LITERAL(42)); auto *rhs = storage.Create<MapLiteral>(prop_map); - symbol_table[*rhs] = n.sym_; for (auto op_type : {plan::SetProperties::Op::REPLACE, plan::SetProperties::Op::UPDATE}) { auto set_op = @@ -1048,15 +1014,10 @@ TEST(QueryPlan, DeleteSetPropertiesFromVertex) { SymbolTable symbol_table; // MATCH (n) DELETE n SET n = n auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_get = storage.Create<Identifier>("n"); - symbol_table[*n_get] = n.sym_; + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); auto delete_op = std::make_shared<plan::Delete>( n.op_, std::vector<Expression *>{n_get}, false); - auto prop = PROPERTY_PAIR("property"); - auto n_prop = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_prop->expression_] = n.sym_; - auto *rhs = IDENT("n"); - symbol_table[*rhs] = n.sym_; + auto *rhs = IDENT("n")->MapTo(n.sym_); for (auto op_type : {plan::SetProperties::Op::REPLACE, plan::SetProperties::Op::UPDATE}) { auto set_op = @@ -1077,8 +1038,7 @@ TEST(QueryPlan, DeleteRemoveLabels) { SymbolTable symbol_table; // MATCH (n) DELETE n REMOVE n :label auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_get = storage.Create<Identifier>("n"); - symbol_table[*n_get] = n.sym_; + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); auto delete_op = std::make_shared<plan::Delete>( n.op_, std::vector<Expression *>{n_get}, false); std::vector<storage::Label> labels{dba->Label("label")}; @@ -1099,13 +1059,11 @@ TEST(QueryPlan, DeleteRemoveProperty) { SymbolTable symbol_table; // MATCH (n) DELETE n REMOVE n.property auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_get = storage.Create<Identifier>("n"); - symbol_table[*n_get] = n.sym_; + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); auto delete_op = std::make_shared<plan::Delete>( n.op_, std::vector<Expression *>{n_get}, false); auto prop = PROPERTY_PAIR("property"); - auto n_prop = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_prop->expression_] = n.sym_; + auto n_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); auto rem_op = std::make_shared<plan::RemoveProperty>(delete_op, prop.second, n_prop); auto context = MakeContext(storage, symbol_table, &dba); diff --git a/tests/unit/query_plan_match_filter_return.cpp b/tests/unit/query_plan_match_filter_return.cpp index 67e0e92db..c8d296dc5 100644 --- a/tests/unit/query_plan_match_filter_return.cpp +++ b/tests/unit/query_plan_match_filter_return.cpp @@ -52,11 +52,10 @@ TEST_F(MatchReturnFixture, MatchReturn) { auto test_pull_count = [&](GraphView graph_view) { auto scan_all = MakeScanAll(storage, symbol_table, "n", nullptr, graph_view); - auto output = NEXPR("n", IDENT("n")); + auto output = + NEXPR("n", IDENT("n")->MapTo(scan_all.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(scan_all.op_, output); - symbol_table[*output->expression_] = scan_all.sym_; - symbol_table[*output] = - symbol_table.CreateSymbol("named_expression_1", true); auto context = MakeContext(storage, symbol_table, dba_.get()); return PullAll(*produce, &context); }; @@ -78,9 +77,9 @@ TEST_F(MatchReturnFixture, MatchReturnPath) { Symbol path_sym = symbol_table.CreateSymbol("path", true); auto make_path = std::make_shared<ConstructNamedPath>( scan_all.op_, path_sym, std::vector<Symbol>{scan_all.sym_}); - auto output = NEXPR("path", IDENT("path")); - symbol_table[*output->expression_] = path_sym; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); + auto output = + NEXPR("path", IDENT("path")->MapTo(path_sym)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(make_path, output); auto results = Results<query::Path>(produce); ASSERT_EQ(results.size(), 2); @@ -104,14 +103,12 @@ TEST(QueryPlan, MatchReturnCartesian) { auto n = MakeScanAll(storage, symbol_table, "n"); auto m = MakeScanAll(storage, symbol_table, "m", n.op_); - auto return_n = NEXPR("n", IDENT("n")); - symbol_table[*return_n->expression_] = n.sym_; - symbol_table[*return_n] = - symbol_table.CreateSymbol("named_expression_1", true); - auto return_m = NEXPR("m", IDENT("m")); - symbol_table[*return_m->expression_] = m.sym_; - symbol_table[*return_m] = - symbol_table.CreateSymbol("named_expression_2", true); + auto return_n = + NEXPR("n", IDENT("n")->MapTo(n.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto return_m = + NEXPR("m", IDENT("m")->MapTo(m.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_2", true)); auto produce = MakeProduce(m.op_, return_n, return_m); auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); @@ -138,7 +135,7 @@ TEST(QueryPlan, StandaloneReturn) { auto output = NEXPR("n", LITERAL(42)); auto produce = MakeProduce(std::shared_ptr<LogicalOperator>(nullptr), output); - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); + output->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); @@ -189,9 +186,9 @@ TEST(QueryPlan, NodeFilterLabelsAndProperties) { auto node_filter = std::make_shared<Filter>(n.op_, filter_expr); // make a named expression and a produce - auto output = NEXPR("x", IDENT("n")); - symbol_table[*output->expression_] = n.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); + auto output = + NEXPR("x", IDENT("n")->MapTo(n.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(node_filter, output); auto context = MakeContext(storage, symbol_table, &dba); @@ -244,13 +241,11 @@ TEST(QueryPlan, NodeFilterMultipleLabels) { auto node_filter = std::make_shared<Filter>(n.op_, filter_expr); // make a named expression and a produce - auto output = NEXPR("n", IDENT("n")); + auto output = + NEXPR("n", IDENT("n")->MapTo(n.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(node_filter, output); - // fill up the symbol table - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); - symbol_table[*output->expression_] = n.sym_; - auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); EXPECT_EQ(results.size(), 2); @@ -275,14 +270,12 @@ TEST(QueryPlan, Cartesian) { auto n = MakeScanAll(storage, symbol_table, "n"); auto m = MakeScanAll(storage, symbol_table, "m"); - auto return_n = NEXPR("n", IDENT("n")); - symbol_table[*return_n->expression_] = n.sym_; - symbol_table[*return_n] = - symbol_table.CreateSymbol("named_expression_1", true); - auto return_m = NEXPR("m", IDENT("m")); - symbol_table[*return_m->expression_] = m.sym_; - symbol_table[*return_m] = - symbol_table.CreateSymbol("named_expression_2", true); + auto return_n = + NEXPR("n", IDENT("n")->MapTo(n.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto return_m = + NEXPR("m", IDENT("m")->MapTo(m.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_2", true)); std::vector<Symbol> left_symbols{n.sym_}; std::vector<Symbol> right_symbols{m.sym_}; @@ -311,14 +304,12 @@ TEST(QueryPlan, CartesianEmptySet) { auto n = MakeScanAll(storage, symbol_table, "n"); auto m = MakeScanAll(storage, symbol_table, "m"); - auto return_n = NEXPR("n", IDENT("n")); - symbol_table[*return_n->expression_] = n.sym_; - symbol_table[*return_n] = - symbol_table.CreateSymbol("named_expression_1", true); - auto return_m = NEXPR("m", IDENT("m")); - symbol_table[*return_m->expression_] = m.sym_; - symbol_table[*return_m] = - symbol_table.CreateSymbol("named_expression_2", true); + auto return_n = + NEXPR("n", IDENT("n")->MapTo(n.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto return_m = + NEXPR("m", IDENT("m")->MapTo(m.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_2", true)); std::vector<Symbol> left_symbols{n.sym_}; std::vector<Symbol> right_symbols{m.sym_}; @@ -350,18 +341,15 @@ TEST(QueryPlan, CartesianThreeWay) { auto n = MakeScanAll(storage, symbol_table, "n"); auto m = MakeScanAll(storage, symbol_table, "m"); auto l = MakeScanAll(storage, symbol_table, "l"); - auto return_n = NEXPR("n", IDENT("n")); - symbol_table[*return_n->expression_] = n.sym_; - symbol_table[*return_n] = - symbol_table.CreateSymbol("named_expression_1", true); - auto return_m = NEXPR("m", IDENT("m")); - symbol_table[*return_m->expression_] = m.sym_; - symbol_table[*return_m] = - symbol_table.CreateSymbol("named_expression_2", true); - auto return_l = NEXPR("l", IDENT("l")); - symbol_table[*return_l->expression_] = l.sym_; - symbol_table[*return_l] = - symbol_table.CreateSymbol("named_expression_3", true); + auto return_n = + NEXPR("n", IDENT("n")->MapTo(n.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto return_m = + NEXPR("m", IDENT("m")->MapTo(m.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_2", true)); + auto return_l = + NEXPR("l", IDENT("l")->MapTo(l.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_3", true)); std::vector<Symbol> n_symbols{n.sym_}; std::vector<Symbol> m_symbols{m.sym_}; @@ -420,10 +408,9 @@ TEST_F(ExpandFixture, Expand) { {}, "m", false, graph_view); // make a named expression and a produce - auto output = NEXPR("m", IDENT("m")); - symbol_table[*output->expression_] = r_m.node_sym_; - symbol_table[*output] = - symbol_table.CreateSymbol("named_expression_1", true); + auto output = + NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(r_m.op_, output); auto context = MakeContext(storage, symbol_table, dba_.get()); return PullAll(*produce, &context); @@ -456,9 +443,9 @@ TEST_F(ExpandFixture, ExpandPath) { auto path = std::make_shared<ConstructNamedPath>( r_m.op_, path_sym, std::vector<Symbol>{n.sym_, r_m.edge_sym_, r_m.node_sym_}); - auto output = NEXPR("m", IDENT("m")); - symbol_table[*output->expression_] = path_sym; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); + auto output = + NEXPR("path", IDENT("path")->MapTo(path_sym)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(path, output); std::vector<query::Path> expected_paths{{v1, r2, v3}, {v1, r1, v2}}; @@ -560,7 +547,7 @@ class QueryPlanExpandVariable : public testing::Test { auto n_to = NODE(node_to); auto n_to_sym = symbol_table.CreateSymbol(node_to, true); - symbol_table[*n_to->identifier_] = n_to_sym; + n_to->identifier_->MapTo(n_to_sym); if (std::is_same<TExpansionOperator, ExpandVariable>::value) { // convert optional ints to optional expressions @@ -588,7 +575,7 @@ class QueryPlanExpandVariable : public testing::Test { auto Edge(const std::string &identifier, EdgeAtom::Direction direction) { auto edge = EDGE(identifier, direction); auto edge_sym = symbol_table.CreateSymbol(identifier, true); - symbol_table[*edge->identifier_] = edge_sym; + edge->identifier_->MapTo(edge_sym); return edge_sym; } @@ -769,8 +756,8 @@ TEST_F(QueryPlanExpandVariable, NamedPath) { AddMatch<ExpandVariable>(nullptr, "n", 0, EdgeAtom::Direction::OUT, {}, 2, 2, e, "m", GraphView::OLD); auto find_symbol = [this](const std::string &name) { - for (const auto &pos_sym : symbol_table.table()) - if (pos_sym.second.name() == name) return pos_sym.second; + for (const auto &sym : symbol_table.table()) + if (sym.name() == name) return sym; throw std::runtime_error("Symbol not found"); }; @@ -884,7 +871,7 @@ class QueryPlanExpandWeightedShortestPath : public testing::Test { } auto ident_e = IDENT("e"); - symbol_table[*ident_e] = weight_edge; + ident_e->MapTo(weight_edge); // expand wshortest auto node_sym = existing_node_input @@ -929,7 +916,7 @@ class QueryPlanExpandWeightedShortestPath : public testing::Test { Expression *PropNe(Symbol symbol, int value) { auto ident = IDENT("inner_element"); - symbol_table[*ident] = symbol; + ident->MapTo(symbol); return NEQ(PROPERTY_LOOKUP(ident, prop), LITERAL(value)); } }; @@ -1173,15 +1160,12 @@ TEST(QueryPlan, ExpandOptional) { n.op_, r_m.op_, std::vector<Symbol>{r_m.edge_sym_, r_m.node_sym_}); // RETURN n, r, m - auto n_ne = NEXPR("n", IDENT("n")); - symbol_table[*n_ne->expression_] = n.sym_; - symbol_table[*n_ne] = symbol_table.CreateSymbol("n", true); - auto r_ne = NEXPR("r", IDENT("r")); - symbol_table[*r_ne->expression_] = r_m.edge_sym_; - symbol_table[*r_ne] = symbol_table.CreateSymbol("r", true); - auto m_ne = NEXPR("m", IDENT("m")); - symbol_table[*m_ne->expression_] = r_m.node_sym_; - symbol_table[*m_ne] = symbol_table.CreateSymbol("m", true); + auto n_ne = NEXPR("n", IDENT("n")->MapTo(n.sym_)) + ->MapTo(symbol_table.CreateSymbol("n", true)); + auto r_ne = NEXPR("r", IDENT("r")->MapTo(r_m.edge_sym_)) + ->MapTo(symbol_table.CreateSymbol("r", true)); + auto m_ne = NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_)) + ->MapTo(symbol_table.CreateSymbol("m", true)); auto produce = MakeProduce(optional, n_ne, r_ne, m_ne); auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); @@ -1214,9 +1198,8 @@ TEST(QueryPlan, OptionalMatchEmptyDB) { // OPTIONAL MATCH (n) auto n = MakeScanAll(storage, symbol_table, "n"); // RETURN n - auto n_ne = NEXPR("n", IDENT("n")); - symbol_table[*n_ne->expression_] = n.sym_; - symbol_table[*n_ne] = symbol_table.CreateSymbol("n", true); + auto n_ne = NEXPR("n", IDENT("n")->MapTo(n.sym_)) + ->MapTo(symbol_table.CreateSymbol("n", true)); auto optional = std::make_shared<plan::Optional>(nullptr, n.op_, std::vector<Symbol>{n.sym_}); auto produce = MakeProduce(optional, n_ne); @@ -1236,19 +1219,17 @@ TEST(QueryPlan, OptionalMatchEmptyDBExpandFromNode) { auto optional = std::make_shared<plan::Optional>(nullptr, n.op_, std::vector<Symbol>{n.sym_}); // WITH n - auto n_ne = NEXPR("n", IDENT("n")); - symbol_table[*n_ne->expression_] = n.sym_; + auto n_ne = NEXPR("n", IDENT("n")->MapTo(n.sym_)); auto with_n_sym = symbol_table.CreateSymbol("n", true); - symbol_table[*n_ne] = with_n_sym; + n_ne->MapTo(with_n_sym); auto with = MakeProduce(optional, n_ne); // MATCH (n) -[r]-> (m) auto r_m = MakeExpand(storage, symbol_table, with, with_n_sym, "r", EdgeAtom::Direction::OUT, {}, "m", false, GraphView::OLD); // RETURN m - auto m_ne = NEXPR("m", IDENT("m")); - symbol_table[*m_ne->expression_] = r_m.node_sym_; - symbol_table[*m_ne] = symbol_table.CreateSymbol("m", true); + auto m_ne = NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_)) + ->MapTo(symbol_table.CreateSymbol("m", true)); auto produce = MakeProduce(r_m.op_, m_ne); auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); @@ -1279,26 +1260,24 @@ TEST(QueryPlan, OptionalMatchThenExpandToMissingNode) { auto optional = std::make_shared<plan::Optional>(nullptr, node_filter, std::vector<Symbol>{n.sym_}); // WITH n - auto n_ne = NEXPR("n", IDENT("n")); - symbol_table[*n_ne->expression_] = n.sym_; + auto n_ne = NEXPR("n", IDENT("n")->MapTo(n.sym_)); auto with_n_sym = symbol_table.CreateSymbol("n", true); - symbol_table[*n_ne] = with_n_sym; + n_ne->MapTo(with_n_sym); auto with = MakeProduce(optional, n_ne); // MATCH (m) -[r]-> (n) auto m = MakeScanAll(storage, symbol_table, "m", with); auto edge_direction = EdgeAtom::Direction::OUT; auto edge = EDGE("r", edge_direction); auto edge_sym = symbol_table.CreateSymbol("r", true); - symbol_table[*edge->identifier_] = edge_sym; + edge->identifier_->MapTo(edge_sym); auto node = NODE("n"); - symbol_table[*node->identifier_] = with_n_sym; + node->identifier_->MapTo(with_n_sym); auto expand = std::make_shared<plan::Expand>( m.op_, m.sym_, with_n_sym, edge_sym, edge_direction, std::vector<storage::EdgeType>{}, true, GraphView::OLD); // RETURN m - auto m_ne = NEXPR("m", IDENT("m")); - symbol_table[*m_ne->expression_] = m.sym_; - symbol_table[*m_ne] = symbol_table.CreateSymbol("m", true); + auto m_ne = NEXPR("m", IDENT("m")->MapTo(m.sym_)) + ->MapTo(symbol_table.CreateSymbol("m", true)); auto produce = MakeProduce(expand, m_ne); auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); @@ -1332,10 +1311,9 @@ TEST(QueryPlan, ExpandExistingNode) { std::vector<storage::EdgeType>{}, with_existing, GraphView::OLD); // make a named expression and a produce - auto output = NEXPR("n", IDENT("n")); - symbol_table[*output->expression_] = n.sym_; - symbol_table[*output] = - symbol_table.CreateSymbol("named_expression_1", true); + auto output = + NEXPR("n", IDENT("n")->MapTo(n.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(r_n.op_, output); auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); @@ -1421,10 +1399,9 @@ TEST(QueryPlan, EdgeFilter) { auto edge_filter = std::make_shared<Filter>(r_m.op_, filter_expr); // make a named expression and a produce - auto output = NEXPR("m", IDENT("m")); - symbol_table[*output->expression_] = r_m.node_sym_; - symbol_table[*output] = - symbol_table.CreateSymbol("named_expression_1", true); + auto output = + NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(edge_filter, output); auto context = MakeContext(storage, symbol_table, &dba); return PullAll(*produce, &context); @@ -1462,12 +1439,10 @@ TEST(QueryPlan, EdgeFilterMultipleTypes) { GraphView::OLD); // make a named expression and a produce - auto output = NEXPR("m", IDENT("m")); + auto output = + NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(r_m.op_, output); - - // fill up the symbol table - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); - symbol_table[*output->expression_] = r_m.node_sym_; auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); EXPECT_EQ(results.size(), 2); @@ -1489,14 +1464,12 @@ TEST(QueryPlan, Filter) { SymbolTable symbol_table; auto n = MakeScanAll(storage, symbol_table, "n"); - auto e = PROPERTY_LOOKUP("n", property); - symbol_table[*e->expression_] = n.sym_; + auto e = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), property); auto f = std::make_shared<Filter>(n.op_, e); auto output = - storage.Create<NamedExpression>("x", storage.Create<Identifier>("n")); - symbol_table[*output->expression_] = n.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1", true); + NEXPR("x", IDENT("n")->MapTo(n.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(f, output); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(CollectProduce(*produce, &context).size(), 2); @@ -1555,13 +1528,13 @@ TEST(QueryPlan, Distinct) { auto x = symbol_table.CreateSymbol("x", true); auto unwind = std::make_shared<plan::Unwind>(nullptr, input_expr, x); auto x_expr = IDENT("x"); - symbol_table[*x_expr] = x; + x_expr->MapTo(x); auto distinct = std::make_shared<plan::Distinct>(unwind, std::vector<Symbol>{x}); auto x_ne = NEXPR("x", x_expr); - symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne", true); + x_ne->MapTo(symbol_table.CreateSymbol("x_ne", true)); auto produce = MakeProduce(distinct, x_ne); auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); @@ -1599,10 +1572,9 @@ TEST(QueryPlan, ScanAllByLabel) { auto scan_all_by_label = MakeScanAllByLabel(storage, symbol_table, "n", label); // RETURN n - auto output = NEXPR("n", IDENT("n")); + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all_by_label.sym_)) + ->MapTo(symbol_table.CreateSymbol("n", true)); auto produce = MakeProduce(scan_all_by_label.op_, output); - symbol_table[*output->expression_] = scan_all_by_label.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("n", true); auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); ASSERT_EQ(results.size(), 1); @@ -1645,10 +1617,9 @@ TEST(QueryPlan, ScanAllByLabelProperty) { storage, symbol_table, "n", label, prop, "prop", Bound{LITERAL(lower), lower_type}, Bound{LITERAL(upper), upper_type}); // RETURN n - auto output = NEXPR("n", IDENT("n")); + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_)) + ->MapTo(symbol_table.CreateSymbol("n", true)); auto produce = MakeProduce(scan_all.op_, output); - symbol_table[*output->expression_] = scan_all.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("n", true); auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); ASSERT_EQ(results.size(), expected.size()); @@ -1707,10 +1678,9 @@ TEST(QueryPlan, ScanAllByLabelPropertyEqualityNoError) { auto scan_all = MakeScanAllByLabelPropertyValue( storage, symbol_table, "n", label, prop, "prop", LITERAL(42)); // RETURN n - auto output = NEXPR("n", IDENT("n")); + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_)) + ->MapTo(symbol_table.CreateSymbol("n", true)); auto produce = MakeProduce(scan_all.op_, output); - symbol_table[*output->expression_] = scan_all.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("n", true); auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); ASSERT_EQ(results.size(), 1); @@ -1743,7 +1713,7 @@ TEST(QueryPlan, ScanAllByLabelPropertyValueError) { SymbolTable symbol_table; auto scan_all = MakeScanAll(storage, symbol_table, "m"); auto *ident_m = IDENT("m"); - symbol_table[*ident_m] = scan_all.sym_; + ident_m->MapTo(scan_all.sym_); auto scan_index = MakeScanAllByLabelPropertyValue( storage, symbol_table, "n", label, prop, "prop", ident_m, scan_all.op_); auto context = MakeContext(storage, symbol_table, dba.get()); @@ -1771,7 +1741,7 @@ TEST(QueryPlan, ScanAllByLabelPropertyRangeError) { SymbolTable symbol_table; auto scan_all = MakeScanAll(storage, symbol_table, "m"); auto *ident_m = IDENT("m"); - symbol_table[*ident_m] = scan_all.sym_; + ident_m->MapTo(scan_all.sym_); { // Lower bound isn't property value auto scan_index = MakeScanAllByLabelPropertyRange( @@ -1827,10 +1797,9 @@ TEST(QueryPlan, ScanAllByLabelPropertyEqualNull) { MakeScanAllByLabelPropertyValue(storage, symbol_table, "n", label, prop, "prop", LITERAL(TypedValue::Null)); // RETURN n - auto output = NEXPR("n", IDENT("n")); + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_)) + ->MapTo(symbol_table.CreateSymbol("n", true)); auto produce = MakeProduce(scan_all.op_, output); - symbol_table[*output->expression_] = scan_all.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("n", true); auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); EXPECT_EQ(results.size(), 0); @@ -1863,10 +1832,9 @@ TEST(QueryPlan, ScanAllByLabelPropertyRangeNull) { Bound{LITERAL(TypedValue::Null), Bound::Type::INCLUSIVE}, Bound{LITERAL(TypedValue::Null), Bound::Type::EXCLUSIVE}); // RETURN n - auto output = NEXPR("n", IDENT("n")); + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_)) + ->MapTo(symbol_table.CreateSymbol("n", true)); auto produce = MakeProduce(scan_all.op_, output); - symbol_table[*output->expression_] = scan_all.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("n", true); auto context = MakeContext(storage, symbol_table, dba.get()); auto results = CollectProduce(*produce, &context); EXPECT_EQ(results.size(), 0); @@ -1895,7 +1863,7 @@ TEST(QueryPlan, ScanAllByLabelPropertyNoValueInIndexContinuation) { auto x = symbol_table.CreateSymbol("x", true); auto unwind = std::make_shared<plan::Unwind>(nullptr, input_expr, x); auto x_expr = IDENT("x"); - symbol_table[*x_expr] = x; + x_expr->MapTo(x); // MATCH (n :label {prop: x}) auto scan_all = MakeScanAllByLabelPropertyValue( @@ -1937,11 +1905,10 @@ TEST(QueryPlan, ScanAllEqualsScanAllByLabelProperty) { auto dba = db.Access(); auto scan_all_by_label_property_value = MakeScanAllByLabelPropertyValue( storage, symbol_table, "n", label, prop, "prop", LITERAL(prop_value)); - auto output = NEXPR("n", IDENT("n")); + auto output = + NEXPR("n", IDENT("n")->MapTo(scan_all_by_label_property_value.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(scan_all_by_label_property_value.op_, output); - symbol_table[*output->expression_] = scan_all_by_label_property_value.sym_; - symbol_table[*output] = - symbol_table.CreateSymbol("named_expression_1", true); auto context = MakeContext(storage, symbol_table, dba.get()); EXPECT_EQ(PullAll(*produce, &context), prop_count); }; @@ -1953,15 +1920,14 @@ TEST(QueryPlan, ScanAllEqualsScanAllByLabelProperty) { auto dba_ptr = db.Access(); auto &dba = *dba_ptr; auto scan_all = MakeScanAll(storage, symbol_table, "n"); - auto e = PROPERTY_LOOKUP("n", std::make_pair("prop", prop)); - symbol_table[*e->expression_] = scan_all.sym_; + auto e = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), + std::make_pair("prop", prop)); auto filter = std::make_shared<Filter>(scan_all.op_, EQ(e, LITERAL(prop_value))); - auto output = NEXPR("n", IDENT("n")); + auto output = + NEXPR("n", IDENT("n")->MapTo(scan_all.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); auto produce = MakeProduce(filter, output); - symbol_table[*output->expression_] = scan_all.sym_; - symbol_table[*output] = - symbol_table.CreateSymbol("named_expression_1", true); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(PullAll(*produce, &context), prop_count); }; diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 433155d10..79dd0fbf8 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -30,19 +30,20 @@ TEST_F(TestSymbolGenerator, MatchNodeReturn) { EXPECT_EQ(symbol_table.max_position(), 3); auto match = dynamic_cast<Match *>(query_ast->single_query_->clauses_[0]); auto pattern = match->patterns_[0]; - auto pattern_sym = symbol_table[*pattern->identifier_]; + auto pattern_sym = symbol_table.at(*pattern->identifier_); EXPECT_EQ(pattern_sym.type(), Symbol::Type::PATH); EXPECT_FALSE(pattern_sym.user_declared()); auto node_atom = dynamic_cast<NodeAtom *>(pattern->atoms_[0]); - auto node_sym = symbol_table[*node_atom->identifier_]; + auto node_sym = symbol_table.at(*node_atom->identifier_); EXPECT_EQ(node_sym.name(), "node_atom_1"); EXPECT_EQ(node_sym.type(), Symbol::Type::VERTEX); auto ret = dynamic_cast<Return *>(query_ast->single_query_->clauses_[1]); auto named_expr = ret->body_.named_expressions[0]; - auto column_sym = symbol_table[*named_expr]; + auto column_sym = symbol_table.at(*named_expr); EXPECT_EQ(node_sym.name(), column_sym.name()); EXPECT_NE(node_sym, column_sym); - auto ret_sym = symbol_table[*named_expr->expression_]; + auto ret_sym = + symbol_table.at(*dynamic_cast<Identifier *>(named_expr->expression_)); EXPECT_EQ(node_sym, ret_sym); } @@ -55,7 +56,7 @@ TEST_F(TestSymbolGenerator, MatchNamedPattern) { EXPECT_EQ(symbol_table.max_position(), 3); auto match = dynamic_cast<Match *>(query_ast->single_query_->clauses_[0]); auto pattern = match->patterns_[0]; - auto pattern_sym = symbol_table[*pattern->identifier_]; + auto pattern_sym = symbol_table.at(*pattern->identifier_); EXPECT_EQ(pattern_sym.type(), Symbol::Type::PATH); EXPECT_EQ(pattern_sym.name(), "p"); EXPECT_TRUE(pattern_sym.user_declared()); @@ -93,15 +94,16 @@ TEST_F(TestSymbolGenerator, CreateNodeReturn) { auto create = dynamic_cast<Create *>(query_ast->single_query_->clauses_[0]); auto pattern = create->patterns_[0]; auto node_atom = dynamic_cast<NodeAtom *>(pattern->atoms_[0]); - auto node_sym = symbol_table[*node_atom->identifier_]; + auto node_sym = symbol_table.at(*node_atom->identifier_); EXPECT_EQ(node_sym.name(), "n"); EXPECT_EQ(node_sym.type(), Symbol::Type::VERTEX); auto ret = dynamic_cast<Return *>(query_ast->single_query_->clauses_[1]); auto named_expr = ret->body_.named_expressions[0]; - auto column_sym = symbol_table[*named_expr]; + auto column_sym = symbol_table.at(*named_expr); EXPECT_EQ(node_sym.name(), column_sym.name()); EXPECT_NE(node_sym, column_sym); - auto ret_sym = symbol_table[*named_expr->expression_]; + auto ret_sym = + symbol_table.at(*dynamic_cast<Identifier *>(named_expr->expression_)); EXPECT_EQ(node_sym, ret_sym); } @@ -255,7 +257,7 @@ TEST_F(TestSymbolGenerator, MatchWithWhere) { EXPECT_EQ(node_symbol, old); auto with_n = symbol_table.at(*with_as_n); EXPECT_NE(old, with_n); - auto n = symbol_table.at(*n_prop->expression_); + auto n = symbol_table.at(*dynamic_cast<Identifier *>(n_prop->expression_)); EXPECT_EQ(n, with_n); } @@ -372,7 +374,8 @@ TEST_F(TestSymbolGenerator, MatchPropCreateNodeProp) { // symbols: pattern * 2, `node_n`, `node_m` EXPECT_EQ(symbol_table.max_position(), 4); auto n = symbol_table.at(*node_n->identifier_); - EXPECT_EQ(n, symbol_table.at(*n_prop->expression_)); + EXPECT_EQ(n, + symbol_table.at(*dynamic_cast<Identifier *>(n_prop->expression_))); auto m = symbol_table.at(*node_m->identifier_); EXPECT_NE(n, m); } @@ -551,7 +554,8 @@ TEST_F(TestSymbolGenerator, MergeOnMatchOnCreate) { EXPECT_EQ(symbol_table.max_position(), 6); auto n = symbol_table.at(*match_n->identifier_); EXPECT_EQ(n, symbol_table.at(*merge_n->identifier_)); - EXPECT_EQ(n, symbol_table.at(*n_prop->expression_)); + EXPECT_EQ(n, + symbol_table.at(*dynamic_cast<Identifier *>(n_prop->expression_))); auto r = symbol_table.at(*edge_r->identifier_); EXPECT_NE(r, n); EXPECT_EQ(r, symbol_table.at(*ident_r)); @@ -560,7 +564,8 @@ TEST_F(TestSymbolGenerator, MergeOnMatchOnCreate) { EXPECT_NE(m, n); EXPECT_NE(m, r); EXPECT_NE(m, symbol_table.at(*as_r)); - EXPECT_EQ(m, symbol_table.at(*m_prop->expression_)); + EXPECT_EQ(m, + symbol_table.at(*dynamic_cast<Identifier *>(m_prop->expression_))); } TEST_F(TestSymbolGenerator, WithUnwindRedeclareReturn) { @@ -586,7 +591,8 @@ TEST_F(TestSymbolGenerator, WithUnwindReturn) { // Symbols for: `list`, `elem`, `AS list`, `AS elem` EXPECT_EQ(symbol_table.max_position(), 4); const auto &list = symbol_table.at(*with_as_list); - EXPECT_EQ(list, symbol_table.at(*unwind->named_expression_->expression_)); + EXPECT_EQ(list, symbol_table.at(*dynamic_cast<Identifier *>( + unwind->named_expression_->expression_))); const auto &elem = symbol_table.at(*unwind->named_expression_); EXPECT_NE(list, elem); EXPECT_EQ(list, symbol_table.at(*ret_list)); @@ -612,11 +618,13 @@ TEST_F(TestSymbolGenerator, MatchCrossReferenceVariable) { // Symbols for pattern * 2, `n`, `m` and `AS n` EXPECT_EQ(symbol_table.max_position(), 5); auto n = symbol_table.at(*node_n->identifier_); - EXPECT_EQ(n, symbol_table.at(*n_prop->expression_)); + EXPECT_EQ(n, + symbol_table.at(*dynamic_cast<Identifier *>(n_prop->expression_))); EXPECT_EQ(n, symbol_table.at(*ident_n)); EXPECT_NE(n, symbol_table.at(*as_n)); auto m = symbol_table.at(*node_m->identifier_); - EXPECT_EQ(m, symbol_table.at(*m_prop->expression_)); + EXPECT_EQ(m, + symbol_table.at(*dynamic_cast<Identifier *>(m_prop->expression_))); EXPECT_NE(n, m); EXPECT_NE(m, symbol_table.at(*as_n)); } @@ -638,7 +646,8 @@ TEST_F(TestSymbolGenerator, MatchWithAsteriskReturnAsterisk) { // Symbols for pattern, `n`, `e`, `m`, `AS n.prop`. EXPECT_EQ(symbol_table.max_position(), 5); auto n = symbol_table.at(*node_n->identifier_); - EXPECT_EQ(n, symbol_table.at(*n_prop->expression_)); + EXPECT_EQ(n, + symbol_table.at(*dynamic_cast<Identifier *>(n_prop->expression_))); } TEST_F(TestSymbolGenerator, MatchReturnAsteriskSameResult) { @@ -683,7 +692,8 @@ TEST_F(TestSymbolGenerator, MatchEdgeWithIdentifierInProperty) { // Symbols for pattern, `n`, `r`, `m` and implicit in RETURN `r AS r` EXPECT_EQ(symbol_table.max_position(), 5); auto n = symbol_table.at(*node_n->identifier_); - EXPECT_EQ(n, symbol_table.at(*n_prop->expression_)); + EXPECT_EQ(n, + symbol_table.at(*dynamic_cast<Identifier *>(n_prop->expression_))); } TEST_F(TestSymbolGenerator, MatchVariablePathUsingIdentifier) { @@ -701,7 +711,8 @@ TEST_F(TestSymbolGenerator, MatchVariablePathUsingIdentifier) { // implicit in RETURN `r AS r` EXPECT_EQ(symbol_table.max_position(), 9); auto l = symbol_table.at(*node_l->identifier_); - EXPECT_EQ(l, symbol_table.at(*l_prop->expression_)); + EXPECT_EQ(l, + symbol_table.at(*dynamic_cast<Identifier *>(l_prop->expression_))); auto r = symbol_table.at(*edge->identifier_); EXPECT_EQ(r.type(), Symbol::Type::EDGE_LIST); } @@ -772,7 +783,8 @@ TEST_F(TestSymbolGenerator, MatchPropertySameIdentifier) { auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(node_n)), RETURN("n"))); auto symbol_table = query::MakeSymbolTable(query); auto n = symbol_table.at(*node_n->identifier_); - EXPECT_EQ(n, symbol_table.at(*n_prop->expression_)); + EXPECT_EQ(n, + symbol_table.at(*dynamic_cast<Identifier *>(n_prop->expression_))); } TEST_F(TestSymbolGenerator, WithReturnAll) { @@ -901,12 +913,12 @@ TEST_F(TestSymbolGenerator, MatchBfsReturn) { symbol_table.at(*bfs->filter_lambda_.inner_edge)); EXPECT_TRUE(symbol_table.at(*bfs->filter_lambda_.inner_edge).user_declared()); EXPECT_EQ(symbol_table.at(*bfs->filter_lambda_.inner_edge), - symbol_table.at(*r_prop->expression_)); + symbol_table.at(*dynamic_cast<Identifier *>(r_prop->expression_))); EXPECT_NE(symbol_table.at(*node_n->identifier_), symbol_table.at(*bfs->filter_lambda_.inner_node)); EXPECT_TRUE(symbol_table.at(*bfs->filter_lambda_.inner_node).user_declared()); EXPECT_EQ(symbol_table.at(*node_n->identifier_), - symbol_table.at(*n_prop->expression_)); + symbol_table.at(*dynamic_cast<Identifier *>(n_prop->expression_))); } TEST_F(TestSymbolGenerator, MatchBfsUsesEdgeSymbolError) { @@ -935,7 +947,8 @@ TEST_F(TestSymbolGenerator, MatchBfsUsesPreviousOuterSymbol) { QUERY(SINGLE_QUERY(MATCH(PATTERN(node_a, bfs, NODE("m"))), RETURN("r"))); auto symbol_table = query::MakeSymbolTable(query); EXPECT_EQ(symbol_table.at(*node_a->identifier_), - symbol_table.at(*bfs->filter_lambda_.expression)); + symbol_table.at( + *dynamic_cast<Identifier *>(bfs->filter_lambda_.expression))); } TEST_F(TestSymbolGenerator, MatchBfsUsesLaterSymbolError) { @@ -971,12 +984,11 @@ TEST_F(TestSymbolGenerator, MatchVariableLambdaSymbols) { // `AS res` and the auto-generated path name symbol. EXPECT_EQ(symbol_table.max_position(), 7); // All symbols except `AS res` are anonymously generated. - for (const auto &id_and_symbol : symbol_table.table()) { - const auto &symbol = id_and_symbol.second; + for (const auto &symbol : symbol_table.table()) { if (symbol.name() == "res") { EXPECT_TRUE(symbol.user_declared()); } else { - EXPECT_FALSE(id_and_symbol.second.user_declared()); + EXPECT_FALSE(symbol.user_declared()); } } } @@ -1017,14 +1029,16 @@ TEST_F(TestSymbolGenerator, MatchWShortestReturn) { symbol_table.at(*shortest->filter_lambda_.inner_edge)); EXPECT_TRUE( symbol_table.at(*shortest->filter_lambda_.inner_edge).user_declared()); - EXPECT_EQ(symbol_table.at(*shortest->weight_lambda_.inner_edge), - symbol_table.at(*r_weight->expression_)); + EXPECT_EQ( + symbol_table.at(*shortest->weight_lambda_.inner_edge), + symbol_table.at(*dynamic_cast<Identifier *>(r_weight->expression_))); EXPECT_NE(symbol_table.at(*shortest->weight_lambda_.inner_edge), symbol_table.at(*shortest->filter_lambda_.inner_edge)); EXPECT_NE(symbol_table.at(*shortest->weight_lambda_.inner_node), symbol_table.at(*shortest->filter_lambda_.inner_node)); - EXPECT_EQ(symbol_table.at(*shortest->filter_lambda_.inner_edge), - symbol_table.at(*r_filter->expression_)); + EXPECT_EQ( + symbol_table.at(*shortest->filter_lambda_.inner_edge), + symbol_table.at(*dynamic_cast<Identifier *>(r_filter->expression_))); EXPECT_TRUE( symbol_table.at(*shortest->filter_lambda_.inner_node).user_declared()); }