From fdc389f1eb533fcef5773136af48e54ecd8fca42 Mon Sep 17 00:00:00 2001 From: Teon Banek <teon.banek@memgraph.io> Date: Thu, 24 Aug 2017 14:24:17 +0200 Subject: [PATCH] Templatize CostEstimator on DbAccessor Summary: This allows for inserting dummy DbAccessor in tests. Unfortunate side effect of this change is that the whole implementation had to be moved from cpp to hpp. Also templatize remaining RuleBasedPlanner implementation Reviewers: florijan, mislav.bradac Reviewed By: mislav.bradac Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D704 --- CMakeLists.txt | 1 - src/query/interpreter.hpp | 4 +- src/query/plan/cost_estimator.cpp | 146 ------------------------ src/query/plan/cost_estimator.hpp | 157 +++++++++++++++++++++++--- src/query/plan/rule_based_planner.cpp | 100 ---------------- src/query/plan/rule_based_planner.hpp | 123 +++++++++++++++++--- tests/unit/query_cost_estimator.cpp | 19 ++-- 7 files changed, 259 insertions(+), 291 deletions(-) delete mode 100644 src/query/plan/cost_estimator.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index d20b496b4..9a3969c40 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -284,7 +284,6 @@ set(memgraph_src_files ${src_dir}/query/frontend/stripped.cpp ${src_dir}/query/interpret/awesome_memgraph_functions.cpp ${src_dir}/query/interpreter.cpp - ${src_dir}/query/plan/cost_estimator.cpp ${src_dir}/query/plan/operator.cpp ${src_dir}/query/plan/rule_based_planner.cpp ${src_dir}/query/plan/variable_start_planner.cpp diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 85429adcb..ebeee6889 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -117,7 +117,7 @@ class Interpreter { ast_storage, symbol_table, db_accessor); double min_cost = std::numeric_limits<double>::max(); for (auto &plan : plans) { - plan::CostEstimator estimator(db_accessor); + plan::CostEstimator<GraphDbAccessor> estimator(db_accessor); plan->Accept(estimator); auto cost = estimator.cost(); if (!logical_plan || cost < min_cost) { @@ -131,7 +131,7 @@ class Interpreter { } else { logical_plan = plan::MakeLogicalPlan<plan::RuleBasedPlanner>( ast_storage, symbol_table, db_accessor); - plan::CostEstimator cost_estimator(db_accessor); + plan::CostEstimator<GraphDbAccessor> cost_estimator(db_accessor); logical_plan->Accept(cost_estimator); query_plan_cost_estimation = cost_estimator.cost(); } diff --git a/src/query/plan/cost_estimator.cpp b/src/query/plan/cost_estimator.cpp deleted file mode 100644 index f86e10324..000000000 --- a/src/query/plan/cost_estimator.cpp +++ /dev/null @@ -1,146 +0,0 @@ -#include <experimental/optional> - -#include "cost_estimator.hpp" - -namespace query::plan { - -bool CostEstimator::PostVisit(ScanAll &) { - cardinality_ *= db_accessor_.VerticesCount(); - // ScanAll performs some work for every element that is produced - IncrementCost(CostParam::kScanAll); - return true; -} - -bool CostEstimator::PostVisit(ScanAllByLabel &scan_all_by_label) { - cardinality_ *= db_accessor_.VerticesCount(scan_all_by_label.label()); - // ScanAll performs some work for every element that is produced - IncrementCost(CostParam::kScanAllByLabel); - return true; -} - -bool CostEstimator::PostVisit(ScanAllByLabelPropertyValue &logical_op) { - // this cardinality estimation depends on the property value (expression). - // if it's a literal (const) we can evaluate cardinality exactly, otherwise - // we estimate - std::experimental::optional<PropertyValue> property_value = - std::experimental::nullopt; - if (auto *literal = dynamic_cast<PrimitiveLiteral *>(logical_op.expression())) - if (literal->value_.IsPropertyValue()) - property_value = - std::experimental::optional<PropertyValue>(literal->value_); - - double factor = 1.0; - if (property_value) - // get the exact influence based on ScanAll(label, property, value) - factor = db_accessor_.VerticesCount( - logical_op.label(), logical_op.property(), property_value.value()); - else - // estimate the influence as ScanAll(label, property) * filtering - factor = - db_accessor_.VerticesCount(logical_op.label(), logical_op.property()) * - CardParam::kFilter; - - cardinality_ *= factor; - - // ScanAll performs some work for every element that is produced - IncrementCost(CostParam::MakeScanAllByLabelPropertyValue); - return true; -} - -namespace { -// converts an optional ScanAll range bound into a property value -// if the bound is present and is a literal expression convertible to -// a property value. otherwise returns nullopt -std::experimental::optional<utils::Bound<PropertyValue>> BoundToPropertyValue( - std::experimental::optional<ScanAllByLabelPropertyRange::Bound> bound) { - if (bound) - if (auto *literal = dynamic_cast<PrimitiveLiteral *>(bound->value())) - return std::experimental::make_optional( - utils::Bound<PropertyValue>(literal->value_, bound->type())); - return std::experimental::nullopt; -} -} - -bool CostEstimator::PostVisit(ScanAllByLabelPropertyRange &logical_op) { - // this cardinality estimation depends on Bound expressions. - // if they are literals we can evaluate cardinality properly - auto lower = BoundToPropertyValue(logical_op.lower_bound()); - auto upper = BoundToPropertyValue(logical_op.upper_bound()); - - int64_t factor = 1; - if (upper || lower) - // if we have either Bound<PropertyValue>, use the value index - factor = db_accessor_.VerticesCount(logical_op.label(), - logical_op.property(), lower, upper); - else - // no values, but we still have the label - factor = - db_accessor_.VerticesCount(logical_op.label(), logical_op.property()); - - // if we failed to take either bound from the op into account, then apply - // the filtering constant to the factor - if ((logical_op.upper_bound() && !upper) || - (logical_op.lower_bound() && !lower)) - factor *= CardParam::kFilter; - - cardinality_ *= factor; - - // ScanAll performs some work for every element that is produced - IncrementCost(CostParam::MakeScanAllByLabelPropertyRange); - return true; -} - -// For the given op first increments the cardinality and then cost. -#define POST_VISIT_CARD_FIRST(NAME) \ - bool CostEstimator::PostVisit(NAME &) { \ - cardinality_ *= CardParam::k##NAME; \ - IncrementCost(CostParam::k##NAME); \ - return true; \ - } - -POST_VISIT_CARD_FIRST(Expand); -POST_VISIT_CARD_FIRST(ExpandVariable); -POST_VISIT_CARD_FIRST(ExpandBreadthFirst); - -#undef POST_VISIT_CARD_FIRST - -// For the given op first increments the cost and then cardinality. -#define POST_VISIT_COST_FIRST(LOGICAL_OP, PARAM_NAME) \ - bool CostEstimator::PostVisit(LOGICAL_OP &) { \ - IncrementCost(CostParam::PARAM_NAME); \ - cardinality_ *= CardParam::PARAM_NAME; \ - return true; \ - } - -POST_VISIT_COST_FIRST(Filter, kFilter) -POST_VISIT_COST_FIRST(ExpandUniquenessFilter<VertexAccessor>, - kExpandUniquenessFilter); -POST_VISIT_COST_FIRST(ExpandUniquenessFilter<EdgeAccessor>, - kExpandUniquenessFilter); - -#undef POST_VISIT_COST_FIRST - -bool CostEstimator::PostVisit(Unwind &unwind) { - // Unwind cost depends more on the number of lists that get unwound - // much less on the number of outputs - // for that reason first increment cost, then modify cardinality - IncrementCost(CostParam::kUnwind); - - // try to determine how many values will be yielded by Unwind - // if the Unwind expression is a list literal, we can deduce cardinality - // exactly, otherwise we approximate - int unwind_value; - if (auto literal = - dynamic_cast<query::ListLiteral *>(unwind.input_expression())) - unwind_value = literal->elements_.size(); - else - unwind_value = MiscParam::kUnwindNoLiteral; - - cardinality_ *= unwind_value; - return true; -} - -bool CostEstimator::Visit(Once &) { return true; } -bool CostEstimator::Visit(CreateIndex &) { return true; } - -} // namespace query::plan diff --git a/src/query/plan/cost_estimator.hpp b/src/query/plan/cost_estimator.hpp index d4733c7e2..54aa62e33 100644 --- a/src/query/plan/cost_estimator.hpp +++ b/src/query/plan/cost_estimator.hpp @@ -33,6 +33,7 @@ namespace query::plan { * actual query execution for plan A is less then that of plan B. It can NOT be * used to estimate how MUCH execution between A and B will differ. */ +template <class TDbAccessor> class CostEstimator : public HierarchicalLogicalOperatorVisitor { public: struct CostParam { @@ -63,22 +64,133 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { using HierarchicalLogicalOperatorVisitor::PreVisit; using HierarchicalLogicalOperatorVisitor::PostVisit; - CostEstimator(const GraphDbAccessor &db_accessor) - : db_accessor_(db_accessor) {} + CostEstimator(const TDbAccessor &db_accessor) : db_accessor_(db_accessor) {} - bool PostVisit(ScanAll &) override; - bool PostVisit(ScanAllByLabel &scan_all_by_label) override; - bool PostVisit(ScanAllByLabelPropertyValue &logical_op) override; - bool PostVisit(ScanAllByLabelPropertyRange &logical_op) override; - bool PostVisit(Expand &) override; - bool PostVisit(ExpandVariable &) override; - bool PostVisit(ExpandBreadthFirst &) override; - bool PostVisit(Filter &) override; - bool PostVisit(ExpandUniquenessFilter<VertexAccessor> &) override; - bool PostVisit(ExpandUniquenessFilter<EdgeAccessor> &) override; - bool PostVisit(Unwind &unwind) override; - bool Visit(Once &) override; - bool Visit(CreateIndex &) override; + bool PostVisit(ScanAll &) override { + cardinality_ *= db_accessor_.VerticesCount(); + // ScanAll performs some work for every element that is produced + IncrementCost(CostParam::kScanAll); + return true; + } + + bool PostVisit(ScanAllByLabel &scan_all_by_label) override { + cardinality_ *= db_accessor_.VerticesCount(scan_all_by_label.label()); + // ScanAll performs some work for every element that is produced + IncrementCost(CostParam::kScanAllByLabel); + return true; + } + + bool PostVisit(ScanAllByLabelPropertyValue &logical_op) override { + // this cardinality estimation depends on the property value (expression). + // if it's a literal (const) we can evaluate cardinality exactly, otherwise + // we estimate + std::experimental::optional<PropertyValue> property_value = + std::experimental::nullopt; + if (auto *literal = + dynamic_cast<PrimitiveLiteral *>(logical_op.expression())) + if (literal->value_.IsPropertyValue()) + property_value = + std::experimental::optional<PropertyValue>(literal->value_); + + double factor = 1.0; + if (property_value) + // get the exact influence based on ScanAll(label, property, value) + factor = db_accessor_.VerticesCount( + logical_op.label(), logical_op.property(), property_value.value()); + else + // estimate the influence as ScanAll(label, property) * filtering + factor = db_accessor_.VerticesCount(logical_op.label(), + logical_op.property()) * + CardParam::kFilter; + + cardinality_ *= factor; + + // ScanAll performs some work for every element that is produced + IncrementCost(CostParam::MakeScanAllByLabelPropertyValue); + return true; + } + + bool PostVisit(ScanAllByLabelPropertyRange &logical_op) override { + // this cardinality estimation depends on Bound expressions. + // if they are literals we can evaluate cardinality properly + auto lower = BoundToPropertyValue(logical_op.lower_bound()); + auto upper = BoundToPropertyValue(logical_op.upper_bound()); + + int64_t factor = 1; + if (upper || lower) + // if we have either Bound<PropertyValue>, use the value index + factor = db_accessor_.VerticesCount(logical_op.label(), + logical_op.property(), lower, upper); + else + // no values, but we still have the label + factor = + db_accessor_.VerticesCount(logical_op.label(), logical_op.property()); + + // if we failed to take either bound from the op into account, then apply + // the filtering constant to the factor + if ((logical_op.upper_bound() && !upper) || + (logical_op.lower_bound() && !lower)) + factor *= CardParam::kFilter; + + cardinality_ *= factor; + + // ScanAll performs some work for every element that is produced + IncrementCost(CostParam::MakeScanAllByLabelPropertyRange); + return true; + } + +// For the given op first increments the cardinality and then cost. +#define POST_VISIT_CARD_FIRST(NAME) \ + bool PostVisit(NAME &) override { \ + cardinality_ *= CardParam::k##NAME; \ + IncrementCost(CostParam::k##NAME); \ + return true; \ + } + + POST_VISIT_CARD_FIRST(Expand); + POST_VISIT_CARD_FIRST(ExpandVariable); + POST_VISIT_CARD_FIRST(ExpandBreadthFirst); + +#undef POST_VISIT_CARD_FIRST + +// For the given op first increments the cost and then cardinality. +#define POST_VISIT_COST_FIRST(LOGICAL_OP, PARAM_NAME) \ + bool PostVisit(LOGICAL_OP &) override { \ + IncrementCost(CostParam::PARAM_NAME); \ + cardinality_ *= CardParam::PARAM_NAME; \ + return true; \ + } + + POST_VISIT_COST_FIRST(Filter, kFilter) + POST_VISIT_COST_FIRST(ExpandUniquenessFilter<VertexAccessor>, + kExpandUniquenessFilter); + POST_VISIT_COST_FIRST(ExpandUniquenessFilter<EdgeAccessor>, + kExpandUniquenessFilter); + +#undef POST_VISIT_COST_FIRST + + bool PostVisit(Unwind &unwind) override { + // Unwind cost depends more on the number of lists that get unwound + // much less on the number of outputs + // for that reason first increment cost, then modify cardinality + IncrementCost(CostParam::kUnwind); + + // try to determine how many values will be yielded by Unwind + // if the Unwind expression is a list literal, we can deduce cardinality + // exactly, otherwise we approximate + int unwind_value; + if (auto literal = + dynamic_cast<query::ListLiteral *>(unwind.input_expression())) + unwind_value = literal->elements_.size(); + else + unwind_value = MiscParam::kUnwindNoLiteral; + + cardinality_ *= unwind_value; + return true; + } + + bool Visit(Once &) override { return true; } + bool Visit(CreateIndex &) override { return true; } auto cost() const { return cost_; } auto cardinality() const { return cardinality_; } @@ -93,9 +205,22 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { double cardinality_{1}; // // accessor used for cardinality estimates in ScanAll and ScanAllByLabel - const GraphDbAccessor &db_accessor_; + const TDbAccessor &db_accessor_; void IncrementCost(double param) { cost_ += param * cardinality_; } + + // converts an optional ScanAll range bound into a property value + // if the bound is present and is a literal expression convertible to + // a property value. otherwise returns nullopt + static std::experimental::optional<utils::Bound<PropertyValue>> + BoundToPropertyValue( + std::experimental::optional<ScanAllByLabelPropertyRange::Bound> bound) { + if (bound) + if (auto *literal = dynamic_cast<PrimitiveLiteral *>(bound->value())) + return std::experimental::make_optional( + utils::Bound<PropertyValue>(literal->value_, bound->type())); + return std::experimental::nullopt; + } }; } // namespace query::plan diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index ab1b25363..33089dc64 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -590,71 +590,6 @@ void AddMatching(const Match &match, SymbolTable &symbol_table, matching); } -const GraphDbTypes::Label &FindBestLabelIndex( - const GraphDbAccessor &db, const std::set<GraphDbTypes::Label> &labels) { - debug_assert(!labels.empty(), - "Trying to find the best label without any labels."); - return *std::min_element(labels.begin(), labels.end(), - [&db](const auto &label1, const auto &label2) { - return db.VerticesCount(label1) < - db.VerticesCount(label2); - }); -} - -// Finds the label-property combination which has indexed the lowest amount of -// vertices. `best_label` and `best_property` will be set to that combination -// and the function will return `true`. If the index cannot be found, the -// function will return `false` while leaving `best_label` and `best_property` -// unchanged. -bool FindBestLabelPropertyIndex( - const GraphDbAccessor &db, const std::set<GraphDbTypes::Label> &labels, - const std::map<GraphDbTypes::Property, std::vector<Filters::PropertyFilter>> - &property_filters, - const Symbol &symbol, const std::unordered_set<Symbol> &bound_symbols, - GraphDbTypes::Label &best_label, - std::pair<GraphDbTypes::Property, Filters::PropertyFilter> &best_property) { - auto are_bound = [&bound_symbols](const auto &used_symbols) { - for (const auto &used_symbol : used_symbols) { - if (bound_symbols.find(used_symbol) == bound_symbols.end()) { - return false; - } - } - return true; - }; - bool found = false; - auto min_count = std::numeric_limits<decltype(db.VerticesCount( - GraphDbTypes::Label{}, GraphDbTypes::Property{}))>::max(); - for (const auto &label : labels) { - for (const auto &prop_pair : property_filters) { - const auto &property = prop_pair.first; - if (db.LabelPropertyIndexExists(label, property)) { - auto vertices_count = db.VerticesCount(label, property); - if (vertices_count < min_count) { - for (const auto &prop_filter : prop_pair.second) { - if (prop_filter.used_symbols.find(symbol) != - prop_filter.used_symbols.end()) { - // Skip filter expressions which use the symbol whose property we - // are looking up. We cannot scan by such expressions. For - // example, in `n.a = 2 + n.b` both sides of `=` refer to `n`, so - // we cannot scan `n` by property index. - continue; - } - if (are_bound(prop_filter.used_symbols)) { - // Take the first property filter which uses bound symbols. - best_label = label; - best_property = {property, prop_filter}; - min_count = vertices_count; - found = true; - break; - } - } - } - } - } - } - return found; -} - } // namespace // Analyzes the filter expression by collecting information on filtering labels @@ -903,41 +838,6 @@ LogicalOperator *GenFilters( return last_op; } -ScanAll *GenScanByIndex( - LogicalOperator *last_op, const GraphDbAccessor &db, - const Symbol &node_symbol, const MatchContext &context, - const std::set<GraphDbTypes::Label> &labels, - const std::map<GraphDbTypes::Property, std::vector<Filters::PropertyFilter>> - &properties) { - debug_assert(!labels.empty(), - "Without labels, indexed data cannot be scanned."); - // First, try to see if we can use label+property index. If not, use just the - // label index (which ought to exist). - GraphDbTypes::Label best_label; - std::pair<GraphDbTypes::Property, Filters::PropertyFilter> best_property; - if (FindBestLabelPropertyIndex(db, labels, properties, node_symbol, - context.bound_symbols, best_label, - best_property)) { - const auto &prop_filter = best_property.second; - if (prop_filter.lower_bound || prop_filter.upper_bound) { - return new ScanAllByLabelPropertyRange( - std::shared_ptr<LogicalOperator>(last_op), node_symbol, best_label, - best_property.first, prop_filter.lower_bound, prop_filter.upper_bound, - context.graph_view); - } else { - debug_assert( - prop_filter.expression, - "Property filter should either have bounds or an expression."); - return new ScanAllByLabelPropertyValue( - std::shared_ptr<LogicalOperator>(last_op), node_symbol, best_label, - best_property.first, prop_filter.expression, context.graph_view); - } - } - auto label = FindBestLabelIndex(db, labels); - return new ScanAllByLabel(std::shared_ptr<LogicalOperator>(last_op), - node_symbol, label, context.graph_view); -} - LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op, SymbolTable &symbol_table, bool is_write, const std::unordered_set<Symbol> &bound_symbols, diff --git a/src/query/plan/rule_based_planner.hpp b/src/query/plan/rule_based_planner.hpp index 9c7be2840..795686ad4 100644 --- a/src/query/plan/rule_based_planner.hpp +++ b/src/query/plan/rule_based_planner.hpp @@ -196,13 +196,6 @@ LogicalOperator *GenFilters( &all_filters, AstTreeStorage &storage); -ScanAll *GenScanByIndex( - LogicalOperator *last_op, const GraphDbAccessor &db, - const Symbol &node_symbol, const MatchContext &context, - const std::set<GraphDbTypes::Label> &labels, - const std::map<GraphDbTypes::Property, std::vector<Filters::PropertyFilter>> - &properties); - LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op, SymbolTable &symbol_table, bool is_write, const std::unordered_set<Symbol> &bound_symbols, @@ -265,8 +258,7 @@ class RuleBasedPlanner { input_op = GenMerge(*merge, input_op, query_part.merge_matching[merge_id++]); // Treat MERGE clause as write, because we do not know if it will - // create - // anything. + // create anything. is_write = true; } else if (auto *with = dynamic_cast<query::With *>(clause)) { input_op = @@ -302,6 +294,106 @@ class RuleBasedPlanner { private: TPlanningContext &context_; + // Finds the label-property combination which has indexed the lowest amount of + // vertices. `best_label` and `best_property` will be set to that combination + // and the function will return `true`. If the index cannot be found, the + // function will return `false` while leaving `best_label` and `best_property` + // unchanged. + bool FindBestLabelPropertyIndex( + const std::set<GraphDbTypes::Label> &labels, + const std::map<GraphDbTypes::Property, + std::vector<Filters::PropertyFilter>> &property_filters, + const Symbol &symbol, const std::unordered_set<Symbol> &bound_symbols, + GraphDbTypes::Label &best_label, + std::pair<GraphDbTypes::Property, Filters::PropertyFilter> + &best_property) { + auto are_bound = [&bound_symbols](const auto &used_symbols) { + for (const auto &used_symbol : used_symbols) { + if (bound_symbols.find(used_symbol) == bound_symbols.end()) { + return false; + } + } + return true; + }; + bool found = false; + auto min_count = std::numeric_limits<decltype(context_.db.VerticesCount( + GraphDbTypes::Label{}, GraphDbTypes::Property{}))>::max(); + for (const auto &label : labels) { + for (const auto &prop_pair : property_filters) { + const auto &property = prop_pair.first; + if (context_.db.LabelPropertyIndexExists(label, property)) { + auto vertices_count = context_.db.VerticesCount(label, property); + if (vertices_count < min_count) { + for (const auto &prop_filter : prop_pair.second) { + if (prop_filter.used_symbols.find(symbol) != + prop_filter.used_symbols.end()) { + // Skip filter expressions which use the symbol whose property + // we are looking up. We cannot scan by such expressions. For + // example, in `n.a = 2 + n.b` both sides of `=` refer to `n`, + // so we cannot scan `n` by property index. + continue; + } + if (are_bound(prop_filter.used_symbols)) { + // Take the first property filter which uses bound symbols. + best_label = label; + best_property = {property, prop_filter}; + min_count = vertices_count; + found = true; + break; + } + } + } + } + } + } + return found; + } + + const GraphDbTypes::Label &FindBestLabelIndex( + const std::set<GraphDbTypes::Label> &labels) { + debug_assert(!labels.empty(), + "Trying to find the best label without any labels."); + return *std::min_element(labels.begin(), labels.end(), + [this](const auto &label1, const auto &label2) { + return context_.db.VerticesCount(label1) < + context_.db.VerticesCount(label2); + }); + } + + ScanAll *GenScanByIndex( + LogicalOperator *last_op, const Symbol &node_symbol, + const MatchContext &context, const std::set<GraphDbTypes::Label> &labels, + const std::map<GraphDbTypes::Property, + std::vector<Filters::PropertyFilter>> &properties) { + debug_assert(!labels.empty(), + "Without labels, indexed data cannot be scanned."); + // First, try to see if we can use label+property index. If not, use just + // the label index (which ought to exist). + GraphDbTypes::Label best_label; + std::pair<GraphDbTypes::Property, Filters::PropertyFilter> best_property; + if (FindBestLabelPropertyIndex(labels, properties, node_symbol, + context.bound_symbols, best_label, + best_property)) { + const auto &prop_filter = best_property.second; + if (prop_filter.lower_bound || prop_filter.upper_bound) { + return new ScanAllByLabelPropertyRange( + std::shared_ptr<LogicalOperator>(last_op), node_symbol, best_label, + best_property.first, prop_filter.lower_bound, + prop_filter.upper_bound, context.graph_view); + } else { + debug_assert( + prop_filter.expression, + "Property filter should either have bounds or an expression."); + return new ScanAllByLabelPropertyValue( + std::shared_ptr<LogicalOperator>(last_op), node_symbol, best_label, + best_property.first, prop_filter.expression, context.graph_view); + } + } + auto label = FindBestLabelIndex(labels); + return new ScanAllByLabel(std::shared_ptr<LogicalOperator>(last_op), + node_symbol, label, context.graph_view); + } + LogicalOperator *PlanMatching(const Matching &matching, LogicalOperator *input_op, MatchContext &match_context) { @@ -312,8 +404,7 @@ class RuleBasedPlanner { auto all_filters = matching.filters.all_filters(); // Try to generate any filters even before the 1st match operator. This // optimizes the optional match which filters only on symbols bound in - // regular - // match. + // regular match. auto *last_op = impl::GenFilters(input_op, bound_symbols, all_filters, storage); for (const auto &expansion : matching.expansions) { @@ -334,8 +425,8 @@ class RuleBasedPlanner { std::map<GraphDbTypes::Property, std::vector<Filters::PropertyFilter>>()) .first; - last_op = impl::GenScanByIndex(last_op, context_.db, node1_symbol, - match_context, labels, properties); + last_op = GenScanByIndex(last_op, node1_symbol, match_context, labels, + properties); } match_context.new_symbols.emplace_back(node1_symbol); last_op = @@ -387,8 +478,7 @@ class RuleBasedPlanner { } if (!existing_edge) { // Ensure Cyphermorphism (different edge symbols always map to - // different - // edges). + // different edges). for (const auto &edge_symbols : matching.edge_symbols) { if (edge_symbols.find(edge_symbol) == edge_symbols.end()) { continue; @@ -419,8 +509,7 @@ class RuleBasedPlanner { auto GenMerge(query::Merge &merge, LogicalOperator *input_op, const Matching &matching) { // Copy the bound symbol set, because we don't want to use the updated - // version - // when generating the create part. + // version when generating the create part. std::unordered_set<Symbol> bound_symbols_copy(context_.bound_symbols); MatchContext match_ctx{context_.symbol_table, bound_symbols_copy, GraphView::NEW}; diff --git a/tests/unit/query_cost_estimator.cpp b/tests/unit/query_cost_estimator.cpp index 8a666068e..da5cfaed9 100644 --- a/tests/unit/query_cost_estimator.cpp +++ b/tests/unit/query_cost_estimator.cpp @@ -11,9 +11,9 @@ using namespace query; using namespace query::plan; -using CardParam = CostEstimator::CardParam; -using CostParam = CostEstimator::CostParam; -using MiscParam = CostEstimator::MiscParam; +using CardParam = CostEstimator<GraphDbAccessor>::CardParam; +using CostParam = CostEstimator<GraphDbAccessor>::CostParam; +using MiscParam = CostEstimator<GraphDbAccessor>::MiscParam; /** A fixture for cost estimation. Sets up the database * and accessor (adds some vertices). Provides convenience @@ -61,7 +61,7 @@ class QueryCostEstimator : public ::testing::Test { } auto Cost() { - CostEstimator cost_estimator(*dba); + CostEstimator<GraphDbAccessor> cost_estimator(*dba); last_op_->Accept(cost_estimator); return cost_estimator.cost(); } @@ -187,11 +187,12 @@ TEST_F(QueryCostEstimator, ExpandUniquenessFilter) { } TEST_F(QueryCostEstimator, UnwindLiteral) { - TEST_OP(MakeOp<query::plan::Unwind>( - last_op_, storage_.Create<ListLiteral>( - std::vector<Expression *>(7, nullptr)), - NextSymbol()), - CostParam::kUnwind, 7); + TEST_OP( + MakeOp<query::plan::Unwind>( + last_op_, + storage_.Create<ListLiteral>(std::vector<Expression *>(7, nullptr)), + NextSymbol()), + CostParam::kUnwind, 7); } TEST_F(QueryCostEstimator, UnwindNoLiteral) {