From 956705293e7019a65100cb9e218ffd27185f65b3 Mon Sep 17 00:00:00 2001 From: gvolfing <gabor.volfinger@memgraph.io> Date: Mon, 11 Mar 2024 09:56:00 +0100 Subject: [PATCH] Add logical operator for id based edge scans --- src/query/plan/hint_provider.hpp | 3 ++ src/query/plan/operator.cpp | 42 +++++++++++++--- src/query/plan/operator.hpp | 49 ++++++++++++++----- src/query/plan/operator_type_info.cpp | 5 ++ src/query/plan/pretty_print.cpp | 15 ++++++ src/query/plan/pretty_print.hpp | 2 + .../plan/rewrite/edge_type_index_lookup.hpp | 9 ++++ src/utils/event_counter.cpp | 3 +- tests/unit/query_plan_checker.hpp | 2 + 9 files changed, 111 insertions(+), 19 deletions(-) diff --git a/src/query/plan/hint_provider.hpp b/src/query/plan/hint_provider.hpp index 3c8510561..4c981bfb0 100644 --- a/src/query/plan/hint_provider.hpp +++ b/src/query/plan/hint_provider.hpp @@ -117,6 +117,9 @@ class PlanHintsProvider final : public HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByEdgeType & /*unused*/) override { return true; } bool PostVisit(ScanAllByEdgeType & /*unused*/) override { return true; } + bool PreVisit(ScanAllByEdgeId & /*unused*/) override { return true; } + bool PostVisit(ScanAllByEdgeId & /*unused*/) override { return true; } + bool PreVisit(ConstructNamedPath & /*unused*/) override { return true; } bool PostVisit(ConstructNamedPath & /*unused*/) override { return true; } diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 7cd506050..6cbc8de4d 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -106,6 +106,7 @@ extern const Event ScanAllByLabelPropertyValueOperator; extern const Event ScanAllByLabelPropertyOperator; extern const Event ScanAllByIdOperator; extern const Event ScanAllByEdgeTypeOperator; +extern const Event ScanAllByEdgeIdOperator; extern const Event ExpandOperator; extern const Event ExpandVariableOperator; extern const Event ConstructNamedPathOperator; @@ -521,7 +522,7 @@ class ScanAllCursor : public Cursor { template <typename TEdgesFun> class ScanAllByEdgeTypeCursor : public Cursor { public: - explicit ScanAllByEdgeTypeCursor(const ScanAllByEdgeType &self, Symbol output_symbol, UniqueCursorPtr input_cursor, + explicit ScanAllByEdgeTypeCursor(const ScanAll &self, Symbol output_symbol, UniqueCursorPtr input_cursor, storage::View view, TEdgesFun get_edges, const char *op_name) : self_(self), output_symbol_(std::move(output_symbol)), @@ -561,7 +562,7 @@ class ScanAllByEdgeTypeCursor : public Cursor { } private: - const ScanAllByEdgeType &self_; + const ScanAll &self_; const Symbol output_symbol_; const UniqueCursorPtr input_cursor_; storage::View view_; @@ -613,10 +614,7 @@ UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource *mem) const { ScanAllByEdgeType::ScanAllByEdgeType(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::EdgeTypeId edge_type, storage::View view) - : input_(input ? input : std::make_shared<Once>()), - output_symbol_(std::move(output_symbol)), - view_(view), - edge_type_(edge_type) {} + : ScanAll(input, output_symbol, view), edge_type_(edge_type) {} ACCEPT_WITH_INPUT(ScanAllByEdgeType) @@ -782,6 +780,38 @@ UniqueCursorPtr ScanAllById::MakeCursor(utils::MemoryResource *mem) const { view_, std::move(vertices), "ScanAllById"); } +ScanAllByEdgeId::ScanAllByEdgeId(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, + Expression *expression, storage::View view) + : ScanAll(input, output_symbol, view), expression_(expression) { + MG_ASSERT(expression); +} + +ACCEPT_WITH_INPUT(ScanAllByEdgeId) + +UniqueCursorPtr ScanAllByEdgeId::MakeCursor(utils::MemoryResource *mem) const { + memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByEdgeIdOperator); + + auto edges = [this](Frame &frame, ExecutionContext &context) -> std::optional<std::vector<EdgeAccessor>> { + auto *db = context.db_accessor; + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_); + auto value = expression_->Accept(evaluator); + if (!value.IsNumeric()) return std::nullopt; + int64_t id = value.IsInt() ? value.ValueInt() : value.ValueDouble(); + if (value.IsDouble() && id != value.ValueDouble()) return std::nullopt; + auto maybe_edge = db->FindEdge(storage::Gid::FromInt(id), view_); + if (!maybe_edge) return std::nullopt; + return std::vector<EdgeAccessor>{*maybe_edge}; + }; + return MakeUniqueCursorPtr<ScanAllByEdgeTypeCursor<decltype(edges)>>( + mem, *this, output_symbol_, input_->MakeCursor(mem), view_, std::move(edges), "ScanAllByEdgeId"); +} + +std::vector<Symbol> ScanAllByEdgeId::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + symbols.emplace_back(output_symbol_); + return symbols; +} + namespace { bool CheckExistingNode(const VertexAccessor &new_node, const Symbol &existing_node_sym, Frame &frame) { const TypedValue &existing_node = frame[existing_node_sym]; diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 6563c2bb0..2abd9744f 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -100,6 +100,7 @@ class ScanAllByLabelPropertyValue; class ScanAllByLabelProperty; class ScanAllById; class ScanAllByEdgeType; +class ScanAllByEdgeId; class Expand; class ExpandVariable; class ConstructNamedPath; @@ -133,13 +134,12 @@ class IndexedJoin; class HashJoin; class RollUpApply; -using LogicalOperatorCompositeVisitor = - utils::CompositeVisitor<Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange, - ScanAllByLabelPropertyValue, ScanAllByLabelProperty, ScanAllById, ScanAllByEdgeType, Expand, - ExpandVariable, ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties, - SetLabels, RemoveProperty, RemoveLabels, EdgeUniquenessFilter, Accumulate, Aggregate, Skip, - Limit, OrderBy, Merge, Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv, - Foreach, EmptyResult, EvaluatePatternFilter, Apply, IndexedJoin, HashJoin, RollUpApply>; +using LogicalOperatorCompositeVisitor = utils::CompositeVisitor< + Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange, ScanAllByLabelPropertyValue, + ScanAllByLabelProperty, ScanAllById, ScanAllByEdgeType, ScanAllByEdgeId, Expand, ExpandVariable, ConstructNamedPath, + Filter, Produce, Delete, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, EdgeUniquenessFilter, + Accumulate, Aggregate, Skip, Limit, OrderBy, Merge, Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, + LoadCsv, Foreach, EmptyResult, EvaluatePatternFilter, Apply, IndexedJoin, HashJoin, RollUpApply>; using LogicalOperatorLeafVisitor = utils::LeafVisitor<Once>; @@ -593,7 +593,7 @@ class ScanAllByLabel : public memgraph::query::plan::ScanAll { } }; -class ScanAllByEdgeType : public memgraph::query::plan::LogicalOperator { +class ScanAllByEdgeType : public memgraph::query::plan::ScanAll { public: static const utils::TypeInfo kType; const utils::TypeInfo &GetTypeInfo() const override { return kType; } @@ -613,10 +613,6 @@ class ScanAllByEdgeType : public memgraph::query::plan::LogicalOperator { return fmt::format("ScanAllByEdgeType ({} :{})", output_symbol_.name(), dba_->EdgeTypeToName(edge_type_)); } - std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; - Symbol output_symbol_; - storage::View view_; - storage::EdgeTypeId edge_type_; std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { @@ -819,6 +815,35 @@ class ScanAllById : public memgraph::query::plan::ScanAll { return object; } }; +class ScanAllByEdgeId : public memgraph::query::plan::ScanAll { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ScanAllByEdgeId() = default; + ScanAllByEdgeId(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, Expression *expression, + storage::View view = storage::View::OLD); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::string ToString() const override { return fmt::format("ScanAllByEdgeId ({})", output_symbol_.name()); } + + Expression *expression_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<ScanAllByEdgeId>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->output_symbol_ = output_symbol_; + object->view_ = view_; + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } +}; struct ExpandCommon { static const utils::TypeInfo kType; diff --git a/src/query/plan/operator_type_info.cpp b/src/query/plan/operator_type_info.cpp index 6b0a28313..ac75581d1 100644 --- a/src/query/plan/operator_type_info.cpp +++ b/src/query/plan/operator_type_info.cpp @@ -49,9 +49,14 @@ constexpr utils::TypeInfo query::plan::ScanAllByLabelProperty::kType{ constexpr utils::TypeInfo query::plan::ScanAllById::kType{utils::TypeId::SCAN_ALL_BY_ID, "ScanAllById", &query::plan::ScanAll::kType}; + +// TODO check if the kType is correct here. constexpr utils::TypeInfo query::plan::ScanAllByEdgeType::kType{utils::TypeId::SCAN_ALL_BY_EDGE_TYPE, "ScanAllByEdgeType", &query::plan::ScanAll::kType}; +constexpr utils::TypeInfo query::plan::ScanAllByEdgeId::kType{utils::TypeId::SCAN_ALL_BY_ID, "ScanAllByEdgeId", + &query::plan::ScanAll::kType}; + constexpr utils::TypeInfo query::plan::ExpandCommon::kType{utils::TypeId::EXPAND_COMMON, "ExpandCommon", nullptr}; constexpr utils::TypeInfo query::plan::Expand::kType{utils::TypeId::EXPAND, "Expand", diff --git a/src/query/plan/pretty_print.cpp b/src/query/plan/pretty_print.cpp index eeb0c15b5..93daf9e57 100644 --- a/src/query/plan/pretty_print.cpp +++ b/src/query/plan/pretty_print.cpp @@ -83,6 +83,11 @@ bool PlanPrinter::PreVisit(query::plan::ScanAllByEdgeType &op) { return true; } +bool PlanPrinter::PreVisit(query::plan::ScanAllByEdgeId &op) { + WithPrintLn([&op](auto &out) { out << "* " << op.ToString(); }); + return true; +} + bool PlanPrinter::PreVisit(query::plan::Expand &op) { op.dba_ = dba_; WithPrintLn([&op](auto &out) { out << "* " << op.ToString(); }); @@ -484,6 +489,16 @@ bool PlanToJsonVisitor::PreVisit(ScanAllByEdgeType &op) { return false; } +bool PlanToJsonVisitor::PreVisit(ScanAllByEdgeId &op) { + json self; + self["name"] = "ScanAllByEdgeId"; + self["output_symbol"] = ToJson(op.output_symbol_); + op.input_->Accept(*this); + self["input"] = PopOutput(); + output_ = std::move(self); + return false; +} + bool PlanToJsonVisitor::PreVisit(CreateNode &op) { json self; self["name"] = "CreateNode"; diff --git a/src/query/plan/pretty_print.hpp b/src/query/plan/pretty_print.hpp index d62ae6bf2..9b418aa3e 100644 --- a/src/query/plan/pretty_print.hpp +++ b/src/query/plan/pretty_print.hpp @@ -68,6 +68,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelProperty &) override; bool PreVisit(ScanAllById &) override; bool PreVisit(ScanAllByEdgeType &) override; + bool PreVisit(ScanAllByEdgeId &) override; bool PreVisit(Expand &) override; bool PreVisit(ExpandVariable &) override; @@ -206,6 +207,7 @@ class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelProperty &) override; bool PreVisit(ScanAllById &) override; bool PreVisit(ScanAllByEdgeType &) override; + bool PreVisit(ScanAllByEdgeId &) override; bool PreVisit(EmptyResult &) override; bool PreVisit(Produce &) override; diff --git a/src/query/plan/rewrite/edge_type_index_lookup.hpp b/src/query/plan/rewrite/edge_type_index_lookup.hpp index ed8666513..eedf36ae0 100644 --- a/src/query/plan/rewrite/edge_type_index_lookup.hpp +++ b/src/query/plan/rewrite/edge_type_index_lookup.hpp @@ -254,6 +254,15 @@ class EdgeTypeIndexRewriter final : public HierarchicalLogicalOperatorVisitor { return true; } + bool PreVisit(ScanAllByEdgeId &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(ScanAllByEdgeId &) override { + prev_ops_.pop_back(); + return true; + } + bool PreVisit(ConstructNamedPath &op) override { prev_ops_.push_back(&op); return true; diff --git a/src/utils/event_counter.cpp b/src/utils/event_counter.cpp index 54ff4ed5c..8ad705906 100644 --- a/src/utils/event_counter.cpp +++ b/src/utils/event_counter.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -27,6 +27,7 @@ M(ScanAllByLabelPropertyOperator, Operator, "Number of times ScanAllByLabelProperty operator was used.") \ M(ScanAllByIdOperator, Operator, "Number of times ScanAllById operator was used.") \ M(ScanAllByEdgeTypeOperator, Operator, "Number of times ScanAllByEdgeTypeOperator operator was used.") \ + M(ScanAllByEdgeIdOperator, Operator, "Number of times ScanAllByEdgeIdOperator operator was used.") \ M(ExpandOperator, Operator, "Number of times Expand operator was used.") \ M(ExpandVariableOperator, Operator, "Number of times ExpandVariable operator was used.") \ M(ConstructNamedPathOperator, Operator, "Number of times ConstructNamedPath operator was used.") \ diff --git a/tests/unit/query_plan_checker.hpp b/tests/unit/query_plan_checker.hpp index 6eef3841a..aee241b55 100644 --- a/tests/unit/query_plan_checker.hpp +++ b/tests/unit/query_plan_checker.hpp @@ -66,6 +66,7 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { PRE_VISIT(ScanAllByLabelPropertyRange); PRE_VISIT(ScanAllByLabelProperty); PRE_VISIT(ScanAllByEdgeType); + PRE_VISIT(ScanAllByEdgeId); PRE_VISIT(ScanAllById); PRE_VISIT(Expand); PRE_VISIT(ExpandVariable); @@ -172,6 +173,7 @@ using ExpectDelete = OpChecker<Delete>; using ExpectScanAll = OpChecker<ScanAll>; using ExpectScanAllByLabel = OpChecker<ScanAllByLabel>; using ExpectScanAllByEdgeType = OpChecker<ScanAllByEdgeType>; +using ExpectScanAllByEdgeId = OpChecker<ScanAllByEdgeId>; using ExpectScanAllById = OpChecker<ScanAllById>; using ExpectExpand = OpChecker<Expand>; using ExpectConstructNamedPath = OpChecker<ConstructNamedPath>;