// Copyright 2021 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // License, and you may not use this file except in compliance with the Business Source License. // // As of the Change Date specified in that file, in accordance with // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. #include #include #include "query/frontend/semantic/symbol_generator.hpp" #include "query/frontend/semantic/symbol_table.hpp" #include "query/plan/operator.hpp" #include "query/plan/planner.hpp" #include "query/plan/preprocess.hpp" namespace query::plan { class BaseOpChecker { public: virtual ~BaseOpChecker() {} virtual void CheckOp(LogicalOperator &, const SymbolTable &) = 0; }; class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { public: using HierarchicalLogicalOperatorVisitor::PostVisit; using HierarchicalLogicalOperatorVisitor::PreVisit; using HierarchicalLogicalOperatorVisitor::Visit; PlanChecker(const std::list> &checkers, const SymbolTable &symbol_table) : symbol_table_(symbol_table) { for (const auto &checker : checkers) checkers_.emplace_back(checker.get()); } PlanChecker(const std::list &checkers, const SymbolTable &symbol_table) : checkers_(checkers), symbol_table_(symbol_table) {} #define PRE_VISIT(TOp) \ bool PreVisit(TOp &op) override { \ CheckOp(op); \ return true; \ } #define VISIT(TOp) \ bool Visit(TOp &op) override { \ CheckOp(op); \ return true; \ } PRE_VISIT(CreateNode); PRE_VISIT(CreateExpand); PRE_VISIT(Delete); PRE_VISIT(ScanAll); PRE_VISIT(ScanAllByLabel); PRE_VISIT(ScanAllByLabelPropertyValue); PRE_VISIT(ScanAllByLabelPropertyRange); PRE_VISIT(ScanAllByLabelProperty); PRE_VISIT(ScanAllById); PRE_VISIT(Expand); PRE_VISIT(ExpandVariable); PRE_VISIT(Filter); PRE_VISIT(ConstructNamedPath); PRE_VISIT(Produce); PRE_VISIT(SetProperty); PRE_VISIT(SetProperties); PRE_VISIT(SetLabels); PRE_VISIT(RemoveProperty); PRE_VISIT(RemoveLabels); PRE_VISIT(EdgeUniquenessFilter); PRE_VISIT(Accumulate); PRE_VISIT(Aggregate); PRE_VISIT(Skip); PRE_VISIT(Limit); PRE_VISIT(OrderBy); bool PreVisit(Merge &op) override { CheckOp(op); op.input()->Accept(*this); return false; } bool PreVisit(Optional &op) override { CheckOp(op); op.input()->Accept(*this); return false; } PRE_VISIT(Unwind); PRE_VISIT(Distinct); bool Visit(Once &) override { // Ignore checking Once, it is implicitly at the end. return true; } bool PreVisit(Cartesian &op) override { CheckOp(op); return false; } PRE_VISIT(CallProcedure); #undef PRE_VISIT #undef VISIT void CheckOp(LogicalOperator &op) { ASSERT_FALSE(checkers_.empty()); checkers_.back()->CheckOp(op, symbol_table_); checkers_.pop_back(); } std::list checkers_; const SymbolTable &symbol_table_; }; template class OpChecker : public BaseOpChecker { public: void CheckOp(LogicalOperator &op, const SymbolTable &symbol_table) override { auto *expected_op = dynamic_cast(&op); ASSERT_TRUE(expected_op) << "op is '" << op.GetTypeInfo().name << "' expected '" << TOp::kType.name << "'!"; ExpectOp(*expected_op, symbol_table); } virtual void ExpectOp(TOp &, const SymbolTable &) {} }; using ExpectCreateNode = OpChecker; using ExpectCreateExpand = OpChecker; using ExpectDelete = OpChecker; using ExpectScanAll = OpChecker; using ExpectScanAllByLabel = OpChecker; using ExpectScanAllById = OpChecker; using ExpectExpand = OpChecker; using ExpectFilter = OpChecker; using ExpectConstructNamedPath = OpChecker; using ExpectProduce = OpChecker; using ExpectSetProperty = OpChecker; using ExpectSetProperties = OpChecker; using ExpectSetLabels = OpChecker; using ExpectRemoveProperty = OpChecker; using ExpectRemoveLabels = OpChecker; using ExpectEdgeUniquenessFilter = OpChecker; using ExpectSkip = OpChecker; using ExpectLimit = OpChecker; using ExpectOrderBy = OpChecker; using ExpectUnwind = OpChecker; using ExpectDistinct = OpChecker; class ExpectExpandVariable : public OpChecker { public: void ExpectOp(ExpandVariable &op, const SymbolTable &) override { EXPECT_EQ(op.type_, query::EdgeAtom::Type::DEPTH_FIRST); } }; class ExpectExpandBfs : public OpChecker { public: void ExpectOp(ExpandVariable &op, const SymbolTable &) override { EXPECT_EQ(op.type_, query::EdgeAtom::Type::BREADTH_FIRST); } }; class ExpectAccumulate : public OpChecker { public: explicit ExpectAccumulate(const std::unordered_set &symbols) : symbols_(symbols) {} void ExpectOp(Accumulate &op, const SymbolTable &) override { std::unordered_set got_symbols(op.symbols_.begin(), op.symbols_.end()); EXPECT_EQ(symbols_, got_symbols); } private: const std::unordered_set symbols_; }; class ExpectAggregate : public OpChecker { public: ExpectAggregate(const std::vector &aggregations, const std::unordered_set &group_by) : aggregations_(aggregations), group_by_(group_by) {} void ExpectOp(Aggregate &op, const SymbolTable &symbol_table) override { auto aggr_it = aggregations_.begin(); for (const auto &aggr_elem : op.aggregations_) { ASSERT_NE(aggr_it, aggregations_.end()); auto aggr = *aggr_it++; // TODO: Proper expression equality EXPECT_EQ(typeid(aggr_elem.value).hash_code(), typeid(aggr->expression1_).hash_code()); EXPECT_EQ(typeid(aggr_elem.key).hash_code(), typeid(aggr->expression2_).hash_code()); EXPECT_EQ(aggr_elem.op, aggr->op_); EXPECT_EQ(aggr_elem.output_sym, symbol_table.at(*aggr)); } EXPECT_EQ(aggr_it, aggregations_.end()); // TODO: Proper group by expression equality std::unordered_set got_group_by; for (auto *expr : op.group_by_) got_group_by.insert(typeid(*expr).hash_code()); std::unordered_set expected_group_by; for (auto *expr : group_by_) expected_group_by.insert(typeid(*expr).hash_code()); EXPECT_EQ(got_group_by, expected_group_by); } private: std::vector aggregations_; std::unordered_set group_by_; }; class ExpectMerge : public OpChecker { public: ExpectMerge(const std::list &on_match, const std::list &on_create) : on_match_(on_match), on_create_(on_create) {} void ExpectOp(Merge &merge, const SymbolTable &symbol_table) override { PlanChecker check_match(on_match_, symbol_table); merge.merge_match_->Accept(check_match); PlanChecker check_create(on_create_, symbol_table); merge.merge_create_->Accept(check_create); } private: const std::list &on_match_; const std::list &on_create_; }; class ExpectOptional : public OpChecker { public: explicit ExpectOptional(const std::list &optional) : optional_(optional) {} ExpectOptional(const std::vector &optional_symbols, const std::list &optional) : optional_symbols_(optional_symbols), optional_(optional) {} void ExpectOp(Optional &optional, const SymbolTable &symbol_table) override { if (!optional_symbols_.empty()) { EXPECT_THAT(optional.optional_symbols_, testing::UnorderedElementsAreArray(optional_symbols_)); } PlanChecker check_optional(optional_, symbol_table); optional.optional_->Accept(check_optional); } private: std::vector optional_symbols_; const std::list &optional_; }; class ExpectScanAllByLabelPropertyValue : public OpChecker { public: ExpectScanAllByLabelPropertyValue(storage::LabelId label, const std::pair &prop_pair, query::Expression *expression) : label_(label), property_(prop_pair.second), expression_(expression) {} void ExpectOp(ScanAllByLabelPropertyValue &scan_all, const SymbolTable &) override { EXPECT_EQ(scan_all.label_, label_); EXPECT_EQ(scan_all.property_, property_); // TODO: Proper expression equality EXPECT_EQ(typeid(scan_all.expression_).hash_code(), typeid(expression_).hash_code()); } private: storage::LabelId label_; storage::PropertyId property_; query::Expression *expression_; }; class ExpectScanAllByLabelPropertyRange : public OpChecker { public: ExpectScanAllByLabelPropertyRange(storage::LabelId label, storage::PropertyId property, std::optional lower_bound, std::optional upper_bound) : label_(label), property_(property), lower_bound_(lower_bound), upper_bound_(upper_bound) {} void ExpectOp(ScanAllByLabelPropertyRange &scan_all, const SymbolTable &) override { EXPECT_EQ(scan_all.label_, label_); EXPECT_EQ(scan_all.property_, property_); if (lower_bound_) { ASSERT_TRUE(scan_all.lower_bound_); // TODO: Proper expression equality EXPECT_EQ(typeid(scan_all.lower_bound_->value()).hash_code(), typeid(lower_bound_->value()).hash_code()); EXPECT_EQ(scan_all.lower_bound_->type(), lower_bound_->type()); } if (upper_bound_) { ASSERT_TRUE(scan_all.upper_bound_); // TODO: Proper expression equality EXPECT_EQ(typeid(scan_all.upper_bound_->value()).hash_code(), typeid(upper_bound_->value()).hash_code()); EXPECT_EQ(scan_all.upper_bound_->type(), upper_bound_->type()); } } private: storage::LabelId label_; storage::PropertyId property_; std::optional lower_bound_; std::optional upper_bound_; }; class ExpectScanAllByLabelProperty : public OpChecker { public: ExpectScanAllByLabelProperty(storage::LabelId label, const std::pair &prop_pair) : label_(label), property_(prop_pair.second) {} void ExpectOp(ScanAllByLabelProperty &scan_all, const SymbolTable &) override { EXPECT_EQ(scan_all.label_, label_); EXPECT_EQ(scan_all.property_, property_); } private: storage::LabelId label_; storage::PropertyId property_; }; class ExpectCartesian : public OpChecker { public: ExpectCartesian(const std::list> &left, const std::list> &right) : left_(left), right_(right) {} void ExpectOp(Cartesian &op, const SymbolTable &symbol_table) override { ASSERT_TRUE(op.left_op_); PlanChecker left_checker(left_, symbol_table); op.left_op_->Accept(left_checker); ASSERT_TRUE(op.right_op_); PlanChecker right_checker(right_, symbol_table); op.right_op_->Accept(right_checker); } private: const std::list> &left_; const std::list> &right_; }; class ExpectCallProcedure : public OpChecker { public: ExpectCallProcedure(const std::string &name, const std::vector &args, const std::vector &fields, const std::vector &result_syms) : name_(name), args_(args), fields_(fields), result_syms_(result_syms) {} void ExpectOp(CallProcedure &op, const SymbolTable &symbol_table) override { EXPECT_EQ(op.procedure_name_, name_); EXPECT_EQ(op.arguments_.size(), args_.size()); for (size_t i = 0; i < args_.size(); ++i) { const auto *op_arg = op.arguments_[i]; const auto *expected_arg = args_[i]; // TODO: Proper expression equality EXPECT_EQ(op_arg->GetTypeInfo(), expected_arg->GetTypeInfo()); } EXPECT_EQ(op.result_fields_, fields_); EXPECT_EQ(op.result_symbols_, result_syms_); } private: std::string name_; std::vector args_; std::vector fields_; std::vector result_syms_; }; template std::list> MakeCheckers(T arg) { std::list> l; l.emplace_back(std::make_unique(arg)); return l; } template std::list> MakeCheckers(T arg, Rest &&...rest) { auto l = MakeCheckers(std::forward(rest)...); l.emplace_front(std::make_unique(arg)); return std::move(l); } template TPlanner MakePlanner(TDbAccessor *dba, AstStorage &storage, SymbolTable &symbol_table, CypherQuery *query) { auto planning_context = MakePlanningContext(&storage, &symbol_table, query, dba); auto query_parts = CollectQueryParts(symbol_table, storage, query); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; return TPlanner(single_query_parts, planning_context); } class FakeDbAccessor { public: int64_t VerticesCount(storage::LabelId label) const { auto found = label_index_.find(label); if (found != label_index_.end()) return found->second; return 0; } int64_t VerticesCount(storage::LabelId label, storage::PropertyId property) const { for (auto &index : label_property_index_) { if (std::get<0>(index) == label && std::get<1>(index) == property) { return std::get<2>(index); } } return 0; } bool LabelIndexExists(storage::LabelId label) const { return label_index_.find(label) != label_index_.end(); } bool LabelPropertyIndexExists(storage::LabelId label, storage::PropertyId property) const { for (auto &index : label_property_index_) { if (std::get<0>(index) == label && std::get<1>(index) == property) { return true; } } return false; } void SetIndexCount(storage::LabelId label, int64_t count) { label_index_[label] = count; } void SetIndexCount(storage::LabelId label, storage::PropertyId property, int64_t count) { for (auto &index : label_property_index_) { if (std::get<0>(index) == label && std::get<1>(index) == property) { std::get<2>(index) = count; return; } } label_property_index_.emplace_back(label, property, count); } storage::LabelId NameToLabel(const std::string &name) { auto found = labels_.find(name); if (found != labels_.end()) return found->second; return labels_.emplace(name, storage::LabelId::FromUint(labels_.size())).first->second; } storage::LabelId Label(const std::string &name) { return NameToLabel(name); } storage::EdgeTypeId NameToEdgeType(const std::string &name) { auto found = edge_types_.find(name); if (found != edge_types_.end()) return found->second; return edge_types_.emplace(name, storage::EdgeTypeId::FromUint(edge_types_.size())).first->second; } storage::PropertyId NameToProperty(const std::string &name) { auto found = properties_.find(name); if (found != properties_.end()) return found->second; return properties_.emplace(name, storage::PropertyId::FromUint(properties_.size())).first->second; } storage::PropertyId Property(const std::string &name) { return NameToProperty(name); } std::string PropertyToName(storage::PropertyId property) const { for (const auto &kv : properties_) { if (kv.second == property) return kv.first; } LOG_FATAL("Unable to find property name"); } std::string PropertyName(storage::PropertyId property) const { return PropertyToName(property); } private: std::unordered_map labels_; std::unordered_map edge_types_; std::unordered_map properties_; std::unordered_map label_index_; std::vector> label_property_index_; }; } // namespace query::plan