diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index cac2021d9..be5917a74 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -64,11 +64,13 @@ class OrOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - // TODO: Should we short-circuit? - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + // TODO: Should we short-circuit? + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -80,10 +82,12 @@ class XorOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -95,11 +99,13 @@ class AndOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - // TODO: Should we short-circuit? - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + // TODO: Should we short-circuit? + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -111,10 +117,12 @@ class AdditionOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -126,10 +134,12 @@ class SubtractionOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -141,10 +151,12 @@ class MultiplicationOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -156,10 +168,12 @@ class DivisionOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -171,10 +185,12 @@ class ModOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -186,10 +202,12 @@ class NotEqualOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -201,10 +219,12 @@ class EqualOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -216,10 +236,12 @@ class LessOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -231,10 +253,12 @@ class GreaterOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -246,10 +270,12 @@ class LessEqualOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -261,10 +287,12 @@ class GreaterEqualOperator : public BinaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression1_->Accept(visitor); - expression2_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -276,9 +304,11 @@ class NotOperator : public UnaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -290,9 +320,11 @@ class UnaryPlusOperator : public UnaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -304,9 +336,11 @@ class UnaryMinusOperator : public UnaryOperator { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -342,9 +376,11 @@ class PropertyLookup : public Expression { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression_->Accept(visitor); + visitor.PostVisit(*this); + } } Expression *expression_ = nullptr; @@ -371,9 +407,11 @@ class Aggregation : public UnaryOperator { Op op_; void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression_->Accept(visitor); + visitor.PostVisit(*this); + } } protected: @@ -386,9 +424,11 @@ class NamedExpression : public Tree { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression_->Accept(visitor); + visitor.PostVisit(*this); + } } std::string name_; @@ -418,9 +458,11 @@ class NodeAtom : public PatternAtom { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - identifier_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + identifier_->Accept(visitor); + visitor.PostVisit(*this); + } } std::vector labels_; @@ -441,9 +483,11 @@ class EdgeAtom : public PatternAtom { enum class Direction { LEFT, RIGHT, BOTH }; void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - identifier_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + identifier_->Accept(visitor); + visitor.PostVisit(*this); + } } Direction direction_ = Direction::BOTH; @@ -469,11 +513,13 @@ class Pattern : public Tree { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - for (auto &part : atoms_) { - part->Accept(visitor); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + for (auto &part : atoms_) { + part->Accept(visitor); + } + visitor.PostVisit(*this); } - visitor.PostVisit(*this); } Identifier *identifier_ = nullptr; std::vector atoms_; @@ -487,11 +533,13 @@ class Query : public Tree { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - for (auto &clause : clauses_) { - clause->Accept(visitor); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + for (auto &clause : clauses_) { + clause->Accept(visitor); + } + visitor.PostVisit(*this); } - visitor.PostVisit(*this); } std::vector clauses_; @@ -506,11 +554,13 @@ class Create : public Clause { Create(int uid) : Clause(uid) {} std::vector patterns_; void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - for (auto &pattern : patterns_) { - pattern->Accept(visitor); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + for (auto &pattern : patterns_) { + pattern->Accept(visitor); + } + visitor.PostVisit(*this); } - visitor.PostVisit(*this); } }; @@ -519,9 +569,11 @@ class Where : public Tree { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - expression_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + expression_->Accept(visitor); + visitor.PostVisit(*this); + } } Expression *expression_ = nullptr; @@ -535,14 +587,16 @@ class Match : public Clause { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - for (auto &pattern : patterns_) { - pattern->Accept(visitor); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + for (auto &pattern : patterns_) { + pattern->Accept(visitor); + } + if (where_) { + where_->Accept(visitor); + } + visitor.PostVisit(*this); } - if (where_) { - where_->Accept(visitor); - } - visitor.PostVisit(*this); } std::vector patterns_; Where *where_ = nullptr; @@ -556,11 +610,13 @@ class Return : public Clause { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - for (auto &expr : named_expressions_) { - expr->Accept(visitor); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + for (auto &expr : named_expressions_) { + expr->Accept(visitor); + } + visitor.PostVisit(*this); } - visitor.PostVisit(*this); } std::vector named_expressions_; @@ -573,12 +629,14 @@ class With : public Clause { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - for (auto &expr : named_expressions_) { - expr->Accept(visitor); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + for (auto &expr : named_expressions_) { + expr->Accept(visitor); + } + if (where_) where_->Accept(visitor); + visitor.PostVisit(*this); } - if (where_) where_->Accept(visitor); - visitor.PostVisit(*this); } bool distinct_ = false; @@ -594,11 +652,13 @@ class Delete : public Clause { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - for (auto &expr : expressions_) { - expr->Accept(visitor); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + for (auto &expr : expressions_) { + expr->Accept(visitor); + } + visitor.PostVisit(*this); } - visitor.PostVisit(*this); } std::vector expressions_; bool detach_ = false; @@ -612,10 +672,12 @@ class SetProperty : public Clause { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - property_lookup_->Accept(visitor); - expression_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + property_lookup_->Accept(visitor); + expression_->Accept(visitor); + visitor.PostVisit(*this); + } } PropertyLookup *property_lookup_ = nullptr; Expression *expression_ = nullptr; @@ -633,10 +695,12 @@ class SetProperties : public Clause { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - identifier_->Accept(visitor); - expression_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + identifier_->Accept(visitor); + expression_->Accept(visitor); + visitor.PostVisit(*this); + } } Identifier *identifier_ = nullptr; Expression *expression_ = nullptr; @@ -657,9 +721,11 @@ class SetLabels : public Clause { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - identifier_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + identifier_->Accept(visitor); + visitor.PostVisit(*this); + } } Identifier *identifier_ = nullptr; std::vector labels_; @@ -676,9 +742,11 @@ class RemoveProperty : public Clause { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - property_lookup_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + property_lookup_->Accept(visitor); + visitor.PostVisit(*this); + } } PropertyLookup *property_lookup_ = nullptr; @@ -693,9 +761,11 @@ class RemoveLabels : public Clause { public: void Accept(TreeVisitorBase &visitor) override { - visitor.Visit(*this); - identifier_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + identifier_->Accept(visitor); + visitor.PostVisit(*this); + } } Identifier *identifier_ = nullptr; std::vector labels_; diff --git a/src/query/frontend/logical/operator.cpp b/src/query/frontend/logical/operator.cpp index e6ccfbc42..d142d3033 100644 --- a/src/query/frontend/logical/operator.cpp +++ b/src/query/frontend/logical/operator.cpp @@ -12,9 +12,11 @@ CreateNode::CreateNode(NodeAtom *node_atom, : node_atom_(node_atom), input_(input) {} void CreateNode::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - if (input_) input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + if (input_) input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr CreateNode::MakeCursor(GraphDbAccessor &db) { @@ -69,9 +71,11 @@ CreateExpand::CreateExpand(NodeAtom *node_atom, EdgeAtom *edge_atom, node_existing_(node_existing) {} void CreateExpand::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr CreateExpand::MakeCursor(GraphDbAccessor &db) { @@ -155,9 +159,11 @@ ScanAll::ScanAll(NodeAtom *node_atom, std::shared_ptr input) : node_atom_(node_atom), input_(input) {} void ScanAll::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - if (input_) input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + if (input_) input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr ScanAll::MakeCursor(GraphDbAccessor &db) { @@ -202,9 +208,11 @@ Expand::Expand(NodeAtom *node_atom, EdgeAtom *edge_atom, edge_cycle_(edge_cycle) {} void Expand::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr Expand::MakeCursor(GraphDbAccessor &db) { @@ -323,9 +331,11 @@ NodeFilter::NodeFilter(std::shared_ptr input, : input_(input), input_symbol_(input_symbol), node_atom_(node_atom) {} void NodeFilter::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr NodeFilter::MakeCursor(GraphDbAccessor &db) { @@ -373,9 +383,11 @@ EdgeFilter::EdgeFilter(std::shared_ptr input, : input_(input), input_symbol_(input_symbol), edge_atom_(edge_atom) {} void EdgeFilter::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr EdgeFilter::MakeCursor(GraphDbAccessor &db) { @@ -426,9 +438,11 @@ Filter::Filter(const std::shared_ptr &input_, : input_(input_), expression_(expression_) {} void Filter::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr Filter::MakeCursor(GraphDbAccessor &db) { @@ -458,9 +472,11 @@ Produce::Produce(std::shared_ptr input, : input_(input), named_expressions_(named_expressions) {} void Produce::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - if (input_) input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + if (input_) input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr Produce::MakeCursor(GraphDbAccessor &db) { @@ -499,9 +515,11 @@ Delete::Delete(const std::shared_ptr &input_, : input_(input_), expressions_(expressions), detach_(detach_) {} void Delete::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr Delete::MakeCursor(GraphDbAccessor &db) { @@ -550,9 +568,11 @@ SetProperty::SetProperty(const std::shared_ptr input, : input_(input), lhs_(lhs), rhs_(rhs) {} void SetProperty::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr SetProperty::MakeCursor(GraphDbAccessor &db) { @@ -599,9 +619,11 @@ SetProperties::SetProperties(const std::shared_ptr input, : input_(input), input_symbol_(input_symbol), rhs_(rhs), op_(op) {} void SetProperties::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr SetProperties::MakeCursor(GraphDbAccessor &db) { @@ -682,9 +704,11 @@ SetLabels::SetLabels(const std::shared_ptr input, : input_(input), input_symbol_(input_symbol), labels_(labels) {} void SetLabels::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr SetLabels::MakeCursor(GraphDbAccessor &db) { @@ -711,9 +735,11 @@ RemoveProperty::RemoveProperty(const std::shared_ptr input, : input_(input), lhs_(lhs) {} void RemoveProperty::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr RemoveProperty::MakeCursor(GraphDbAccessor &db) { @@ -756,9 +782,11 @@ RemoveLabels::RemoveLabels(const std::shared_ptr input, : input_(input), input_symbol_(input_symbol), labels_(labels) {} void RemoveLabels::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr RemoveLabels::MakeCursor(GraphDbAccessor &db) { @@ -792,9 +820,11 @@ ExpandUniquenessFilter::ExpandUniquenessFilter( template void ExpandUniquenessFilter::Accept( LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } template @@ -867,9 +897,11 @@ Accumulate::Accumulate(std::shared_ptr input, : input_(input), symbols_(symbols), advance_command_(advance_command) {} void Accumulate::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr Accumulate::MakeCursor(GraphDbAccessor &db) { return std::make_unique(*this, db); @@ -911,9 +943,11 @@ Aggregate::Aggregate(const std::shared_ptr &input, : input_(input), aggregations_(aggregations), group_by_(group_by) {} void Aggregate::Accept(LogicalOperatorVisitor &visitor) { - visitor.Visit(*this); - input_->Accept(visitor); - visitor.PostVisit(*this); + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } } std::unique_ptr Aggregate::MakeCursor(GraphDbAccessor &db) { diff --git a/src/utils/visitor/visitable.hpp b/src/utils/visitor/visitable.hpp index bca5a40ee..e5a530399 100644 --- a/src/utils/visitor/visitable.hpp +++ b/src/utils/visitor/visitable.hpp @@ -18,6 +18,11 @@ namespace utils { /// public: /// void Accept(ExpressionVisitor &visitor) override { /// // Implement custom Accept. +/// if (visitor.PreVisit(*this)) { +/// visitor.Visit(*this); +/// ... // e.g. send visitor to children +/// visitor.PostVisit(*this); +/// } /// } /// }; /// @@ -34,8 +39,10 @@ class Visitable { /// @sa utils::Visitable #define DEFVISITABLE(TVisitor) \ void Accept(TVisitor &visitor) override { \ - visitor.Visit(*this); \ - visitor.PostVisit(*this); \ + if (visitor.PreVisit(*this)) { \ + visitor.Visit(*this); \ + visitor.PostVisit(*this); \ + } \ } }; diff --git a/src/utils/visitor/visitor.hpp b/src/utils/visitor/visitor.hpp index f838f23bd..27ba390db 100644 --- a/src/utils/visitor/visitor.hpp +++ b/src/utils/visitor/visitor.hpp @@ -11,6 +11,7 @@ class VisitorBase { public: virtual ~VisitorBase() = default; + virtual bool PreVisit(T &) { return true; } virtual void Visit(T &) {} virtual void PostVisit(T &) {} }; @@ -22,9 +23,11 @@ template class RecursiveVisitorBase : public VisitorBase, public RecursiveVisitorBase { public: + using VisitorBase::PreVisit; using VisitorBase::Visit; using VisitorBase::PostVisit; + using RecursiveVisitorBase::PreVisit; using RecursiveVisitorBase::Visit; using RecursiveVisitorBase::PostVisit; }; @@ -32,19 +35,22 @@ class RecursiveVisitorBase template class RecursiveVisitorBase : public VisitorBase { public: + using VisitorBase::PreVisit; using VisitorBase::Visit; using VisitorBase::PostVisit; }; } // namespace detail -/// Inherit from this class if you want to visit TVisitable types. +/// @brief Inherit from this class if you want to visit TVisitable types. +/// /// Example usage: /// /// // Typedef for convenience or to establish a base class of visitors. /// typedef Visitor ExpressionVisitorBase; /// class ExpressionVisitor : public ExpressionVisitorBase { /// public: +/// using ExpressionVisitorBase::PreVisit; /// using ExpressionVisitorBase::Visit; /// using ExpressionVisitorBase::PostVisit; /// @@ -57,6 +63,7 @@ class RecursiveVisitorBase : public VisitorBase { template class Visitor : public detail::RecursiveVisitorBase { public: + using detail::RecursiveVisitorBase::PreVisit; using detail::RecursiveVisitorBase::Visit; using detail::RecursiveVisitorBase::PostVisit; };