diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index c7a80ba98..ab0f237b3 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -82,7 +82,7 @@ class Interpreter { double query_plan_cost_estimation = 0.0; if (FLAGS_query_cost_planner) { auto plans = plan::MakeLogicalPlan( - ast_storage, symbol_table, &db_accessor); + ast_storage, symbol_table, db_accessor); double min_cost = std::numeric_limits::max(); for (auto &plan : plans) { plan::CostEstimator estimator(db_accessor); @@ -98,7 +98,7 @@ class Interpreter { query_plan_cost_estimation = min_cost; } else { logical_plan = plan::MakeLogicalPlan( - ast_storage, symbol_table, &db_accessor); + ast_storage, symbol_table, db_accessor); plan::CostEstimator cost_estimator(db_accessor); logical_plan->Accept(cost_estimator); query_plan_cost_estimation = cost_estimator.cost(); diff --git a/src/query/plan/planner.hpp b/src/query/plan/planner.hpp index 79d8a26b7..151fcf9cb 100644 --- a/src/query/plan/planner.hpp +++ b/src/query/plan/planner.hpp @@ -141,10 +141,10 @@ struct PlanningContext { /// @brief The storage is used to traverse the AST as well as create new nodes /// for use in operators. AstTreeStorage &ast_storage; - /// @brief Optional GraphDbAccessor, which may be used to get some information - /// from the database to generate better plans. The accessor is required only - /// to live long enough for the plan generation to finish. - const GraphDbAccessor *db = nullptr; + /// @brief GraphDbAccessor, which may be used to get some information from the + /// database to generate better plans. The accessor is required only to live + /// long enough for the plan generation to finish. + const GraphDbAccessor &db; /// @brief Symbol set is used to differentiate cycles in pattern matching. /// /// During planning, symbols will be added as each operator produces values @@ -216,9 +216,9 @@ std::vector CollectQueryParts(const SymbolTable &, AstTreeStorage &); /// @sa RuleBasedPlanner /// @sa VariableStartPlanner template -typename TPlanner::PlanResult MakeLogicalPlan( - AstTreeStorage &storage, SymbolTable &symbol_table, - const GraphDbAccessor *db = nullptr) { +typename TPlanner::PlanResult MakeLogicalPlan(AstTreeStorage &storage, + SymbolTable &symbol_table, + const GraphDbAccessor &db) { auto query_parts = CollectQueryParts(symbol_table, storage); PlanningContext context{symbol_table, storage, db}; return TPlanner(context).Plan(query_parts); diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index d6f50ea33..4e8afbb3c 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -655,18 +655,13 @@ void AddMatching(const Match &match, const SymbolTable &symbol_table, } const GraphDbTypes::Label &FindBestLabelIndex( - const GraphDbAccessor *db, const std::set &labels) { + const GraphDbAccessor &db, const std::set &labels) { debug_assert(!labels.empty(), "Trying to find the best label without any labels."); - if (!db) { - // We don't have a database to get index information, so just take the first - // label. - return *labels.begin(); - } return *std::min_element(labels.begin(), labels.end(), - [db](const auto &label1, const auto &label2) { - return db->vertices_count(label1) < - db->vertices_count(label2); + [&db](const auto &label1, const auto &label2) { + return db.vertices_count(label1) < + db.vertices_count(label2); }); } @@ -676,16 +671,12 @@ const GraphDbTypes::Label &FindBestLabelIndex( // function will return `false` while leaving `best_label` and `best_property` // unchanged. bool FindBestLabelPropertyIndex( - const GraphDbAccessor *db, const std::set &labels, + const GraphDbAccessor &db, const std::set &labels, const std::map> &property_filters, const Symbol &symbol, const std::unordered_set &bound_symbols, GraphDbTypes::Label &best_label, std::pair &best_property) { - if (!db) { - // Without the database, we cannot determine whether the index even exists. - return false; - } auto are_bound = [&bound_symbols](const auto &used_symbols) { for (const auto &used_symbol : used_symbols) { if (bound_symbols.find(used_symbol) == bound_symbols.end()) { @@ -699,8 +690,8 @@ bool FindBestLabelPropertyIndex( for (const auto &label : labels) { for (const auto &prop_pair : property_filters) { const auto &property = prop_pair.first; - if (db->LabelPropertyIndexExists(label, property)) { - auto vertices_count = db->vertices_count(label, property); + if (db.LabelPropertyIndexExists(label, property)) { + auto vertices_count = db.vertices_count(label, property); if (vertices_count < min_count) { for (const auto &prop_filter : prop_pair.second) { if (prop_filter.used_symbols.find(symbol) != diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index 439e0ea47..b8107524c 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -267,7 +267,9 @@ auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table, template auto CheckPlan(AstTreeStorage &storage, TChecker... checker) { auto symbol_table = MakeSymbolTable(*storage.query()); - auto plan = MakeLogicalPlan(storage, symbol_table); + Dbms dbms; + auto plan = + MakeLogicalPlan(storage, symbol_table, *dbms.active()); CheckPlan(*plan, symbol_table, checker...); } @@ -285,7 +287,9 @@ TEST(TestLogicalPlanner, CreateNodeReturn) { auto query = QUERY(CREATE(PATTERN(NODE("n"))), RETURN(ident_n, AS("n"))); auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); - auto plan = MakeLogicalPlan(storage, symbol_table); + Dbms dbms; + auto plan = + MakeLogicalPlan(storage, symbol_table, *dbms.active()); CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, ExpectProduce()); } @@ -556,7 +560,7 @@ TEST(TestLogicalPlanner, CreateWithSum) { auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)}); auto aggr = ExpectAggregate({sum}, {}); - auto plan = MakeLogicalPlan(storage, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table, *dba); // We expect both the accumulation and aggregation because the part before // WITH updates the database. CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, @@ -593,7 +597,9 @@ TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) { RETURN("m", LIMIT(LITERAL(1)))); auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); - auto plan = MakeLogicalPlan(storage, symbol_table); + Dbms dbms; + auto plan = + MakeLogicalPlan(storage, symbol_table, *dbms.active()); // Since we have a write query, we need to have Accumulate. This is a bit // different than Neo4j 3.0, which optimizes WITH followed by RETURN as a // single RETURN clause and then moves Skip and Limit before Accumulate. This @@ -616,7 +622,7 @@ TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) { auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)}); auto aggr = ExpectAggregate({sum}, {}); - auto plan = MakeLogicalPlan(storage, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table, *dba); CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(), ExpectSkip(), ExpectLimit()); } @@ -655,7 +661,7 @@ TEST(TestLogicalPlanner, CreateWithOrderByWhere) { symbol_table.at(*r_prop->expression_), // `r` in ORDER BY symbol_table.at(*m_prop->expression_), // `m` in WHERE }); - auto plan = MakeLogicalPlan(storage, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table, *dba); CheckPlan(*plan, symbol_table, ExpectCreateNode(), ExpectCreateExpand(), acc, ExpectProduce(), ExpectFilter(), ExpectOrderBy()); } @@ -693,7 +699,7 @@ TEST(TestLogicalPlanner, MatchMerge) { auto symbol_table = MakeSymbolTable(*query); // We expect Accumulate after Merge, because it is considered as a write. auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); - auto plan = MakeLogicalPlan(storage, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table, *dba); CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectMerge(on_match, on_create), acc, ExpectProduce()); for (auto &op : on_match) delete op; @@ -748,7 +754,7 @@ TEST(TestLogicalPlanner, CreateWithDistinctSumWhereReturn) { auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*node_n->identifier_)}); auto aggr = ExpectAggregate({sum}, {}); - auto plan = MakeLogicalPlan(storage, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table, *dba); CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(), ExpectFilter(), ExpectDistinct(), ExpectProduce()); } @@ -827,7 +833,7 @@ TEST(TestLogicalPlanner, MatchReturnAsterisk) { ret->body_.all_identifiers = true; auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), ret); auto symbol_table = MakeSymbolTable(*query); - auto plan = MakeLogicalPlan(storage, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table, *dba); CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectExpand(), ExpectProduce()); std::vector output_names; @@ -849,7 +855,7 @@ TEST(TestLogicalPlanner, MatchReturnAsteriskSum) { ret->body_.all_identifiers = true; auto query = QUERY(MATCH(PATTERN(NODE("n"))), ret); auto symbol_table = MakeSymbolTable(*query); - auto plan = MakeLogicalPlan(storage, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table, *dba); auto *produce = dynamic_cast(plan.get()); ASSERT_TRUE(produce); const auto &named_expressions = produce->named_expressions(); @@ -982,8 +988,7 @@ TEST(TestLogicalPlanner, AtomIndexedLabelProperty) { node->properties_[not_indexed] = LITERAL(0); QUERY(MATCH(PATTERN(node)), RETURN("n")); auto symbol_table = MakeSymbolTable(*storage.query()); - auto plan = - MakeLogicalPlan(storage, symbol_table, dba.get()); + auto plan = MakeLogicalPlan(storage, symbol_table, *dba); CheckPlan(*plan, symbol_table, ExpectScanAllByLabelPropertyValue(label, property, lit_42), ExpectFilter(), ExpectProduce()); @@ -1008,8 +1013,7 @@ TEST(TestLogicalPlanner, AtomPropertyWhereLabelIndexing) { IDENT("n"), std::vector{label}))), RETURN("n")); auto symbol_table = MakeSymbolTable(*storage.query()); - auto plan = - MakeLogicalPlan(storage, symbol_table, dba.get()); + auto plan = MakeLogicalPlan(storage, symbol_table, *dba); CheckPlan(*plan, symbol_table, ExpectScanAllByLabelPropertyValue(label, property, lit_42), ExpectFilter(), ExpectProduce()); @@ -1028,8 +1032,7 @@ TEST(TestLogicalPlanner, WhereIndexedLabelProperty) { QUERY(MATCH(PATTERN(NODE("n", label))), WHERE(EQ(PROPERTY_LOOKUP("n", property), lit_42)), RETURN("n")); auto symbol_table = MakeSymbolTable(*storage.query()); - auto plan = - MakeLogicalPlan(storage, symbol_table, dba.get()); + auto plan = MakeLogicalPlan(storage, symbol_table, *dba); CheckPlan(*plan, symbol_table, ExpectScanAllByLabelPropertyValue(label, property, lit_42), ExpectFilter(), ExpectProduce()); @@ -1062,8 +1065,7 @@ TEST(TestLogicalPlanner, BestPropertyIndexed) { EQ(PROPERTY_LOOKUP("n", better), lit_42))), RETURN("n")); auto symbol_table = MakeSymbolTable(*storage.query()); - auto plan = - MakeLogicalPlan(storage, symbol_table, dba.get()); + auto plan = MakeLogicalPlan(storage, symbol_table, *dba); CheckPlan(*plan, symbol_table, ExpectScanAllByLabelPropertyValue(label, better, lit_42), ExpectFilter(), ExpectProduce()); @@ -1090,8 +1092,7 @@ TEST(TestLogicalPlanner, MultiPropertyIndexScan) { EQ(PROPERTY_LOOKUP("m", prop2), lit_2))), RETURN("n", "m")); auto symbol_table = MakeSymbolTable(*storage.query()); - auto plan = - MakeLogicalPlan(storage, symbol_table, dba.get()); + auto plan = MakeLogicalPlan(storage, symbol_table, *dba); CheckPlan(*plan, symbol_table, ExpectScanAllByLabelPropertyValue(label1, prop1, lit_1), ExpectFilter(), diff --git a/tests/unit/query_variable_start_planner.cpp b/tests/unit/query_variable_start_planner.cpp index a67d1d274..ea85572a9 100644 --- a/tests/unit/query_variable_start_planner.cpp +++ b/tests/unit/query_variable_start_planner.cpp @@ -68,7 +68,7 @@ void CheckPlansProduce( std::function> &)> check) { auto symbol_table = MakeSymbolTable(*storage.query()); auto plans = - MakeLogicalPlan(storage, symbol_table, &dba); + MakeLogicalPlan(storage, symbol_table, dba); EXPECT_EQ(std::distance(plans.begin(), plans.end()), expected_plan_count); for (const auto &plan : plans) { auto *produce = dynamic_cast(plan.get());