diff --git a/src/flags/CMakeLists.txt b/src/flags/CMakeLists.txt index 6f4c1b748..e8988756f 100644 --- a/src/flags/CMakeLists.txt +++ b/src/flags/CMakeLists.txt @@ -5,6 +5,7 @@ add_library(mg-flags STATIC audit.cpp log_level.cpp memory_limit.cpp run_time_configurable.cpp - storage_mode.cpp) + storage_mode.cpp + query.cpp) target_include_directories(mg-flags PUBLIC ${CMAKE_SOURCE_DIR}/include) target_link_libraries(mg-flags PUBLIC spdlog::spdlog mg-settings mg-utils) diff --git a/src/flags/all.hpp b/src/flags/all.hpp index 9cdcd1e8f..f7b44272a 100644 --- a/src/flags/all.hpp +++ b/src/flags/all.hpp @@ -16,5 +16,6 @@ #include "flags/isolation_level.hpp" #include "flags/log_level.hpp" #include "flags/memory_limit.hpp" +#include "flags/query.hpp" #include "flags/run_time_configurable.hpp" #include "flags/storage_mode.hpp" diff --git a/src/flags/query.cpp b/src/flags/query.cpp new file mode 100644 index 000000000..6d72d9c49 --- /dev/null +++ b/src/flags/query.cpp @@ -0,0 +1,14 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. +#include "flags/query.hpp" + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +// DEFINE_bool(cartesian_product_enabled, true, "Enable cartesian product expansion."); Moved to run_time_configurable diff --git a/src/flags/query.hpp b/src/flags/query.hpp new file mode 100644 index 000000000..bee2509a4 --- /dev/null +++ b/src/flags/query.hpp @@ -0,0 +1,16 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. +#pragma once + +#include "gflags/gflags.h" + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +// DECLARE_bool(cartesian_product_enabled); Moved to run_time_configurable diff --git a/src/flags/run_time_configurable.cpp b/src/flags/run_time_configurable.cpp index a861769d8..a42ebd3d0 100644 --- a/src/flags/run_time_configurable.cpp +++ b/src/flags/run_time_configurable.cpp @@ -19,6 +19,7 @@ #include "flags/bolt.hpp" #include "flags/general.hpp" #include "flags/log_level.hpp" +#include "flags/query.hpp" #include "spdlog/cfg/helpers-inl.h" #include "spdlog/spdlog.h" #include "utils/exceptions.hpp" @@ -49,6 +50,10 @@ DEFINE_double(query_execution_timeout_sec, 600, "Maximum allowed query execution time. Queries exceeding this " "limit will be aborted. Value of 0 means no limit."); +// Query plan flags +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_bool(cartesian_product_enabled, true, "Enable cartesian product expansion."); + namespace { // Bolt server name constexpr auto kServerNameSettingKey = "server.name"; @@ -65,9 +70,15 @@ constexpr auto kLogLevelGFlagsKey = "log_level"; constexpr auto kLogToStderrSettingKey = "log.to_stderr"; constexpr auto kLogToStderrGFlagsKey = "also_log_to_stderr"; +constexpr auto kCartesianProductEnabledSettingKey = "cartesian-product-enabled"; +constexpr auto kCartesianProductEnabledGFlagsKey = "cartesian-product-enabled"; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) std::atomic<double> execution_timeout_sec_; // Local cache-like thing +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +std::atomic<bool> cartesian_product_enabled_{true}; // Local cache-like thing + auto ToLLEnum(std::string_view val) { const auto ll_enum = memgraph::flags::LogLevelToEnum(val); if (!ll_enum) { @@ -171,6 +182,10 @@ void Initialize() { } }, ValidBoolStr); + + register_flag( + kCartesianProductEnabledGFlagsKey, kCartesianProductEnabledSettingKey, !kRestore, + [](const std::string &val) { cartesian_product_enabled_ = val == "true"; }, ValidBoolStr); } std::string GetServerName() { @@ -182,4 +197,6 @@ std::string GetServerName() { double GetExecutionTimeout() { return execution_timeout_sec_; } +bool GetCartesianProductEnabled() { return cartesian_product_enabled_; } + } // namespace memgraph::flags::run_time diff --git a/src/flags/run_time_configurable.hpp b/src/flags/run_time_configurable.hpp index a45258853..944a0539f 100644 --- a/src/flags/run_time_configurable.hpp +++ b/src/flags/run_time_configurable.hpp @@ -35,4 +35,11 @@ std::string GetServerName(); */ double GetExecutionTimeout(); +/** + * @brief Get the cartesian product enabled value + * + * @return bool + */ +bool GetCartesianProductEnabled(); + } // namespace memgraph::flags::run_time diff --git a/src/query/plan/cost_estimator.hpp b/src/query/plan/cost_estimator.hpp index 3169d2e2f..47da0a23b 100644 --- a/src/query/plan/cost_estimator.hpp +++ b/src/query/plan/cost_estimator.hpp @@ -38,6 +38,14 @@ struct Scope { std::unordered_map<std::string, SymbolStatistics> symbol_stats; }; +struct CostEstimation { + // expense of running the query + double cost; + + // expected number of rows + double cardinality; +}; + /** * Query plan execution time cost estimator, for comparing and choosing optimal * execution plans. @@ -271,12 +279,12 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { } bool PreVisit(Union &op) override { - double left_cost = EstimateCostOnBranch(&op.left_op_); - double right_cost = EstimateCostOnBranch(&op.right_op_); + CostEstimation left_estimation = EstimateCostOnBranch(&op.left_op_); + CostEstimation right_estimation = EstimateCostOnBranch(&op.right_op_); // the number of hits in the previous operator should be the joined number of results of both parts of the union - cardinality_ *= (left_cost + right_cost); - IncrementCost(CostParam::kUnion); + cost_ = left_estimation.cost + right_estimation.cost; + cardinality_ = left_estimation.cardinality + right_estimation.cardinality; return false; } @@ -303,11 +311,57 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { // Estimate cost on the subquery branch independently, use a copy auto &last_scope = scopes_.back(); - double subquery_cost = EstimateCostOnBranch(&op.subquery_, last_scope); - subquery_cost = !utils::ApproxEqualDecimal(subquery_cost, 0.0) ? subquery_cost : 1; - cardinality_ *= subquery_cost; + CostEstimation subquery_estimation = EstimateCostOnBranch(&op.subquery_, last_scope); + double subquery_cost = !utils::ApproxEqualDecimal(subquery_estimation.cost, 0.0) ? subquery_estimation.cost : 1; + IncrementCost(subquery_cost); - IncrementCost(CostParam::kSubquery); + double subquery_cardinality = + !utils::ApproxEqualDecimal(subquery_estimation.cardinality, 0.0) ? subquery_estimation.cardinality : 1; + cardinality_ *= subquery_cardinality; + + return false; + } + + bool PreVisit(Cartesian &op) override { + // Get the cost of the main branch + op.left_op_->Accept(*this); + + // add cost from the right branch and multiply cardinalities + CostEstimation right_cost_estimation = EstimateCostOnBranch(&op.right_op_); + cost_ += right_cost_estimation.cost; + double right_cardinality = + !utils::ApproxEqualDecimal(right_cost_estimation.cardinality, 0.0) ? right_cost_estimation.cardinality : 1; + cardinality_ *= right_cardinality; + + return false; + } + + bool PreVisit(IndexedJoin &op) override { + // Get the cost of the main branch + op.main_branch_->Accept(*this); + + // add cost from the right branch and multiply cardinalities + CostEstimation right_cost_estimation = EstimateCostOnBranch(&op.sub_branch_); + IncrementCost(right_cost_estimation.cost); + + double right_cardinality = + !utils::ApproxEqualDecimal(right_cost_estimation.cardinality, 0.0) ? right_cost_estimation.cardinality : 1; + cardinality_ *= right_cardinality; + + return false; + } + + bool PreVisit(HashJoin &op) override { + // Get the cost of the main branch + op.left_op_->Accept(*this); + + // add cost from the right branch and multiply cardinalities + CostEstimation right_cost_estimation = EstimateCostOnBranch(&op.right_op_); + IncrementCost(right_cost_estimation.cost); + + double right_cardinality = + !utils::ApproxEqualDecimal(right_cost_estimation.cardinality, 0.0) ? right_cost_estimation.cardinality : 1; + cardinality_ *= right_cardinality; return false; } @@ -339,16 +393,16 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { void IncrementCost(double param) { cost_ += param * cardinality_; } - double EstimateCostOnBranch(std::shared_ptr<LogicalOperator> *branch) { + CostEstimation EstimateCostOnBranch(std::shared_ptr<LogicalOperator> *branch) { CostEstimator<TDbAccessor> cost_estimator(db_accessor_, table_, parameters); (*branch)->Accept(cost_estimator); - return cost_estimator.cost(); + return CostEstimation{.cost = cost_estimator.cost(), .cardinality = cost_estimator.cardinality()}; } - double EstimateCostOnBranch(std::shared_ptr<LogicalOperator> *branch, Scope scope) { + CostEstimation EstimateCostOnBranch(std::shared_ptr<LogicalOperator> *branch, Scope scope) { CostEstimator<TDbAccessor> cost_estimator(db_accessor_, table_, parameters, scope); (*branch)->Accept(cost_estimator); - return cost_estimator.cost(); + return CostEstimation{.cost = cost_estimator.cost(), .cardinality = cost_estimator.cardinality()}; } // converts an optional ScanAll range bound into a property value diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 227e523fb..6170dce51 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -130,6 +130,8 @@ extern const Event ForeachOperator; extern const Event EmptyResultOperator; extern const Event EvaluatePatternFilterOperator; extern const Event ApplyOperator; +extern const Event IndexedJoinOperator; +extern const Event HashJoinOperator; } // namespace memgraph::metrics namespace memgraph::query::plan { @@ -5160,4 +5162,200 @@ void Apply::ApplyCursor::Reset() { pull_input_ = true; } +IndexedJoin::IndexedJoin(const std::shared_ptr<LogicalOperator> main_branch, + const std::shared_ptr<LogicalOperator> sub_branch) + : main_branch_(main_branch ? main_branch : std::make_shared<Once>()), sub_branch_(sub_branch) {} + +WITHOUT_SINGLE_INPUT(IndexedJoin); + +bool IndexedJoin::Accept(HierarchicalLogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + main_branch_->Accept(visitor) && sub_branch_->Accept(visitor); + } + return visitor.PostVisit(*this); +} + +UniqueCursorPtr IndexedJoin::MakeCursor(utils::MemoryResource *mem) const { + memgraph::metrics::IncrementCounter(memgraph::metrics::IndexedJoinOperator); + + return MakeUniqueCursorPtr<IndexedJoinCursor>(mem, *this, mem); +} + +IndexedJoin::IndexedJoinCursor::IndexedJoinCursor(const IndexedJoin &self, utils::MemoryResource *mem) + : self_(self), main_branch_(self.main_branch_->MakeCursor(mem)), sub_branch_(self.sub_branch_->MakeCursor(mem)) {} + +std::vector<Symbol> IndexedJoin::ModifiedSymbols(const SymbolTable &table) const { + // Since Apply is the Cartesian product, modified symbols are combined from + // both execution branches. + auto symbols = main_branch_->ModifiedSymbols(table); + auto sub_branch_symbols = sub_branch_->ModifiedSymbols(table); + symbols.insert(symbols.end(), sub_branch_symbols.begin(), sub_branch_symbols.end()); + return symbols; +} + +bool IndexedJoin::IndexedJoinCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("IndexedJoin"); + + while (true) { + if (pull_input_ && !main_branch_->Pull(frame, context)) { + return false; + }; + + if (sub_branch_->Pull(frame, context)) { + // if successful, next Pull from this should not pull_input_ + pull_input_ = false; + return true; + } + // failed to pull from subquery cursor + // skip that row + pull_input_ = true; + sub_branch_->Reset(); + } +} + +void IndexedJoin::IndexedJoinCursor::Shutdown() { + main_branch_->Shutdown(); + sub_branch_->Shutdown(); +} + +void IndexedJoin::IndexedJoinCursor::Reset() { + main_branch_->Reset(); + sub_branch_->Reset(); + pull_input_ = true; +} + +std::vector<Symbol> HashJoin::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = left_op_->ModifiedSymbols(table); + auto right = right_op_->ModifiedSymbols(table); + symbols.insert(symbols.end(), right.begin(), right.end()); + return symbols; +} + +bool HashJoin::Accept(HierarchicalLogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + left_op_->Accept(visitor) && right_op_->Accept(visitor); + } + return visitor.PostVisit(*this); +} + +WITHOUT_SINGLE_INPUT(HashJoin); + +namespace { + +class HashJoinCursor : public Cursor { + public: + HashJoinCursor(const HashJoin &self, utils::MemoryResource *mem) + : self_(self), + left_op_cursor_(self.left_op_->MakeCursor(mem)), + right_op_cursor_(self_.right_op_->MakeCursor(mem)), + hashtable_(mem), + right_op_frame_(mem) { + MG_ASSERT(left_op_cursor_ != nullptr, "HashJoinCursor: Missing left operator cursor."); + MG_ASSERT(right_op_cursor_ != nullptr, "HashJoinCursor: Missing right operator cursor."); + } + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("HashJoin"); + + if (!hash_join_initialized_) { + InitializeHashJoin(frame, context); + hash_join_initialized_ = true; + } + + // If left_op yielded zero results, there is no cartesian product. + if (hashtable_.empty()) { + return false; + } + + auto restore_frame = [&frame, &context](const auto &symbols, const auto &restore_from) { + for (const auto &symbol : symbols) { + frame[symbol] = restore_from[symbol.position()]; + if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) { + context.frame_change_collector->ResetTrackingValue(symbol.name()); + } + } + }; + + if (!common_value_found_) { + // Pull from the right_op until there’s a mergeable frame + while (true) { + auto pulled = right_op_cursor_->Pull(frame, context); + if (!pulled) return false; + + // Check if the join value from the pulled frame is shared with any left frames + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::View::OLD); + auto right_value = self_.hash_join_condition_->expression1_->Accept(evaluator); + if (hashtable_.contains(right_value)) { + // If so, finish pulling for now and proceed to joining the pulled frame + right_op_frame_.assign(frame.elems().begin(), frame.elems().end()); + common_value_found_ = true; + common_value = right_value; + left_op_frame_it_ = hashtable_[common_value].begin(); + break; + } + } + } else { + // Restore the right frame ahead of restoring the left frame + restore_frame(self_.right_symbols_, right_op_frame_); + } + + restore_frame(self_.left_symbols_, *left_op_frame_it_); + + left_op_frame_it_++; + // When all left frames with the common value have been joined, move on to pulling and joining the next right frame + if (common_value_found_ && left_op_frame_it_ == hashtable_[common_value].end()) { + common_value_found_ = false; + } + + return true; + } + + void Shutdown() override { + left_op_cursor_->Shutdown(); + right_op_cursor_->Shutdown(); + } + + void Reset() override { + left_op_cursor_->Reset(); + right_op_cursor_->Reset(); + hashtable_.clear(); + right_op_frame_.clear(); + left_op_frame_it_ = {}; + hash_join_initialized_ = false; + common_value_found_ = false; + } + + private: + void InitializeHashJoin(Frame &frame, ExecutionContext &context) { + // Pull all left_op_ frames + while (left_op_cursor_->Pull(frame, context)) { + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::View::OLD); + auto left_value = self_.hash_join_condition_->expression2_->Accept(evaluator); + if (left_value.type() != TypedValue::Type::Null) { + hashtable_[left_value].emplace_back(frame.elems().begin(), frame.elems().end()); + } + } + } + + const HashJoin &self_; + const UniqueCursorPtr left_op_cursor_; + const UniqueCursorPtr right_op_cursor_; + utils::pmr::unordered_map<TypedValue, utils::pmr::vector<utils::pmr::vector<TypedValue>>, TypedValue::Hash, + TypedValue::BoolEqual> + hashtable_; + utils::pmr::vector<TypedValue> right_op_frame_; + utils::pmr::vector<utils::pmr::vector<TypedValue>>::iterator left_op_frame_it_; + bool hash_join_initialized_{false}; + bool common_value_found_{false}; + TypedValue common_value; +}; +} // namespace + +UniqueCursorPtr HashJoin::MakeCursor(utils::MemoryResource *mem) const { + memgraph::metrics::IncrementCounter(memgraph::metrics::HashJoinOperator); + return MakeUniqueCursorPtr<HashJoinCursor>(mem, *this, mem); +} + } // namespace memgraph::query::plan diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index c3cf80042..1b65708b3 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -129,6 +129,8 @@ class Foreach; class EmptyResult; class EvaluatePatternFilter; class Apply; +class IndexedJoin; +class HashJoin; using LogicalOperatorCompositeVisitor = utils::CompositeVisitor<Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange, @@ -136,7 +138,7 @@ using LogicalOperatorCompositeVisitor = ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit, OrderBy, Merge, Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv, - Foreach, EmptyResult, EvaluatePatternFilter, Apply>; + Foreach, EmptyResult, EvaluatePatternFilter, Apply, IndexedJoin, HashJoin>; using LogicalOperatorLeafVisitor = utils::LeafVisitor<Once>; @@ -1537,6 +1539,12 @@ class EdgeUniquenessFilter : public memgraph::query::plan::LogicalOperator { std::shared_ptr<LogicalOperator> input() const override { return input_; } void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + std::string ToString() const override { + return fmt::format("EdgeUniquenessFilter {{{0} : {1}}}", + utils::IterableToString(previous_symbols_, ", ", [](const auto &sym) { return sym.name(); }), + expand_symbol_.name()); + } + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; Symbol expand_symbol_; std::vector<Symbol> previous_symbols_; @@ -2477,6 +2485,97 @@ class Apply : public memgraph::query::plan::LogicalOperator { }; }; +/// Applies symbols from both join branches +class IndexedJoin : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + IndexedJoin() {} + + IndexedJoin(std::shared_ptr<LogicalOperator> main_branch, std::shared_ptr<LogicalOperator> sub_branch); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource * /*unused*/) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable & /*unused*/) const override; + + bool HasSingleInput() const override; + std::shared_ptr<LogicalOperator> input() const override; + void set_input(std::shared_ptr<LogicalOperator> /*unused*/) override; + + std::shared_ptr<memgraph::query::plan::LogicalOperator> main_branch_; + std::shared_ptr<memgraph::query::plan::LogicalOperator> sub_branch_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<IndexedJoin>(); + object->main_branch_ = main_branch_ ? main_branch_->Clone(storage) : nullptr; + object->sub_branch_ = sub_branch_ ? sub_branch_->Clone(storage) : nullptr; + return object; + } + + private: + class IndexedJoinCursor : public Cursor { + public: + IndexedJoinCursor(const IndexedJoin &, utils::MemoryResource *); + bool Pull(Frame & /*unused*/, ExecutionContext & /*unused*/) override; + void Shutdown() override; + void Reset() override; + + private: + const IndexedJoin &self_; + UniqueCursorPtr main_branch_; + UniqueCursorPtr sub_branch_; + bool pull_input_{true}; + }; +}; + +/// Operator for producing the hash join of two input branches +class HashJoin : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + HashJoin() {} + /** Construct the operator with left input branch and right input branch. */ + HashJoin(const std::shared_ptr<LogicalOperator> &left_op, const std::vector<Symbol> &left_symbols, + const std::shared_ptr<LogicalOperator> &right_op, const std::vector<Symbol> &right_symbols, + EqualOperator *hash_join_condition) + : left_op_(left_op), + left_symbols_(left_symbols), + right_op_(right_op), + right_symbols_(right_symbols), + hash_join_condition_(hash_join_condition) {} + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override; + std::shared_ptr<LogicalOperator> input() const override; + void set_input(std::shared_ptr<LogicalOperator>) override; + + std::shared_ptr<memgraph::query::plan::LogicalOperator> left_op_; + std::vector<Symbol> left_symbols_; + std::shared_ptr<memgraph::query::plan::LogicalOperator> right_op_; + std::vector<Symbol> right_symbols_; + EqualOperator *hash_join_condition_; + + std::string ToString() const override { + return fmt::format("HashJoin {{{} : {}}}", + utils::IterableToString(left_symbols_, ", ", [](const auto &sym) { return sym.name(); }), + utils::IterableToString(right_symbols_, ", ", [](const auto &sym) { return sym.name(); })); + } + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<HashJoin>(); + object->left_op_ = left_op_ ? left_op_->Clone(storage) : nullptr; + object->left_symbols_ = left_symbols_; + object->right_op_ = right_op_ ? right_op_->Clone(storage) : nullptr; + object->right_symbols_ = right_symbols_; + object->hash_join_condition_ = hash_join_condition_ ? hash_join_condition_->Clone(storage) : nullptr; + return object; + } +}; + } // namespace plan } // namespace query } // namespace memgraph diff --git a/src/query/plan/operator_type_info.cpp b/src/query/plan/operator_type_info.cpp index efedc9b04..3b3ffe14e 100644 --- a/src/query/plan/operator_type_info.cpp +++ b/src/query/plan/operator_type_info.cpp @@ -148,4 +148,10 @@ constexpr utils::TypeInfo query::plan::Foreach::kType{utils::TypeId::FOREACH, "F constexpr utils::TypeInfo query::plan::Apply::kType{utils::TypeId::APPLY, "Apply", &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::IndexedJoin::kType{utils::TypeId::INDEXED_JOIN, "IndexedJoin", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::HashJoin::kType{utils::TypeId::HASH_JOIN, "HashJoin", + &query::plan::LogicalOperator::kType}; } // namespace memgraph diff --git a/src/query/plan/planner.hpp b/src/query/plan/planner.hpp index 443680c37..12b40a912 100644 --- a/src/query/plan/planner.hpp +++ b/src/query/plan/planner.hpp @@ -22,6 +22,7 @@ #include "query/plan/preprocess.hpp" #include "query/plan/pretty_print.hpp" #include "query/plan/rewrite/index_lookup.hpp" +#include "query/plan/rewrite/join.hpp" #include "query/plan/rule_based_planner.hpp" #include "query/plan/variable_start_planner.hpp" #include "query/plan/vertex_count_cache.hpp" @@ -43,7 +44,10 @@ class PostProcessor final { template <class TPlanningContext> std::unique_ptr<LogicalOperator> Rewrite(std::unique_ptr<LogicalOperator> plan, TPlanningContext *context) { - return RewriteWithIndexLookup(std::move(plan), context->symbol_table, context->ast_storage, context->db); + auto index_lookup_plan = + RewriteWithIndexLookup(std::move(plan), context->symbol_table, context->ast_storage, context->db); + return RewriteWithJoinRewriter(std::move(index_lookup_plan), context->symbol_table, context->ast_storage, + context->db); } template <class TVertexCounts> diff --git a/src/query/plan/preprocess.cpp b/src/query/plan/preprocess.cpp index 0a8a33e0a..8b1689796 100644 --- a/src/query/plan/preprocess.cpp +++ b/src/query/plan/preprocess.cpp @@ -19,6 +19,7 @@ #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/ast_visitor.hpp" +#include "query/frontend/semantic/symbol_table.hpp" #include "query/plan/preprocess.hpp" #include "utils/typeinfo.hpp" @@ -55,35 +56,126 @@ void ForEachPattern(Pattern &pattern, std::function<void(NodeAtom *)> base, // want to start expanding. std::vector<Expansion> NormalizePatterns(const SymbolTable &symbol_table, const std::vector<Pattern *> &patterns) { std::vector<Expansion> expansions; + ExpansionGroupId unknown_expansion_group_id = ExpansionGroupId::FromInt(-1); auto ignore_node = [&](auto *) {}; - auto collect_expansion = [&](auto *prev_node, auto *edge, auto *current_node) { - UsedSymbolsCollector collector(symbol_table); - if (edge->IsVariable()) { - if (edge->lower_bound_) edge->lower_bound_->Accept(collector); - if (edge->upper_bound_) edge->upper_bound_->Accept(collector); - if (edge->filter_lambda_.expression) edge->filter_lambda_.expression->Accept(collector); - // Remove symbols which are bound by lambda arguments. - collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_edge)); - collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_node)); - if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH || edge->type_ == EdgeAtom::Type::ALL_SHORTEST_PATHS) { - collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_edge)); - collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_node)); - } - } - expansions.emplace_back(Expansion{prev_node, edge, edge->direction_, false, collector.symbols_, current_node}); - }; for (const auto &pattern : patterns) { if (pattern->atoms_.size() == 1U) { auto *node = utils::Downcast<NodeAtom>(pattern->atoms_[0]); DMG_ASSERT(node, "First pattern atom is not a node"); - expansions.emplace_back(Expansion{node}); + expansions.emplace_back(Expansion{.node1 = node, .expansion_group_id = unknown_expansion_group_id}); } else { + auto collect_expansion = [&](auto *prev_node, auto *edge, auto *current_node) { + UsedSymbolsCollector collector(symbol_table); + if (edge->IsVariable()) { + if (edge->lower_bound_) edge->lower_bound_->Accept(collector); + if (edge->upper_bound_) edge->upper_bound_->Accept(collector); + if (edge->filter_lambda_.expression) edge->filter_lambda_.expression->Accept(collector); + // Remove symbols which are bound by lambda arguments. + collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_edge)); + collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_node)); + if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH || + edge->type_ == EdgeAtom::Type::ALL_SHORTEST_PATHS) { + collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_edge)); + collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_node)); + } + } + expansions.emplace_back(Expansion{prev_node, edge, edge->direction_, false, collector.symbols_, current_node, + unknown_expansion_group_id}); + }; ForEachPattern(*pattern, ignore_node, collect_expansion); } } return expansions; } +void AssignExpansionGroupIds(std::vector<Expansion> &expansions, Matching &matching, const SymbolTable &symbol_table) { + ExpansionGroupId next_expansion_group_id = ExpansionGroupId::FromUint(matching.number_of_expansion_groups + 1); + + auto assign_expansion_group_id = [&matching, &next_expansion_group_id](Symbol symbol, Expansion &expansion) { + ExpansionGroupId expansion_group_id_to_assign = next_expansion_group_id; + if (matching.node_symbol_to_expansion_group_id.contains(symbol)) { + expansion_group_id_to_assign = matching.node_symbol_to_expansion_group_id[symbol]; + } + + if (expansion.expansion_group_id.AsInt() == -1 || + expansion_group_id_to_assign.AsInt() < expansion.expansion_group_id.AsInt()) { + expansion.expansion_group_id = expansion_group_id_to_assign; + } + + matching.node_symbol_to_expansion_group_id[symbol] = expansion.expansion_group_id; + }; + + for (auto &expansion : expansions) { + const auto &node1_sym = symbol_table.at(*expansion.node1->identifier_); + assign_expansion_group_id(node1_sym, expansion); + + if (expansion.edge) { + const auto &edge_sym = symbol_table.at(*expansion.edge->identifier_); + const auto &node2_sym = symbol_table.at(*expansion.node2->identifier_); + + assign_expansion_group_id(edge_sym, expansion); + assign_expansion_group_id(node2_sym, expansion); + } + + matching.number_of_expansion_groups = matching.number_of_expansion_groups < expansion.expansion_group_id.AsUint() + ? expansion.expansion_group_id.AsUint() + : matching.number_of_expansion_groups; + next_expansion_group_id = ExpansionGroupId::FromUint(matching.number_of_expansion_groups + 1); + } + + // By the time we finished assigning expansions, no expansion should have its expansion group ID unassigned + for (const auto &expansion : matching.expansions) { + MG_ASSERT(expansion.expansion_group_id.AsInt() != -1, "Expansion group ID is not assigned to the pattern!"); + } +} + +void CollectEdgeSymbols(std::vector<Expansion> &expansions, Matching &matching, const SymbolTable &symbol_table) { + std::unordered_set<Symbol> edge_symbols; + for (auto &expansion : expansions) { + if (expansion.edge) { + const auto &edge_sym = symbol_table.at(*expansion.edge->identifier_); + // Fill edge symbols for Cyphermorphism. + edge_symbols.insert(edge_sym); + } + } + + if (!edge_symbols.empty()) { + matching.edge_symbols.emplace_back(edge_symbols); + } +} + +void CollectExpansionSymbols(std::vector<Expansion> &expansions, Matching &matching, const SymbolTable &symbol_table) { + for (auto &expansion : expansions) { + // Map node1 symbol to expansion + const auto &node1_sym = symbol_table.at(*expansion.node1->identifier_); + matching.expansion_symbols.insert(node1_sym); + + if (expansion.edge) { + const auto &edge_sym = symbol_table.at(*expansion.edge->identifier_); + matching.expansion_symbols.insert(edge_sym); + + const auto &node2_sym = symbol_table.at(*expansion.node2->identifier_); + matching.expansion_symbols.insert(node2_sym); + } + } +} + +void AddExpansionsToMatching(std::vector<Expansion> &expansions, Matching &matching, const SymbolTable &symbol_table) { + for (auto &expansion : expansions) { + // Matching may already have some expansions, so offset our index. + const size_t expansion_ix = matching.expansions.size(); + const auto &node1_sym = symbol_table.at(*expansion.node1->identifier_); + matching.node_symbol_to_expansions[node1_sym].insert(expansion_ix); + + if (expansion.edge) { + const auto &node2_sym = symbol_table.at(*expansion.node2->identifier_); + matching.node_symbol_to_expansions[node2_sym].insert(expansion_ix); + } + + matching.expansions.push_back(expansion); + } +} + auto SplitExpressionOnAnd(Expression *expression) { // TODO: Think about converting all filtering expression into CNF to improve // the granularity of filters which can be stand alone. @@ -487,32 +579,21 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr, const SymbolTable &symbol_ // were in a Where clause). void AddMatching(const std::vector<Pattern *> &patterns, Where *where, SymbolTable &symbol_table, AstStorage &storage, Matching &matching) { - auto expansions = NormalizePatterns(symbol_table, patterns); - std::unordered_set<Symbol> edge_symbols; - for (const auto &expansion : expansions) { - // Matching may already have some expansions, so offset our index. - const size_t expansion_ix = matching.expansions.size(); - // Map node1 symbol to expansion - const auto &node1_sym = symbol_table.at(*expansion.node1->identifier_); - matching.node_symbol_to_expansions[node1_sym].insert(expansion_ix); - // Add node1 to all symbols. - matching.expansion_symbols.insert(node1_sym); - if (expansion.edge) { - const auto &edge_sym = symbol_table.at(*expansion.edge->identifier_); - // Fill edge symbols for Cyphermorphism. - edge_symbols.insert(edge_sym); - // Map node2 symbol to expansion - const auto &node2_sym = symbol_table.at(*expansion.node2->identifier_); - matching.node_symbol_to_expansions[node2_sym].insert(expansion_ix); - // Add edge and node2 to all symbols - matching.expansion_symbols.insert(edge_sym); - matching.expansion_symbols.insert(node2_sym); - } - matching.expansions.push_back(expansion); - } - if (!edge_symbols.empty()) { - matching.edge_symbols.emplace_back(edge_symbols); - } + std::vector<Expansion> expansions = NormalizePatterns(symbol_table, patterns); + + // At this point, all of the expansions have the expansion group id of -1 + // By the time the assigning is done, all the expansions should have their expansion group id adjusted + AssignExpansionGroupIds(expansions, matching, symbol_table); + + // Add edge symbols for every expansion to ensure edge uniqueness + CollectEdgeSymbols(expansions, matching, symbol_table); + + // Add all the symbols found in these expansions + CollectExpansionSymbols(expansions, matching, symbol_table); + + // Matching is of reference type and needs to append the expansions + AddExpansionsToMatching(expansions, matching, symbol_table); + for (auto *const pattern : patterns) { matching.filters.CollectPatternFilters(*pattern, symbol_table, storage); if (pattern->identifier_->user_declared_) { diff --git a/src/query/plan/preprocess.hpp b/src/query/plan/preprocess.hpp index 684c6e534..4f46cc0f0 100644 --- a/src/query/plan/preprocess.hpp +++ b/src/query/plan/preprocess.hpp @@ -103,6 +103,36 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor { bool in_exists{false}; }; +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define PREPROCESS_DEFINE_ID_TYPE(name) \ + class name final { \ + private: \ + explicit name(uint64_t id) : id_(id) {} \ + \ + public: \ + /* Default constructor to allow serialization or preallocation. */ \ + name() = default; \ + \ + static name FromUint(uint64_t id) { return name(id); } \ + static name FromInt(int64_t id) { return name(utils::MemcpyCast<uint64_t>(id)); } \ + uint64_t AsUint() const { return id_; } \ + int64_t AsInt() const { return utils::MemcpyCast<int64_t>(id_); } \ + \ + private: \ + uint64_t id_; \ + }; \ + static_assert(std::is_trivially_copyable<name>::value, "query::plan::" #name " must be trivially copyable!"); \ + inline bool operator==(const name &first, const name &second) { return first.AsUint() == second.AsUint(); } \ + inline bool operator!=(const name &first, const name &second) { return first.AsUint() != second.AsUint(); } \ + inline bool operator<(const name &first, const name &second) { return first.AsUint() < second.AsUint(); } \ + inline bool operator>(const name &first, const name &second) { return first.AsUint() > second.AsUint(); } \ + inline bool operator<=(const name &first, const name &second) { return first.AsUint() <= second.AsUint(); } \ + inline bool operator>=(const name &first, const name &second) { return first.AsUint() >= second.AsUint(); } + +PREPROCESS_DEFINE_ID_TYPE(ExpansionGroupId); + +#undef STORAGE_DEFINE_ID_TYPE + /// Normalized representation of a pattern that needs to be matched. struct Expansion { /// The first node in the expansion, it can be a single node. @@ -119,6 +149,8 @@ struct Expansion { /// Optional node at the other end of an edge. If the expansion /// contains an edge, then this node is required. NodeAtom *node2 = nullptr; + // ExpansionGroupId represents a distinct part of the matching which is not tied to any other symbols. + ExpansionGroupId expansion_group_id = ExpansionGroupId(); }; struct FilterMatching; @@ -394,6 +426,10 @@ struct Matching { Filters filters; /// Maps node symbols to expansions which bind them. std::unordered_map<Symbol, std::set<size_t>> node_symbol_to_expansions{}; + /// Tracker of the total number of expansion groups for correct assigning of expansion group IDs + size_t number_of_expansion_groups{0}; + /// Maps every node symbol to its expansion group ID + std::unordered_map<Symbol, ExpansionGroupId> node_symbol_to_expansion_group_id{}; /// Maps named path symbols to a vector of Symbols that define its pattern. std::unordered_map<Symbol, std::vector<Symbol>> named_paths{}; /// All node and edge symbols across all expansions (from all matches). diff --git a/src/query/plan/pretty_print.cpp b/src/query/plan/pretty_print.cpp index 47fb8dc9d..8315cf175 100644 --- a/src/query/plan/pretty_print.cpp +++ b/src/query/plan/pretty_print.cpp @@ -101,7 +101,6 @@ PRE_VISIT(SetProperties); PRE_VISIT(SetLabels); PRE_VISIT(RemoveProperty); PRE_VISIT(RemoveLabels); -PRE_VISIT(EdgeUniquenessFilter); PRE_VISIT(Accumulate); PRE_VISIT(EmptyResult); PRE_VISIT(EvaluatePatternFilter); @@ -172,6 +171,13 @@ bool PlanPrinter::PreVisit(query::plan::Cartesian &op) { return false; } +bool PlanPrinter::PreVisit(query::plan::HashJoin &op) { + WithPrintLn([&](auto &out) { out << "* " << op.ToString(); }); + Branch(*op.right_op_); + op.left_op_->Accept(*this); + return false; +} + bool PlanPrinter::PreVisit(query::plan::Foreach &op) { WithPrintLn([](auto &out) { out << "* Foreach"; }); Branch(*op.update_clauses_); @@ -188,12 +194,24 @@ bool PlanPrinter::PreVisit(query::plan::Filter &op) { return false; } +bool PlanPrinter::PreVisit(query::plan::EdgeUniquenessFilter &op) { + WithPrintLn([&](auto &out) { out << "* " << op.ToString(); }); + return true; +} + bool PlanPrinter::PreVisit(query::plan::Apply &op) { WithPrintLn([](auto &out) { out << "* Apply"; }); Branch(*op.subquery_); op.input_->Accept(*this); return false; } + +bool PlanPrinter::PreVisit(query::plan::IndexedJoin &op) { + WithPrintLn([](auto &out) { out << "* IndexedJoin"; }); + Branch(*op.sub_branch_); + op.main_branch_->Accept(*this); + return false; +} #undef PRE_VISIT bool PlanPrinter::DefaultPreVisit() { @@ -879,6 +897,20 @@ bool PlanToJsonVisitor::PreVisit(Cartesian &op) { return false; } +bool PlanToJsonVisitor::PreVisit(HashJoin &op) { + json self; + self["name"] = "HashJoin"; + + op.left_op_->Accept(*this); + self["left_op"] = PopOutput(); + + op.right_op_->Accept(*this); + self["right_op"] = PopOutput(); + + output_ = std::move(self); + return false; +} + bool PlanToJsonVisitor::PreVisit(Foreach &op) { json self; self["name"] = "Foreach"; @@ -921,6 +953,20 @@ bool PlanToJsonVisitor::PreVisit(Apply &op) { return false; } +bool PlanToJsonVisitor::PreVisit(IndexedJoin &op) { + json self; + self["name"] = "IndexedJoin"; + + op.main_branch_->Accept(*this); + self["left"] = PopOutput(); + + op.sub_branch_->Accept(*this); + self["right"] = PopOutput(); + + output_ = std::move(self); + return false; +} + } // namespace impl } // namespace memgraph::query::plan diff --git a/src/query/plan/pretty_print.hpp b/src/query/plan/pretty_print.hpp index 9b220b726..645fe17a5 100644 --- a/src/query/plan/pretty_print.hpp +++ b/src/query/plan/pretty_print.hpp @@ -80,6 +80,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(Merge &) override; bool PreVisit(Optional &) override; bool PreVisit(Cartesian &) override; + bool PreVisit(HashJoin &) override; bool PreVisit(EmptyResult &) override; bool PreVisit(Produce &) override; @@ -96,6 +97,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(LoadCsv &) override; bool PreVisit(Foreach &) override; bool PreVisit(Apply & /*unused*/) override; + bool PreVisit(IndexedJoin & /*unused*/) override; bool Visit(Once &) override; @@ -192,6 +194,8 @@ class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(EdgeUniquenessFilter &) override; bool PreVisit(Cartesian &) override; bool PreVisit(Apply & /*unused*/) override; + bool PreVisit(HashJoin &) override; + bool PreVisit(IndexedJoin & /*unused*/) override; bool PreVisit(ScanAll &) override; bool PreVisit(ScanAllByLabel &) override; diff --git a/src/query/plan/rewrite/index_lookup.cpp b/src/query/plan/rewrite/index_lookup.cpp index 532539a99..e584864a5 100644 --- a/src/query/plan/rewrite/index_lookup.cpp +++ b/src/query/plan/rewrite/index_lookup.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -22,30 +22,53 @@ DEFINE_VALIDATED_int64(query_vertex_count_to_expand_existing, 10, namespace memgraph::query::plan::impl { -Expression *RemoveAndExpressions(Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove) { - auto *and_op = utils::Downcast<AndOperator>(expr); - if (!and_op) return expr; - if (utils::Contains(exprs_to_remove, and_op)) { - return nullptr; +ExpressionRemovalResult RemoveExpressions(Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove) { + if (utils::Contains(exprs_to_remove, expr)) { + return ExpressionRemovalResult{.trimmed_expression = nullptr, .did_remove = true}; } + + auto *and_op = utils::Downcast<AndOperator>(expr); + + // currently we are processing expressions by dividing them into and disjoint expressions + // no work needed if there is no multiple and expressions + if (!and_op) return ExpressionRemovalResult{.trimmed_expression = expr}; + + // and operation is fully contained inside the expressions to remove + if (utils::Contains(exprs_to_remove, and_op)) { + return ExpressionRemovalResult{.trimmed_expression = nullptr, .did_remove = true}; + } + + bool did_remove = false; if (utils::Contains(exprs_to_remove, and_op->expression1_)) { and_op->expression1_ = nullptr; + did_remove = true; } if (utils::Contains(exprs_to_remove, and_op->expression2_)) { and_op->expression2_ = nullptr; + did_remove = true; } - and_op->expression1_ = RemoveAndExpressions(and_op->expression1_, exprs_to_remove); - and_op->expression2_ = RemoveAndExpressions(and_op->expression2_, exprs_to_remove); + + auto removal1 = RemoveExpressions(and_op->expression1_, exprs_to_remove); + and_op->expression1_ = removal1.trimmed_expression; + did_remove = did_remove || removal1.did_remove; + + auto removal2 = RemoveExpressions(and_op->expression2_, exprs_to_remove); + and_op->expression2_ = removal2.trimmed_expression; + did_remove = did_remove || removal2.did_remove; + if (!and_op->expression1_ && !and_op->expression2_) { - return nullptr; + return ExpressionRemovalResult{.trimmed_expression = nullptr, .did_remove = did_remove}; } + if (and_op->expression1_ && !and_op->expression2_) { - return and_op->expression1_; + return ExpressionRemovalResult{.trimmed_expression = and_op->expression1_, .did_remove = did_remove}; } + if (and_op->expression2_ && !and_op->expression1_) { - return and_op->expression2_; + return ExpressionRemovalResult{.trimmed_expression = and_op->expression2_, .did_remove = did_remove}; } - return and_op; + + return ExpressionRemovalResult{.trimmed_expression = and_op, .did_remove = did_remove}; } } // namespace memgraph::query::plan::impl diff --git a/src/query/plan/rewrite/index_lookup.hpp b/src/query/plan/rewrite/index_lookup.hpp index e10d14b82..610e8a61b 100644 --- a/src/query/plan/rewrite/index_lookup.hpp +++ b/src/query/plan/rewrite/index_lookup.hpp @@ -34,9 +34,14 @@ namespace memgraph::query::plan { namespace impl { +struct ExpressionRemovalResult { + Expression *trimmed_expression; + bool did_remove{false}; +}; + // Return the new root expression after removing the given expressions from the // given expression tree. -Expression *RemoveAndExpressions(Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove); +ExpressionRemovalResult RemoveExpressions(Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove); template <class TDbAccessor> class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { @@ -61,10 +66,31 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { // free the memory. bool PostVisit(Filter &op) override { prev_ops_.pop_back(); - op.expression_ = RemoveAndExpressions(op.expression_, filter_exprs_for_removal_); - if (!op.expression_ || utils::Contains(filter_exprs_for_removal_, op.expression_)) { + ExpressionRemovalResult removal = RemoveExpressions(op.expression_, filter_exprs_for_removal_); + op.expression_ = removal.trimmed_expression; + + // edge uniqueness filter comes always before filter in plan generation + LogicalOperator *input = op.input().get(); + LogicalOperator *parent = &op; + while (input->GetTypeInfo() == EdgeUniquenessFilter::kType) { + parent = input; + input = input->input().get(); + } + bool is_child_cartesian = input->GetTypeInfo() == Cartesian::kType; + + if (is_child_cartesian && removal.did_remove) { + // if we removed something from filter in front of a Cartesian, then we are doing a join from + // 2 different branches + auto *cartesian = dynamic_cast<Cartesian *>(input); + auto indexed_join = std::make_shared<IndexedJoin>(cartesian->left_op_, cartesian->right_op_); + parent->set_input(indexed_join); + } + + if (!op.expression_) { + // if we emptied all the expressions from the filter, then we don't need this operator anymore SetOnParent(op.input()); } + return true; } @@ -183,12 +209,34 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { bool PreVisit(Cartesian &op) override { prev_ops_.push_back(&op); RewriteBranch(&op.left_op_); - RewriteBranch(&op.right_op_); + + // we add the symbols that we encountered in the left part of the cartesian + // the reason for that is that in right part of the cartesian, we could be + // possibly using an indexed operation instead of a scan all + additional_bound_symbols_.insert(op.left_symbols_.begin(), op.left_symbols_.end()); + op.right_op_->Accept(*this); + return false; } bool PostVisit(Cartesian &) override { prev_ops_.pop_back(); + + // clear cartesian symbols as we exited the cartesian operator + additional_bound_symbols_.clear(); + + return true; + } + + bool PreVisit(IndexedJoin &op) override { + prev_ops_.push_back(&op); + RewriteBranch(&op.main_branch_); + RewriteBranch(&op.sub_branch_); + return false; + } + + bool PostVisit(IndexedJoin & /*unused*/) override { + prev_ops_.pop_back(); return true; } @@ -488,6 +536,9 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { std::unordered_set<Expression *> filter_exprs_for_removal_; std::vector<LogicalOperator *> prev_ops_; + // additional symbols that are present from other non-main branches but have influence on indexing + std::unordered_set<Symbol> additional_bound_symbols_; + struct LabelPropertyIndex { LabelIx label; // FilterInfo with PropertyFilter. @@ -505,7 +556,22 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { new_root_ = input; return; } - prev_ops_.back()->set_input(input); + + auto *parent = prev_ops_.back(); + if (parent->HasSingleInput()) { + parent->set_input(input); + return; + } + + if (parent->GetTypeInfo() == Cartesian::kType) { + auto *parent_cartesian = dynamic_cast<Cartesian *>(parent); + parent_cartesian->right_op_ = input; + parent_cartesian->right_symbols_ = input->ModifiedSymbols(*symbol_table_); + return; + } + + // if we're sure that we want to set on parent, this should never happen + LOG_FATAL("Error during index rewriting of the query!"); } void RewriteBranch(std::shared_ptr<LogicalOperator> *branch) { @@ -535,10 +601,10 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { } // Finds the label-property combination. The first criteria based on number of vertices indexed -> if one index has - // 10x less than the other one, always choose the smaller one. Otherwise, choose the index with smallest average group - // size based on key distribution. If average group size is equal, choose the index that has distribution closer to - // uniform distribution. Conditions based on average group size and key distribution can be only taken into account if - // the user has run `ANALYZE GRAPH` query before If the index cannot be found, nullopt is returned. + // 10x less than the other one, always choose the smaller one. Otherwise, choose the index with smallest average + // group size based on key distribution. If average group size is equal, choose the index that has distribution + // closer to uniform distribution. Conditions based on average group size and key distribution can be only taken + // into account if the user has run `ANALYZE GRAPH` query before If the index cannot be found, nullopt is returned. std::optional<LabelPropertyIndex> FindBestLabelPropertyIndex(const Symbol &symbol, const std::unordered_set<Symbol> &bound_symbols) { auto are_bound = [&bound_symbols](const auto &used_symbols) { @@ -551,10 +617,10 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { }; /* - * Comparator function between two indices. If new index has >= 10x vertices than the existing, it cannot be better. - * If it is <= 10x in number of vertices, check average group size of property values. The index with smaller - * average group size is better. If the average group size is the same, choose the one closer to the uniform - * distribution + * Comparator function between two indices. If new index has >= 10x vertices than the existing, it cannot be + * better. If it is <= 10x in number of vertices, check average group size of property values. The index with + * smaller average group size is better. If the average group size is the same, choose the one closer to the + * uniform distribution * @param found: Current best label-property index. * @param new_stats: Label-property index candidate. * @param vertex_count: New index's number of vertices. @@ -633,8 +699,12 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { const auto &input = scan.input(); const auto &node_symbol = scan.output_symbol_; const auto &view = scan.view_; + const auto &modified_symbols = scan.ModifiedSymbols(*symbol_table_); + std::unordered_set<Symbol> bound_symbols(modified_symbols.begin(), modified_symbols.end()); + bound_symbols.insert(additional_bound_symbols_.begin(), additional_bound_symbols_.end()); + auto are_bound = [&bound_symbols](const auto &used_symbols) { for (const auto &used_symbol : used_symbols) { if (!utils::Contains(bound_symbols, used_symbol)) { diff --git a/src/query/plan/rewrite/join.hpp b/src/query/plan/rewrite/join.hpp new file mode 100644 index 000000000..c16f5b60d --- /dev/null +++ b/src/query/plan/rewrite/join.hpp @@ -0,0 +1,550 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +/// This file provides a plan rewriter which replaces `Filter` and `ScanAll` +/// operations with `ScanAllBy<Index>` if possible. The public entrypoint is +/// `RewriteWithIndexLookup`. + +#pragma once + +#include <algorithm> +#include <memory> +#include <optional> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include <gflags/gflags.h> + +#include "query/plan/operator.hpp" +#include "query/plan/preprocess.hpp" + +namespace memgraph::query::plan { + +namespace impl { + +template <class TDbAccessor> +class JoinRewriter final : public HierarchicalLogicalOperatorVisitor { + public: + JoinRewriter(SymbolTable *symbol_table, AstStorage *ast_storage, TDbAccessor *db) + : symbol_table_(symbol_table), ast_storage_(ast_storage), db_(db) {} + + using HierarchicalLogicalOperatorVisitor::PostVisit; + using HierarchicalLogicalOperatorVisitor::PreVisit; + using HierarchicalLogicalOperatorVisitor::Visit; + + bool Visit(Once &) override { return true; } + + bool PreVisit(Filter &op) override { + prev_ops_.push_back(&op); + filters_.CollectFilterExpression(op.expression_, *symbol_table_); + return true; + } + + // Remove no longer needed Filter in PostVisit, this should be the last thing + // Filter::Accept does, so it should be safe to remove the last reference and + // free the memory. + bool PostVisit(Filter &op) override { + prev_ops_.pop_back(); + + ExpressionRemovalResult removal = RemoveExpressions(op.expression_, filter_exprs_for_removal_); + op.expression_ = removal.trimmed_expression; + if (!op.expression_ || utils::Contains(filter_exprs_for_removal_, op.expression_)) { + SetOnParent(op.input()); + } + + return true; + } + + bool PreVisit(ScanAll &op) override { + prev_ops_.push_back(&op); + return true; + } + + bool PostVisit(ScanAll &scan) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Expand &op) override { + prev_ops_.push_back(&op); + return true; + } + + bool PostVisit(Expand & /*expand*/) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ExpandVariable &op) override { + prev_ops_.push_back(&op); + return true; + } + + // See if it might be better to do ScanAllBy<Index> of the destination and + // then do ExpandVariable to existing. + bool PostVisit(ExpandVariable &expand) override { + prev_ops_.pop_back(); + return true; + } + + // The following operators may only use index lookup in filters inside of + // their own branches. So we handle them all the same. + // * Input operator is visited with the current visitor. + // * Custom operator branches are visited with a new visitor. + + bool PreVisit(Merge &op) override { + prev_ops_.push_back(&op); + op.input()->Accept(*this); + RewriteBranch(&op.merge_match_); + return false; + } + + bool PostVisit(Merge &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Optional &op) override { + prev_ops_.push_back(&op); + op.input()->Accept(*this); + RewriteBranch(&op.optional_); + return false; + } + + bool PostVisit(Optional &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Cartesian &op) override { + prev_ops_.push_back(&op); + RewriteBranch(&op.left_op_); + cartesian_symbols_.insert(op.left_symbols_.begin(), op.left_symbols_.end()); + op.right_op_->Accept(*this); + return false; + } + + bool PostVisit(Cartesian &op) override { + prev_ops_.pop_back(); + auto hash_join = GenHashJoin(op); + if (hash_join) { + SetOnParent(std::move(hash_join)); + } + return true; + } + + bool PreVisit(IndexedJoin &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(IndexedJoin &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(HashJoin &op) override { + prev_ops_.push_back(&op); + RewriteBranch(&op.left_op_); + RewriteBranch(&op.right_op_); + return false; + } + + bool PostVisit(HashJoin &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Union &op) override { + prev_ops_.push_back(&op); + RewriteBranch(&op.left_op_); + RewriteBranch(&op.right_op_); + return false; + } + + bool PostVisit(Union &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(CreateNode &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(CreateNode &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(CreateExpand &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(CreateExpand &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ScanAllByLabel &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ScanAllByLabel &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ScanAllByLabelPropertyRange &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ScanAllByLabelPropertyRange &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ScanAllByLabelPropertyValue &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ScanAllByLabelPropertyValue &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ScanAllByLabelProperty &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ScanAllByLabelProperty &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ScanAllById &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ScanAllById &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(ConstructNamedPath &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ConstructNamedPath &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Produce &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Produce &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(EmptyResult &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(EmptyResult &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Delete &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Delete &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(SetProperty &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(SetProperty &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(SetProperties &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(SetProperties &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(SetLabels &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(SetLabels &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(RemoveProperty &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(RemoveProperty &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(RemoveLabels &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(RemoveLabels &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(EdgeUniquenessFilter &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(EdgeUniquenessFilter &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Accumulate &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Accumulate &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Aggregate &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Aggregate &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Skip &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Skip &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Limit &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Limit &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(OrderBy &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(OrderBy &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Unwind &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Unwind &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Distinct &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(Distinct &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(CallProcedure &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(CallProcedure &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Foreach &op) override { + prev_ops_.push_back(&op); + op.input()->Accept(*this); + RewriteBranch(&op.update_clauses_); + return false; + } + + bool PostVisit(Foreach &) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(EvaluatePatternFilter &op) override { + prev_ops_.push_back(&op); + return true; + } + + bool PostVisit(EvaluatePatternFilter & /*op*/) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(Apply &op) override { + prev_ops_.push_back(&op); + op.input()->Accept(*this); + RewriteBranch(&op.subquery_); + return false; + } + + bool PostVisit(Apply & /*op*/) override { + prev_ops_.pop_back(); + return true; + } + + bool PreVisit(LoadCsv &op) override { + prev_ops_.push_back(&op); + return true; + } + + bool PostVisit(LoadCsv & /*op*/) override { + prev_ops_.pop_back(); + return true; + } + + std::shared_ptr<LogicalOperator> new_root_; + + private: + SymbolTable *symbol_table_; + AstStorage *ast_storage_; + TDbAccessor *db_; + // Collected filters, pending for examination if they can be used for advanced + // lookup operations (by index, node ID, ...). + Filters filters_; + // Expressions which no longer need a plain Filter operator. + std::unordered_set<Expression *> filter_exprs_for_removal_; + std::vector<LogicalOperator *> prev_ops_; + std::unordered_set<Symbol> cartesian_symbols_; + + bool DefaultPreVisit() override { throw utils::NotYetImplemented("Operator not yet covered by JoinRewriter"); } + + void SetOnParent(const std::shared_ptr<LogicalOperator> &input) { + MG_ASSERT(input); + if (prev_ops_.empty()) { + MG_ASSERT(!new_root_); + new_root_ = input; + return; + } + prev_ops_.back()->set_input(input); + } + + void RewriteBranch(std::shared_ptr<LogicalOperator> *branch) { + JoinRewriter<TDbAccessor> rewriter(symbol_table_, ast_storage_, db_); + (*branch)->Accept(rewriter); + if (rewriter.new_root_) { + *branch = rewriter.new_root_; + } + } + + std::unique_ptr<HashJoin> GenHashJoin(const Cartesian &cartesian) { + const auto &left_op = cartesian.left_op_; + const auto &left_symbols = cartesian.left_symbols_; + const auto &right_op = cartesian.right_op_; + const auto &right_symbols = cartesian.right_symbols_; + + auto modified_symbols = cartesian.ModifiedSymbols(*symbol_table_); + modified_symbols.insert(modified_symbols.end(), left_symbols.begin(), left_symbols.end()); + + std::unordered_set<Symbol> bound_symbols(modified_symbols.begin(), modified_symbols.end()); + auto are_bound = [&bound_symbols](const auto &used_symbols) { + for (const auto &used_symbol : used_symbols) { + if (!utils::Contains(bound_symbols, used_symbol)) { + return false; + } + } + return true; + }; + + for (const auto &filter : filters_) { + if (filter.type != FilterInfo::Type::Property) { + continue; + } + + if (filter.property_filter->is_symbol_in_value_ || !are_bound(filter.used_symbols)) { + continue; + } + + if (filter.property_filter->type_ != PropertyFilter::Type::EQUAL) { + continue; + } + + if (filter.property_filter->value_->GetTypeInfo() != PropertyLookup::kType) { + continue; + } + auto *rhs_lookup = static_cast<PropertyLookup *>(filter.property_filter->value_); + + auto *join_condition = static_cast<EqualOperator *>(filter.expression); + auto lhs_symbol = filter.property_filter->symbol_; + auto lhs_property = filter.property_filter->property_; + auto rhs_symbol = symbol_table_->at(*static_cast<Identifier *>(rhs_lookup->expression_)); + auto rhs_property = rhs_lookup->property_; + filter_exprs_for_removal_.insert(filter.expression); + filters_.EraseFilter(filter); + return std::make_unique<HashJoin>(left_op, left_symbols, right_op, right_symbols, join_condition); + } + + return nullptr; + } +}; + +} // namespace impl + +template <class TDbAccessor> +std::unique_ptr<LogicalOperator> RewriteWithJoinRewriter(std::unique_ptr<LogicalOperator> root_op, + SymbolTable *symbol_table, AstStorage *ast_storage, + TDbAccessor *db) { + impl::JoinRewriter<TDbAccessor> rewriter(symbol_table, ast_storage, db); + root_op->Accept(rewriter); + if (rewriter.new_root_) { + // This shouldn't happen in real use cases because, as JoinRewriter removes Filter operations, they cannot be the + // root operator. In case we somehow missed this, raise NotYetImplemented instead of a MG_ASSERT crashing the + // application. + throw utils::NotYetImplemented("A Filter operator cannot be JoinRewriter's root"); + } + return root_op; +} + +} // namespace memgraph::query::plan diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index c02a7eaa6..cd223dd8e 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -606,6 +606,9 @@ std::unique_ptr<LogicalOperator> GenUnion(const CypherUnion &cypher_union, std:: right_op->OutputSymbols(symbol_table)); } +Symbol GetSymbol(NodeAtom *atom, const SymbolTable &symbol_table) { return symbol_table.at(*atom->identifier_); } +Symbol GetSymbol(EdgeAtom *atom, const SymbolTable &symbol_table) { return symbol_table.at(*atom->identifier_); } + } // namespace impl } // namespace memgraph::query::plan diff --git a/src/query/plan/rule_based_planner.hpp b/src/query/plan/rule_based_planner.hpp index 2ecf54ee2..dac9462ad 100644 --- a/src/query/plan/rule_based_planner.hpp +++ b/src/query/plan/rule_based_planner.hpp @@ -15,8 +15,7 @@ #include <optional> #include <variant> -#include "gflags/gflags.h" - +#include "flags/run_time_configurable.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/ast_visitor.hpp" #include "query/plan/operator.hpp" @@ -90,6 +89,9 @@ bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols, cons std::unordered_set<Symbol> GetSubqueryBoundSymbols(const std::vector<SingleQueryPart> &single_query_parts, SymbolTable &symbol_table, AstStorage &storage); +Symbol GetSymbol(NodeAtom *atom, const SymbolTable &symbol_table); +Symbol GetSymbol(EdgeAtom *atom, const SymbolTable &symbol_table); + /// Utility function for iterating pattern atoms and accumulating a result. /// /// Each pattern is of the form `NodeAtom (, EdgeAtom, NodeAtom)*`. Therefore, @@ -436,8 +438,8 @@ class RuleBasedPlanner { // regular match. auto last_op = GenFilters(std::move(input_op), bound_symbols, filters, storage, symbol_table); - last_op = HandleExpansion(std::move(last_op), matching, symbol_table, storage, bound_symbols, - match_context.new_symbols, named_paths, filters, match_context.view); + last_op = HandleExpansions(std::move(last_op), matching, symbol_table, storage, bound_symbols, + match_context.new_symbols, named_paths, filters, match_context.view); MG_ASSERT(named_paths.empty(), "Expected to generate all named paths"); // We bound all named path symbols, so just add them to new_symbols. @@ -475,32 +477,198 @@ class RuleBasedPlanner { return std::make_unique<plan::Merge>(std::move(input_op), std::move(on_match), std::move(on_create)); } - std::unique_ptr<LogicalOperator> HandleExpansion(std::unique_ptr<LogicalOperator> last_op, const Matching &matching, - const SymbolTable &symbol_table, AstStorage &storage, - std::unordered_set<Symbol> &bound_symbols, - std::vector<Symbol> &new_symbols, - std::unordered_map<Symbol, std::vector<Symbol>> &named_paths, - Filters &filters, storage::View view) { + std::unique_ptr<LogicalOperator> HandleExpansions(std::unique_ptr<LogicalOperator> last_op, const Matching &matching, + const SymbolTable &symbol_table, AstStorage &storage, + std::unordered_set<Symbol> &bound_symbols, + std::vector<Symbol> &new_symbols, + std::unordered_map<Symbol, std::vector<Symbol>> &named_paths, + Filters &filters, storage::View view) { + if (flags::run_time::GetCartesianProductEnabled()) { + return HandleExpansionsWithCartesian(std::move(last_op), matching, symbol_table, storage, bound_symbols, + new_symbols, named_paths, filters, view); + } + + return HandleExpansionsWithoutCartesian(std::move(last_op), matching, symbol_table, storage, bound_symbols, + new_symbols, named_paths, filters, view); + } + + std::unique_ptr<LogicalOperator> HandleExpansionsWithCartesian( + std::unique_ptr<LogicalOperator> last_op, const Matching &matching, const SymbolTable &symbol_table, + AstStorage &storage, std::unordered_set<Symbol> &bound_symbols, std::vector<Symbol> &new_symbols, + std::unordered_map<Symbol, std::vector<Symbol>> &named_paths, Filters &filters, storage::View view) { + if (matching.expansions.empty()) { + return last_op; + } + + std::set<ExpansionGroupId> all_expansion_groups; for (const auto &expansion : matching.expansions) { - const auto &node1_symbol = symbol_table.at(*expansion.node1->identifier_); - if (bound_symbols.insert(node1_symbol).second) { - // We have just bound this symbol, so generate ScanAll which fills it. - last_op = std::make_unique<ScanAll>(std::move(last_op), node1_symbol, view); - new_symbols.emplace_back(node1_symbol); + all_expansion_groups.insert(expansion.expansion_group_id); + } - last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); - last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths); - last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); - } else if (named_paths.size() == 1U) { - last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); - last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths); - last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); + std::set<ExpansionGroupId> visited_expansion_groups; + + last_op = + GenerateExpansionOnAlreadySeenSymbols(std::move(last_op), matching, visited_expansion_groups, symbol_table, + storage, bound_symbols, new_symbols, named_paths, filters, view); + + // We want to create separate branches of scan operators for each expansion group group of patterns + // Whenever there are 2 scan branches, they will be joined with a Cartesian operator + + // New symbols from the opposite branch + // We need to see what are cross new symbols in order to check for edge uniqueness for cross branch of same matching + // Since one matching needs to comfort to Cyphermorphism + std::vector<Symbol> cross_branch_new_symbols; + bool initial_expansion_done = false; + for (const auto &expansion : matching.expansions) { + if (visited_expansion_groups.contains(expansion.expansion_group_id)) { + continue; } - if (expansion.edge) { - last_op = GenExpand(std::move(last_op), expansion, symbol_table, bound_symbols, matching, storage, filters, - named_paths, new_symbols, view); + std::unique_ptr<LogicalOperator> starting_expansion_operator = nullptr; + if (!initial_expansion_done) { + starting_expansion_operator = std::move(last_op); + initial_expansion_done = true; } + std::vector<Symbol> starting_symbols; + if (starting_expansion_operator) { + starting_symbols = starting_expansion_operator->ModifiedSymbols(symbol_table); + } + std::vector<Symbol> new_expansion_group_symbols; + std::unordered_set<Symbol> new_bound_symbols{starting_symbols.begin(), starting_symbols.end()}; + std::unique_ptr<LogicalOperator> expansion_group = GenerateExpansionGroup( + std::move(starting_expansion_operator), matching, symbol_table, storage, new_bound_symbols, + new_expansion_group_symbols, named_paths, filters, view, expansion.expansion_group_id); + + visited_expansion_groups.insert(expansion.expansion_group_id); + + new_symbols.insert(new_symbols.end(), new_expansion_group_symbols.begin(), new_expansion_group_symbols.end()); + bound_symbols.insert(new_bound_symbols.begin(), new_bound_symbols.end()); + + // If we just started and have no beginning operator, make the beginning operator and transfer cross symbols + // for next iteration + bool started_matching_operators = !last_op; + bool has_more_expansions = visited_expansion_groups.size() < all_expansion_groups.size(); + if (started_matching_operators) { + last_op = std::move(expansion_group); + if (has_more_expansions) { + cross_branch_new_symbols = new_expansion_group_symbols; + } + continue; + } + + // if there is already a last operator, then we have 2 branches that we can merge into cartesian + last_op = GenerateCartesian(std::move(last_op), std::move(expansion_group), symbol_table); + + // additionally, check for Cyphermorphism of the previous branch with new bound symbols + for (const auto &new_symbol : cross_branch_new_symbols) { + if (new_symbol.type_ == Symbol::Type::EDGE) { + last_op = EnsureCyphermorphism(std::move(last_op), new_symbol, matching, new_bound_symbols); + } + } + + last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); + + // we aggregate all the so far new symbols so we can test them in the next iteration against the new + // expansion group + if (has_more_expansions) { + cross_branch_new_symbols.insert(cross_branch_new_symbols.end(), new_expansion_group_symbols.begin(), + new_expansion_group_symbols.end()); + } + } + + MG_ASSERT(visited_expansion_groups.size() == all_expansion_groups.size(), + "Did not create expansions for all expansion group expansions in the planner!"); + + return last_op; + } + + std::unique_ptr<LogicalOperator> HandleExpansionsWithoutCartesian( + std::unique_ptr<LogicalOperator> last_op, const Matching &matching, const SymbolTable &symbol_table, + AstStorage &storage, std::unordered_set<Symbol> &bound_symbols, std::vector<Symbol> &new_symbols, + std::unordered_map<Symbol, std::vector<Symbol>> &named_paths, Filters &filters, storage::View view) { + for (const auto &expansion : matching.expansions) { + last_op = GenerateOperatorsForExpansion(std::move(last_op), matching, expansion, symbol_table, storage, + bound_symbols, new_symbols, named_paths, filters, view); + } + + return last_op; + } + + std::unique_ptr<LogicalOperator> GenerateExpansionOnAlreadySeenSymbols( + std::unique_ptr<LogicalOperator> last_op, const Matching &matching, + std::set<ExpansionGroupId> &visited_expansion_groups, SymbolTable symbol_table, AstStorage &storage, + std::unordered_set<Symbol> &bound_symbols, std::vector<Symbol> &new_symbols, + std::unordered_map<Symbol, std::vector<Symbol>> &named_paths, Filters &filters, storage::View view) { + bool added_new_expansions = true; + while (added_new_expansions) { + added_new_expansions = false; + for (const auto &expansion : matching.expansions) { + // We want to create separate matching branch operators for each expansion group group of patterns + if (visited_expansion_groups.contains(expansion.expansion_group_id)) { + continue; + } + + bool src_node_already_seen = bound_symbols.contains(impl::GetSymbol(expansion.node1, symbol_table)); + bool edge_already_seen = + expansion.edge && bound_symbols.contains(impl::GetSymbol(expansion.edge, symbol_table)); + bool dest_node_already_seen = + expansion.edge && bound_symbols.contains(impl::GetSymbol(expansion.node2, symbol_table)); + + if (src_node_already_seen || edge_already_seen || dest_node_already_seen) { + last_op = GenerateExpansionGroup(std::move(last_op), matching, symbol_table, storage, bound_symbols, + new_symbols, named_paths, filters, view, expansion.expansion_group_id); + visited_expansion_groups.insert(expansion.expansion_group_id); + added_new_expansions = true; + break; + } + } + } + + return last_op; + } + + std::unique_ptr<LogicalOperator> GenerateExpansionGroup( + std::unique_ptr<LogicalOperator> last_op, const Matching &matching, const SymbolTable &symbol_table, + AstStorage &storage, std::unordered_set<Symbol> &bound_symbols, std::vector<Symbol> &new_symbols, + std::unordered_map<Symbol, std::vector<Symbol>> &named_paths, Filters &filters, storage::View view, + ExpansionGroupId expansion_group_id) { + for (size_t i = 0, size = matching.expansions.size(); i < size; i++) { + const auto &expansion = matching.expansions[i]; + + if (expansion.expansion_group_id != expansion_group_id) { + continue; + } + + // When we picked a pattern to expand, we expand it through the end + last_op = GenerateOperatorsForExpansion(std::move(last_op), matching, expansion, symbol_table, storage, + bound_symbols, new_symbols, named_paths, filters, view); + } + return last_op; + } + + std::unique_ptr<LogicalOperator> GenerateOperatorsForExpansion( + std::unique_ptr<LogicalOperator> last_op, const Matching &matching, const Expansion &expansion, + const SymbolTable &symbol_table, AstStorage &storage, std::unordered_set<Symbol> &bound_symbols, + std::vector<Symbol> &new_symbols, std::unordered_map<Symbol, std::vector<Symbol>> &named_paths, Filters &filters, + storage::View view) { + const auto &node1_symbol = symbol_table.at(*expansion.node1->identifier_); + if (bound_symbols.insert(node1_symbol).second) { + // We have just bound this symbol, so generate ScanAll which fills it. + last_op = std::make_unique<ScanAll>(std::move(last_op), node1_symbol, view); + new_symbols.emplace_back(node1_symbol); + + last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); + last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths); + last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); + } else if (named_paths.size() == 1U) { + last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); + last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths); + last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); + } + + if (expansion.edge) { + last_op = GenExpand(std::move(last_op), expansion, symbol_table, bound_symbols, matching, storage, filters, + named_paths, new_symbols, view); } return last_op; @@ -591,9 +759,25 @@ class RuleBasedPlanner { new_symbols.emplace_back(node_symbol); } + last_op = EnsureCyphermorphism(std::move(last_op), edge_symbol, matching, bound_symbols); + + last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); + last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths); + last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); + + return last_op; + } + + std::unique_ptr<LogicalOperator> EnsureCyphermorphism(std::unique_ptr<LogicalOperator> last_op, + const Symbol &edge_symbol, const Matching &matching, + const std::unordered_set<Symbol> &bound_symbols) { // Ensure Cyphermorphism (different edge symbols always map to // different edges). for (const auto &edge_symbols : matching.edge_symbols) { + if (edge_symbols.size() <= 1) { + // nothing to test edge uniqueness with + continue; + } if (edge_symbols.find(edge_symbol) == edge_symbols.end()) { continue; } @@ -609,10 +793,6 @@ class RuleBasedPlanner { } } - last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); - last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths); - last_op = GenFilters(std::move(last_op), bound_symbols, filters, storage, symbol_table); - return last_op; } @@ -667,6 +847,14 @@ class RuleBasedPlanner { return last_op; } + std::unique_ptr<LogicalOperator> GenerateCartesian(std::unique_ptr<LogicalOperator> left, + std::unique_ptr<LogicalOperator> right, + const SymbolTable &symbol_table) { + auto left_symbols = left->ModifiedSymbols(symbol_table); + auto right_symbols = right->ModifiedSymbols(symbol_table); + return std::make_unique<Cartesian>(std::move(left), left_symbols, std::move(right), right_symbols); + } + std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator> last_op, const std::unordered_set<Symbol> &bound_symbols, Filters &filters, AstStorage &storage, const SymbolTable &symbol_table) { @@ -692,8 +880,8 @@ class RuleBasedPlanner { std::unordered_map<Symbol, std::vector<Symbol>> named_paths; - last_op = HandleExpansion(std::move(last_op), matching, symbol_table, storage, expand_symbols, new_symbols, - named_paths, filters, storage::View::OLD); + last_op = HandleExpansions(std::move(last_op), matching, symbol_table, storage, expand_symbols, new_symbols, + named_paths, filters, storage::View::OLD); last_op = std::make_unique<Limit>(std::move(last_op), storage.Create<PrimitiveLiteral>(1)); diff --git a/src/utils/event_counter.cpp b/src/utils/event_counter.cpp index fc6d59585..a7f4d30fb 100644 --- a/src/utils/event_counter.cpp +++ b/src/utils/event_counter.cpp @@ -54,9 +54,11 @@ M(ForeachOperator, Operator, "Number of times Foreach operator was used.") \ M(EvaluatePatternFilterOperator, Operator, "Number of times EvaluatePatternFilter operator was used.") \ M(ApplyOperator, Operator, "Number of times ApplyOperator operator was used.") \ + M(IndexedJoinOperator, Operator, "Number of times IndexedJoin operator was used.") \ + M(HashJoinOperator, Operator, "Number of times HashJoin operator was used.") \ \ M(ActiveLabelIndices, Index, "Number of active label indices in the system.") \ - M(ActiveLabelPropertyIndices, Index, "Number of active label property indices in the system<.") \ + M(ActiveLabelPropertyIndices, Index, "Number of active label property indices in the system.") \ \ M(StreamsCreated, Stream, "Number of Streams created.") \ M(MessagesConsumed, Stream, "Number of consumed streamed messages.") \ diff --git a/src/utils/typeinfo.hpp b/src/utils/typeinfo.hpp index 679c75ff3..682b5ac55 100644 --- a/src/utils/typeinfo.hpp +++ b/src/utils/typeinfo.hpp @@ -65,6 +65,8 @@ enum class TypeId : uint64_t { LOAD_CSV, FOREACH, APPLY, + INDEXED_JOIN, + HASH_JOIN, // Replication REP_APPEND_DELTAS_REQ, diff --git a/tests/drivers/run.sh b/tests/drivers/run.sh index e6a12a89e..d82b81ea9 100755 --- a/tests/drivers/run.sh +++ b/tests/drivers/run.sh @@ -30,6 +30,7 @@ binary_dir="$DIR/../../build" # Start memgraph. $binary_dir/memgraph \ + --cartesian-product-enabled=false \ --data-directory=$tmpdir \ --query-execution-timeout-sec=5 \ --bolt-session-inactivity-timeout=10 \ diff --git a/tests/e2e/analyze_graph/optimize_indexes.py b/tests/e2e/analyze_graph/optimize_indexes.py index 6358e4ddd..7bede8bd4 100644 --- a/tests/e2e/analyze_graph/optimize_indexes.py +++ b/tests/e2e/analyze_graph/optimize_indexes.py @@ -318,7 +318,10 @@ def test_given_supernode_when_expanding_then_expand_other_way_around(memgraph): f" |\\ On Create", f" | * CreateExpand (n)<-[anon3:HAS_REL_TO]-(s)", f" | * Once", - f" * ScanAllByLabel (n :Node)", + f" * Cartesian {{s : n}}", + f" |\\ ", + f" | * ScanAllByLabel (n :Node)", + f" | * Once", f" * ScanAllByLabelPropertyValue (s :SuperNode {{id}})", f" * Once", ] @@ -329,15 +332,27 @@ def test_given_supernode_when_expanding_then_expand_other_way_around(memgraph): memgraph.execute("analyze graph;") - expected_explain = [ - x.replace(f" | * Expand (s)-[anon3:HAS_REL_TO]->(n)", f" | * Expand (n)<-[anon3:HAS_REL_TO]-(s)") - for x in expected_explain + expected_explain_after_analysis = [ + f" * EmptyResult", + f" * Merge", + f" |\\ On Match", + f" | * Expand (n)<-[anon3:HAS_REL_TO]-(s)", + f" | * Once", + f" |\\ On Create", + f" | * CreateExpand (n)<-[anon3:HAS_REL_TO]-(s)", + f" | * Once", + f" * Cartesian {{n : s}}", + f" |\\ ", + f" | * ScanAllByLabelPropertyValue (s :SuperNode {{id}})", + f" | * Once", + f" * ScanAllByLabel (n :Node)", + f" * Once", ] result_with_analysis = list(memgraph.execute_and_fetch(query)) result_with_analysis = [x[QUERY_PLAN] for x in result_with_analysis] - assert expected_explain == result_with_analysis + assert expected_explain_after_analysis == result_with_analysis def test_given_supernode_when_subquery_then_carry_information_to_subquery(memgraph): @@ -373,7 +388,10 @@ def test_given_supernode_when_subquery_then_carry_information_to_subquery(memgra f" | | * Once", f" | * Produce {{n, s}}", f" | * Once", - f" * ScanAllByLabel (n :Node)", + f" * Cartesian {{s : n}}", + f" |\\ ", + f" | * ScanAllByLabel (n :Node)", + f" | * Once", f" * ScanAllByLabelPropertyValue (s :SuperNode {{id}})", f" * Once", ] @@ -384,14 +402,33 @@ def test_given_supernode_when_subquery_then_carry_information_to_subquery(memgra memgraph.execute("analyze graph;") - expected_explain = [ - x.replace(f" | | * Expand (s)-[anon3:HAS_REL_TO]->(n)", f" | | * Expand (n)<-[anon3:HAS_REL_TO]-(s)") - for x in expected_explain + expected_explain_after_analysis = [ + f" * Produce {{0}}", + f" * Accumulate", + f" * Accumulate", + f" * Apply", + f" |\\ ", + f" | * EmptyResult", + f" | * Merge", + f" | |\\ On Match", + f" | | * Expand (n)<-[anon3:HAS_REL_TO]-(s)", + f" | | * Once", + f" | |\\ On Create", + f" | | * CreateExpand (n)<-[anon3:HAS_REL_TO]-(s)", + f" | | * Once", + f" | * Produce {{n, s}}", + f" | * Once", + f" * Cartesian {{n : s}}", + f" |\\ ", + f" | * ScanAllByLabelPropertyValue (s :SuperNode {{id}})", + f" | * Once", + f" * ScanAllByLabel (n :Node)", + f" * Once", ] result_with_analysis = list(memgraph.execute_and_fetch(query)) result_with_analysis = [x[QUERY_PLAN] for x in result_with_analysis] - assert expected_explain == result_with_analysis + assert expected_explain_after_analysis == result_with_analysis def test_given_supernode_when_subquery_and_union_then_carry_information(memgraph): @@ -427,7 +464,10 @@ def test_given_supernode_when_subquery_and_union_then_carry_information(memgraph f" | | | * Once", f" | | * Produce {{n, s}}", f" | | * Once", - f" | * ScanAllByLabel (n :Node)", + f" | * Cartesian {{s : n}}", + f" | |\\ ", + f" | | * ScanAllByLabel (n :Node)", + f" | | * Once", f" | * ScanAllByLabelPropertyValue (s :SuperNode {{id}})", f" | * Once", f" * Produce {{s}}", @@ -445,7 +485,10 @@ def test_given_supernode_when_subquery_and_union_then_carry_information(memgraph f" | | * Once", f" | * Produce {{n, s}}", f" | * Once", - f" * ScanAllByLabel (n :Node)", + f" * Cartesian {{s : n}}", + f" |\\ ", + f" | * ScanAllByLabel (n :Node)", + f" | * Once", f" * ScanAllByLabelPropertyValue (s :SuperNode {{id}})", f" * Once", ] @@ -456,18 +499,56 @@ def test_given_supernode_when_subquery_and_union_then_carry_information(memgraph memgraph.execute("analyze graph;") - expected_explain = [ - x.replace(f" | | * Expand (s)-[anon3:HAS_REL_TO]->(n)", f" | | * Expand (n)<-[anon3:HAS_REL_TO]-(s)") - for x in expected_explain - ] - expected_explain = [ - x.replace(f" | | | * Expand (s)-[anon7:HAS_REL_TO]->(n)", f" | | | * Expand (n)<-[anon7:HAS_REL_TO]-(s)") - for x in expected_explain + expected_explain_after_analysis = [ + f" * Union {{s : s}}", + f" |\\ ", + f" | * Produce {{s}}", + f" | * Accumulate", + f" | * Accumulate", + f" | * Apply", + f" | |\\ ", + f" | | * EmptyResult", + f" | | * Merge", + f" | | |\\ On Match", + f" | | | * Expand (n)<-[anon7:HAS_REL_TO]-(s)", + f" | | | * Once", + f" | | |\\ On Create", + f" | | | * CreateExpand (n)<-[anon7:HAS_REL_TO]-(s)", + f" | | | * Once", + f" | | * Produce {{n, s}}", + f" | | * Once", + f" | * Cartesian {{n : s}}", + f" | |\\ ", + f" | | * ScanAllByLabelPropertyValue (s :SuperNode {{id}})", + f" | | * Once", + f" | * ScanAllByLabel (n :Node)", + f" | * Once", + f" * Produce {{s}}", + f" * Accumulate", + f" * Accumulate", + f" * Apply", + f" |\\ ", + f" | * EmptyResult", + f" | * Merge", + f" | |\\ On Match", + f" | | * Expand (n)<-[anon3:HAS_REL_TO]-(s)", + f" | | * Once", + f" | |\\ On Create", + f" | | * CreateExpand (n)<-[anon3:HAS_REL_TO]-(s)", + f" | | * Once", + f" | * Produce {{n, s}}", + f" | * Once", + f" * Cartesian {{n : s}}", + f" |\\ ", + f" | * ScanAllByLabelPropertyValue (s :SuperNode {{id}})", + f" | * Once", + f" * ScanAllByLabel (n :Node)", + f" * Once", ] result_with_analysis = list(memgraph.execute_and_fetch(query)) result_with_analysis = [x[QUERY_PLAN] for x in result_with_analysis] - assert expected_explain == result_with_analysis + assert expected_explain_after_analysis == result_with_analysis def test_given_empty_graph_when_analyzing_graph_return_zero_degree(memgraph): diff --git a/tests/e2e/configuration/default_config.py b/tests/e2e/configuration/default_config.py index 8c056cab3..13f286909 100644 --- a/tests/e2e/configuration/default_config.py +++ b/tests/e2e/configuration/default_config.py @@ -65,6 +65,7 @@ startup_config_dict = { "1800", "Time in seconds after which inactive Bolt sessions will be closed.", ), + "cartesian_product_enabled": ("true", "true", "Enable cartesian product expansion."), "data_directory": ("mg_data", "mg_data", "Path to directory in which to save all permanent data."), "data_recovery_on_startup": ( "false", diff --git a/tests/gql_behave/README.md b/tests/gql_behave/README.md index 31599881c..77e60190e 100644 --- a/tests/gql_behave/README.md +++ b/tests/gql_behave/README.md @@ -4,9 +4,8 @@ Python script used to run graph query language behavior tests against Memgraph. To run the script please execute: ``` -cd memgraph/tests +cd memgraph/tests/gql_behave source ve3/bin/activate -cd gql_behave ./run.py --help ./run.py memgraph_V1 ``` diff --git a/tests/gql_behave/tests/memgraph_V1/features/cartesian.feature b/tests/gql_behave/tests/memgraph_V1/features/cartesian.feature index 7cc127884..809a3d73a 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/cartesian.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/cartesian.feature @@ -171,3 +171,283 @@ Feature: Cartesian MATCH (a)-[]->() MATCH (a:B) MATCH (a:C) RETURN a """ Then the result should be empty + + Scenario: Multiple match with WHERE x = y 01 + Given an empty graph + And having executed + """ + CREATE (:A {id: 1}), (:A {id: 2}), (:B {id: 1}) + """ + When executing query: + """ + MATCH (a:A) MATCH (b:B) WHERE a.id = b.id RETURN a, b + """ + Then the result should be: + | a | b | + | (:A {id: 1}) | (:B {id: 1}) | + + Scenario: Multiple match with WHERE x = y 02 + Given an empty graph + And having executed + """ + CREATE (:A {id: 1}), (:A {id: 2}), (:B {id: 1}), (:B {id: 2}) + """ + When executing query: + """ + MATCH (a:A) MATCH (b:B) WHERE a.id = b.id RETURN a, b + """ + Then the result should be: + | a | b | + | (:A {id: 1}) | (:B {id: 1}) | + | (:A {id: 2}) | (:B {id: 2}) | + + Scenario: Multiple match with WHERE x = y 03 + Given an empty graph + And having executed + """ + CREATE (:A {prop: 1, id: 1}), (:A {prop: 2, id: 2}), (:A {prop: 1, id: 2}), (:B {prop: 2, id: 3}) + """ + When executing query: + """ + MATCH (a) MATCH (b) WHERE a.prop = b.prop RETURN a, b + """ + Then the result should be: + | a | b | + | (:A {id: 1, prop: 1}) | (:A {id: 1, prop: 1}) | + | (:A {id: 1, prop: 1}) | (:A {id: 2, prop: 1}) | + | (:A {id: 2, prop: 2}) | (:A {id: 2, prop: 2}) | + | (:A {id: 2, prop: 2}) | (:B {id: 3, prop: 2}) | + | (:A {id: 2, prop: 1}) | (:A {id: 1, prop: 1}) | + | (:A {id: 2, prop: 1}) | (:A {id: 2, prop: 1}) | + | (:B {id: 3, prop: 2}) | (:A {id: 2, prop: 2}) | + | (:B {id: 3, prop: 2}) | (:B {id: 3, prop: 2}) | + + Scenario: Multiple match with WHERE x = y 04 + Given an empty graph + And having executed + """ + CREATE (:A {prop: 1, id: 1}), (:A {prop: 2, id: 2}), (:A {prop: 1, id: 2}), (:B {prop: 2, id: 3}) + """ + When executing query: + """ + MATCH (a:A) MATCH (b:A) WHERE a.prop = b.prop RETURN a, b + """ + Then the result should be: + | a | b | + | (:A {id: 1, prop: 1}) | (:A {id: 1, prop: 1}) | + | (:A {id: 1, prop: 1}) | (:A {id: 2, prop: 1}) | + | (:A {id: 2, prop: 2}) | (:A {id: 2, prop: 2}) | + | (:A {id: 2, prop: 1}) | (:A {id: 1, prop: 1}) | + | (:A {id: 2, prop: 1}) | (:A {id: 2, prop: 1}) | + + Scenario: Multiple match with WHERE x = y 05 + Given an empty graph + And having executed + """ + CREATE (:A {prop: 1, id: 1}), (:A {prop: 2, id: 2}), (:A {prop: 1, id: 2}), (:A {prop: 2, id: 3}) + """ + When executing query: + """ + MATCH (a:A) MATCH (b:A) WHERE a.prop = b.id RETURN a, b + """ + Then the result should be: + | a | b | + | (:A {id: 1, prop: 1}) | (:A {id: 1, prop: 1}) | + | (:A {id: 2, prop: 1}) | (:A {id: 1, prop: 1}) | + | (:A {id: 2, prop: 2}) | (:A {id: 2, prop: 2}) | + | (:A {id: 3, prop: 2}) | (:A {id: 2, prop: 2}) | + | (:A {id: 2, prop: 2}) | (:A {id: 2, prop: 1}) | + | (:A {id: 3, prop: 2}) | (:A {id: 2, prop: 1}) | + + Scenario: Multiple match with WHERE x = y 06 + Given an empty graph + And having executed + """ + CREATE (:A {prop: 1, id: 1}), (:A {prop: 2, id: 2}), (:A {prop: 1, id: 2}), (:A {prop: 2, id: 3}) + """ + When executing query: + """ + MATCH (a:A) MATCH (b:A) WHERE a.id = b.prop RETURN a, b + """ + Then the result should be: + | a | b | + | (:A {id: 1, prop: 1}) | (:A {id: 1, prop: 1}) | + | (:A {id: 2, prop: 2}) | (:A {id: 2, prop: 2}) | + | (:A {id: 2, prop: 1}) | (:A {id: 2, prop: 2}) | + | (:A {id: 1, prop: 1}) | (:A {id: 2, prop: 1}) | + | (:A {id: 2, prop: 2}) | (:A {id: 3, prop: 2}) | + | (:A {id: 2, prop: 1}) | (:A {id: 3, prop: 2}) | + + Scenario: Multiple match with WHERE x = y 07: nothing on the left side + Given an empty graph + And having executed + """ + CREATE (:B {id: 1}) + """ + When executing query: + """ + MATCH (a:A) MATCH (b:B) WHERE a.id = b.id RETURN a, b + """ + Then the result should be empty + + Scenario: Multiple match with WHERE x = y 08: nothing on the right side + Given an empty graph + And having executed + """ + CREATE (:A {id: 1}), (:A {id: 2}) + """ + When executing query: + """ + MATCH (a:A) MATCH (b:B) WHERE a.id = b.id RETURN a, b + """ + Then the result should be empty + + Scenario: Multiple match with WHERE x = y 09: sides never equal + Given an empty graph + And having executed + """ + CREATE (:A {id: 1}), (:B {id: 2}) + """ + When executing query: + """ + MATCH (a:A) MATCH (b:B) WHERE a.id = b.id RETURN a, b + """ + Then the result should be empty + + Scenario: Multiple match + with 01 + Given an empty graph + And having executed + """ + CREATE (:A {id: 1}), (:A {id: 2}), (:B {id: 1}) + """ + When executing query: + """ + MATCH (a:A) WITH a MATCH (b:B) WHERE a.id = b.id RETURN a, b + """ + Then the result should be: + | a | b | + | (:A {id: 1}) | (:B {id: 1}) | + + Scenario: Multiple match + with 02 + Given an empty graph + And having executed + """ + CREATE (:A {id: 1}), (:A {id: 2}), (:B {id: 1}) + """ + When executing query: + """ + MATCH (a:A) WITH a.id as id MATCH (a:A) return a; + """ + Then the result should be: + | a | + | (:A {id: 1}) | + | (:A {id: 2}) | + | (:A {id: 1}) | + | (:A {id: 2}) | + + Scenario: Multiple match + with 03 + Given an empty graph + And having executed + """ + CREATE (:A {id: 1})-[:TYPE]->(:B {id: 1}), (:A {id: 2})-[:TYPE]->(:B {id: 2}) + """ + When executing query: + """ + MATCH (a:A) WITH a.id as id MATCH (a)-[:TYPE]->(b) return a, b; + """ + Then the result should be: + | a | b | + | (:A {id: 1}) | (:B {id: 1}) | + | (:A {id: 2}) | (:B {id: 2}) | + | (:A {id: 1}) | (:B {id: 1}) | + | (:A {id: 2}) | (:B {id: 2}) | + + Scenario: Multiple match + with 04 + Given an empty graph + And having executed + """ + CREATE (:A {id: 1})-[:TYPE]->(:B {id: 1}), (:A {id: 2})-[:TYPE]->(:B {id: 2}) + """ + When executing query: + """ + MATCH (a:A) WITH a MATCH (a)-[:TYPE]->(b) return a, b; + """ + Then the result should be: + | a | b | + | (:A {id: 1}) | (:B {id: 1}) | + | (:A {id: 2}) | (:B {id: 2}) | + + Scenario: Multiple match + with 05 + Given an empty graph + And having executed + """ + CREATE (:A {id: 1})-[:TYPE]->(:B {id: 1}), (:A {id: 2})-[:TYPE]->(:B {id: 2}) + """ + When executing query: + """ + MATCH (a:A) WITH a MATCH (c:A {id: 1}), (a)-[:TYPE]->(b) return a, b; + """ + Then the result should be: + | a | b | + | (:A {id: 1}) | (:B {id: 1}) | + | (:A {id: 2}) | (:B {id: 2}) | + + Scenario: Double match with Cyphermorphism + Given an empty graph + And having executed + """ + CREATE (:A {id: 1})-[:TYPE]->(:B {id: 1}), (:A {id: 2})-[:TYPE]->(:B {id: 2}) + """ + When executing query: + """ + MATCH (a)-->(b), (c)-->(d) RETURN a, b, c, d + """ + Then the result should be: + | a | b | c | d | + | (:A {id: 1}) | (:B {id: 1}) | (:A {id: 2}) | (:B {id: 2}) | + | (:A {id: 2}) | (:B {id: 2}) | (:A {id: 1}) | (:B {id: 1}) | + + Scenario: Triple match with Cyphermorphism empty result + Given an empty graph + And having executed + """ + CREATE (:A {id: 1})-[:TYPE]->(:B {id: 1}), (:A {id: 2})-[:TYPE]->(:B {id: 2}) + """ + When executing query: + """ + MATCH (a)-->(b), (c)-->(d), (e)-->(f) RETURN a, b, c, d, e, f + """ + Then the result should be empty + + Scenario: Triple match with Cyphermorphism yields result + Given an empty graph + And having executed + """ + CREATE (:A {id: 1})-[:TYPE]->(:B {id: 1}), (:A {id: 2})-[:TYPE]->(:B {id: 2}), (:A {id: 3})-[:TYPE]->(:B {id: 3}) + """ + When executing query: + """ + MATCH (a)-->(b), (c)-->(d), (e)-->(f) RETURN a, b, c, d, e, f + """ + Then the result should be: + | a | b | c | d | e | f | + | (:A {id: 1}) | (:B {id: 1}) | (:A {id: 2}) | (:B {id: 2}) | (:A {id: 3}) | (:B {id: 3}) | + | (:A {id: 1}) | (:B {id: 1}) | (:A {id: 3}) | (:B {id: 3}) | (:A {id: 2}) | (:B {id: 2}) | + | (:A {id: 2}) | (:B {id: 2}) | (:A {id: 1}) | (:B {id: 1}) | (:A {id: 3}) | (:B {id: 3}) | + | (:A {id: 2}) | (:B {id: 2}) | (:A {id: 3}) | (:B {id: 3}) | (:A {id: 1}) | (:B {id: 1}) | + | (:A {id: 3}) | (:B {id: 3}) | (:A {id: 2}) | (:B {id: 2}) | (:A {id: 1}) | (:B {id: 1}) | + | (:A {id: 3}) | (:B {id: 3}) | (:A {id: 1}) | (:B {id: 1}) | (:A {id: 2}) | (:B {id: 2}) | + + Scenario: Same cyphermorphism group in 3 matches + Given an empty graph + And having executed + """ + CREATE (:A)-[:TYPE]->(:B)-[:TYPE]->(:C)-[:TYPE]->(:D)-[:TYPE]->(:E) + """ + When executing query: + """ + MATCH (a:A)-->(b), (d)-->(e), (c)<--(b), (d)<--(c) RETURN a, b, c, d, e + """ + Then the result should be: + | a | b | c | d | e | + | (:A) | (:B) | (:C) | (:D) | (:E) | diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 3b920c125..69be6b490 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -529,16 +529,16 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec #define MAP(...) \ this->storage.template Create<memgraph::query::MapLiteral>( \ std::unordered_map<memgraph::query::PropertyIx, memgraph::query::Expression *>{__VA_ARGS__}) -#define PROPERTY_PAIR(dba, property_name) std::make_pair(property_name, dba.NameToProperty(property_name)) -#define PROPERTY_LOOKUP(dba, ...) memgraph::query::test_common::GetPropertyLookup(this->storage, dba, __VA_ARGS__) -#define PARAMETER_LOOKUP(token_position) \ - this->storage.template Create<memgraph::query::ParameterLookup>((token_position)) -#define NEXPR(name, expr) this->storage.template Create<memgraph::query::NamedExpression>((name), (expr)) #define MAP_PROJECTION(map_variable, elements) \ this->storage.template Create<memgraph::query::MapProjectionLiteral>( \ (memgraph::query::Expression *){map_variable}, \ std::unordered_map<memgraph::query::PropertyIx, memgraph::query::Expression *>{elements}) +#define PROPERTY_PAIR(dba, property_name) std::make_pair(property_name, dba.NameToProperty(property_name)) +#define PROPERTY_LOOKUP(dba, ...) memgraph::query::test_common::GetPropertyLookup(this->storage, dba, __VA_ARGS__) #define ALL_PROPERTIES_LOOKUP(expr) memgraph::query::test_common::GetAllPropertiesLookup(this->storage, expr) +#define PARAMETER_LOOKUP(token_position) \ + this->storage.template Create<memgraph::query::ParameterLookup>((token_position)) +#define NEXPR(name, expr) this->storage.template Create<memgraph::query::NamedExpression>((name), (expr)) // AS is alternative to NEXPR which does not initialize NamedExpression with // Expression. It should be used with RETURN or WITH. For example: // RETURN(IDENT("n"), AS("n")) vs. RETURN(NEXPR("n", IDENT("n"))). diff --git a/tests/unit/query_plan.cpp b/tests/unit/query_plan.cpp index 3324b4fc5..910ebdc54 100644 --- a/tests/unit/query_plan.cpp +++ b/tests/unit/query_plan.cpp @@ -280,8 +280,92 @@ TYPED_TEST(TestPlanner, MatchMultiPattern) { MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m")), PATTERN(NODE("j"), EDGE("e"), NODE("i"))), RETURN("n"))); // We expect the expansions after the first to have a uniqueness filter in a // single MATCH clause. - CheckPlan<TypeParam>(query, this->storage, ExpectScanAll(), ExpectExpand(), ExpectScanAll(), ExpectExpand(), + std::list<BaseOpChecker *> left_cartesian_ops{new ExpectScanAll(), new ExpectExpand()}; + std::list<BaseOpChecker *> right_cartesian_ops{new ExpectScanAll(), new ExpectExpand()}; + + CheckPlan<TypeParam>(query, this->storage, ExpectCartesian(left_cartesian_ops, right_cartesian_ops), ExpectEdgeUniquenessFilter(), ExpectProduce()); + + DeleteListContent(&left_cartesian_ops); + DeleteListContent(&right_cartesian_ops); +} + +TYPED_TEST(TestPlanner, MatchMultiPatternWithHashJoin) { + // Test MATCH (a:label)-[r1]->(b), (c:label)-[r2]->(d) WHERE c.id = a.id return a, b, c, d; + FakeDbAccessor dba; + const auto label_name = "label"; + const auto property = PROPERTY_PAIR(dba, "id"); + + auto *query = QUERY( + SINGLE_QUERY(MATCH(PATTERN(NODE("a", label_name), EDGE("r1"), NODE("b")), + PATTERN(NODE("c", label_name), EDGE("r2"), NODE("d"))), + WHERE(EQ(PROPERTY_LOOKUP(dba, "c", property.second), PROPERTY_LOOKUP(dba, "a", property.second))), + RETURN("a", "b", "c", "d"))); + + std::list<BaseOpChecker *> left_indexed_join_ops{new ExpectScanAll(), new ExpectFilter(), new ExpectExpand()}; + std::list<BaseOpChecker *> right_indexed_join_ops{new ExpectScanAll(), new ExpectFilter(), new ExpectExpand()}; + + auto symbol_table = memgraph::query::MakeSymbolTable(query); + auto planner = MakePlanner<TypeParam>(&dba, this->storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectHashJoin(left_indexed_join_ops, right_indexed_join_ops), + ExpectEdgeUniquenessFilter(), ExpectProduce()); + + DeleteListContent(&left_indexed_join_ops); + DeleteListContent(&right_indexed_join_ops); +} + +TYPED_TEST(TestPlanner, MatchMultiPatternWithAsymmetricHashJoin) { + // Test MATCH (a:label)-[r1]->(b), (c:label)-[r2]->(d) WHERE c.id = a.id2 return a, b, c, d; + FakeDbAccessor dba; + const auto label_name = "label"; + const auto property_1 = PROPERTY_PAIR(dba, "id"); + const auto property_2 = PROPERTY_PAIR(dba, "id2"); + + auto *query = QUERY(SINGLE_QUERY( + MATCH(PATTERN(NODE("a", label_name), EDGE("r1"), NODE("b")), + PATTERN(NODE("c", label_name), EDGE("r2"), NODE("d"))), + WHERE(EQ(PROPERTY_LOOKUP(dba, "c", property_1.second), PROPERTY_LOOKUP(dba, "a", property_2.second))), + RETURN("a", "b", "c", "d"))); + + std::list<BaseOpChecker *> left_indexed_join_ops{new ExpectScanAll(), new ExpectFilter(), new ExpectExpand()}; + std::list<BaseOpChecker *> right_indexed_join_ops{new ExpectScanAll(), new ExpectFilter(), new ExpectExpand()}; + + auto symbol_table = memgraph::query::MakeSymbolTable(query); + auto planner = MakePlanner<TypeParam>(&dba, this->storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectHashJoin(left_indexed_join_ops, right_indexed_join_ops), + ExpectEdgeUniquenessFilter(), ExpectProduce()); + + DeleteListContent(&left_indexed_join_ops); + DeleteListContent(&right_indexed_join_ops); +} + +TYPED_TEST(TestPlanner, MatchMultiPatternWithIndexJoin) { + // Test MATCH (a:label)-[r1]->(b), (c:label)-[r2]->(d) WHERE c.id = a.id return a, b, c, d; + FakeDbAccessor dba; + const auto label_name = "label"; + const auto label = dba.Label(label_name); + const auto property = PROPERTY_PAIR(dba, "id"); + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, property.second, 1); + + auto *query = QUERY( + SINGLE_QUERY(MATCH(PATTERN(NODE("a", label_name), EDGE("r1"), NODE("b")), + PATTERN(NODE("c", label_name), EDGE("r2"), NODE("d"))), + WHERE(EQ(PROPERTY_LOOKUP(dba, "c", property.second), PROPERTY_LOOKUP(dba, "a", property.second))), + RETURN("a", "b", "c", "d"))); + + auto c_prop = PROPERTY_LOOKUP(dba, "c", property); + std::list<BaseOpChecker *> left_indexed_join_ops{new ExpectScanAllByLabel(), new ExpectExpand()}; + std::list<BaseOpChecker *> right_indexed_join_ops{new ExpectScanAllByLabelPropertyValue(label, property, c_prop), + new ExpectExpand()}; + + auto symbol_table = memgraph::query::MakeSymbolTable(query); + auto planner = MakePlanner<TypeParam>(&dba, this->storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectIndexedJoin(left_indexed_join_ops, right_indexed_join_ops), + ExpectEdgeUniquenessFilter(), ExpectProduce()); + + DeleteListContent(&left_indexed_join_ops); + DeleteListContent(&right_indexed_join_ops); } TYPED_TEST(TestPlanner, MatchMultiPatternSameStart) { @@ -318,10 +402,16 @@ TYPED_TEST(TestPlanner, MultiMatch) { MATCH(PATTERN(node_j, edge_e, node_i, edge_f, node_h)), RETURN("n"))); auto symbol_table = memgraph::query::MakeSymbolTable(query); auto planner = MakePlanner<TypeParam>(&dba, this->storage, symbol_table, query); + // Multiple MATCH clauses form a Cartesian product, so the uniqueness should // not cross MATCH boundaries. - CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectExpand(), ExpectScanAll(), ExpectExpand(), - ExpectExpand(), ExpectEdgeUniquenessFilter(), ExpectProduce()); + std::list<BaseOpChecker *> left_cartesian_ops{new ExpectScanAll(), new ExpectExpand()}; + std::list<BaseOpChecker *> right_cartesian_ops{new ExpectScanAll(), new ExpectExpand(), new ExpectExpand(), + new ExpectEdgeUniquenessFilter()}; + CheckPlan(planner.plan(), symbol_table, ExpectCartesian(left_cartesian_ops, right_cartesian_ops), ExpectProduce()); + + DeleteListContent(&left_cartesian_ops); + DeleteListContent(&right_cartesian_ops); } TYPED_TEST(TestPlanner, MultiMatchSameStart) { @@ -690,9 +780,19 @@ TYPED_TEST(TestPlanner, MatchCrossReferenceVariable) { auto n_prop = PROPERTY_LOOKUP(dba, "n", prop.second); std::get<0>(node_m->properties_)[this->storage.GetPropertyIx(prop.first)] = n_prop; auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(node_n), PATTERN(node_m)), RETURN("n"))); + // We expect both ScanAll to come before filters (2 are joined into one), // because they need to populate the symbol values. - CheckPlan<TypeParam>(query, this->storage, ExpectScanAll(), ExpectScanAll(), ExpectFilter(), ExpectProduce()); + // They are combined in a Cartesian (rewritten to HashJoin) to generate values from both symbols respectively and + // independently + std::list<BaseOpChecker *> left_hash_join_ops{new ExpectScanAll()}; + std::list<BaseOpChecker *> right_hash_join_ops{new ExpectScanAll()}; + + CheckPlan<TypeParam>(query, this->storage, ExpectHashJoin(left_hash_join_ops, right_hash_join_ops), ExpectFilter(), + ExpectProduce()); + + DeleteListContent(&left_hash_join_ops); + DeleteListContent(&right_hash_join_ops); } TYPED_TEST(TestPlanner, MatchWhereBeforeExpand) { @@ -759,10 +859,15 @@ TYPED_TEST(TestPlanner, MultiMatchWhere) { auto prop = dba.Property("prop"); auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), MATCH(PATTERN(NODE("l"))), WHERE(LESS(PROPERTY_LOOKUP(dba, "n", prop), LITERAL(42))), RETURN("n"))); - // Even though WHERE is in the second MATCH clause, we expect Filter to come - // before second ScanAll, since it only uses the value from first ScanAll. - CheckPlan<TypeParam>(query, this->storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), ExpectScanAll(), - ExpectProduce()); + + // The 2 match expansions need to be separated with a cartesian so they can be generated independently + std::list<BaseOpChecker *> left_cartesian_ops{new ExpectScanAll(), new ExpectFilter(), new ExpectExpand()}; + std::list<BaseOpChecker *> right_cartesian_ops{new ExpectScanAll()}; + + CheckPlan<TypeParam>(query, this->storage, ExpectCartesian(left_cartesian_ops, right_cartesian_ops), ExpectProduce()); + + DeleteListContent(&left_cartesian_ops); + DeleteListContent(&right_cartesian_ops); } TYPED_TEST(TestPlanner, MatchOptionalMatchWhere) { @@ -1078,8 +1183,14 @@ TYPED_TEST(TestPlanner, MultiPropertyIndexScan) { RETURN("n", "m"))); auto symbol_table = memgraph::query::MakeSymbolTable(query); auto planner = MakePlanner<TypeParam>(&dba, this->storage, symbol_table, query); - CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label1, prop1, lit_1), - ExpectScanAllByLabelPropertyValue(label2, prop2, lit_2), ExpectProduce()); + + std::list<BaseOpChecker *> left_cartesian_ops{new ExpectScanAllByLabelPropertyValue(label1, prop1, lit_1)}; + std::list<BaseOpChecker *> right_cartesian_ops{new ExpectScanAllByLabelPropertyValue(label2, prop2, lit_2)}; + + CheckPlan(planner.plan(), symbol_table, ExpectCartesian(left_cartesian_ops, right_cartesian_ops), ExpectProduce()); + + DeleteListContent(&left_cartesian_ops); + DeleteListContent(&right_cartesian_ops); } TYPED_TEST(TestPlanner, WhereIndexedLabelPropertyRange) { @@ -1172,9 +1283,36 @@ TYPED_TEST(TestPlanner, SecondPropertyIndex) { WHERE(EQ(m_prop, n_prop)), RETURN("n"))); auto symbol_table = memgraph::query::MakeSymbolTable(query); auto planner = MakePlanner<TypeParam>(&dba, this->storage, symbol_table, query); - CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabel(), - // Note: We are scanning for m, therefore property should equal n_prop. - ExpectScanAllByLabelPropertyValue(label, property, n_prop), ExpectProduce()); + + // Note: We are scanning for m, therefore property should equal n_prop. + std::list<BaseOpChecker *> left_index_join_ops{new ExpectScanAllByLabel()}; + std::list<BaseOpChecker *> right_index_join_ops{new ExpectScanAllByLabelPropertyValue(label, property, n_prop)}; + + CheckPlan(planner.plan(), symbol_table, ExpectIndexedJoin(left_index_join_ops, right_index_join_ops), + ExpectProduce()); + + DeleteListContent(&left_index_join_ops); + DeleteListContent(&right_index_join_ops); +} + +TYPED_TEST(TestPlanner, UnableToUseSecondPropertyIndex) { + // Test MATCH (n :label), (m :label) WHERE m.property = n.property RETURN n + FakeDbAccessor dba; + auto property = PROPERTY_PAIR(dba, "property"); + auto n_prop = PROPERTY_LOOKUP(dba, "n", property); + auto m_prop = PROPERTY_LOOKUP(dba, "m", property); + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", "label")), PATTERN(NODE("m", "label"))), + WHERE(EQ(m_prop, n_prop)), RETURN("n"))); + auto symbol_table = memgraph::query::MakeSymbolTable(query); + auto planner = MakePlanner<TypeParam>(&dba, this->storage, symbol_table, query); + + std::list<BaseOpChecker *> left_index_join_ops{new ExpectScanAll(), new ExpectFilter()}; + std::list<BaseOpChecker *> right_index_join_ops{new ExpectScanAll(), new ExpectFilter()}; + + CheckPlan(planner.plan(), symbol_table, ExpectHashJoin(left_index_join_ops, right_index_join_ops), ExpectProduce()); + + DeleteListContent(&left_index_join_ops); + DeleteListContent(&right_index_join_ops); } TYPED_TEST(TestPlanner, ReturnSumGroupByAll) { diff --git a/tests/unit/query_plan_checker.hpp b/tests/unit/query_plan_checker.hpp index 0a8b4d3ab..37695b581 100644 --- a/tests/unit/query_plan_checker.hpp +++ b/tests/unit/query_plan_checker.hpp @@ -114,6 +114,16 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { return false; } + bool PreVisit(HashJoin &op) override { + CheckOp(op); + return false; + } + + bool PreVisit(IndexedJoin &op) override { + CheckOp(op); + return false; + } + bool PreVisit(Apply &op) override { CheckOp(op); op.input()->Accept(*this); @@ -405,8 +415,7 @@ class ExpectScanAllByLabelProperty : public OpChecker<ScanAllByLabelProperty> { class ExpectCartesian : public OpChecker<Cartesian> { public: - ExpectCartesian(const std::list<std::unique_ptr<BaseOpChecker>> &left, - const std::list<std::unique_ptr<BaseOpChecker>> &right) + ExpectCartesian(const std::list<BaseOpChecker *> &left, const std::list<BaseOpChecker *> &right) : left_(left), right_(right) {} void ExpectOp(Cartesian &op, const SymbolTable &symbol_table) override { @@ -419,8 +428,46 @@ class ExpectCartesian : public OpChecker<Cartesian> { } private: - const std::list<std::unique_ptr<BaseOpChecker>> &left_; - const std::list<std::unique_ptr<BaseOpChecker>> &right_; + const std::list<BaseOpChecker *> &left_; + const std::list<BaseOpChecker *> &right_; +}; + +class ExpectHashJoin : public OpChecker<HashJoin> { + public: + ExpectHashJoin(const std::list<BaseOpChecker *> &left, const std::list<BaseOpChecker *> &right) + : left_(left), right_(right) {} + + void ExpectOp(HashJoin &op, const SymbolTable &symbol_table) override { + ASSERT_TRUE(op.left_op_); + PlanChecker left_checker(left_, symbol_table); + op.left_op_->Accept(left_checker); + ASSERT_TRUE(op.right_op_); + PlanChecker right_checker(right_, symbol_table); + op.right_op_->Accept(right_checker); + } + + private: + const std::list<BaseOpChecker *> &left_; + const std::list<BaseOpChecker *> &right_; +}; + +class ExpectIndexedJoin : public OpChecker<IndexedJoin> { + public: + ExpectIndexedJoin(const std::list<BaseOpChecker *> &main_branch, const std::list<BaseOpChecker *> &sub_branch) + : main_branch_(main_branch), sub_branch_(sub_branch) {} + + void ExpectOp(IndexedJoin &op, const SymbolTable &symbol_table) override { + ASSERT_TRUE(op.main_branch_); + PlanChecker main_branch_checker(main_branch_, symbol_table); + op.main_branch_->Accept(main_branch_checker); + ASSERT_TRUE(op.sub_branch_); + PlanChecker sub_branch_checker(sub_branch_, symbol_table); + op.sub_branch_->Accept(sub_branch_checker); + } + + private: + const std::list<BaseOpChecker *> &main_branch_; + const std::list<BaseOpChecker *> &sub_branch_; }; class ExpectCallProcedure : public OpChecker<CallProcedure> { diff --git a/tests/unit/query_plan_operator_to_string.cpp b/tests/unit/query_plan_operator_to_string.cpp index cc442f535..9d49033b8 100644 --- a/tests/unit/query_plan_operator_to_string.cpp +++ b/tests/unit/query_plan_operator_to_string.cpp @@ -310,7 +310,7 @@ TYPED_TEST(OperatorToStringTest, EdgeUniquenessFilter) { std::vector<memgraph::storage::EdgeTypeId>{}, false, memgraph::storage::View::OLD); last_op = std::make_shared<EdgeUniquenessFilter>(last_op, edge2_sym, std::vector<Symbol>{edge1_sym}); - std::string expected_string{"EdgeUniquenessFilter"}; + std::string expected_string{"EdgeUniquenessFilter {edge1 : edge2}"}; EXPECT_EQ(last_op->ToString(), expected_string); } @@ -495,3 +495,16 @@ TYPED_TEST(OperatorToStringTest, Apply) { std::string expected_string{"Apply"}; EXPECT_EQ(last_op.ToString(), expected_string); } + +TYPED_TEST(OperatorToStringTest, HashJoin) { + Symbol lhs_sym = this->GetSymbol("node1"); + Symbol rhs_sym = this->GetSymbol("node2"); + + std::shared_ptr<LogicalOperator> lhs_match = std::make_shared<ScanAll>(nullptr, lhs_sym); + std::shared_ptr<LogicalOperator> rhs_match = std::make_shared<ScanAll>(nullptr, rhs_sym); + std::shared_ptr<LogicalOperator> last_op = std::make_shared<HashJoin>( + lhs_match, std::vector<Symbol>{lhs_sym}, rhs_match, std::vector<Symbol>{rhs_sym}, nullptr); + + std::string expected_string{"HashJoin {node1 : node2}"}; + EXPECT_EQ(last_op->ToString(), expected_string); +}