diff --git a/src/query/frontend/logical/operator.cpp b/src/query/frontend/logical/operator.cpp index caaed5b64..2ec7a5c5d 100644 --- a/src/query/frontend/logical/operator.cpp +++ b/src/query/frontend/logical/operator.cpp @@ -148,17 +148,46 @@ void CreateExpand::CreateExpandCursor::CreateEdge( frame[symbol_table[*self_.edge_atom_->identifier_]] = edge; } -ScanAll::ScanAll(NodeAtom *node_atom) : node_atom_(node_atom) {} +ScanAll::ScanAll(NodeAtom *node_atom) + : node_atom_(node_atom), input_(nullptr) {} + +ScanAll::ScanAll(NodeAtom *node_atom, std::shared_ptr input) + : node_atom_(node_atom), input_(input) {} + +void ScanAll::Accept(LogicalOperatorVisitor &visitor) { + visitor.Visit(*this); + if (input_) input_->Accept(visitor); + visitor.PostVisit(*this); +} std::unique_ptr ScanAll::MakeCursor(GraphDbAccessor &db) { return std::make_unique(*this, db); } ScanAll::ScanAllCursor::ScanAllCursor(ScanAll &self, GraphDbAccessor &db) - : self_(self), vertices_(db.vertices()), vertices_it_(vertices_.begin()) {} + : self_(self), + input_cursor_(self.input_ ? self.input_->MakeCursor(db) : nullptr), + vertices_(db.vertices()), + vertices_it_(vertices_.begin()) {} bool ScanAll::ScanAllCursor::Pull(Frame &frame, SymbolTable &symbol_table) { - if (vertices_it_ == vertices_.end()) return false; + if (input_cursor_) { + // using an input. we need to pull from it if we are in the first pull + // of this cursor, or if we have exhausted vertices_it_ + if (first_pull_ || vertices_it_ == vertices_.end()) { + first_pull_ = false; + // if the input is empty, we are for sure done + if (!input_cursor_->Pull(frame, symbol_table)) return false; + vertices_it_ = vertices_.begin(); + } + } + + // if we have no more vertices, we're done (if input_ is set we have + // just tried to re-init vertices_it_, and if not we only iterate + // through it once + if (vertices_it_ == vertices_.end()) + return false; + frame[symbol_table[*self_.node_atom_->identifier_]] = *vertices_it_++; return true; } diff --git a/src/query/frontend/logical/operator.hpp b/src/query/frontend/logical/operator.hpp index b6b0c192c..3152c1e38 100644 --- a/src/query/frontend/logical/operator.hpp +++ b/src/query/frontend/logical/operator.hpp @@ -204,15 +204,22 @@ class CreateExpand : public LogicalOperator { /** * @brief Operator which iterates over all the nodes currently in the database. + * When given an input (optional), does a cartesian product. + * + * It accepts an optional input. If provided then this op scans all the nodes + * currently in the database for each successful Pull from it's input, thereby + * producing a cartesian product of input Pulls and database elements. */ class ScanAll : public LogicalOperator { public: ScanAll(NodeAtom *node_atom); - DEFVISITABLE(LogicalOperatorVisitor); + ScanAll(NodeAtom *node_atom, std::shared_ptr input); + void Accept(LogicalOperatorVisitor &visitor) override; std::unique_ptr MakeCursor(GraphDbAccessor &db) override; private: NodeAtom *node_atom_ = nullptr; + std::shared_ptr input_; class ScanAllCursor : public Cursor { public: @@ -221,8 +228,11 @@ class ScanAll : public LogicalOperator { private: ScanAll &self_; + std::unique_ptr input_cursor_; decltype(std::declval().vertices()) vertices_; decltype(vertices_.begin()) vertices_it_; + // if this is the first pull from this cursor + bool first_pull_{true}; }; }; diff --git a/tests/unit/interpreter.cpp b/tests/unit/interpreter.cpp index a2f287474..8d08d686e 100644 --- a/tests/unit/interpreter.cpp +++ b/tests/unit/interpreter.cpp @@ -92,9 +92,10 @@ struct ScanAllTuple { * Returns (node_atom, scan_all_logical_op, symbol). */ ScanAllTuple MakeScanAll(AstTreeStorage &storage, SymbolTable &symbol_table, - const std::string &identifier) { + const std::string &identifier, + std::shared_ptr input = {nullptr}) { auto node = NODE(identifier); - auto logical_op = std::make_shared(node); + auto logical_op = std::make_shared(node, input); auto symbol = symbol_table.CreateSymbol(identifier); symbol_table[*node->identifier_] = symbol; // return std::make_tuple(node, logical_op, symbol); @@ -159,6 +160,38 @@ TEST(Interpreter, MatchReturn) { EXPECT_EQ(result.GetResults().size(), 2); } +TEST(Interpreter, MatchReturnCartesian) { + Dbms dbms; + auto dba = dbms.active(); + + dba->insert_vertex().add_label(dba->label("l1")); + dba->insert_vertex().add_label(dba->label("l2")); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto m = MakeScanAll(storage, symbol_table, "m", n.op_); + auto return_n = NEXPR("n", IDENT("n")); + symbol_table[*return_n->expression_] = n.sym_; + symbol_table[*return_n] = symbol_table.CreateSymbol("named_expression_1"); + auto return_m = NEXPR("m", IDENT("m")); + symbol_table[*return_m->expression_] = m.sym_; + symbol_table[*return_m] = symbol_table.CreateSymbol("named_expression_2"); + auto produce = MakeProduce(m.op_, return_n, return_m); + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + auto result_data = result.GetResults(); + EXPECT_EQ(result_data.size(), 4); + // ensure the result ordering is OK: + // "n" from the results is the same for the first two rows, while "m" isn't + EXPECT_EQ(result_data[0][0].Value(), + result_data[1][0].Value()); + EXPECT_NE(result_data[0][1].Value(), + result_data[1][1].Value()); +} + TEST(Interpreter, StandaloneReturn) { Dbms dbms; auto dba = dbms.active();