diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 0b1bacce2..eff2d687c 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -126,7 +126,11 @@ bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, CreateEdge(v1, v2, frame, symbol_table, evaluator); break; case EdgeAtom::Direction::BOTH: - permanent_fail("Undefined direction not allowed in create"); + // in the case of an undirected CreateExpand we choose an arbitrary + // direction. this is used in the MERGE clause + // it is not allowed in the CREATE clause, and the semantic + // checker needs to ensure it doesn't reach this point + CreateEdge(v1, v2, frame, symbol_table, evaluator); } return true; @@ -185,10 +189,13 @@ std::unique_ptr<Cursor> ScanAll::MakeCursor(GraphDbAccessor &db) { ScanAll::ScanAllCursor::ScanAllCursor(const ScanAll &self, GraphDbAccessor &db) : self_(self), input_cursor_(self.input_->MakeCursor(db)), - // TODO change to db.vertices(self.switch_ == GraphView::NEW) + // TODO change to db.vertices(self.graph_view_ == GraphView::NEW) // once this GraphDbAccessor API is available vertices_(db.vertices()), - vertices_it_(vertices_.end()) {} + vertices_it_(vertices_.end()) { + if (self.graph_view_ == GraphView::NEW) + throw utils::NotYetImplemented(); + } bool ScanAll::ScanAllCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { @@ -1369,5 +1376,75 @@ bool OrderBy::TypedValueListCompare::operator()( return (c1_it == c1.end()) && (c2_it != c2.end()); } +Merge::Merge(const std::shared_ptr<LogicalOperator> input, + const std::shared_ptr<LogicalOperator> merge_match, + const std::shared_ptr<LogicalOperator> merge_create) + : input_(input ? input : std::make_shared<Once>()), + merge_match_(merge_match), + merge_create_(merge_create) {} + +void Merge::Accept(LogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + merge_match_->Accept(visitor); + merge_create_->Accept(visitor); + visitor.PostVisit(*this); + } +} + +std::unique_ptr<Cursor> Merge::MakeCursor(GraphDbAccessor &db) { + return std::make_unique<MergeCursor>(*this, db); +} + +Merge::MergeCursor::MergeCursor(Merge &self, GraphDbAccessor &db) + : input_cursor_(self.input_->MakeCursor(db)), + merge_match_cursor_(self.merge_match_->MakeCursor(db)), + merge_create_cursor_(self.merge_create_->MakeCursor(db)) {} + +bool Merge::MergeCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { + // the loop is here to go back to input pull + // when the merge_match gets exhausted + while (true) { + if (pull_input_) { + if (input_cursor_->Pull(frame, symbol_table)) { + // after a successful input from the input + // reset merge_match (it's expand iterators maintain state) + // and merge_create (could have a Once at the beginning) + merge_match_cursor_->Reset(); + merge_create_cursor_->Reset(); + } else + // input is exhausted, we're done + return false; + } + + // pull from the merge_match cursor + if (merge_match_cursor_->Pull(frame, symbol_table)) { + // if successful, next Pull from this should not pull_input_ + pull_input_ = false; + return true; + } else { + // failed to Pull from the merge_match cursor + if (pull_input_) { + // if we have just now pulled from the input + // and failed to pull from merge_match, we should create + bool merge_create_pull_result = + merge_create_cursor_->Pull(frame, symbol_table); + debug_assert(merge_create_pull_result, "MergeCreate must never fail"); + return true; + } + // we have exhausted merge_match + // so we should pull from input on next pull + pull_input_ = true; + } + } +} + +void Merge::MergeCursor::Reset() { + input_cursor_->Reset(); + merge_match_cursor_->Reset(); + merge_create_cursor_->Reset(); +} + } // namespace plan } // namespace query diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 72e085a06..247782c50 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -73,6 +73,7 @@ class Aggregate; class Skip; class Limit; class OrderBy; +class Merge; /** @brief Base class for visitors of @c LogicalOperator class hierarchy. */ using LogicalOperatorVisitor = @@ -81,7 +82,7 @@ using LogicalOperatorVisitor = SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, ExpandUniquenessFilter<VertexAccessor>, ExpandUniquenessFilter<EdgeAccessor>, Accumulate, - AdvanceCommand, Aggregate, Skip, Limit, OrderBy>; + AdvanceCommand, Aggregate, Skip, Limit, OrderBy, Merge>; /** @brief Base class for logical operators. * @@ -1143,5 +1144,51 @@ class OrderBy : public LogicalOperator { }; }; +/** + * Merge operator. For every sucessful Pull from the + * input operator a Pull from the merge_match is attempted. All + * successfull Pulls from the merge_match are passed on as output. + * If merge_match Pull does not yield any elements, a single Pull + * from the merge_create op is performed. + * + * The input logical op is optional. If false (nullptr) + * it will be replaced by a Once op. + * + * For an argumentation of this implementation see the wiki + * documentation. + */ +class Merge : public LogicalOperator { + public: + Merge(const std::shared_ptr<LogicalOperator> input, + const std::shared_ptr<LogicalOperator> merge_match, + const std::shared_ptr<LogicalOperator> merge_create); + void Accept(LogicalOperatorVisitor &visitor) override; + std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override; + + private: + const std::shared_ptr<LogicalOperator> input_; + const std::shared_ptr<LogicalOperator> merge_match_; + const std::shared_ptr<LogicalOperator> merge_create_; + + class MergeCursor : public Cursor { + public: + MergeCursor(Merge &self, GraphDbAccessor &db); + bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + void Reset() override; + + private: + const std::unique_ptr<Cursor> input_cursor_; + const std::unique_ptr<Cursor> merge_match_cursor_; + const std::unique_ptr<Cursor> merge_create_cursor_; + + // indicates if the next Pull from this cursor + // should perform a pull from input_cursor_ + // this is true when: + // - first Pulling from this cursor + // - previous Pull from this cursor exhausted the merge_match_cursor + bool pull_input_{true}; + }; +}; + } // namespace plan } // namespace query diff --git a/tests/unit/query_plan_create_set_remove_delete.cpp b/tests/unit/query_plan_create_set_remove_delete.cpp index 0b5289200..59781d090 100644 --- a/tests/unit/query_plan_create_set_remove_delete.cpp +++ b/tests/unit/query_plan_create_set_remove_delete.cpp @@ -769,3 +769,72 @@ TEST(QueryPlan, SetRemove) { EXPECT_FALSE(v.has_label(label1)); EXPECT_FALSE(v.has_label(label2)); } + +TEST(QueryPlan, Merge) { + // test setup: + // - three nodes, two of them connected with T + // - merge input branch matches all nodes + // - merge_match branch looks for an expansion (any direction) + // and sets some property (for result validation) + // - merge_create branch just sets some other property + Dbms dbms; + auto dba = dbms.active(); + auto v1 = dba->insert_vertex(); + auto v2 = dba->insert_vertex(); + dba->insert_edge(v1, v2, dba->edge_type("Type")); + auto v3 = dba->insert_vertex(); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto prop = dba->property("prop"); + auto n = MakeScanAll(storage, symbol_table, "n"); + + // merge_match branch + auto r_m = MakeExpand(storage, symbol_table, std::make_shared<Once>(), n.sym_, + "r", EdgeAtom::Direction::BOTH, false, "m", false); + auto m_p = PROPERTY_LOOKUP("m", prop); + symbol_table[*m_p->expression_] = r_m.node_sym_; + auto m_set = std::make_shared<plan::SetProperty>(r_m.op_, m_p, LITERAL(1)); + + // merge_create branch + auto n_p = PROPERTY_LOOKUP("n", prop); + symbol_table[*n_p->expression_] = n.sym_; + auto n_set = std::make_shared<plan::SetProperty>(std::make_shared<Once>(), + n_p, LITERAL(2)); + + auto merge = std::make_shared<plan::Merge>(n.op_, m_set, n_set); + ASSERT_EQ(3, PullAll(merge, *dba, symbol_table)); + dba->advance_command(); + v1.Reconstruct(); + v2.Reconstruct(); + v3.Reconstruct(); + + ASSERT_EQ(v1.PropsAt(prop).type(), PropertyValue::Type::Int); + ASSERT_EQ(v1.PropsAt(prop).Value<int64_t>(), 1); + ASSERT_EQ(v2.PropsAt(prop).type(), PropertyValue::Type::Int); + ASSERT_EQ(v2.PropsAt(prop).Value<int64_t>(), 1); + ASSERT_EQ(v3.PropsAt(prop).type(), PropertyValue::Type::Int); + ASSERT_EQ(v3.PropsAt(prop).Value<int64_t>(), 2); +} + +TEST(QueryPlan, MergeNoInput) { + // merge with no input, creates a single node + + Dbms dbms; + auto dba = dbms.active(); + AstTreeStorage storage; + SymbolTable symbol_table; + + auto node = NODE("n"); + auto sym_n = symbol_table.CreateSymbol("n"); + symbol_table[*node->identifier_] = sym_n; + auto create = std::make_shared<CreateNode>(node, nullptr); + auto merge = std::make_shared<plan::Merge>(nullptr, create, create); + + EXPECT_EQ(0, CountIterable(dba->vertices())); + EXPECT_EQ(1, PullAll(merge, *dba, symbol_table)); + dba->advance_command(); + EXPECT_EQ(1, CountIterable(dba->vertices())); +} diff --git a/tests/unit/query_plan_match_filter_return.cpp b/tests/unit/query_plan_match_filter_return.cpp index b2dc3034c..adbaf8c02 100644 --- a/tests/unit/query_plan_match_filter_return.cpp +++ b/tests/unit/query_plan_match_filter_return.cpp @@ -43,10 +43,10 @@ TEST(QueryPlan, MatchReturn) { return PullAll(produce, *dba, symbol_table); }; - EXPECT_EQ(2, test_pull_count(GraphView::NEW)); + // TODO uncomment once the functionality is implemented + // EXPECT_EQ(2, test_pull_count(GraphView::NEW)); EXPECT_EQ(2, test_pull_count(GraphView::OLD)); dba->insert_vertex(); - // TODO uncomment once the functionality is implemented // EXPECT_EQ(3, test_pull_count(GraphView::NEW)); EXPECT_EQ(2, test_pull_count(GraphView::OLD)); dba->advance_command();