// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // License, and you may not use this file except in compliance with the Business Source License. // // As of the Change Date specified in that file, in accordance with // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. #include #include #include "query/frontend/semantic/symbol_generator.hpp" #include "query/frontend/semantic/symbol_table.hpp" #include "query/v2/plan/operator.hpp" #include "query/v2/plan/planner.hpp" #include "query/v2/plan/preprocess.hpp" namespace memgraph::query::v2::plan { class BaseOpChecker { public: virtual ~BaseOpChecker() {} virtual void CheckOp(LogicalOperator &, const SymbolTable &) = 0; }; class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { public: using HierarchicalLogicalOperatorVisitor::PostVisit; using HierarchicalLogicalOperatorVisitor::PreVisit; using HierarchicalLogicalOperatorVisitor::Visit; PlanChecker(const std::list> &checkers, const SymbolTable &symbol_table) : symbol_table_(symbol_table) { for (const auto &checker : checkers) checkers_.emplace_back(checker.get()); } PlanChecker(const std::list &checkers, const SymbolTable &symbol_table) : checkers_(checkers), symbol_table_(symbol_table) {} #define PRE_VISIT(TOp) \ bool PreVisit(TOp &op) override { \ CheckOp(op); \ return true; \ } #define VISIT(TOp) \ bool Visit(TOp &op) override { \ CheckOp(op); \ return true; \ } PRE_VISIT(CreateNode); PRE_VISIT(CreateExpand); PRE_VISIT(Delete); PRE_VISIT(ScanAll); PRE_VISIT(ScanAllByLabel); PRE_VISIT(ScanAllByLabelPropertyValue); PRE_VISIT(ScanAllByLabelPropertyRange); PRE_VISIT(ScanAllByLabelProperty); PRE_VISIT(Expand); PRE_VISIT(ExpandVariable); PRE_VISIT(Filter); PRE_VISIT(ConstructNamedPath); PRE_VISIT(Produce); PRE_VISIT(SetProperty); PRE_VISIT(SetProperties); PRE_VISIT(SetLabels); PRE_VISIT(RemoveProperty); PRE_VISIT(RemoveLabels); PRE_VISIT(EdgeUniquenessFilter); PRE_VISIT(Accumulate); PRE_VISIT(Aggregate); PRE_VISIT(Skip); PRE_VISIT(Limit); PRE_VISIT(OrderBy); bool PreVisit(Merge &op) override { CheckOp(op); op.input()->Accept(*this); return false; } bool PreVisit(Optional &op) override { CheckOp(op); op.input()->Accept(*this); return false; } PRE_VISIT(Unwind); PRE_VISIT(Distinct); bool PreVisit(Foreach &op) override { CheckOp(op); return false; } bool Visit(Once &) override { // Ignore checking Once, it is implicitly at the end. return true; } bool PreVisit(Cartesian &op) override { CheckOp(op); return false; } PRE_VISIT(CallProcedure); #undef PRE_VISIT #undef VISIT void CheckOp(LogicalOperator &op) { ASSERT_FALSE(checkers_.empty()); checkers_.back()->CheckOp(op, symbol_table_); checkers_.pop_back(); } std::list checkers_; const SymbolTable &symbol_table_; }; template class OpChecker : public BaseOpChecker { public: void CheckOp(LogicalOperator &op, const SymbolTable &symbol_table) override { auto *expected_op = dynamic_cast(&op); ASSERT_TRUE(expected_op) << "op is '" << op.GetTypeInfo().name << "' expected '" << TOp::kType.name << "'!"; ExpectOp(*expected_op, symbol_table); } virtual void ExpectOp(TOp &, const SymbolTable &) {} }; // using ExpectScanAllByPrimaryKey = OpChecker; using ExpectCreateNode = OpChecker; using ExpectCreateExpand = OpChecker; using ExpectDelete = OpChecker; using ExpectScanAll = OpChecker; using ExpectScanAllByLabel = OpChecker; // using ExpectScanAllById = OpChecker; using ExpectExpand = OpChecker; using ExpectFilter = OpChecker; using ExpectConstructNamedPath = OpChecker; using ExpectProduce = OpChecker; using ExpectSetProperty = OpChecker; using ExpectSetProperties = OpChecker; using ExpectSetLabels = OpChecker; using ExpectRemoveProperty = OpChecker; using ExpectRemoveLabels = OpChecker; using ExpectEdgeUniquenessFilter = OpChecker; using ExpectSkip = OpChecker; using ExpectLimit = OpChecker; using ExpectOrderBy = OpChecker; using ExpectUnwind = OpChecker; using ExpectDistinct = OpChecker; // class ExpectForeach : public OpChecker { // public: // ExpectForeach(const std::list &input, const std::list &updates) // : input_(input), updates_(updates) {} // void ExpectOp(Foreach &foreach, const SymbolTable &symbol_table) override { // PlanChecker check_input(input_, symbol_table); // foreach // .input_->Accept(check_input); // PlanChecker check_updates(updates_, symbol_table); // foreach // .update_clauses_->Accept(check_updates); // } // private: // std::list input_; // std::list updates_; // }; // class ExpectExpandVariable : public OpChecker { // public: // void ExpectOp(ExpandVariable &op, const SymbolTable &) override { // EXPECT_EQ(op.type_, memgraph::query::EdgeAtom::Type::DEPTH_FIRST); // } // }; // class ExpectExpandBfs : public OpChecker { // public: // void ExpectOp(ExpandVariable &op, const SymbolTable &) override { // EXPECT_EQ(op.type_, memgraph::query::EdgeAtom::Type::BREADTH_FIRST); // } // }; // class ExpectAccumulate : public OpChecker { // public: // explicit ExpectAccumulate(const std::unordered_set &symbols) : symbols_(symbols) {} // void ExpectOp(Accumulate &op, const SymbolTable &) override { // std::unordered_set got_symbols(op.symbols_.begin(), op.symbols_.end()); // EXPECT_EQ(symbols_, got_symbols); // } // private: // const std::unordered_set symbols_; // }; // class ExpectAggregate : public OpChecker { // public: // ExpectAggregate(const std::vector &aggregations, // const std::unordered_set &group_by) // : aggregations_(aggregations), group_by_(group_by) {} // void ExpectOp(Aggregate &op, const SymbolTable &symbol_table) override { // auto aggr_it = aggregations_.begin(); // for (const auto &aggr_elem : op.aggregations_) { // ASSERT_NE(aggr_it, aggregations_.end()); // auto aggr = *aggr_it++; // // TODO: Proper expression equality // EXPECT_EQ(typeid(aggr_elem.value).hash_code(), typeid(aggr->expression1_).hash_code()); // EXPECT_EQ(typeid(aggr_elem.key).hash_code(), typeid(aggr->expression2_).hash_code()); // EXPECT_EQ(aggr_elem.op, aggr->op_); // EXPECT_EQ(aggr_elem.output_sym, symbol_table.at(*aggr)); // } // EXPECT_EQ(aggr_it, aggregations_.end()); // // TODO: Proper group by expression equality // std::unordered_set got_group_by; // for (auto *expr : op.group_by_) got_group_by.insert(typeid(*expr).hash_code()); // std::unordered_set expected_group_by; // for (auto *expr : group_by_) expected_group_by.insert(typeid(*expr).hash_code()); // EXPECT_EQ(got_group_by, expected_group_by); // } // private: // std::vector aggregations_; // std::unordered_set group_by_; // }; // class ExpectMerge : public OpChecker { // public: // ExpectMerge(const std::list &on_match, const std::list &on_create) // : on_match_(on_match), on_create_(on_create) {} // void ExpectOp(Merge &merge, const SymbolTable &symbol_table) override { // PlanChecker check_match(on_match_, symbol_table); // merge.merge_match_->Accept(check_match); // PlanChecker check_create(on_create_, symbol_table); // merge.merge_create_->Accept(check_create); // } // private: // const std::list &on_match_; // const std::list &on_create_; // }; // class ExpectOptional : public OpChecker { // public: // explicit ExpectOptional(const std::list &optional) : optional_(optional) {} // ExpectOptional(const std::vector &optional_symbols, const std::list &optional) // : optional_symbols_(optional_symbols), optional_(optional) {} // void ExpectOp(Optional &optional, const SymbolTable &symbol_table) override { // if (!optional_symbols_.empty()) { // EXPECT_THAT(optional.optional_symbols_, testing::UnorderedElementsAreArray(optional_symbols_)); // } // PlanChecker check_optional(optional_, symbol_table); // optional.optional_->Accept(check_optional); // } // private: // std::vector optional_symbols_; // const std::list &optional_; // }; // class ExpectScanAllByLabelPropertyValue : public OpChecker { // public: // ExpectScanAllByLabelPropertyValue(memgraph::storage::LabelId label, // const std::pair &prop_pair, // memgraph::query::Expression *expression) // : label_(label), property_(prop_pair.second), expression_(expression) {} // void ExpectOp(ScanAllByLabelPropertyValue &scan_all, const SymbolTable &) override { // EXPECT_EQ(scan_all.label_, label_); // EXPECT_EQ(scan_all.property_, property_); // // TODO: Proper expression equality // EXPECT_EQ(typeid(scan_all.expression_).hash_code(), typeid(expression_).hash_code()); // } // private: // memgraph::storage::LabelId label_; // memgraph::storage::PropertyId property_; // memgraph::query::Expression *expression_; // }; // class ExpectScanAllByLabelPropertyRange : public OpChecker { // public: // ExpectScanAllByLabelPropertyRange(memgraph::storage::LabelId label, memgraph::storage::PropertyId property, // std::optional lower_bound, // std::optional upper_bound) // : label_(label), property_(property), lower_bound_(lower_bound), upper_bound_(upper_bound) {} // void ExpectOp(ScanAllByLabelPropertyRange &scan_all, const SymbolTable &) override { // EXPECT_EQ(scan_all.label_, label_); // EXPECT_EQ(scan_all.property_, property_); // if (lower_bound_) { // ASSERT_TRUE(scan_all.lower_bound_); // // TODO: Proper expression equality // EXPECT_EQ(typeid(scan_all.lower_bound_->value()).hash_code(), typeid(lower_bound_->value()).hash_code()); // EXPECT_EQ(scan_all.lower_bound_->type(), lower_bound_->type()); // } // if (upper_bound_) { // ASSERT_TRUE(scan_all.upper_bound_); // // TODO: Proper expression equality // EXPECT_EQ(typeid(scan_all.upper_bound_->value()).hash_code(), typeid(upper_bound_->value()).hash_code()); // EXPECT_EQ(scan_all.upper_bound_->type(), upper_bound_->type()); // } // } // private: // memgraph::storage::LabelId label_; // memgraph::storage::PropertyId property_; // std::optional lower_bound_; // std::optional upper_bound_; // }; // class ExpectScanAllByLabelProperty : public OpChecker { // public: // ExpectScanAllByLabelProperty(memgraph::storage::LabelId label, // const std::pair &prop_pair) // : label_(label), property_(prop_pair.second) {} // void ExpectOp(ScanAllByLabelProperty &scan_all, const SymbolTable &) override { // EXPECT_EQ(scan_all.label_, label_); // EXPECT_EQ(scan_all.property_, property_); // } // private: // memgraph::storage::LabelId label_; // memgraph::storage::PropertyId property_; // }; class ExpectScanAllByPrimaryKey : public OpChecker { public: ExpectScanAllByPrimaryKey(memgraph::storage::v3::LabelId label, const std::vector &properties) : label_(label), properties_(properties) {} void ExpectOp(v2::plan::ScanAllByPrimaryKey &scan_all, const SymbolTable &) override { EXPECT_EQ(scan_all.label_, label_); // EXPECT_EQ(scan_all.property_, property_); // TODO(gvolfing) maybe assert the size of the 2 vectors. // TODO(gvolfing) maybe use some std alg if Expression lets us. bool primary_property_match = true; for (const auto &expected_prop : properties_) { bool has_match = false; for (const auto &prop : scan_all.primary_key_) { if (typeid(prop).hash_code() == typeid(expected_prop).hash_code()) { has_match = true; } } if (!has_match) { primary_property_match = false; } } EXPECT_TRUE(primary_property_match); } private: memgraph::storage::v3::LabelId label_; std::vector properties_; }; class ExpectCartesian : public OpChecker { public: ExpectCartesian(const std::list> &left, const std::list> &right) : left_(left), right_(right) {} void ExpectOp(Cartesian &op, const SymbolTable &symbol_table) override { ASSERT_TRUE(op.left_op_); PlanChecker left_checker(left_, symbol_table); op.left_op_->Accept(left_checker); ASSERT_TRUE(op.right_op_); PlanChecker right_checker(right_, symbol_table); op.right_op_->Accept(right_checker); } private: const std::list> &left_; const std::list> &right_; }; class ExpectCallProcedure : public OpChecker { public: ExpectCallProcedure(const std::string &name, const std::vector &args, const std::vector &fields, const std::vector &result_syms) : name_(name), args_(args), fields_(fields), result_syms_(result_syms) {} void ExpectOp(CallProcedure &op, const SymbolTable &symbol_table) override { EXPECT_EQ(op.procedure_name_, name_); EXPECT_EQ(op.arguments_.size(), args_.size()); for (size_t i = 0; i < args_.size(); ++i) { const auto *op_arg = op.arguments_[i]; const auto *expected_arg = args_[i]; // TODO: Proper expression equality EXPECT_EQ(op_arg->GetTypeInfo(), expected_arg->GetTypeInfo()); } EXPECT_EQ(op.result_fields_, fields_); EXPECT_EQ(op.result_symbols_, result_syms_); } private: std::string name_; std::vector args_; std::vector fields_; std::vector result_syms_; }; template std::list> MakeCheckers(T arg) { std::list> l; l.emplace_back(std::make_unique(arg)); return l; } template std::list> MakeCheckers(T arg, Rest &&...rest) { auto l = MakeCheckers(std::forward(rest)...); l.emplace_front(std::make_unique(arg)); return std::move(l); } template TPlanner MakePlanner(TDbAccessor *dba, AstStorage &storage, SymbolTable &symbol_table, CypherQuery *query) { auto planning_context = MakePlanningContext(&storage, &symbol_table, query, dba); auto query_parts = CollectQueryParts(symbol_table, storage, query); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; return TPlanner(single_query_parts, planning_context); } class FakeDistributedDbAccessor { public: int64_t VerticesCount(memgraph::storage::v3::LabelId label) const { auto found = label_index_.find(label); if (found != label_index_.end()) return found->second; return 0; } int64_t VerticesCount(memgraph::storage::v3::LabelId label, memgraph::storage::v3::PropertyId property) const { for (auto &index : label_property_index_) { if (std::get<0>(index) == label && std::get<1>(index) == property) { return std::get<2>(index); } } return 0; } bool LabelIndexExists(memgraph::storage::v3::LabelId label) const { return label_index_.find(label) != label_index_.end(); } bool LabelPropertyIndexExists(memgraph::storage::v3::LabelId label, memgraph::storage::v3::PropertyId property) const { for (auto &index : label_property_index_) { if (std::get<0>(index) == label && std::get<1>(index) == property) { return true; } } return false; } void SetIndexCount(memgraph::storage::v3::LabelId label, int64_t count) { label_index_[label] = count; } void SetIndexCount(memgraph::storage::v3::LabelId label, memgraph::storage::v3::PropertyId property, int64_t count) { for (auto &index : label_property_index_) { if (std::get<0>(index) == label && std::get<1>(index) == property) { std::get<2>(index) = count; return; } } label_property_index_.emplace_back(label, property, count); } memgraph::storage::v3::LabelId NameToLabel(const std::string &name) { auto found = primary_labels_.find(name); if (found != primary_labels_.end()) return found->second; return primary_labels_.emplace(name, memgraph::storage::v3::LabelId::FromUint(primary_labels_.size())) .first->second; } memgraph::storage::v3::LabelId Label(const std::string &name) { return NameToLabel(name); } memgraph::storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) { auto found = edge_types_.find(name); if (found != edge_types_.end()) return found->second; return edge_types_.emplace(name, memgraph::storage::v3::EdgeTypeId::FromUint(edge_types_.size())).first->second; } memgraph::storage::v3::PropertyId NameToPrimaryProperty(const std::string &name) { auto found = primary_properties_.find(name); if (found != primary_properties_.end()) return found->second; return primary_properties_.emplace(name, memgraph::storage::v3::PropertyId::FromUint(primary_properties_.size())) .first->second; } memgraph::storage::v3::PropertyId NameToSecondaryProperty(const std::string &name) { auto found = secondary_properties_.find(name); if (found != secondary_properties_.end()) return found->second; return secondary_properties_ .emplace(name, memgraph::storage::v3::PropertyId::FromUint(secondary_properties_.size())) .first->second; } memgraph::storage::v3::PropertyId PrimaryProperty(const std::string &name) { return NameToPrimaryProperty(name); } memgraph::storage::v3::PropertyId SecondaryProperty(const std::string &name) { return NameToSecondaryProperty(name); } std::string PrimaryPropertyToName(memgraph::storage::v3::PropertyId property) const { for (const auto &kv : primary_properties_) { if (kv.second == property) return kv.first; } LOG_FATAL("Unable to find primary property name"); } std::string SecondaryPropertyToName(memgraph::storage::v3::PropertyId property) const { for (const auto &kv : secondary_properties_) { if (kv.second == property) return kv.first; } LOG_FATAL("Unable to find secondary property name"); } std::string PrimaryPropertyName(memgraph::storage::v3::PropertyId property) const { return PrimaryPropertyToName(property); } std::string SecondaryPropertyName(memgraph::storage::v3::PropertyId property) const { return SecondaryPropertyToName(property); } memgraph::storage::v3::PropertyId NameToProperty(const std::string &name) { return storage::v3::PropertyId::FromUint(0); } std::vector ExtractPrimaryKey(storage::v3::LabelId label, std::vector property_filters) { return std::vector{}; } private: std::unordered_map primary_labels_; std::unordered_map secondary_labels_; std::unordered_map edge_types_; std::unordered_map primary_properties_; std::unordered_map secondary_properties_; std::unordered_map label_index_; std::vector> label_property_index_; }; } // namespace memgraph::query::v2::plan