diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 643271f35..c5aea834a 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -9,14 +9,32 @@ namespace query { namespace plan { +void Once::Accept(LogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + visitor.PostVisit(*this); + } +} +std::unique_ptr Once::MakeCursor(GraphDbAccessor &db) { + return std::make_unique(); +} + +bool Once::OnceCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { + if (!did_pull_) { + did_pull_ = true; + return true; + } + return false; +} + CreateNode::CreateNode(const NodeAtom *node_atom, const std::shared_ptr &input) - : node_atom_(node_atom), input_(input) {} + : node_atom_(node_atom), input_(input ? input : std::make_shared()) {} void CreateNode::Accept(LogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { visitor.Visit(*this); - if (input_) input_->Accept(visitor); + input_->Accept(visitor); visitor.PostVisit(*this); } } @@ -27,21 +45,12 @@ std::unique_ptr CreateNode::MakeCursor(GraphDbAccessor &db) { CreateNode::CreateNodeCursor::CreateNodeCursor(const CreateNode &self, GraphDbAccessor &db) - : self_(self), - db_(db), - input_cursor_(self.input_ ? self.input_->MakeCursor(db) : nullptr) {} + : self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {} bool CreateNode::CreateNodeCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { - if (input_cursor_) { - if (input_cursor_->Pull(frame, symbol_table)) { - Create(frame, symbol_table); - return true; - } else - return false; - } else if (!did_create_) { + if (input_cursor_->Pull(frame, symbol_table)) { Create(frame, symbol_table); - did_create_ = true; return true; } else return false; @@ -155,17 +164,14 @@ void CreateExpand::CreateExpandCursor::CreateEdge( frame[symbol_table.at(*self_.edge_atom_->identifier_)] = edge; } -ScanAll::ScanAll(const NodeAtom *node_atom) - : node_atom_(node_atom), input_(nullptr) {} - ScanAll::ScanAll(const NodeAtom *node_atom, const std::shared_ptr &input) - : node_atom_(node_atom), input_(input) {} + : node_atom_(node_atom), input_(input ? input : std::make_shared()) {} void ScanAll::Accept(LogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { visitor.Visit(*this); - if (input_) input_->Accept(visitor); + input_->Accept(visitor); visitor.PostVisit(*this); } } @@ -176,26 +182,19 @@ std::unique_ptr ScanAll::MakeCursor(GraphDbAccessor &db) { ScanAll::ScanAllCursor::ScanAllCursor(const ScanAll &self, GraphDbAccessor &db) : self_(self), - input_cursor_(self.input_ ? self.input_->MakeCursor(db) : nullptr), + input_cursor_(self.input_->MakeCursor(db)), vertices_(db.vertices()), - vertices_it_(vertices_.begin()) {} + vertices_it_(vertices_.end()) {} bool ScanAll::ScanAllCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { - if (input_cursor_) { - // using an input. we need to pull from it if we are in the first pull - // of this cursor, or if we have exhausted vertices_it_ - if (first_pull_ || vertices_it_ == vertices_.end()) { - first_pull_ = false; - // if the input is empty, we are for sure done - if (!input_cursor_->Pull(frame, symbol_table)) return false; - vertices_it_ = vertices_.begin(); - } + if (vertices_it_ == vertices_.end()) { + if (!input_cursor_->Pull(frame, symbol_table)) return false; + vertices_it_ = vertices_.begin(); } - // if we have no more vertices, we're done (if input_ is set we have - // just tried to re-init vertices_it_, and if not we only iterate - // through it once + // if vertices_ is empty then we are done even though we have just + // reinitialized vertices_it_ if (vertices_it_ == vertices_.end()) return false; frame[symbol_table.at(*self_.node_atom_->identifier_)] = *vertices_it_++; @@ -478,12 +477,13 @@ bool Filter::FilterCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { Produce::Produce(const std::shared_ptr &input, const std::vector named_expressions) - : input_(input), named_expressions_(named_expressions) {} + : input_(input ? input : std::make_shared()), + named_expressions_(named_expressions) {} void Produce::Accept(LogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { visitor.Visit(*this); - if (input_) input_->Accept(visitor); + input_->Accept(visitor); visitor.PostVisit(*this); } } @@ -497,27 +497,18 @@ const std::vector &Produce::named_expressions() { } Produce::ProduceCursor::ProduceCursor(const Produce &self, GraphDbAccessor &db) - : self_(self), - input_cursor_(self.input_ ? self_.input_->MakeCursor(db) : nullptr) {} + : self_(self), input_cursor_(self_.input_->MakeCursor(db)) {} bool Produce::ProduceCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { - ExpressionEvaluator evaluator(frame, symbol_table); - // Produce should always yield the latest results. - evaluator.SwitchNew(); - if (input_cursor_) { - if (input_cursor_->Pull(frame, symbol_table)) { - for (auto named_expr : self_.named_expressions_) - named_expr->Accept(evaluator); - return true; - } - return false; - } else if (!did_produce_) { + if (input_cursor_->Pull(frame, symbol_table)) { + ExpressionEvaluator evaluator(frame, symbol_table); + // Produce should always yield the latest results. + evaluator.SwitchNew(); for (auto named_expr : self_.named_expressions_) named_expr->Accept(evaluator); - did_produce_ = true; return true; - } else - return false; + } + return false; } Delete::Delete(const std::shared_ptr &input_, @@ -973,7 +964,7 @@ Aggregate::Aggregate(const std::shared_ptr &input, const std::vector &aggregations, const std::vector &group_by, const std::vector &remember) - : input_(input), + : input_(input ? input : std::make_shared()), aggregations_(aggregations), group_by_(group_by), remember_(remember) {} @@ -992,8 +983,7 @@ std::unique_ptr Aggregate::MakeCursor(GraphDbAccessor &db) { Aggregate::AggregateCursor::AggregateCursor(Aggregate &self, GraphDbAccessor &db) - : self_(self), - input_cursor_(self.input_ ? self_.input_->MakeCursor(db) : nullptr) {} + : self_(self), input_cursor_(self_.input_->MakeCursor(db)) {} bool Aggregate::AggregateCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { @@ -1023,10 +1013,7 @@ void Aggregate::AggregateCursor::ProcessAll(Frame &frame, const SymbolTable &symbol_table) { ExpressionEvaluator evaluator(frame, symbol_table); evaluator.SwitchNew(); - if (input_cursor_) - while (input_cursor_->Pull(frame, symbol_table)) - ProcessOne(frame, symbol_table, evaluator); - else + while (input_cursor_->Pull(frame, symbol_table)) ProcessOne(frame, symbol_table, evaluator); // calculate AVG aggregations (so far they have only been summed) diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 6ba204a68..16480bc42 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -43,6 +43,7 @@ class Cursor { virtual ~Cursor() {} }; +class Once; class CreateNode; class CreateExpand; class ScanAll; @@ -68,10 +69,10 @@ class OrderBy; /** @brief Base class for visitors of @c LogicalOperator class hierarchy. */ using LogicalOperatorVisitor = - ::utils::Visitor, + ::utils::Visitor, ExpandUniquenessFilter, Accumulate, AdvanceCommand, Aggregate, Skip, Limit, OrderBy>; @@ -91,6 +92,24 @@ class LogicalOperator : public ::utils::Visitable { virtual ~LogicalOperator() {} }; +/** + * A logical operator whose Cursor returns true on the first Pull + * and false on every following Pull. + */ +class Once : public LogicalOperator { + public: + void Accept(LogicalOperatorVisitor &visitor) override; + std::unique_ptr MakeCursor(GraphDbAccessor &db) override; + + private: + class OnceCursor : public Cursor { + public: + bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + private: + bool did_pull_{false}; + }; +}; + /** @brief Operator for creating a node. * * This op is used both for creating a single node (`CREATE` statement without @@ -126,11 +145,7 @@ class CreateNode : public LogicalOperator { private: const CreateNode &self_; GraphDbAccessor &db_; - // optional, used in situations in which this create op - // pulls from an input (in MATCH CREATE, CREATE ... CREATE) const std::unique_ptr input_cursor_; - // control switch when creating only one node (nullptr input) - bool did_create_{false}; /** * Creates a single node and places it in the frame. @@ -227,7 +242,6 @@ class CreateExpand : public LogicalOperator { */ class ScanAll : public LogicalOperator { public: - ScanAll(const NodeAtom *node_atom); ScanAll(const NodeAtom *node_atom, const std::shared_ptr &input); void Accept(LogicalOperatorVisitor &visitor) override; @@ -247,8 +261,6 @@ class ScanAll : public LogicalOperator { const std::unique_ptr input_cursor_; decltype(std::declval().vertices()) vertices_; decltype(vertices_.begin()) vertices_it_; - // if this is the first pull from this cursor - bool first_pull_{true}; }; }; @@ -478,7 +490,7 @@ class Filter : public LogicalOperator { * for the RETURN clause). * * Supports optional input. When the input is provided, - * it is Pulled from and the Produce succeds once for + * it is Pulled from and the Produce succeeds once for * every input Pull (typically a MATCH/RETURN query). * When the input is not provided (typically a standalone * RETURN clause) the Produce's pull succeeds exactly once. @@ -502,10 +514,7 @@ class Produce : public LogicalOperator { private: const Produce &self_; - // optional, see class documentation const std::unique_ptr input_cursor_; - // control switch when creating only one node (nullptr input) - bool did_produce_{false}; }; }; @@ -875,7 +884,6 @@ class Aggregate : public LogicalOperator { }; Aggregate &self_; - // optional std::unique_ptr input_cursor_; // storage for aggregated data // map key is the list of group-by values @@ -896,7 +904,7 @@ class Aggregate : public LogicalOperator { /** * Pulls from the input operator until exhausted and aggregates the * results. If the input operator is not provided, a single call - * to ProccessOne is issued. + * to ProcessOne is issued. * * Accumulation automatically groups the results so that `aggregation_` * cache cardinality depends on number of