diff --git a/src/query/plan/distributed.cpp b/src/query/plan/distributed.cpp index 43c422908..bd2e963d1 100644 --- a/src/query/plan/distributed.cpp +++ b/src/query/plan/distributed.cpp @@ -48,29 +48,46 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor { } // ScanAll are all done on each machine locally. - bool PreVisit(ScanAll &) override { return true; } + bool PreVisit(ScanAll &scan) override { + prev_ops_.push_back(&scan); + return true; + } bool PostVisit(ScanAll &) override { + prev_ops_.pop_back(); RaiseIfCartesian(); RaiseIfHasWorkerPlan(); has_scan_all_ = true; return true; } - bool PreVisit(ScanAllByLabel &) override { return true; } + + bool PreVisit(ScanAllByLabel &scan) override { + prev_ops_.push_back(&scan); + return true; + } bool PostVisit(ScanAllByLabel &) override { + prev_ops_.pop_back(); RaiseIfCartesian(); RaiseIfHasWorkerPlan(); has_scan_all_ = true; return true; } - bool PreVisit(ScanAllByLabelPropertyRange &) override { return true; } + bool PreVisit(ScanAllByLabelPropertyRange &scan) override { + prev_ops_.push_back(&scan); + return true; + } bool PostVisit(ScanAllByLabelPropertyRange &) override { + prev_ops_.pop_back(); RaiseIfCartesian(); RaiseIfHasWorkerPlan(); has_scan_all_ = true; return true; } - bool PreVisit(ScanAllByLabelPropertyValue &) override { return true; } + bool PreVisit(ScanAllByLabelPropertyValue &scan) override { + prev_ops_.push_back(&scan); + return true; + } bool PostVisit(ScanAllByLabelPropertyValue &) override { + prev_ops_.pop_back(); RaiseIfCartesian(); RaiseIfHasWorkerPlan(); has_scan_all_ = true; @@ -79,28 +96,46 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor { // Expand is done locally on each machine with RPC calls for worker-boundary // crossing edges. - bool PreVisit(Expand &) override { return true; } + bool PreVisit(Expand &exp) override { + prev_ops_.push_back(&exp); + return true; + } // TODO: ExpandVariable // The following operators filter the frame or put something on it. They // should be worker local. - bool PreVisit(ConstructNamedPath &) override { return true; } - bool PreVisit(Filter &) override { return true; } - bool PreVisit(ExpandUniquenessFilter &) override { + bool PreVisit(ConstructNamedPath &op) override { + prev_ops_.push_back(&op); return true; } - bool PreVisit(ExpandUniquenessFilter &) override { + bool PreVisit(Filter &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PreVisit(ExpandUniquenessFilter &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PreVisit(ExpandUniquenessFilter &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PreVisit(Optional &op) override { + prev_ops_.push_back(&op); return true; } - bool PreVisit(Optional &) override { return true; } // Skip needs to skip only the first N results from *all* of the results. // Therefore, the earliest (deepest in the plan tree) encountered Skip will // break the plan in 2 parts. // 1) Master plan with Skip and everything above it. // 2) Worker plan with operators below Skip, but without Skip itself. - bool PreVisit(Skip &) override { return true; } + bool PreVisit(Skip &skip) override { + prev_ops_.push_back(&skip); + return true; + } bool PostVisit(Skip &skip) override { + prev_ops_.pop_back(); if (ShouldSplit()) { auto input = skip.input(); distributed_plan_.worker_plan = input; @@ -116,8 +151,12 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor { // improve the execution speed of workers. So, the 2 parts of the plan are: // 1) Master plan with Limit and everything above. // 2) Worker plan with operators below Limit, but including Limit itself. - bool PreVisit(Limit &) override { return true; } + bool PreVisit(Limit &limit) override { + prev_ops_.push_back(&limit); + return true; + } bool PostVisit(Limit &limit) override { + prev_ops_.pop_back(); if (ShouldSplit()) { // Shallow copy Limit distributed_plan_.worker_plan = std::make_shared(limit); @@ -132,8 +171,12 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor { // OrderBy is an associative operator, this means we can do ordering // on workers and then merge the results on master. This requires a more // involved solution, so for now treat OrderBy just like Split. - bool PreVisit(OrderBy &) override { return true; } + bool PreVisit(OrderBy &order_by) override { + prev_ops_.push_back(&order_by); + return true; + } bool PostVisit(OrderBy &order_by) override { + prev_ops_.pop_back(); // TODO: Associative combination of OrderBy if (ShouldSplit()) { auto input = order_by.input(); @@ -146,8 +189,12 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor { } // Treat Distinct just like Limit. - bool PreVisit(Distinct &) override { return true; } + bool PreVisit(Distinct &distinct) override { + prev_ops_.push_back(&distinct); + return true; + } bool PostVisit(Distinct &distinct) override { + prev_ops_.pop_back(); if (ShouldSplit()) { // Shallow copy Distinct distributed_plan_.worker_plan = std::make_shared(distinct); @@ -173,8 +220,12 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor { // // Non-associative aggregation needs to see all of the results and is // completely done on master. - bool PreVisit(Aggregate &) override { return true; } + bool PreVisit(Aggregate &aggr_op) override { + prev_ops_.push_back(&aggr_op); + return true; + } bool PostVisit(Aggregate &aggr_op) override { + prev_ops_.pop_back(); if (!ShouldSplit()) { // We have already split the plan, so the aggregation we are visiting is // on master. @@ -311,8 +362,12 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor { return true; } - bool PreVisit(Produce &) override { return true; } + bool PreVisit(Produce &produce) override { + prev_ops_.push_back(&produce); + return true; + } bool PostVisit(Produce &produce) override { + prev_ops_.pop_back(); if (!master_aggr_) return true; // We have to rewire master/worker aggregation. DCHECK(worker_aggr_); @@ -323,22 +378,113 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor { return true; } + bool PreVisit(Unwind &op) override { + prev_ops_.push_back(&op); + return true; + } + bool Visit(Once &) override { return true; } bool Visit(CreateIndex &) override { return true; } - // TODO: Write operators, accumulate and unwind + // Accumulate is used only if the query performs any writes. In such a case, + // we need to synchronize the work done on master and all workers. + // Synchronization will force applying changes to distributed storage, and + // then we can continue with the rest of the plan. Currently, the remainder of + // the plan is executed on master. In the future, when we support Cartesian + // products after the WITH clause, we will need to split the plan in more + // subparts to be executed on workers. + bool PreVisit(Accumulate &acc) override { + prev_ops_.push_back(&acc); + return true; + } + bool PostVisit(Accumulate &acc) override { + prev_ops_.pop_back(); + if (!ShouldSplit()) return true; + if (acc.advance_command()) + throw utils::NotYetImplemented("WITH clause distributed planning"); + // Accumulate on workers, but set advance_command to false, because the + // Synchronize operator should do that in distributed execution. + distributed_plan_.worker_plan = + std::make_shared(acc.input(), acc.symbols(), false); + // Create a synchronization point. Use pull remote to fetch accumulated + // symbols from workers. Local input operations are the same as on workers. + auto pull_remote = std::make_shared( + nullptr, distributed_plan_.plan_id, acc.symbols()); + auto sync = std::make_shared( + distributed_plan_.worker_plan, pull_remote, acc.advance_command()); + auto *prev_op = prev_ops_.back(); + // Wire the previous operator (on master) into our synchronization operator. + // TODO: Find a better way to replace the previous operation's input than + // using dynamic casting. + if (auto *produce = dynamic_cast(prev_op)) { + produce->set_input(sync); + } else if (auto *aggr_op = dynamic_cast(prev_op)) { + aggr_op->set_input(sync); + } else { + throw utils::NotYetImplemented("WITH clause distributed planning"); + } + return true; + } + + bool PreVisit(CreateNode &op) override { + // TODO: Creation needs to be modified if running on master, so as to + // distribute node creation to workers. + prev_ops_.push_back(&op); + return true; + } + + bool PreVisit(CreateExpand &op) override { + prev_ops_.push_back(&op); + return true; + } + + bool PreVisit(Delete &op) override { + prev_ops_.push_back(&op); + return true; + } + + bool PreVisit(SetProperty &op) override { + prev_ops_.push_back(&op); + return true; + } + + bool PreVisit(SetProperties &op) override { + prev_ops_.push_back(&op); + return true; + } + + bool PreVisit(SetLabels &op) override { + prev_ops_.push_back(&op); + return true; + } + + bool PreVisit(RemoveProperty &op) override { + prev_ops_.push_back(&op); + return true; + } + + bool PreVisit(RemoveLabels &op) override { + prev_ops_.push_back(&op); + return true; + } protected: bool DefaultPreVisit() override { throw utils::NotYetImplemented("distributed planning"); } + bool DefaultPostVisit() override { + prev_ops_.pop_back(); + return true; + } + private: DistributedPlan &distributed_plan_; // Used for rewiring the master/worker aggregation in PostVisit(Produce) std::shared_ptr worker_aggr_; std::unique_ptr master_aggr_; + std::vector prev_ops_; bool has_scan_all_ = false; void RaiseIfCartesian() { diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index de32f923f..d2ec5d526 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -1474,7 +1474,9 @@ class Accumulate : public LogicalOperator { std::unique_ptr MakeCursor( database::GraphDbAccessor &db) const override; + auto input() const { return input_; } const auto &symbols() const { return symbols_; }; + auto advance_command() const { return advance_command_; } private: std::shared_ptr input_; @@ -2331,6 +2333,10 @@ class Synchronize : public LogicalOperator { std::unique_ptr MakeCursor( database::GraphDbAccessor &db) const override; + auto input() const { return input_; } + auto pull_remote() const { return pull_remote_; } + auto advance_command() const { return advance_command_; } + private: std::shared_ptr input_; std::shared_ptr pull_remote_; diff --git a/tests/manual/query_planner.cpp b/tests/manual/query_planner.cpp index 502109824..1c5812ff8 100644 --- a/tests/manual/query_planner.cpp +++ b/tests/manual/query_planner.cpp @@ -518,6 +518,16 @@ class PlanPrinter : public query::plan::HierarchicalLogicalOperatorVisitor { --depth_; return true; } + + bool PreVisit(query::plan::Synchronize &op) override { + WithPrintLn([&op](auto &out) { + out << "* Synchronize"; + if (op.advance_command()) out << " (ADV CMD)"; + }); + Branch(*op.pull_remote()); + op.input()->Accept(*this); + return false; + } #undef PRE_VISIT private: diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index 7b35581fa..10ab2e7b2 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -18,6 +19,13 @@ #include "query_common.hpp" +namespace query { +::std::ostream &operator<<(::std::ostream &os, const Symbol &sym) { + return os << "Symbol{\"" << sym.name() << "\" [" << sym.position() << "] " + << Symbol::TypeToString(sym.type()) << "}"; +} +} // namespace query + using namespace query::plan; using query::AstTreeStorage; using query::SingleQuery; @@ -107,6 +115,12 @@ class PlanChecker : public HierarchicalLogicalOperatorVisitor { PRE_VISIT(ProduceRemote); PRE_VISIT(PullRemote); + + bool PreVisit(Synchronize &op) override { + CheckOp(op); + op.input()->Accept(*this); + return false; + } #undef PRE_VISIT std::list checkers_; @@ -362,16 +376,27 @@ class ExpectPullRemote : public OpChecker { ExpectPullRemote(const std::vector &symbols) : symbols_(symbols) {} void ExpectOp(PullRemote &op, const SymbolTable &) override { - if (symbols_.empty()) - EXPECT_FALSE(op.symbols().empty()); - else - EXPECT_THAT(op.symbols(), testing::UnorderedElementsAreArray(symbols_)); + EXPECT_THAT(op.symbols(), testing::UnorderedElementsAreArray(symbols_)); } private: std::vector symbols_; }; +class ExpectSynchronize : public OpChecker { + public: + ExpectSynchronize() {} + ExpectSynchronize(const std::vector &symbols) + : expect_pull_(symbols) {} + + void ExpectOp(Synchronize &op, const SymbolTable &symbol_table) override { + expect_pull_.ExpectOp(*op.pull_remote(), symbol_table); + } + + private: + ExpectPullRemote expect_pull_; +}; + auto MakeSymbolTable(query::Query &query) { SymbolTable symbol_table; SymbolGenerator symbol_generator(symbol_table); @@ -417,6 +442,16 @@ class SerializedPlanner { std::unique_ptr plan_; }; +template +TPlanner MakePlanner(database::MasterBase &master_db, AstTreeStorage &storage, + SymbolTable &symbol_table) { + database::GraphDbAccessor dba(master_db); + auto planning_context = MakePlanningContext(storage, symbol_table, dba); + auto query_parts = CollectQueryParts(symbol_table, storage); + auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; + return TPlanner(single_query_parts, planning_context); +} + template auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table, TChecker... checker) { @@ -430,12 +465,7 @@ template auto CheckPlan(AstTreeStorage &storage, TChecker... checker) { auto symbol_table = MakeSymbolTable(*storage.query()); database::SingleNode db; - database::GraphDbAccessor dba(db); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TPlanner planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, checker...); } @@ -447,12 +477,8 @@ struct ExpectedDistributedPlan { template DistributedPlan MakeDistributedPlan(query::AstTreeStorage &storage) { database::Master db; - database::GraphDbAccessor dba(db); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TPlanner planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); std::atomic next_plan_id{0}; return MakeDistributedPlan(planner.plan(), symbol_table, next_plan_id); } @@ -474,6 +500,14 @@ void CheckDistributedPlan(DistributedPlan &distributed_plan, } } +void CheckDistributedPlan(const LogicalOperator &plan, + const SymbolTable &symbol_table, + ExpectedDistributedPlan &expected_distributed_plan) { + std::atomic next_plan_id{0}; + auto distributed_plan = MakeDistributedPlan(plan, symbol_table, next_plan_id); + CheckDistributedPlan(distributed_plan, expected_distributed_plan); +} + template void CheckDistributedPlan(AstTreeStorage &storage, ExpectedDistributedPlan &expected_distributed_plan) { @@ -505,12 +539,17 @@ TYPED_TEST_CASE(TestPlanner, PlannerTypes); TYPED_TEST(TestPlanner, MatchNodeReturn) { // Test MATCH (n) RETURN n AstTreeStorage storage; - QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectProduce()); + auto *as_n = NEXPR("n", IDENT("n")); + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN(as_n))); + auto symbol_table = MakeSymbolTable(*storage.query()); + database::Master db; + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectProduce()); + ExpectPullRemote pull({symbol_table.at(*as_n)}); ExpectedDistributedPlan expected{ - MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectProduce(), pull), MakeCheckers(ExpectScanAll(), ExpectProduce())}; - CheckDistributedPlan(storage, expected); + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, CreateNodeReturn) { @@ -522,14 +561,17 @@ TYPED_TEST(TestPlanner, CreateNodeReturn) { auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); database::SingleNode db; - database::GraphDbAccessor dba(db); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, ExpectProduce()); + { + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectCreateNode(), acc, ExpectProduce()), {}}; + std::atomic next_plan_id{0}; + auto distributed_plan = + MakeDistributedPlan(planner.plan(), symbol_table, next_plan_id); + CheckDistributedPlan(distributed_plan, expected); + } } TYPED_TEST(TestPlanner, CreateExpand) { @@ -586,6 +628,10 @@ TYPED_TEST(TestPlanner, MatchCreateExpand) { CREATE(PATTERN(NODE("n"), EDGE("r", Direction::OUT, {relationship}), NODE("m"))))); CheckPlan(storage, ExpectScanAll(), ExpectCreateExpand()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectCreateExpand(), ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectCreateExpand())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, MatchLabeledNodes) { @@ -594,12 +640,17 @@ TYPED_TEST(TestPlanner, MatchLabeledNodes) { database::SingleNode db; database::GraphDbAccessor dba(db); auto label = dba.Label("label"); - QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", label))), RETURN("n"))); - CheckPlan(storage, ExpectScanAllByLabel(), ExpectProduce()); + auto *as_n = NEXPR("n", IDENT("n")); + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", label))), RETURN(as_n))); + auto symbol_table = MakeSymbolTable(*storage.query()); + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabel(), + ExpectProduce()); + ExpectPullRemote pull({symbol_table.at(*as_n)}); ExpectedDistributedPlan expected{ - MakeCheckers(ExpectScanAllByLabel(), ExpectProduce(), ExpectPullRemote()), + MakeCheckers(ExpectScanAllByLabel(), ExpectProduce(), pull), MakeCheckers(ExpectScanAllByLabel(), ExpectProduce())}; - CheckDistributedPlan(storage, expected); + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, MatchPathReturn) { @@ -608,17 +659,20 @@ TYPED_TEST(TestPlanner, MatchPathReturn) { database::SingleNode db; database::GraphDbAccessor dba(db); auto relationship = dba.EdgeType("relationship"); + auto *as_n = NEXPR("n", IDENT("n")); QUERY(SINGLE_QUERY( MATCH(PATTERN(NODE("n"), EDGE("r", Direction::BOTH, {relationship}), NODE("m"))), - RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), - ExpectProduce()); + RETURN(as_n))); + auto symbol_table = MakeSymbolTable(*storage.query()); + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectExpand(), + ExpectProduce()); + ExpectPullRemote pull({symbol_table.at(*as_n)}); ExpectedDistributedPlan expected{ - MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectProduce(), - ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectProduce(), pull), MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectProduce())}; - CheckDistributedPlan(storage, expected); + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, MatchNamedPatternReturn) { @@ -627,47 +681,53 @@ TYPED_TEST(TestPlanner, MatchNamedPatternReturn) { database::SingleNode db; database::GraphDbAccessor dba(db); auto relationship = dba.EdgeType("relationship"); + auto *as_p = NEXPR("p", IDENT("p")); QUERY(SINGLE_QUERY( MATCH(NAMED_PATTERN("p", NODE("n"), EDGE("r", Direction::BOTH, {relationship}), NODE("m"))), - RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), - ExpectConstructNamedPath(), ExpectProduce()); + RETURN(as_p))); + auto symbol_table = MakeSymbolTable(*storage.query()); + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectExpand(), + ExpectConstructNamedPath(), ExpectProduce()); + ExpectPullRemote pull({symbol_table.at(*as_p)}); ExpectedDistributedPlan expected{ MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectConstructNamedPath(), - ExpectProduce(), ExpectPullRemote()), + ExpectProduce(), pull), MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectConstructNamedPath(), ExpectProduce())}; - CheckDistributedPlan(storage, expected); + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, MatchNamedPatternWithPredicateReturn) { - // Test MATCH p = (n) -[r :relationship]- (m) RETURN p + // Test MATCH p = (n) -[r :relationship]- (m) WHERE 2 = p RETURN p AstTreeStorage storage; database::SingleNode db; database::GraphDbAccessor dba(db); auto relationship = dba.EdgeType("relationship"); + auto *as_p = NEXPR("p", IDENT("p")); QUERY(SINGLE_QUERY( MATCH(NAMED_PATTERN("p", NODE("n"), EDGE("r", Direction::BOTH, {relationship}), NODE("m"))), - WHERE(EQ(LITERAL(2), IDENT("p"))), RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), - ExpectConstructNamedPath(), ExpectFilter(), - ExpectProduce()); + WHERE(EQ(LITERAL(2), IDENT("p"))), RETURN(as_p))); + auto symbol_table = MakeSymbolTable(*storage.query()); + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectExpand(), + ExpectConstructNamedPath(), ExpectFilter(), ExpectProduce()); + ExpectPullRemote pull({symbol_table.at(*as_p)}); ExpectedDistributedPlan expected{ MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectConstructNamedPath(), - ExpectFilter(), ExpectProduce(), ExpectPullRemote()), + ExpectFilter(), ExpectProduce(), pull), MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectConstructNamedPath(), ExpectFilter(), ExpectProduce())}; - CheckDistributedPlan(storage, expected); + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, OptionalMatchNamedPatternReturn) { // Test OPTIONAL MATCH p = (n) -[r]- (m) RETURN p database::SingleNode db; - database::GraphDbAccessor dba(db); AstTreeStorage storage; auto node_n = NODE("n"); auto edge = EDGE("r"); @@ -675,10 +735,6 @@ TYPED_TEST(TestPlanner, OptionalMatchNamedPatternReturn) { auto pattern = NAMED_PATTERN("p", node_n, edge, node_m); QUERY(SINGLE_QUERY(OPTIONAL_MATCH(pattern), RETURN("p"))); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; std::list optional{new ExpectScanAll(), new ExpectExpand(), new ExpectConstructNamedPath()}; auto get_symbol = [&symbol_table](const auto *ast_node) { @@ -686,7 +742,7 @@ TYPED_TEST(TestPlanner, OptionalMatchNamedPatternReturn) { }; std::vector optional_symbols{get_symbol(pattern), get_symbol(node_n), get_symbol(edge), get_symbol(node_m)}; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectOptional(optional_symbols, optional), ExpectProduce()); } @@ -697,16 +753,19 @@ TYPED_TEST(TestPlanner, MatchWhereReturn) { database::SingleNode db; database::GraphDbAccessor dba(db); auto property = dba.Property("property"); + auto *as_n = NEXPR("n", IDENT("n")); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), WHERE(LESS(PROPERTY_LOOKUP("n", property), LITERAL(42))), - RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectFilter(), - ExpectProduce()); + RETURN(as_n))); + auto symbol_table = MakeSymbolTable(*storage.query()); + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectFilter(), + ExpectProduce()); + ExpectPullRemote pull({symbol_table.at(*as_n)}); ExpectedDistributedPlan expected{ - MakeCheckers(ExpectScanAll(), ExpectFilter(), ExpectProduce(), - ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectFilter(), ExpectProduce(), pull), MakeCheckers(ExpectScanAll(), ExpectFilter(), ExpectProduce())}; - CheckDistributedPlan(storage, expected); + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, MatchDelete) { @@ -714,6 +773,10 @@ TYPED_TEST(TestPlanner, MatchDelete) { AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), DELETE(IDENT("n")))); CheckPlan(storage, ExpectScanAll(), ExpectDelete()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectDelete(), ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectDelete())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, MatchNodeSet) { @@ -728,6 +791,12 @@ TYPED_TEST(TestPlanner, MatchNodeSet) { SET("n", IDENT("n")), SET("n", {label}))); CheckPlan(storage, ExpectScanAll(), ExpectSetProperty(), ExpectSetProperties(), ExpectSetLabels()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectSetProperty(), ExpectSetProperties(), + ExpectSetLabels(), ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectSetProperty(), ExpectSetProperties(), + ExpectSetLabels())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, MatchRemove) { @@ -741,6 +810,12 @@ TYPED_TEST(TestPlanner, MatchRemove) { REMOVE(PROPERTY_LOOKUP("n", prop)), REMOVE("n", {label}))); CheckPlan(storage, ExpectScanAll(), ExpectRemoveProperty(), ExpectRemoveLabels()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectRemoveProperty(), + ExpectRemoveLabels(), ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectRemoveProperty(), + ExpectRemoveLabels())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, MatchMultiPattern) { @@ -800,33 +875,41 @@ TYPED_TEST(TestPlanner, MultiMatch) { TYPED_TEST(TestPlanner, MultiMatchSameStart) { // Test MATCH (n) MATCH (n) -[r]- (m) RETURN n AstTreeStorage storage; + auto *as_n = NEXPR("n", IDENT("n")); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), - RETURN("n"))); + RETURN(as_n))); // Similar to MatchMultiPatternSameStart, we expect only Expand from second // MATCH clause. - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), - ExpectProduce()); + auto symbol_table = MakeSymbolTable(*storage.query()); + database::SingleNode db; + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectExpand(), + ExpectProduce()); + ExpectPullRemote pull({symbol_table.at(*as_n)}); ExpectedDistributedPlan expected{ - MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectProduce(), - ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectProduce(), pull), MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectProduce())}; - CheckDistributedPlan(storage, expected); + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, MatchWithReturn) { // Test MATCH (old) WITH old AS new RETURN new AstTreeStorage storage; + auto *as_new = NEXPR("new", IDENT("new")); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), WITH("old", AS("new")), - RETURN("new"))); + RETURN(as_new))); // No accumulation since we only do reads. - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), - ExpectProduce()); + auto symbol_table = MakeSymbolTable(*storage.query()); + database::SingleNode db; + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectProduce(), + ExpectProduce()); + ExpectPullRemote pull({symbol_table.at(*as_new)}); ExpectedDistributedPlan expected{ - MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectProduce(), - ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectProduce(), pull), MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectProduce())}; - CheckDistributedPlan(storage, expected); + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, MatchWithWhereReturn) { @@ -835,18 +918,22 @@ TYPED_TEST(TestPlanner, MatchWithWhereReturn) { database::GraphDbAccessor dba(db); auto prop = dba.Property("prop"); AstTreeStorage storage; + auto *as_new = NEXPR("new", IDENT("new")); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), WITH("old", AS("new")), WHERE(LESS(PROPERTY_LOOKUP("new", prop), LITERAL(42))), - RETURN("new"))); + RETURN(as_new))); // No accumulation since we only do reads. - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), - ExpectFilter(), ExpectProduce()); + auto symbol_table = MakeSymbolTable(*storage.query()); + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectProduce(), + ExpectFilter(), ExpectProduce()); + ExpectPullRemote pull({symbol_table.at(*as_new)}); ExpectedDistributedPlan expected{ MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectFilter(), - ExpectProduce(), ExpectPullRemote()), + ExpectProduce(), pull), MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectFilter(), ExpectProduce())}; - CheckDistributedPlan(storage, expected); + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, CreateMultiExpand) { @@ -892,14 +979,21 @@ TYPED_TEST(TestPlanner, MatchReturnSum) { QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN(sum, AS("sum"), n_prop2, AS("group")))); auto aggr = ExpectAggregate({sum}, {n_prop2}); - CheckPlan(storage, ExpectScanAll(), aggr, ExpectProduce()); + auto symbol_table = MakeSymbolTable(*storage.query()); + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), aggr, + ExpectProduce()); { - auto distributed_plan = MakeDistributedPlan(storage); + std::atomic next_plan_id{0}; + auto distributed_plan = + MakeDistributedPlan(planner.plan(), symbol_table, next_plan_id); auto merge_sum = SUM(IDENT("worker_sum")); auto master_aggr = ExpectMasterAggregate({merge_sum}, {n_prop2}); + ExpectPullRemote pull( + {symbol_table.at(*sum), symbol_table.at(*n_prop2->expression_)}); ExpectedDistributedPlan expected{ - MakeCheckers(ExpectScanAll(), aggr, ExpectPullRemote(), master_aggr, - ExpectProduce(), ExpectProduce()), + MakeCheckers(ExpectScanAll(), aggr, pull, master_aggr, ExpectProduce(), + ExpectProduce()), MakeCheckers(ExpectScanAll(), aggr)}; CheckDistributedPlan(distributed_plan, expected); } @@ -918,11 +1012,7 @@ TYPED_TEST(TestPlanner, CreateWithSum) { auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)}); auto aggr = ExpectAggregate({sum}, {}); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); // We expect both the accumulation and aggregation because the part before // WITH updates the database. CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, @@ -941,20 +1031,30 @@ TYPED_TEST(TestPlanner, MatchWithCreate) { PATTERN(NODE("a"), EDGE("r", Direction::OUT, {r_type}), NODE("b"))))); CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectCreateExpand()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectCreateExpand(), + ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectCreateExpand())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, MatchReturnSkipLimit) { // Test MATCH (n) RETURN n SKIP 2 LIMIT 1 AstTreeStorage storage; + auto *as_n = NEXPR("n", IDENT("n")); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), - RETURN("n", SKIP(LITERAL(2)), LIMIT(LITERAL(1))))); - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectSkip(), - ExpectLimit()); + RETURN(as_n, SKIP(LITERAL(2)), LIMIT(LITERAL(1))))); + auto symbol_table = MakeSymbolTable(*storage.query()); + database::SingleNode db; + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectProduce(), + ExpectSkip(), ExpectLimit()); + ExpectPullRemote pull({symbol_table.at(*as_n)}); ExpectedDistributedPlan expected{ - MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectPullRemote(), - ExpectSkip(), ExpectLimit()), + MakeCheckers(ExpectScanAll(), ExpectProduce(), pull, ExpectSkip(), + ExpectLimit()), MakeCheckers(ExpectScanAll(), ExpectProduce())}; - CheckDistributedPlan(storage, expected); + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, CreateWithSkipReturnLimit) { @@ -967,12 +1067,7 @@ TYPED_TEST(TestPlanner, CreateWithSkipReturnLimit) { auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); database::SingleNode db; - database::GraphDbAccessor dba(db); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); // Since we have a write query, we need to have Accumulate. This is a bit // different than Neo4j 3.0, which optimizes WITH followed by RETURN as a // single RETURN clause and then moves Skip and Limit before Accumulate. This @@ -996,11 +1091,7 @@ TYPED_TEST(TestPlanner, CreateReturnSumSkipLimit) { auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)}); auto aggr = ExpectAggregate({sum}, {}); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(), ExpectSkip(), ExpectLimit()); } @@ -1011,15 +1102,18 @@ TYPED_TEST(TestPlanner, MatchReturnOrderBy) { database::GraphDbAccessor dba(db); auto prop = dba.Property("prop"); AstTreeStorage storage; - auto ret = RETURN("n", ORDER_BY(PROPERTY_LOOKUP("n", prop))); + auto *as_n = NEXPR("n", IDENT("n")); + auto ret = RETURN(as_n, ORDER_BY(PROPERTY_LOOKUP("n", prop))); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), ret)); - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), - ExpectOrderBy()); + auto symbol_table = MakeSymbolTable(*storage.query()); + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectProduce(), + ExpectOrderBy()); + ExpectPullRemote pull({symbol_table.at(*as_n)}); ExpectedDistributedPlan expected{ - MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectPullRemote(), - ExpectOrderBy()), + MakeCheckers(ExpectScanAll(), ExpectProduce(), pull, ExpectOrderBy()), MakeCheckers(ExpectScanAll(), ExpectProduce())}; - CheckDistributedPlan(storage, expected); + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, CreateWithOrderByWhere) { @@ -1046,11 +1140,7 @@ TYPED_TEST(TestPlanner, CreateWithOrderByWhere) { symbol_table.at(*r_prop->expression_), // `r` in ORDER BY symbol_table.at(*m_prop->expression_), // `m` in WHERE }); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), ExpectCreateExpand(), acc, ExpectProduce(), ExpectOrderBy(), ExpectFilter()); @@ -1090,11 +1180,7 @@ TYPED_TEST(TestPlanner, MatchMerge) { auto symbol_table = MakeSymbolTable(*query); // We expect Accumulate after Merge, because it is considered as a write. auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectMerge(on_match, on_create), acc, ExpectProduce()); for (auto &op : on_match) delete op; @@ -1122,11 +1208,21 @@ TYPED_TEST(TestPlanner, MatchOptionalMatchWhereReturn) { TYPED_TEST(TestPlanner, MatchUnwindReturn) { // Test MATCH (n) UNWIND [1,2,3] AS x RETURN n, x AstTreeStorage storage; + auto *as_n = NEXPR("n", IDENT("n")); + auto *as_x = NEXPR("x", IDENT("x")); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), UNWIND(LIST(LITERAL(1), LITERAL(2), LITERAL(3)), AS("x")), - RETURN("n", "x"))); - CheckPlan(storage, ExpectScanAll(), ExpectUnwind(), - ExpectProduce()); + RETURN(as_n, as_x))); + auto symbol_table = MakeSymbolTable(*storage.query()); + database::SingleNode db; + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectUnwind(), + ExpectProduce()); + ExpectPullRemote pull({symbol_table.at(*as_n), symbol_table.at(*as_x)}); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectUnwind(), ExpectProduce(), pull), + MakeCheckers(ExpectScanAll(), ExpectUnwind(), ExpectProduce())}; + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, ReturnDistinctOrderBySkipLimit) { @@ -1157,11 +1253,7 @@ TYPED_TEST(TestPlanner, CreateWithDistinctSumWhereReturn) { auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*node_n->identifier_)}); auto aggr = ExpectAggregate({sum}, {}); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(), ExpectDistinct(), ExpectFilter(), ExpectProduce()); } @@ -1191,18 +1283,22 @@ TYPED_TEST(TestPlanner, MatchWhereBeforeExpand) { database::GraphDbAccessor dba(db); auto prop = dba.Property("prop"); AstTreeStorage storage; + auto *as_n = NEXPR("n", IDENT("n")); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), WHERE(LESS(PROPERTY_LOOKUP("n", prop), LITERAL(42))), - RETURN("n"))); + RETURN(as_n))); // We expect Fitler to come immediately after ScanAll, since it only uses `n`. - CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), - ExpectProduce()); + auto symbol_table = MakeSymbolTable(*storage.query()); + auto planner = MakePlanner(db, storage, symbol_table); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectFilter(), + ExpectExpand(), ExpectProduce()); + ExpectPullRemote pull({symbol_table.at(*as_n)}); ExpectedDistributedPlan expected{ MakeCheckers(ExpectScanAll(), ExpectFilter(), ExpectExpand(), - ExpectProduce(), ExpectPullRemote()), + ExpectProduce(), pull), MakeCheckers(ExpectScanAll(), ExpectFilter(), ExpectExpand(), ExpectProduce())}; - CheckDistributedPlan(storage, expected); + CheckDistributedPlan(planner.plan(), symbol_table, expected); } TYPED_TEST(TestPlanner, MultiMatchWhere) { @@ -1250,11 +1346,7 @@ TYPED_TEST(TestPlanner, MatchReturnAsterisk) { auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), ret)); auto symbol_table = MakeSymbolTable(*query); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectExpand(), ExpectProduce()); std::vector output_names; @@ -1276,11 +1368,7 @@ TYPED_TEST(TestPlanner, MatchReturnAsteriskSum) { ret->body_.all_identifiers = true; auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), ret)); auto symbol_table = MakeSymbolTable(*query); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); auto *produce = dynamic_cast(&planner.plan()); ASSERT_TRUE(produce); const auto &named_expressions = produce->named_expressions(); @@ -1465,18 +1553,13 @@ TYPED_TEST(TestPlanner, AtomIndexedLabelProperty) { dba.Commit(); database::GraphDbAccessor(db).BuildIndex(label, property.second); { - database::GraphDbAccessor dba(db); auto node = NODE("n", label); auto lit_42 = LITERAL(42); node->properties_[property] = lit_42; node->properties_[not_indexed] = LITERAL(0); QUERY(SINGLE_QUERY(MATCH(PATTERN(node)), RETURN("n"))); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, property, lit_42), ExpectFilter(), ExpectProduce()); @@ -1493,7 +1576,6 @@ TYPED_TEST(TestPlanner, AtomPropertyWhereLabelIndexing) { auto not_indexed = PROPERTY_PAIR("not_indexed"); dba.BuildIndex(label, property.second); { - database::GraphDbAccessor dba(db); auto node = NODE("n"); auto lit_42 = LITERAL(42); node->properties_[property] = lit_42; @@ -1504,11 +1586,7 @@ TYPED_TEST(TestPlanner, AtomPropertyWhereLabelIndexing) { IDENT("n"), std::vector{label}))), RETURN("n"))); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, property, lit_42), ExpectFilter(), ExpectProduce()); @@ -1524,17 +1602,12 @@ TYPED_TEST(TestPlanner, WhereIndexedLabelProperty) { auto property = PROPERTY_PAIR("property"); dba.BuildIndex(label, property.second); { - database::GraphDbAccessor dba(db); auto lit_42 = LITERAL(42); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", label))), WHERE(EQ(PROPERTY_LOOKUP("n", property), lit_42)), RETURN("n"))); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, property, lit_42), ExpectProduce()); @@ -1569,11 +1642,7 @@ TYPED_TEST(TestPlanner, BestPropertyIndexed) { EQ(PROPERTY_LOOKUP("n", better), lit_42))), RETURN("n"))); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, better, lit_42), ExpectFilter(), ExpectProduce()); @@ -1590,7 +1659,6 @@ TYPED_TEST(TestPlanner, MultiPropertyIndexScan) { auto prop2 = PROPERTY_PAIR("prop2"); database::GraphDbAccessor(db).BuildIndex(label1, prop1.second); database::GraphDbAccessor(db).BuildIndex(label2, prop2.second); - database::GraphDbAccessor dba(db); AstTreeStorage storage; auto lit_1 = LITERAL(1); auto lit_2 = LITERAL(2); @@ -1600,11 +1668,7 @@ TYPED_TEST(TestPlanner, MultiPropertyIndexScan) { EQ(PROPERTY_LOOKUP("m", prop2), lit_2))), RETURN("n", "m"))); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label1, prop1, lit_1), ExpectScanAllByLabelPropertyValue(label2, prop2, lit_2), @@ -1618,23 +1682,18 @@ TYPED_TEST(TestPlanner, WhereIndexedLabelPropertyRange) { auto label = database::GraphDbAccessor(db).Label("label"); auto property = database::GraphDbAccessor(db).Property("property"); database::GraphDbAccessor(db).BuildIndex(label, property); - database::GraphDbAccessor dba(db); AstTreeStorage storage; auto lit_42 = LITERAL(42); auto n_prop = PROPERTY_LOOKUP("n", property); - auto check_planned_range = [&label, &property, &dba](const auto &rel_expr, - auto lower_bound, - auto upper_bound) { + auto check_planned_range = [&label, &property, &db](const auto &rel_expr, + auto lower_bound, + auto upper_bound) { // Shadow the first storage, so that the query is created in this one. AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", label))), WHERE(rel_expr), RETURN("n"))); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyRange(label, property, lower_bound, upper_bound), @@ -1674,18 +1733,13 @@ TYPED_TEST(TestPlanner, UnableToUsePropertyIndex) { auto property = dba.Property("property"); dba.BuildIndex(label, property); { - database::GraphDbAccessor dba(db); AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", label))), WHERE(EQ(PROPERTY_LOOKUP("n", property), PROPERTY_LOOKUP("n", property))), RETURN("n"))); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); // We can only get ScanAllByLabelIndex, because we are comparing properties // with those on the same node. CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabel(), @@ -1701,7 +1755,6 @@ TYPED_TEST(TestPlanner, SecondPropertyIndex) { auto property = PROPERTY_PAIR("property"); dba.BuildIndex(label, dba.Property("property")); { - database::GraphDbAccessor dba(db); AstTreeStorage storage; auto n_prop = PROPERTY_LOOKUP("n", property); auto m_prop = PROPERTY_LOOKUP("m", property); @@ -1709,11 +1762,7 @@ TYPED_TEST(TestPlanner, SecondPropertyIndex) { MATCH(PATTERN(NODE("n", label)), PATTERN(NODE("m", label))), WHERE(EQ(m_prop, n_prop)), RETURN("n"))); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); CheckPlan( planner.plan(), symbol_table, ExpectScanAllByLabel(), // Note: We are scanning for m, therefore property should equal n_prop. @@ -1820,11 +1869,7 @@ TYPED_TEST(TestPlanner, MatchDoubleScanToExpandExisting) { QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m", label))), RETURN("r"))); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); // We expect 2x ScanAll and then Expand, since we are guessing that is // faster (due to low label index vertex count). CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), @@ -1850,18 +1895,13 @@ TYPED_TEST(TestPlanner, MatchScanToExpand) { vertex.PropsSet(property, 1); dba.Commit(); { - database::GraphDbAccessor dba(db); AstTreeStorage storage; auto node_m = NODE("m", label); node_m->properties_[std::make_pair("property", property)] = LITERAL(1); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), node_m)), RETURN("r"))); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); // We expect 1x ScanAllByLabel and then Expand, since we are guessing that // is faster (due to high label index vertex count). CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectExpand(), @@ -1887,7 +1927,6 @@ TYPED_TEST(TestPlanner, MatchWhereAndSplit) { TYPED_TEST(TestPlanner, ReturnAsteriskOmitsLambdaSymbols) { // Test MATCH (n) -[r* (ie, in | true)]- (m) RETURN * database::SingleNode db; - database::GraphDbAccessor dba(db); AstTreeStorage storage; auto edge = EDGE_VARIABLE("r", Direction::BOTH); edge->inner_edge_ = IDENT("ie"); @@ -1897,11 +1936,7 @@ TYPED_TEST(TestPlanner, ReturnAsteriskOmitsLambdaSymbols) { ret->body_.all_identifiers = true; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), ret)); auto symbol_table = MakeSymbolTable(*storage.query()); - auto planning_context = MakePlanningContext(storage, symbol_table, dba); - auto query_parts = CollectQueryParts(symbol_table, storage); - ASSERT_TRUE(query_parts.query_parts.size() > 0); - auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - TypeParam planner(single_query_parts, planning_context); + auto planner = MakePlanner(db, storage, symbol_table); auto *produce = dynamic_cast(&planner.plan()); ASSERT_TRUE(produce); std::vector outputs; @@ -1924,6 +1959,7 @@ TYPED_TEST(TestPlanner, DistributedAvg) { QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN(AVG(PROPERTY_LOOKUP("n", prop)), AS("res")))); auto distributed_plan = MakeDistributedPlan(storage); + auto &symbol_table = distributed_plan.symbol_table; auto worker_sum = SUM(PROPERTY_LOOKUP("n", prop)); auto worker_count = COUNT(PROPERTY_LOOKUP("n", prop)); { @@ -1932,20 +1968,38 @@ TYPED_TEST(TestPlanner, DistributedAvg) { std::dynamic_pointer_cast(distributed_plan.worker_plan); ASSERT_TRUE(worker_aggr_op); ASSERT_EQ(worker_aggr_op->aggregations().size(), 2U); - distributed_plan.symbol_table[*worker_sum] = - worker_aggr_op->aggregations()[0].output_sym; - distributed_plan.symbol_table[*worker_count] = - worker_aggr_op->aggregations()[1].output_sym; + symbol_table[*worker_sum] = worker_aggr_op->aggregations()[0].output_sym; + symbol_table[*worker_count] = worker_aggr_op->aggregations()[1].output_sym; } auto worker_aggr = ExpectAggregate({worker_sum, worker_count}, {}); auto merge_sum = SUM(IDENT("worker_sum")); auto merge_count = SUM(IDENT("worker_count")); auto master_aggr = ExpectMasterAggregate({merge_sum, merge_count}, {}); + ExpectPullRemote pull( + {symbol_table.at(*worker_sum), symbol_table.at(*worker_count)}); ExpectedDistributedPlan expected{ - MakeCheckers(ExpectScanAll(), worker_aggr, ExpectPullRemote(), - master_aggr, ExpectProduce(), ExpectProduce()), + MakeCheckers(ExpectScanAll(), worker_aggr, pull, master_aggr, + ExpectProduce(), ExpectProduce()), MakeCheckers(ExpectScanAll(), worker_aggr)}; CheckDistributedPlan(distributed_plan, expected); } +TYPED_TEST(TestPlanner, DistributedMatchCreateReturn) { + // Test MATCH (n) CREATE (m) RETURN m + AstTreeStorage storage; + auto *ident_m = IDENT("m"); + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), CREATE(PATTERN(NODE("m"))), + RETURN(ident_m, AS("m")))); + auto symbol_table = MakeSymbolTable(*storage.query()); + auto acc = ExpectAccumulate({symbol_table.at(*ident_m)}); + database::Master db; + auto planner = MakePlanner(db, storage, symbol_table); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectCreateNode(), acc, + ExpectSynchronize({symbol_table.at(*ident_m)}), + ExpectProduce()), + MakeCheckers(ExpectScanAll(), ExpectCreateNode(), acc)}; + CheckDistributedPlan(planner.plan(), symbol_table, expected); +} + } // namespace