diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index b8583ee59..e86a1337b 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -160,45 +160,70 @@ void CreateExpand::CreateExpandCursor::CreateEdge( frame[symbol_table.at(*self_.edge_atom_->identifier_)] = edge; } -ScanAll::ScanAll(const NodeAtom *node_atom, - const std::shared_ptr<LogicalOperator> &input, - GraphView graph_view) - : node_atom_(node_atom), - input_(input ? input : std::make_shared<Once>()), +template <class TVertices> +class ScanAllCursor : public Cursor { + public: + ScanAllCursor(Symbol output_symbol, std::unique_ptr<Cursor> input_cursor, + TVertices vertices) + : output_symbol_(output_symbol), + input_cursor_(std::move(input_cursor)), + vertices_(std::move(vertices)), + vertices_it_(vertices_.end()) {} + + bool Pull(Frame &frame, const SymbolTable &symbol_table) override { + if (vertices_it_ == vertices_.end()) { + if (!input_cursor_->Pull(frame, symbol_table)) return false; + vertices_it_ = vertices_.begin(); + } + + // if vertices_ is empty then we are done even though we have just + // reinitialized vertices_it_ + if (vertices_it_ == vertices_.end()) return false; + + frame[output_symbol_] = *vertices_it_++; + return true; + } + + void Reset() override { + input_cursor_->Reset(); + vertices_it_ = vertices_.end(); + } + + private: + const Symbol output_symbol_; + const std::unique_ptr<Cursor> input_cursor_; + TVertices vertices_; + decltype(vertices_.begin()) vertices_it_; +}; + +ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, GraphView graph_view) + : input_(input ? input : std::make_shared<Once>()), + output_symbol_(output_symbol), graph_view_(graph_view) { permanent_assert(graph_view != GraphView::AS_IS, "ScanAll must have explicitly defined GraphView") } + ACCEPT_WITH_INPUT(ScanAll) std::unique_ptr<Cursor> ScanAll::MakeCursor(GraphDbAccessor &db) { - return std::make_unique<ScanAllCursor>(*this, db); + auto vertices = db.vertices(graph_view_ == GraphView::NEW); + return std::make_unique<ScanAllCursor<decltype(vertices)>>( + output_symbol_, input_->MakeCursor(db), std::move(vertices)); } -ScanAll::ScanAllCursor::ScanAllCursor(const ScanAll &self, GraphDbAccessor &db) - : self_(self), - input_cursor_(self.input_->MakeCursor(db)), - vertices_(db.vertices(self.graph_view_ == GraphView::NEW)), - vertices_it_(vertices_.end()) {} +ScanAllByLabel::ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, GraphDbTypes::Label label, + GraphView graph_view) + : ScanAll(input, output_symbol, graph_view), label_(label) {} -bool ScanAll::ScanAllCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { - if (vertices_it_ == vertices_.end()) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; - vertices_it_ = vertices_.begin(); - } +ACCEPT_WITH_INPUT(ScanAllByLabel) - // if vertices_ is empty then we are done even though we have just - // reinitialized vertices_it_ - if (vertices_it_ == vertices_.end()) return false; - - frame[symbol_table.at(*self_.node_atom_->identifier_)] = *vertices_it_++; - return true; -} - -void ScanAll::ScanAllCursor::Reset() { - input_cursor_->Reset(); - vertices_it_ = vertices_.end(); +std::unique_ptr<Cursor> ScanAllByLabel::MakeCursor(GraphDbAccessor &db) { + auto vertices = db.vertices(label_, graph_view_ == GraphView::NEW); + return std::make_unique<ScanAllCursor<decltype(vertices)>>( + output_symbol_, input_->MakeCursor(db), std::move(vertices)); } Expand::Expand(const NodeAtom *node_atom, const EdgeAtom *edge_atom, diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 785a92e36..70010420b 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -53,6 +53,7 @@ class Once; class CreateNode; class CreateExpand; class ScanAll; +class ScanAllByLabel; class Expand; class Filter; class Produce; @@ -76,9 +77,9 @@ class Unwind; class Distinct; using LogicalOperatorCompositeVisitor = ::utils::CompositeVisitor< - Once, CreateNode, CreateExpand, ScanAll, Expand, Filter, Produce, Delete, - SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, - ExpandUniquenessFilter<VertexAccessor>, + Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, Expand, Filter, + Produce, Delete, SetProperty, SetProperties, SetLabels, RemoveProperty, + RemoveLabels, ExpandUniquenessFilter<VertexAccessor>, ExpandUniquenessFilter<EdgeAccessor>, Accumulate, AdvanceCommand, Aggregate, Skip, Limit, OrderBy, Merge, Optional, Unwind, Distinct>; @@ -288,29 +289,39 @@ class CreateExpand : public LogicalOperator { */ class ScanAll : public LogicalOperator { public: - ScanAll(const NodeAtom *node_atom, - const std::shared_ptr<LogicalOperator> &input, + ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, GraphView graph_view = GraphView::OLD); bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override; - private: - const NodeAtom *node_atom_ = nullptr; + protected: const std::shared_ptr<LogicalOperator> input_; + const Symbol output_symbol_; + /** + * @brief Controls which graph state is used to produce vertices. + * + * If @c GraphView::OLD, @c ScanAll will produce vertices visible in the + * previous graph state, before modifications done by current transaction & + * command. With @c GraphView::NEW, all vertices will be produced the current + * transaction sees along with their modifications. + */ const GraphView graph_view_; +}; - class ScanAllCursor : public Cursor { - public: - ScanAllCursor(const ScanAll &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; - void Reset() override; +/** + * @brief Behaves like @c ScanAll, but this operator produces only vertices with + * given label. + */ +class ScanAllByLabel : public ScanAll { + public: + ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, GraphDbTypes::Label label, + GraphView graph_view = GraphView::OLD); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override; - private: - const ScanAll &self_; - const std::unique_ptr<Cursor> input_cursor_; - decltype(std::declval<GraphDbAccessor>().vertices()) vertices_; - decltype(vertices_.begin()) vertices_it_; - }; + private: + const GraphDbTypes::Label label_; }; /** diff --git a/src/query/plan/planner.cpp b/src/query/plan/planner.cpp index 0a8b930d8..472cbb5ab 100644 --- a/src/query/plan/planner.cpp +++ b/src/query/plan/planner.cpp @@ -269,8 +269,8 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op, // Otherwise, someone else generates it (e.g. a previous ScanAll). const auto &node_symbol = symbol_table.at(*node->identifier_); if (BindSymbol(bound_symbols, node_symbol)) { - last_op = new ScanAll(node, std::shared_ptr<LogicalOperator>(last_op), - context.graph_view); + last_op = new ScanAll(std::shared_ptr<LogicalOperator>(last_op), + node_symbol, context.graph_view); context.new_symbols.emplace_back(node_symbol); } return GenFilters(last_op, bound_symbols, context.filters, storage); diff --git a/tests/unit/query_plan_common.hpp b/tests/unit/query_plan_common.hpp index 24b6bbd53..fa6bcfa22 100644 --- a/tests/unit/query_plan_common.hpp +++ b/tests/unit/query_plan_common.hpp @@ -85,17 +85,35 @@ struct ScanAllTuple { * Creates and returns a tuple of stuff for a scan-all starting * from the node with the given name. * - * Returns (node_atom, scan_all_logical_op, symbol). + * Returns ScanAllTuple(node_atom, scan_all_logical_op, symbol). */ ScanAllTuple MakeScanAll(AstTreeStorage &storage, SymbolTable &symbol_table, const std::string &identifier, std::shared_ptr<LogicalOperator> input = {nullptr}, GraphView graph_view = GraphView::OLD) { auto node = NODE(identifier); - auto logical_op = std::make_shared<ScanAll>(node, input, graph_view); auto symbol = symbol_table.CreateSymbol(identifier, true); symbol_table[*node->identifier_] = symbol; - // return std::make_tuple(node, logical_op, symbol); + auto logical_op = std::make_shared<ScanAll>(input, symbol, graph_view); + return ScanAllTuple{node, logical_op, symbol}; +} + +/** + * Creates and returns a tuple of stuff for a scan-all starting + * from the node with the given name and label. + * + * Returns ScanAllTuple(node_atom, scan_all_logical_op, symbol). + */ +ScanAllTuple MakeScanAllByLabel( + AstTreeStorage &storage, SymbolTable &symbol_table, + const std::string &identifier, const GraphDbTypes::Label &label, + std::shared_ptr<LogicalOperator> input = {nullptr}, + GraphView graph_view = GraphView::OLD) { + auto node = NODE(identifier); + auto symbol = symbol_table.CreateSymbol(identifier, true); + symbol_table[*node->identifier_] = symbol; + auto logical_op = + std::make_shared<ScanAllByLabel>(input, symbol, label, graph_view); return ScanAllTuple{node, logical_op, symbol}; } diff --git a/tests/unit/query_plan_match_filter_return.cpp b/tests/unit/query_plan_match_filter_return.cpp index 3d0d0d214..da1400c56 100644 --- a/tests/unit/query_plan_match_filter_return.cpp +++ b/tests/unit/query_plan_match_filter_return.cpp @@ -44,11 +44,10 @@ TEST(QueryPlan, MatchReturn) { return PullAll(produce, *dba, symbol_table); }; - // TODO uncomment once the functionality is implemented - // EXPECT_EQ(2, test_pull_count(GraphView::NEW)); + EXPECT_EQ(2, test_pull_count(GraphView::NEW)); EXPECT_EQ(2, test_pull_count(GraphView::OLD)); dba->insert_vertex(); - // EXPECT_EQ(3, test_pull_count(GraphView::NEW)); + EXPECT_EQ(3, test_pull_count(GraphView::NEW)); EXPECT_EQ(2, test_pull_count(GraphView::OLD)); dba->advance_command(); EXPECT_EQ(3, test_pull_count(GraphView::OLD)); @@ -804,3 +803,31 @@ TEST(QueryPlan, Distinct) { {3, "two", TypedValue::Null, 3, true, false, "TWO", TypedValue::Null}, {3, "two", TypedValue::Null, true, false, "TWO"}, false); } + +TEST(QueryPlan, ScanAllByLabel) { + Dbms dbms; + auto dba = dbms.active(); + // Add a vertex with a label and one without. + auto label = dba->label("label"); + auto labeled_vertex = dba->insert_vertex(); + labeled_vertex.add_label(label); + dba->insert_vertex(); + dba->advance_command(); + EXPECT_EQ(2, CountIterable(dba->vertices())); + // MATCH (n :label) + AstTreeStorage storage; + SymbolTable symbol_table; + auto scan_all_by_label = + MakeScanAllByLabel(storage, symbol_table, "n", label); + // RETURN n + auto output = NEXPR("n", IDENT("n")); + auto produce = MakeProduce(scan_all_by_label.op_, output); + symbol_table[*output->expression_] = scan_all_by_label.sym_; + symbol_table[*output] = symbol_table.CreateSymbol("n", true); + auto result_stream = CollectProduce(produce, symbol_table, *dba); + auto results = result_stream.GetResults(); + ASSERT_EQ(results.size(), 1); + auto result_row = results[0]; + ASSERT_EQ(result_row.size(), 1); + EXPECT_EQ(result_row[0].Value<VertexAccessor>(), labeled_vertex); +}