// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // License, and you may not use this file except in compliance with the Business Source License. // // As of the Change Date specified in that file, in accordance with // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. #include #include #include "query/frontend/semantic/symbol_generator.hpp" #include "query/frontend/semantic/symbol_table.hpp" #include "query/v2/plan/operator.hpp" #include "query/v2/plan/planner.hpp" #include "query/v2/plan/preprocess.hpp" #include "utils/exceptions.hpp" namespace memgraph::query::v2::plan { class BaseOpChecker { public: virtual ~BaseOpChecker() {} virtual void CheckOp(LogicalOperator &, const SymbolTable &) = 0; }; class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { public: using HierarchicalLogicalOperatorVisitor::PostVisit; using HierarchicalLogicalOperatorVisitor::PreVisit; using HierarchicalLogicalOperatorVisitor::Visit; PlanChecker(const std::list> &checkers, const SymbolTable &symbol_table) : symbol_table_(symbol_table) { for (const auto &checker : checkers) checkers_.emplace_back(checker.get()); } PlanChecker(const std::list &checkers, const SymbolTable &symbol_table) : checkers_(checkers), symbol_table_(symbol_table) {} #define PRE_VISIT(TOp) \ bool PreVisit(TOp &op) override { \ CheckOp(op); \ return true; \ } #define VISIT(TOp) \ bool Visit(TOp &op) override { \ CheckOp(op); \ return true; \ } PRE_VISIT(CreateNode); PRE_VISIT(CreateExpand); PRE_VISIT(Delete); PRE_VISIT(ScanAll); PRE_VISIT(ScanAllByLabel); PRE_VISIT(ScanAllByLabelPropertyValue); PRE_VISIT(ScanAllByLabelPropertyRange); PRE_VISIT(ScanAllByLabelProperty); PRE_VISIT(ScanByPrimaryKey); PRE_VISIT(Expand); PRE_VISIT(ExpandVariable); PRE_VISIT(Filter); PRE_VISIT(ConstructNamedPath); PRE_VISIT(Produce); PRE_VISIT(SetProperty); PRE_VISIT(SetProperties); PRE_VISIT(SetLabels); PRE_VISIT(RemoveProperty); PRE_VISIT(RemoveLabels); PRE_VISIT(EdgeUniquenessFilter); PRE_VISIT(Accumulate); PRE_VISIT(Aggregate); PRE_VISIT(Skip); PRE_VISIT(Limit); PRE_VISIT(OrderBy); bool PreVisit(Merge &op) override { CheckOp(op); op.input()->Accept(*this); return false; } bool PreVisit(Optional &op) override { CheckOp(op); op.input()->Accept(*this); return false; } PRE_VISIT(Unwind); PRE_VISIT(Distinct); bool PreVisit(Foreach &op) override { CheckOp(op); return false; } bool Visit(Once &) override { // Ignore checking Once, it is implicitly at the end. return true; } bool PreVisit(Cartesian &op) override { CheckOp(op); return false; } PRE_VISIT(CallProcedure); #undef PRE_VISIT #undef VISIT void CheckOp(LogicalOperator &op) { ASSERT_FALSE(checkers_.empty()); checkers_.back()->CheckOp(op, symbol_table_); checkers_.pop_back(); } std::list checkers_; const SymbolTable &symbol_table_; }; template class OpChecker : public BaseOpChecker { public: void CheckOp(LogicalOperator &op, const SymbolTable &symbol_table) override { auto *expected_op = dynamic_cast(&op); ASSERT_TRUE(expected_op) << "op is '" << op.GetTypeInfo().name << "' expected '" << TOp::kType.name << "'!"; ExpectOp(*expected_op, symbol_table); } virtual void ExpectOp(TOp &, const SymbolTable &) {} }; using ExpectCreateNode = OpChecker; using ExpectCreateExpand = OpChecker; using ExpectDelete = OpChecker; using ExpectScanAll = OpChecker; using ExpectScanAllByLabel = OpChecker; using ExpectExpand = OpChecker; using ExpectFilter = OpChecker; using ExpectConstructNamedPath = OpChecker; using ExpectProduce = OpChecker; using ExpectSetProperty = OpChecker; using ExpectSetProperties = OpChecker; using ExpectSetLabels = OpChecker; using ExpectRemoveProperty = OpChecker; using ExpectRemoveLabels = OpChecker; using ExpectEdgeUniquenessFilter = OpChecker; using ExpectSkip = OpChecker; using ExpectLimit = OpChecker; using ExpectOrderBy = OpChecker; using ExpectUnwind = OpChecker; using ExpectDistinct = OpChecker; class ExpectScanAllByLabelPropertyValue : public OpChecker { public: ExpectScanAllByLabelPropertyValue(memgraph::storage::v3::LabelId label, const std::pair &prop_pair, memgraph::query::v2::Expression *expression) : label_(label), property_(prop_pair.second), expression_(expression) {} void ExpectOp(ScanAllByLabelPropertyValue &scan_all, const SymbolTable &) override { EXPECT_EQ(scan_all.label_, label_); EXPECT_EQ(scan_all.property_, property_); // TODO: Proper expression equality EXPECT_EQ(typeid(scan_all.expression_).hash_code(), typeid(expression_).hash_code()); } private: memgraph::storage::v3::LabelId label_; memgraph::storage::v3::PropertyId property_; memgraph::query::v2::Expression *expression_; }; class ExpectScanByPrimaryKey : public OpChecker { public: ExpectScanByPrimaryKey(memgraph::storage::v3::LabelId label, const std::vector &properties) : label_(label), properties_(properties) {} void ExpectOp(v2::plan::ScanByPrimaryKey &scan_all, const SymbolTable &) override { EXPECT_EQ(scan_all.label_, label_); bool primary_property_match = true; for (const auto &expected_prop : properties_) { bool has_match = false; for (const auto &prop : scan_all.primary_key_) { if (typeid(prop).hash_code() == typeid(expected_prop).hash_code()) { has_match = true; } } if (!has_match) { primary_property_match = false; } } EXPECT_TRUE(primary_property_match); } private: memgraph::storage::v3::LabelId label_; std::vector properties_; }; class ExpectCartesian : public OpChecker { public: ExpectCartesian(const std::list> &left, const std::list> &right) : left_(left), right_(right) {} void ExpectOp(Cartesian &op, const SymbolTable &symbol_table) override { ASSERT_TRUE(op.left_op_); PlanChecker left_checker(left_, symbol_table); op.left_op_->Accept(left_checker); ASSERT_TRUE(op.right_op_); PlanChecker right_checker(right_, symbol_table); op.right_op_->Accept(right_checker); } private: const std::list> &left_; const std::list> &right_; }; class ExpectCallProcedure : public OpChecker { public: ExpectCallProcedure(const std::string &name, const std::vector &args, const std::vector &fields, const std::vector &result_syms) : name_(name), args_(args), fields_(fields), result_syms_(result_syms) {} void ExpectOp(CallProcedure &op, const SymbolTable &symbol_table) override { EXPECT_EQ(op.procedure_name_, name_); EXPECT_EQ(op.arguments_.size(), args_.size()); for (size_t i = 0; i < args_.size(); ++i) { const auto *op_arg = op.arguments_[i]; const auto *expected_arg = args_[i]; // TODO: Proper expression equality EXPECT_EQ(op_arg->GetTypeInfo(), expected_arg->GetTypeInfo()); } EXPECT_EQ(op.result_fields_, fields_); EXPECT_EQ(op.result_symbols_, result_syms_); } private: std::string name_; std::vector args_; std::vector fields_; std::vector result_syms_; }; template std::list> MakeCheckers(T arg) { std::list> l; l.emplace_back(std::make_unique(arg)); return l; } template std::list> MakeCheckers(T arg, Rest &&...rest) { auto l = MakeCheckers(std::forward(rest)...); l.emplace_front(std::make_unique(arg)); return std::move(l); } template TPlanner MakePlanner(TDbAccessor *dba, AstStorage &storage, SymbolTable &symbol_table, CypherQuery *query) { auto planning_context = MakePlanningContext(&storage, &symbol_table, query, dba); auto query_parts = CollectQueryParts(symbol_table, storage, query); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; return TPlanner(single_query_parts, planning_context); } class FakeDistributedDbAccessor { public: int64_t VerticesCount(memgraph::storage::v3::LabelId label) const { auto found = label_index_.find(label); if (found != label_index_.end()) return found->second; return 0; } int64_t VerticesCount(memgraph::storage::v3::LabelId label, memgraph::storage::v3::PropertyId property) const { for (auto &index : label_property_index_) { if (std::get<0>(index) == label && std::get<1>(index) == property) { return std::get<2>(index); } } return 0; } bool LabelIndexExists(memgraph::storage::v3::LabelId label) const { throw utils::NotYetImplemented("Label indicies are yet to be implemented."); } bool LabelPropertyIndexExists(memgraph::storage::v3::LabelId label, memgraph::storage::v3::PropertyId property) const { for (auto &index : label_property_index_) { if (std::get<0>(index) == label && std::get<1>(index) == property) { return true; } } return false; } bool PrimaryLabelExists(storage::v3::LabelId label) { return label_index_.find(label) != label_index_.end(); } void SetIndexCount(memgraph::storage::v3::LabelId label, int64_t count) { label_index_[label] = count; } void SetIndexCount(memgraph::storage::v3::LabelId label, memgraph::storage::v3::PropertyId property, int64_t count) { for (auto &index : label_property_index_) { if (std::get<0>(index) == label && std::get<1>(index) == property) { std::get<2>(index) = count; return; } } label_property_index_.emplace_back(label, property, count); } memgraph::storage::v3::LabelId NameToLabel(const std::string &name) { auto found = primary_labels_.find(name); if (found != primary_labels_.end()) return found->second; return primary_labels_.emplace(name, memgraph::storage::v3::LabelId::FromUint(primary_labels_.size())) .first->second; } memgraph::storage::v3::LabelId Label(const std::string &name) { return NameToLabel(name); } memgraph::storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) { auto found = edge_types_.find(name); if (found != edge_types_.end()) return found->second; return edge_types_.emplace(name, memgraph::storage::v3::EdgeTypeId::FromUint(edge_types_.size())).first->second; } memgraph::storage::v3::PropertyId NameToPrimaryProperty(const std::string &name) { auto found = primary_properties_.find(name); if (found != primary_properties_.end()) return found->second; return primary_properties_.emplace(name, memgraph::storage::v3::PropertyId::FromUint(primary_properties_.size())) .first->second; } memgraph::storage::v3::PropertyId NameToSecondaryProperty(const std::string &name) { auto found = secondary_properties_.find(name); if (found != secondary_properties_.end()) return found->second; return secondary_properties_ .emplace(name, memgraph::storage::v3::PropertyId::FromUint(secondary_properties_.size())) .first->second; } memgraph::storage::v3::PropertyId PrimaryProperty(const std::string &name) { return NameToPrimaryProperty(name); } memgraph::storage::v3::PropertyId SecondaryProperty(const std::string &name) { return NameToSecondaryProperty(name); } std::string PrimaryPropertyToName(memgraph::storage::v3::PropertyId property) const { for (const auto &kv : primary_properties_) { if (kv.second == property) return kv.first; } LOG_FATAL("Unable to find primary property name"); } std::string SecondaryPropertyToName(memgraph::storage::v3::PropertyId property) const { for (const auto &kv : secondary_properties_) { if (kv.second == property) return kv.first; } LOG_FATAL("Unable to find secondary property name"); } std::string PrimaryPropertyName(memgraph::storage::v3::PropertyId property) const { return PrimaryPropertyToName(property); } std::string SecondaryPropertyName(memgraph::storage::v3::PropertyId property) const { return SecondaryPropertyToName(property); } memgraph::storage::v3::PropertyId NameToProperty(const std::string &name) { auto find_in_prim_properties = primary_properties_.find(name); if (find_in_prim_properties != primary_properties_.end()) { return find_in_prim_properties->second; } auto find_in_secondary_properties = secondary_properties_.find(name); if (find_in_secondary_properties != secondary_properties_.end()) { return find_in_secondary_properties->second; } LOG_FATAL("The property does not exist as a primary or a secondary property."); return memgraph::storage::v3::PropertyId::FromUint(0); } std::vector GetSchemaForLabel(storage::v3::LabelId label) { auto schema_properties = schemas_.at(label); std::vector ret; std::transform(schema_properties.begin(), schema_properties.end(), std::back_inserter(ret), [](const auto &prop) { memgraph::storage::v3::SchemaProperty schema_prop = { .property_id = prop, // This should not be hardcoded, but for testing purposes it will suffice. .type = memgraph::common::SchemaType::INT}; return schema_prop; }); return ret; } std::vector> ExtractPrimaryKey( storage::v3::LabelId label, std::vector property_filters) { MG_ASSERT(schemas_.contains(label), "You did not specify the Schema for this label! Use FakeDistributedDbAccessor::CreateSchema(...)."); std::vector> pk; const auto schema = GetSchemaPropertiesForLabel(label); std::vector schema_properties; schema_properties.reserve(schema.size()); std::transform(schema.begin(), schema.end(), std::back_inserter(schema_properties), [](const auto &schema_elem) { return schema_elem; }); for (const auto &property_filter : property_filters) { const auto &property_id = NameToProperty(property_filter.property_filter->property_.name); if (std::find(schema_properties.begin(), schema_properties.end(), property_id) != schema_properties.end()) { pk.emplace_back(std::make_pair(property_filter.expression, property_filter)); } } return pk.size() == schema_properties.size() ? pk : std::vector>{}; } std::vector GetSchemaPropertiesForLabel(storage::v3::LabelId label) { return schemas_.at(label); } void CreateSchema(const memgraph::storage::v3::LabelId primary_label, const std::vector &schemas_types) { MG_ASSERT(!schemas_.contains(primary_label), "You already created the schema for this label!"); schemas_.emplace(primary_label, schemas_types); } private: std::unordered_map primary_labels_; std::unordered_map secondary_labels_; std::unordered_map edge_types_; std::unordered_map primary_properties_; std::unordered_map secondary_properties_; std::unordered_map label_index_; std::vector> label_property_index_; std::unordered_map> schemas_; }; } // namespace memgraph::query::v2::plan