diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 8846f7771..fa471fa9c 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -577,6 +577,9 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match, #define COUNT(expr) \ storage.Create((expr), nullptr, \ query::Aggregation::Op::COUNT) +#define AVG(expr) \ + storage.Create((expr), nullptr, \ + query::Aggregation::Op::AVG) #define EQ(expr1, expr2) storage.Create((expr1), (expr2)) #define NEQ(expr1, expr2) \ storage.Create((expr1), (expr2)) diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index cc532175e..d31384d4a 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -12,6 +12,7 @@ #include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol_generator.hpp" #include "query/frontend/semantic/symbol_table.hpp" +#include "query/plan/distributed.hpp" #include "query/plan/operator.hpp" #include "query/plan/planner.hpp" @@ -41,6 +42,12 @@ class PlanChecker : public HierarchicalLogicalOperatorVisitor { using HierarchicalLogicalOperatorVisitor::PreVisit; using HierarchicalLogicalOperatorVisitor::Visit; + PlanChecker(const std::list> &checkers, + const SymbolTable &symbol_table) + : symbol_table_(symbol_table) { + for (const auto &checker : checkers) checkers_.emplace_back(checker.get()); + } + PlanChecker(const std::list &checkers, const SymbolTable &symbol_table) : checkers_(checkers), symbol_table_(symbol_table) {} @@ -180,9 +187,15 @@ class ExpectAccumulate : public OpChecker { class ExpectAggregate : public OpChecker { public: + ExpectAggregate(bool is_master, + const std::vector &aggregations, + const std::unordered_set &group_by) + : is_master_(is_master), + aggregations_(aggregations), + group_by_(group_by) {} ExpectAggregate(const std::vector &aggregations, const std::unordered_set &group_by) - : aggregations_(aggregations), group_by_(group_by) {} + : is_master_(false), aggregations_(aggregations), group_by_(group_by) {} void ExpectOp(Aggregate &op, const SymbolTable &symbol_table) override { auto aggr_it = aggregations_.begin(); @@ -195,7 +208,11 @@ class ExpectAggregate : public OpChecker { EXPECT_EQ(typeid(aggr_elem.key).hash_code(), typeid(aggr->expression2_).hash_code()); EXPECT_EQ(aggr_elem.op, aggr->op_); - EXPECT_EQ(aggr_elem.output_sym, symbol_table.at(*aggr)); + if (!is_master_) { + // Skip checking virtual merge aggregation symbol when the plan is + // distributed. + EXPECT_EQ(aggr_elem.output_sym, symbol_table.at(*aggr)); + } } EXPECT_EQ(aggr_it, aggregations_.end()); // TODO: Proper group by expression equality @@ -209,10 +226,17 @@ class ExpectAggregate : public OpChecker { } private: - const std::vector aggregations_; - const std::unordered_set group_by_; + bool is_master_ = false; + std::vector aggregations_; + std::unordered_set group_by_; }; +auto ExpectMasterAggregate( + const std::vector &aggregations, + const std::unordered_set &group_by) { + return ExpectAggregate(true, aggregations, group_by); +} + class ExpectMerge : public OpChecker { public: ExpectMerge(const std::list &on_match, @@ -332,6 +356,22 @@ class ExpectCreateIndex : public OpChecker { storage::Property property_; }; +class ExpectPullRemote : public OpChecker { + public: + ExpectPullRemote() {} + ExpectPullRemote(const std::vector &symbols) : symbols_(symbols) {} + + void ExpectOp(PullRemote &op, const SymbolTable &) override { + if (symbols_.empty()) + EXPECT_FALSE(op.symbols().empty()); + else + EXPECT_THAT(op.symbols(), testing::UnorderedElementsAreArray(symbols_)); + } + + private: + std::vector symbols_; +}; + auto MakeSymbolTable(query::Query &query) { SymbolTable symbol_table; SymbolGenerator symbol_generator(symbol_table); @@ -399,6 +439,62 @@ auto CheckPlan(AstTreeStorage &storage, TChecker... checker) { CheckPlan(planner.plan(), symbol_table, checker...); } +struct ExpectedDistributedPlan { + std::list> master_checkers; + std::list> worker_checkers; +}; + +template +DistributedPlan MakeDistributedPlan(query::AstTreeStorage &storage) { + database::Master db; + database::GraphDbAccessor dba(db); + auto symbol_table = MakeSymbolTable(*storage.query()); + auto planning_context = MakePlanningContext(storage, symbol_table, dba); + auto query_parts = CollectQueryParts(symbol_table, storage); + auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; + TPlanner planner(single_query_parts, planning_context); + std::atomic next_plan_id{0}; + return MakeDistributedPlan(planner.plan(), symbol_table, next_plan_id); +} + +void CheckDistributedPlan(DistributedPlan &distributed_plan, + ExpectedDistributedPlan &expected) { + PlanChecker plan_checker(expected.master_checkers, + distributed_plan.symbol_table); + distributed_plan.master_plan->Accept(plan_checker); + EXPECT_TRUE(plan_checker.checkers_.empty()); + if (expected.worker_checkers.empty()) { + EXPECT_FALSE(distributed_plan.worker_plan); + } else { + ASSERT_TRUE(distributed_plan.worker_plan); + PlanChecker plan_checker(expected.worker_checkers, + distributed_plan.symbol_table); + distributed_plan.worker_plan->Accept(plan_checker); + EXPECT_TRUE(plan_checker.checkers_.empty()); + } +} + +template +void CheckDistributedPlan(AstTreeStorage &storage, + ExpectedDistributedPlan &expected_distributed_plan) { + auto distributed_plan = MakeDistributedPlan(storage); + CheckDistributedPlan(distributed_plan, expected_distributed_plan); +} + +template +std::list> MakeCheckers(T arg) { + std::list> l; + l.emplace_back(std::make_unique(arg)); + return l; +} + +template +std::list> MakeCheckers(T arg, Rest &&... rest) { + auto l = MakeCheckers(std::forward(rest)...); + l.emplace_front(std::make_unique(arg)); + return std::move(l); +} + template class TestPlanner : public ::testing::Test {}; @@ -411,6 +507,10 @@ TYPED_TEST(TestPlanner, MatchNodeReturn) { AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n"))); CheckPlan(storage, ExpectScanAll(), ExpectProduce()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectProduce())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, CreateNodeReturn) { @@ -496,6 +596,10 @@ TYPED_TEST(TestPlanner, MatchLabeledNodes) { auto label = dba.Label("label"); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", label))), RETURN("n"))); CheckPlan(storage, ExpectScanAllByLabel(), ExpectProduce()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAllByLabel(), ExpectProduce(), ExpectPullRemote()), + MakeCheckers(ExpectScanAllByLabel(), ExpectProduce())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, MatchPathReturn) { @@ -510,6 +614,11 @@ TYPED_TEST(TestPlanner, MatchPathReturn) { RETURN("n"))); CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectProduce()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectProduce(), + ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectProduce())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, MatchNamedPatternReturn) { @@ -525,6 +634,12 @@ TYPED_TEST(TestPlanner, MatchNamedPatternReturn) { RETURN("n"))); CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectConstructNamedPath(), ExpectProduce()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectConstructNamedPath(), + ExpectProduce(), ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectConstructNamedPath(), + ExpectProduce())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, MatchNamedPatternWithPredicateReturn) { @@ -541,6 +656,12 @@ TYPED_TEST(TestPlanner, MatchNamedPatternWithPredicateReturn) { CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectConstructNamedPath(), ExpectFilter(), ExpectProduce()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectConstructNamedPath(), + ExpectFilter(), ExpectProduce(), ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectConstructNamedPath(), + ExpectFilter(), ExpectProduce())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, OptionalMatchNamedPatternReturn) { @@ -581,6 +702,11 @@ TYPED_TEST(TestPlanner, MatchWhereReturn) { RETURN("n"))); CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectProduce()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectFilter(), ExpectProduce(), + ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectFilter(), ExpectProduce())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, MatchDelete) { @@ -681,6 +807,11 @@ TYPED_TEST(TestPlanner, MultiMatchSameStart) { // MATCH clause. CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectProduce()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectProduce(), + ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectExpand(), ExpectProduce())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, MatchWithReturn) { @@ -691,6 +822,11 @@ TYPED_TEST(TestPlanner, MatchWithReturn) { // No accumulation since we only do reads. CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectProduce()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectProduce(), + ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectProduce())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, MatchWithWhereReturn) { @@ -705,6 +841,12 @@ TYPED_TEST(TestPlanner, MatchWithWhereReturn) { // No accumulation since we only do reads. CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectFilter(), ExpectProduce()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectFilter(), + ExpectProduce(), ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectFilter(), + ExpectProduce())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, CreateMultiExpand) { @@ -751,6 +893,16 @@ TYPED_TEST(TestPlanner, MatchReturnSum) { RETURN(sum, AS("sum"), n_prop2, AS("group")))); auto aggr = ExpectAggregate({sum}, {n_prop2}); CheckPlan(storage, ExpectScanAll(), aggr, ExpectProduce()); + { + auto distributed_plan = MakeDistributedPlan(storage); + auto merge_sum = SUM(IDENT("worker_sum")); + auto master_aggr = ExpectMasterAggregate({merge_sum}, {n_prop2}); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), aggr, ExpectPullRemote(), master_aggr, + ExpectProduce(), ExpectProduce()), + MakeCheckers(ExpectScanAll(), aggr)}; + CheckDistributedPlan(distributed_plan, expected); + } } TYPED_TEST(TestPlanner, CreateWithSum) { @@ -798,6 +950,11 @@ TYPED_TEST(TestPlanner, MatchReturnSkipLimit) { RETURN("n", SKIP(LITERAL(2)), LIMIT(LITERAL(1))))); CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectSkip(), ExpectLimit()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectPullRemote(), + ExpectSkip(), ExpectLimit()), + MakeCheckers(ExpectScanAll(), ExpectProduce())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, CreateWithSkipReturnLimit) { @@ -858,6 +1015,11 @@ TYPED_TEST(TestPlanner, MatchReturnOrderBy) { QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), ret)); CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectOrderBy()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectPullRemote(), + ExpectOrderBy()), + MakeCheckers(ExpectScanAll(), ExpectProduce())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, CreateWithOrderByWhere) { @@ -974,6 +1136,11 @@ TYPED_TEST(TestPlanner, ReturnDistinctOrderBySkipLimit) { SKIP(LITERAL(1)), LIMIT(LITERAL(1))))); CheckPlan(storage, ExpectProduce(), ExpectDistinct(), ExpectOrderBy(), ExpectSkip(), ExpectLimit()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectProduce(), ExpectDistinct(), ExpectOrderBy(), + ExpectSkip(), ExpectLimit()), + {}}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, CreateWithDistinctSumWhereReturn) { @@ -1030,6 +1197,12 @@ TYPED_TEST(TestPlanner, MatchWhereBeforeExpand) { // We expect Fitler to come immediately after ScanAll, since it only uses `n`. CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), ExpectProduce()); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), ExpectFilter(), ExpectExpand(), + ExpectProduce(), ExpectPullRemote()), + MakeCheckers(ExpectScanAll(), ExpectFilter(), ExpectExpand(), + ExpectProduce())}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, MultiMatchWhere) { @@ -1162,6 +1335,8 @@ TYPED_TEST(TestPlanner, FunctionAggregationReturn) { RETURN(FN("sqrt", sum), AS("result"), group_by_literal, AS("group_by")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); CheckPlan(storage, aggr, ExpectProduce()); + ExpectedDistributedPlan expected{MakeCheckers(aggr, ExpectProduce()), {}}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, FunctionWithoutArguments) { @@ -1169,6 +1344,8 @@ TYPED_TEST(TestPlanner, FunctionWithoutArguments) { AstTreeStorage storage; QUERY(SINGLE_QUERY(RETURN(FN("pi"), AS("pi")))); CheckPlan(storage, ExpectProduce()); + ExpectedDistributedPlan expected{MakeCheckers(ExpectProduce()), {}}; + CheckDistributedPlan(storage, expected); } TYPED_TEST(TestPlanner, ListLiteralAggregationReturn) { @@ -1735,4 +1912,37 @@ TYPED_TEST(TestPlanner, ReturnAsteriskOmitsLambdaSymbols) { } } +TYPED_TEST(TestPlanner, DistributedAvg) { + // Test MATCH (n) RETURN AVG(n.prop) AS res + AstTreeStorage storage; + database::Master db; + database::GraphDbAccessor dba(db); + auto prop = dba.Property("prop"); + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), + RETURN(AVG(PROPERTY_LOOKUP("n", prop)), AS("res")))); + auto distributed_plan = MakeDistributedPlan(storage); + auto worker_sum = SUM(PROPERTY_LOOKUP("n", prop)); + auto worker_count = COUNT(PROPERTY_LOOKUP("n", prop)); + { + ASSERT_TRUE(distributed_plan.worker_plan); + auto worker_aggr_op = + std::dynamic_pointer_cast(distributed_plan.worker_plan); + ASSERT_TRUE(worker_aggr_op); + ASSERT_EQ(worker_aggr_op->aggregations().size(), 2U); + distributed_plan.symbol_table[*worker_sum] = + worker_aggr_op->aggregations()[0].output_sym; + distributed_plan.symbol_table[*worker_count] = + worker_aggr_op->aggregations()[1].output_sym; + } + auto worker_aggr = ExpectAggregate({worker_sum, worker_count}, {}); + auto merge_sum = SUM(IDENT("worker_sum")); + auto merge_count = SUM(IDENT("worker_count")); + auto master_aggr = ExpectMasterAggregate({merge_sum, merge_count}, {}); + ExpectedDistributedPlan expected{ + MakeCheckers(ExpectScanAll(), worker_aggr, ExpectPullRemote(), + master_aggr, ExpectProduce(), ExpectProduce()), + MakeCheckers(ExpectScanAll(), worker_aggr)}; + CheckDistributedPlan(distributed_plan, expected); +} + } // namespace