// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // License, and you may not use this file except in compliance with the Business Source License. // // As of the Change Date specified in that file, in accordance with // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. #include "query/plan/operator.hpp" #include <algorithm> #include <cctype> #include <cstdint> #include <limits> #include <optional> #include <queue> #include <random> #include <string> #include <tuple> #include <type_traits> #include <unordered_map> #include <unordered_set> #include <utility> #include <cppitertools/chain.hpp> #include <cppitertools/imap.hpp> #include "memory/query_memory_control.hpp" #include "query/common.hpp" #include "spdlog/spdlog.h" #include "csv/parsing.hpp" #include "license/license.hpp" #include "query/auth_checker.hpp" #include "query/context.hpp" #include "query/db_accessor.hpp" #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol_table.hpp" #include "query/graph.hpp" #include "query/interpret/eval.hpp" #include "query/path.hpp" #include "query/plan/scoped_profile.hpp" #include "query/procedure/cypher_types.hpp" #include "query/procedure/mg_procedure_impl.hpp" #include "query/procedure/module.hpp" #include "query/typed_value.hpp" #include "storage/v2/property_value.hpp" #include "storage/v2/view.hpp" #include "utils/algorithm.hpp" #include "utils/event_counter.hpp" #include "utils/exceptions.hpp" #include "utils/fnv.hpp" #include "utils/java_string_formatter.hpp" #include "utils/likely.hpp" #include "utils/logging.hpp" #include "utils/memory.hpp" #include "utils/memory_tracker.hpp" #include "utils/message.hpp" #include "utils/on_scope_exit.hpp" #include "utils/pmr/deque.hpp" #include "utils/pmr/list.hpp" #include "utils/pmr/unordered_map.hpp" #include "utils/pmr/unordered_set.hpp" #include "utils/pmr/vector.hpp" #include "utils/readable_size.hpp" #include "utils/string.hpp" #include "utils/temporal.hpp" #include "utils/typeinfo.hpp" // macro for the default implementation of LogicalOperator::Accept // that accepts the visitor and visits it's input_ operator // NOLINTNEXTLINE #define ACCEPT_WITH_INPUT(class_name) \ bool class_name::Accept(HierarchicalLogicalOperatorVisitor &visitor) { \ if (visitor.PreVisit(*this)) { \ if (input_ == nullptr) { \ throw QueryRuntimeException( \ "The query couldn't be executed due to the unexpected null value in " #class_name \ " operator. To learn more about operators visit https://memgr.ph/query-operators!"); \ } \ input_->Accept(visitor); \ } \ return visitor.PostVisit(*this); \ } #define WITHOUT_SINGLE_INPUT(class_name) \ bool class_name::HasSingleInput() const { return false; } \ std::shared_ptr<LogicalOperator> class_name::input() const { \ LOG_FATAL("Operator " #class_name " has no single input!"); \ } \ void class_name::set_input(std::shared_ptr<LogicalOperator>) { \ LOG_FATAL("Operator " #class_name " has no single input!"); \ } namespace memgraph::metrics { extern const Event OnceOperator; extern const Event CreateNodeOperator; extern const Event CreateExpandOperator; extern const Event ScanAllOperator; extern const Event ScanAllByLabelOperator; extern const Event ScanAllByLabelPropertyRangeOperator; extern const Event ScanAllByLabelPropertyValueOperator; extern const Event ScanAllByLabelPropertyOperator; extern const Event ScanAllByIdOperator; extern const Event ExpandOperator; extern const Event ExpandVariableOperator; extern const Event ConstructNamedPathOperator; extern const Event FilterOperator; extern const Event ProduceOperator; extern const Event DeleteOperator; extern const Event SetPropertyOperator; extern const Event SetPropertiesOperator; extern const Event SetLabelsOperator; extern const Event RemovePropertyOperator; extern const Event RemoveLabelsOperator; extern const Event EdgeUniquenessFilterOperator; extern const Event AccumulateOperator; extern const Event AggregateOperator; extern const Event SkipOperator; extern const Event LimitOperator; extern const Event OrderByOperator; extern const Event MergeOperator; extern const Event OptionalOperator; extern const Event UnwindOperator; extern const Event DistinctOperator; extern const Event UnionOperator; extern const Event CartesianOperator; extern const Event CallProcedureOperator; extern const Event ForeachOperator; extern const Event EmptyResultOperator; extern const Event EvaluatePatternFilterOperator; extern const Event ApplyOperator; extern const Event IndexedJoinOperator; extern const Event HashJoinOperator; } // namespace memgraph::metrics namespace memgraph::query::plan { using OOMExceptionEnabler = utils::MemoryTracker::OutOfMemoryExceptionEnabler; namespace { // Custom equality function for a vector of typed values. // Used in unordered_maps in Aggregate and Distinct operators. struct TypedValueVectorEqual { template <class TAllocator> bool operator()(const std::vector<TypedValue, TAllocator> &left, const std::vector<TypedValue, TAllocator> &right) const { MG_ASSERT(left.size() == right.size(), "TypedValueVector comparison should only be done over vectors " "of the same size"); return std::equal(left.begin(), left.end(), right.begin(), TypedValue::BoolEqual{}); } }; // Returns boolean result of evaluating filter expression. Null is treated as // false. Other non boolean values raise a QueryRuntimeException. bool EvaluateFilter(ExpressionEvaluator &evaluator, Expression *filter) { TypedValue result = filter->Accept(evaluator); // Null is treated like false. if (result.IsNull()) return false; if (result.type() != TypedValue::Type::Bool) throw QueryRuntimeException("Filter expression must evaluate to bool or null, got {}.", result.type()); return result.ValueBool(); } template <typename T> uint64_t ComputeProfilingKey(const T *obj) { static_assert(sizeof(T *) == sizeof(uint64_t)); return reinterpret_cast<uint64_t>(obj); } inline void AbortCheck(ExecutionContext const &context) { if (auto const reason = MustAbort(context); reason != AbortReason::NO_ABORT) throw HintedAbortError(reason); } } // namespace // NOLINTNEXTLINE(cppcoreguidelines-macro-usage) #define SCOPED_PROFILE_OP(name) ScopedProfile profile{ComputeProfilingKey(this), name, &context}; // NOLINTNEXTLINE(cppcoreguidelines-macro-usage) #define SCOPED_PROFILE_OP_BY_REF(ref) ScopedProfile profile{ComputeProfilingKey(this), ref, &context}; bool Once::OnceCursor::Pull(Frame &, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("Once"); if (!did_pull_) { did_pull_ = true; return true; } return false; } UniqueCursorPtr Once::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::OnceOperator); return MakeUniqueCursorPtr<OnceCursor>(mem); } WITHOUT_SINGLE_INPUT(Once); void Once::OnceCursor::Shutdown() {} void Once::OnceCursor::Reset() { did_pull_ = false; } CreateNode::CreateNode(const std::shared_ptr<LogicalOperator> &input, NodeCreationInfo node_info) : input_(input ? input : std::make_shared<Once>()), node_info_(std::move(node_info)) {} // Creates a vertex on this GraphDb. Returns a reference to vertex placed on the // frame. VertexAccessor &CreateLocalVertex(const NodeCreationInfo &node_info, Frame *frame, ExecutionContext &context, std::vector<storage::LabelId> &labels, ExpressionEvaluator &evaluator) { auto &dba = *context.db_accessor; auto new_node = dba.InsertVertex(); context.execution_stats[ExecutionStats::Key::CREATED_NODES] += 1; for (const auto &label : labels) { auto maybe_error = std::invoke([&] { return new_node.AddLabel(label); }); if (maybe_error.HasError()) { switch (maybe_error.GetError()) { case storage::Error::SERIALIZATION_ERROR: throw TransactionSerializationException(); case storage::Error::DELETED_OBJECT: throw QueryRuntimeException("Trying to set a label on a deleted node."); case storage::Error::VERTEX_HAS_EDGES: case storage::Error::PROPERTIES_DISABLED: case storage::Error::NONEXISTENT_OBJECT: throw QueryRuntimeException("Unexpected error when setting a label."); } } context.execution_stats[ExecutionStats::Key::CREATED_LABELS] += 1; } // TODO: PropsSetChecked allocates a PropertyValue, make it use context.memory // when we update PropertyValue with custom allocator. std::map<storage::PropertyId, storage::PropertyValue> properties; if (const auto *node_info_properties = std::get_if<PropertiesMapList>(&node_info.properties)) { for (const auto &[key, value_expression] : *node_info_properties) { properties.emplace(key, value_expression->Accept(evaluator)); } } else { auto property_map = evaluator.Visit(*std::get<ParameterLookup *>(node_info.properties)); for (const auto &[key, value] : property_map.ValueMap()) { properties.emplace(dba.NameToProperty(key), value); } } MultiPropsInitChecked(&new_node, properties); (*frame)[node_info.symbol] = new_node; return (*frame)[node_info.symbol].ValueVertex(); } ACCEPT_WITH_INPUT(CreateNode) UniqueCursorPtr CreateNode::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::CreateNodeOperator); return MakeUniqueCursorPtr<CreateNodeCursor>(mem, *this, mem); } std::vector<Symbol> CreateNode::ModifiedSymbols(const SymbolTable &table) const { auto symbols = input_->ModifiedSymbols(table); symbols.emplace_back(node_info_.symbol); return symbols; } CreateNode::CreateNodeCursor::CreateNodeCursor(const CreateNode &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} bool CreateNode::CreateNodeCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("CreateNode"); ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW); if (input_cursor_->Pull(frame, context)) { // we have to resolve the labels before we can check for permissions std::vector<storage::LabelId> labels; for (auto label : self_.node_info_.labels) { if (const auto *label_atom = std::get_if<storage::LabelId>(&label)) { labels.emplace_back(*label_atom); } else { labels.emplace_back( context.db_accessor->NameToLabel(std::get<Expression *>(label)->Accept(evaluator).ValueString())); } } #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !context.auth_checker->Has(labels, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)) { throw QueryRuntimeException("Vertex not created due to not having enough permission!"); } #endif auto created_vertex = CreateLocalVertex(self_.node_info_, &frame, context, labels, evaluator); if (context.trigger_context_collector) { context.trigger_context_collector->RegisterCreatedObject(created_vertex); } return true; } return false; } void CreateNode::CreateNodeCursor::Shutdown() { input_cursor_->Shutdown(); } void CreateNode::CreateNodeCursor::Reset() { input_cursor_->Reset(); } CreateExpand::CreateExpand(NodeCreationInfo node_info, EdgeCreationInfo edge_info, const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, bool existing_node) : node_info_(std::move(node_info)), edge_info_(std::move(edge_info)), input_(input ? input : std::make_shared<Once>()), input_symbol_(std::move(input_symbol)), existing_node_(existing_node) {} ACCEPT_WITH_INPUT(CreateExpand) UniqueCursorPtr CreateExpand::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::CreateNodeOperator); return MakeUniqueCursorPtr<CreateExpandCursor>(mem, *this, mem); } std::vector<Symbol> CreateExpand::ModifiedSymbols(const SymbolTable &table) const { auto symbols = input_->ModifiedSymbols(table); symbols.emplace_back(node_info_.symbol); symbols.emplace_back(edge_info_.symbol); return symbols; } CreateExpand::CreateExpandCursor::CreateExpandCursor(const CreateExpand &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} namespace { EdgeAccessor CreateEdge(const EdgeCreationInfo &edge_info, DbAccessor *dba, VertexAccessor *from, VertexAccessor *to, Frame *frame, ExpressionEvaluator *evaluator) { auto maybe_edge = dba->InsertEdge(from, to, edge_info.edge_type); if (maybe_edge.HasValue()) { auto &edge = *maybe_edge; std::map<storage::PropertyId, storage::PropertyValue> properties; if (const auto *edge_info_properties = std::get_if<PropertiesMapList>(&edge_info.properties)) { for (const auto &[key, value_expression] : *edge_info_properties) { properties.emplace(key, value_expression->Accept(*evaluator)); } } else { auto property_map = evaluator->Visit(*std::get<ParameterLookup *>(edge_info.properties)); for (const auto &[key, value] : property_map.ValueMap()) { properties.emplace(dba->NameToProperty(key), value); } } if (!properties.empty()) MultiPropsInitChecked(&edge, properties); (*frame)[edge_info.symbol] = edge; } else { switch (maybe_edge.GetError()) { case storage::Error::SERIALIZATION_ERROR: throw TransactionSerializationException(); case storage::Error::DELETED_OBJECT: throw QueryRuntimeException("Trying to create an edge on a deleted node."); case storage::Error::VERTEX_HAS_EDGES: case storage::Error::PROPERTIES_DISABLED: case storage::Error::NONEXISTENT_OBJECT: throw QueryRuntimeException("Unexpected error when creating an edge."); } } return *maybe_edge; } } // namespace bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP_BY_REF(self_); if (!input_cursor_->Pull(frame, context)) return false; ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW); std::vector<storage::LabelId> labels; for (auto label : self_.node_info_.labels) { if (const auto *label_atom = std::get_if<storage::LabelId>(&label)) { labels.emplace_back(*label_atom); } else { labels.emplace_back( context.db_accessor->NameToLabel(std::get<Expression *>(label)->Accept(evaluator).ValueString())); } } #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast()) { const auto fine_grained_permission = self_.existing_node_ ? memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE : memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE; if (context.auth_checker && !(context.auth_checker->Has(self_.edge_info_.edge_type, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE) && context.auth_checker->Has(labels, fine_grained_permission))) { throw QueryRuntimeException("Edge not created due to not having enough permission!"); } } #endif // get the origin vertex TypedValue &vertex_value = frame[self_.input_symbol_]; ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); auto &v1 = vertex_value.ValueVertex(); // get the destination vertex (possibly an existing node) auto &v2 = OtherVertex(frame, context, labels, evaluator); // create an edge between the two nodes auto *dba = context.db_accessor; auto created_edge = [&] { switch (self_.edge_info_.direction) { case EdgeAtom::Direction::IN: return CreateEdge(self_.edge_info_, dba, &v2, &v1, &frame, &evaluator); case EdgeAtom::Direction::OUT: // in the case of an undirected CreateExpand we choose an arbitrary // direction. this is used in the MERGE clause // it is not allowed in the CREATE clause, and the semantic // checker needs to ensure it doesn't reach this point case EdgeAtom::Direction::BOTH: return CreateEdge(self_.edge_info_, dba, &v1, &v2, &frame, &evaluator); } }(); context.execution_stats[ExecutionStats::Key::CREATED_EDGES] += 1; if (context.trigger_context_collector) { context.trigger_context_collector->RegisterCreatedObject(created_edge); } return true; } void CreateExpand::CreateExpandCursor::Shutdown() { input_cursor_->Shutdown(); } void CreateExpand::CreateExpandCursor::Reset() { input_cursor_->Reset(); } VertexAccessor &CreateExpand::CreateExpandCursor::OtherVertex(Frame &frame, ExecutionContext &context, std::vector<storage::LabelId> &labels, ExpressionEvaluator &evaluator) { if (self_.existing_node_) { TypedValue &dest_node_value = frame[self_.node_info_.symbol]; ExpectType(self_.node_info_.symbol, dest_node_value, TypedValue::Type::Vertex); return dest_node_value.ValueVertex(); } else { auto &created_vertex = CreateLocalVertex(self_.node_info_, &frame, context, labels, evaluator); if (context.trigger_context_collector) { context.trigger_context_collector->RegisterCreatedObject(created_vertex); } return created_vertex; } } template <class TVerticesFun> class ScanAllCursor : public Cursor { public: explicit ScanAllCursor(const ScanAll &self, Symbol output_symbol, UniqueCursorPtr input_cursor, storage::View view, TVerticesFun get_vertices, const char *op_name) : self_(self), output_symbol_(std::move(output_symbol)), input_cursor_(std::move(input_cursor)), view_(view), get_vertices_(std::move(get_vertices)), op_name_(op_name) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP_BY_REF(self_); AbortCheck(context); while (!vertices_ || vertices_it_.value() == vertices_end_it_.value()) { if (!input_cursor_->Pull(frame, context)) return false; // We need a getter function, because in case of exhausting a lazy // iterable, we cannot simply reset it by calling begin(). auto next_vertices = get_vertices_(frame, context); if (!next_vertices) continue; // Since vertices iterator isn't nothrow_move_assignable, we have to use // the roundabout assignment + emplace, instead of simple: // vertices _ = get_vertices_(frame, context); vertices_.emplace(std::move(next_vertices.value())); vertices_it_.emplace(vertices_.value().begin()); vertices_end_it_.emplace(vertices_.value().end()); } #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !FindNextVertex(context)) { return false; } #endif frame[output_symbol_] = *vertices_it_.value(); ++vertices_it_.value(); return true; } #ifdef MG_ENTERPRISE bool FindNextVertex(const ExecutionContext &context) { while (vertices_it_.value() != vertices_end_it_.value()) { if (context.auth_checker->Has(*vertices_it_.value(), view_, memgraph::query::AuthQuery::FineGrainedPrivilege::READ)) { return true; } ++vertices_it_.value(); } return false; } #endif void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); vertices_ = std::nullopt; vertices_it_ = std::nullopt; vertices_end_it_ = std::nullopt; } private: const ScanAll &self_; const Symbol output_symbol_; const UniqueCursorPtr input_cursor_; storage::View view_; TVerticesFun get_vertices_; std::optional<typename std::result_of<TVerticesFun(Frame &, ExecutionContext &)>::type::value_type> vertices_; std::optional<decltype(vertices_.value().begin())> vertices_it_; std::optional<decltype(vertices_.value().end())> vertices_end_it_; const char *op_name_; }; ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::View view) : input_(input ? input : std::make_shared<Once>()), output_symbol_(std::move(output_symbol)), view_(view) {} ACCEPT_WITH_INPUT(ScanAll) UniqueCursorPtr ScanAll::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllOperator); auto vertices = [this](Frame &, ExecutionContext &context) { auto *db = context.db_accessor; return std::make_optional(db->Vertices(view_)); }; return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, *this, output_symbol_, input_->MakeCursor(mem), view_, std::move(vertices), "ScanAll"); } std::vector<Symbol> ScanAll::ModifiedSymbols(const SymbolTable &table) const { auto symbols = input_->ModifiedSymbols(table); symbols.emplace_back(output_symbol_); return symbols; } ScanAllByLabel::ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::LabelId label, storage::View view) : ScanAll(input, output_symbol, view), label_(label) {} ACCEPT_WITH_INPUT(ScanAllByLabel) UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByLabelOperator); auto vertices = [this](Frame &, ExecutionContext &context) { auto *db = context.db_accessor; return std::make_optional(db->Vertices(view_, label_)); }; return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, *this, output_symbol_, input_->MakeCursor(mem), view_, std::move(vertices), "ScanAllByLabel"); } // TODO(buda): Implement ScanAllByLabelProperty operator to iterate over // vertices that have the label and some value for the given property. ScanAllByLabelPropertyRange::ScanAllByLabelPropertyRange(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::LabelId label, storage::PropertyId property, std::string property_name, std::optional<Bound> lower_bound, std::optional<Bound> upper_bound, storage::View view) : ScanAll(input, output_symbol, view), label_(label), property_(property), property_name_(std::move(property_name)), lower_bound_(lower_bound), upper_bound_(upper_bound) { MG_ASSERT(lower_bound_ || upper_bound_, "Only one bound can be left out"); } ACCEPT_WITH_INPUT(ScanAllByLabelPropertyRange) UniqueCursorPtr ScanAllByLabelPropertyRange::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByLabelPropertyRangeOperator); auto vertices = [this](Frame &frame, ExecutionContext &context) -> std::optional<decltype(context.db_accessor->Vertices(view_, label_, property_, std::nullopt, std::nullopt))> { auto *db = context.db_accessor; ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_); auto convert = [&evaluator](const auto &bound) -> std::optional<utils::Bound<storage::PropertyValue>> { if (!bound) return std::nullopt; const auto &value = bound->value()->Accept(evaluator); try { const auto &property_value = storage::PropertyValue(value); switch (property_value.type()) { case storage::PropertyValue::Type::Bool: case storage::PropertyValue::Type::List: case storage::PropertyValue::Type::Map: // Prevent indexed lookup with something that would fail if we did // the original filter with `operator<`. Note, for some reason, // Cypher does not support comparing boolean values. throw QueryRuntimeException("Invalid type {} for '<'.", value.type()); case storage::PropertyValue::Type::Null: case storage::PropertyValue::Type::Int: case storage::PropertyValue::Type::Double: case storage::PropertyValue::Type::String: case storage::PropertyValue::Type::TemporalData: // These are all fine, there's also Point, Date and Time data types // which were added to Cypher, but we don't have support for those // yet. return std::make_optional(utils::Bound<storage::PropertyValue>(property_value, bound->type())); } } catch (const TypedValueException &) { throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type()); } }; auto maybe_lower = convert(lower_bound_); auto maybe_upper = convert(upper_bound_); // If any bound is null, then the comparison would result in nulls. This // is treated as not satisfying the filter, so return no vertices. if (maybe_lower && maybe_lower->value().IsNull()) return std::nullopt; if (maybe_upper && maybe_upper->value().IsNull()) return std::nullopt; return std::make_optional(db->Vertices(view_, label_, property_, maybe_lower, maybe_upper)); }; return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>( mem, *this, output_symbol_, input_->MakeCursor(mem), view_, std::move(vertices), "ScanAllByLabelPropertyRange"); } ScanAllByLabelPropertyValue::ScanAllByLabelPropertyValue(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::LabelId label, storage::PropertyId property, std::string property_name, Expression *expression, storage::View view) : ScanAll(input, output_symbol, view), label_(label), property_(property), property_name_(std::move(property_name)), expression_(expression) { DMG_ASSERT(expression, "Expression is not optional."); } ACCEPT_WITH_INPUT(ScanAllByLabelPropertyValue) UniqueCursorPtr ScanAllByLabelPropertyValue::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByLabelPropertyValueOperator); auto vertices = [this](Frame &frame, ExecutionContext &context) -> std::optional<decltype(context.db_accessor->Vertices(view_, label_, property_, storage::PropertyValue()))> { auto *db = context.db_accessor; ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_); auto value = expression_->Accept(evaluator); if (value.IsNull()) return std::nullopt; if (!value.IsPropertyValue()) { throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type()); } return std::make_optional(db->Vertices(view_, label_, property_, storage::PropertyValue(value))); }; return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>( mem, *this, output_symbol_, input_->MakeCursor(mem), view_, std::move(vertices), "ScanAllByLabelPropertyValue"); } ScanAllByLabelProperty::ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::LabelId label, storage::PropertyId property, std::string property_name, storage::View view) : ScanAll(input, output_symbol, view), label_(label), property_(property), property_name_(std::move(property_name)) {} ACCEPT_WITH_INPUT(ScanAllByLabelProperty) UniqueCursorPtr ScanAllByLabelProperty::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByLabelPropertyOperator); auto vertices = [this](Frame &frame, ExecutionContext &context) { auto *db = context.db_accessor; return std::make_optional(db->Vertices(view_, label_, property_)); }; return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, *this, output_symbol_, input_->MakeCursor(mem), view_, std::move(vertices), "ScanAllByLabelProperty"); } ScanAllById::ScanAllById(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, Expression *expression, storage::View view) : ScanAll(input, output_symbol, view), expression_(expression) { MG_ASSERT(expression); } ACCEPT_WITH_INPUT(ScanAllById) UniqueCursorPtr ScanAllById::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByIdOperator); auto vertices = [this](Frame &frame, ExecutionContext &context) -> std::optional<std::vector<VertexAccessor>> { auto *db = context.db_accessor; ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_); auto value = expression_->Accept(evaluator); if (!value.IsNumeric()) return std::nullopt; int64_t id = value.IsInt() ? value.ValueInt() : value.ValueDouble(); if (value.IsDouble() && id != value.ValueDouble()) return std::nullopt; auto maybe_vertex = db->FindVertex(storage::Gid::FromInt(id), view_); if (!maybe_vertex) return std::nullopt; return std::vector<VertexAccessor>{*maybe_vertex}; }; return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, *this, output_symbol_, input_->MakeCursor(mem), view_, std::move(vertices), "ScanAllById"); } namespace { bool CheckExistingNode(const VertexAccessor &new_node, const Symbol &existing_node_sym, Frame &frame) { const TypedValue &existing_node = frame[existing_node_sym]; if (existing_node.IsNull()) return false; ExpectType(existing_node_sym, existing_node, TypedValue::Type::Vertex); return existing_node.ValueVertex() == new_node; } template <class TEdgesResult> auto UnwrapEdgesResult(storage::Result<TEdgesResult> &&result) { if (result.HasError()) { switch (result.GetError()) { case storage::Error::DELETED_OBJECT: throw QueryRuntimeException("Trying to get relationships of a deleted node."); case storage::Error::NONEXISTENT_OBJECT: throw query::QueryRuntimeException("Trying to get relationships from a node that doesn't exist."); case storage::Error::VERTEX_HAS_EDGES: case storage::Error::SERIALIZATION_ERROR: case storage::Error::PROPERTIES_DISABLED: throw QueryRuntimeException("Unexpected error when accessing relationships."); } } return std::move(*result); } } // namespace Expand::Expand(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Symbol node_symbol, Symbol edge_symbol, EdgeAtom::Direction direction, const std::vector<storage::EdgeTypeId> &edge_types, bool existing_node, storage::View view) : input_(input ? input : std::make_shared<Once>()), input_symbol_(std::move(input_symbol)), common_{node_symbol, edge_symbol, direction, edge_types, existing_node}, view_(view) {} ACCEPT_WITH_INPUT(Expand) UniqueCursorPtr Expand::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::ExpandOperator); return MakeUniqueCursorPtr<ExpandCursor>(mem, *this, mem); } std::vector<Symbol> Expand::ModifiedSymbols(const SymbolTable &table) const { auto symbols = input_->ModifiedSymbols(table); symbols.emplace_back(common_.node_symbol); symbols.emplace_back(common_.edge_symbol); return symbols; } Expand::ExpandCursor::ExpandCursor(const Expand &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} Expand::ExpandCursor::ExpandCursor(const Expand &self, int64_t input_degree, int64_t existing_node_degree, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)), prev_input_degree_(input_degree), prev_existing_degree_(existing_node_degree) {} bool Expand::ExpandCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP_BY_REF(self_); // A helper function for expanding a node from an edge. auto pull_node = [this, &frame](const EdgeAccessor &new_edge, EdgeAtom::Direction direction) { if (self_.common_.existing_node) return; switch (direction) { case EdgeAtom::Direction::IN: frame[self_.common_.node_symbol] = new_edge.From(); break; case EdgeAtom::Direction::OUT: frame[self_.common_.node_symbol] = new_edge.To(); break; case EdgeAtom::Direction::BOTH: LOG_FATAL("Must indicate exact expansion direction here"); } }; while (true) { AbortCheck(context); // attempt to get a value from the incoming edges if (in_edges_ && *in_edges_it_ != in_edges_->end()) { auto edge = *(*in_edges_it_)++; #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge.From(), self_.view_, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { continue; } #endif frame[self_.common_.edge_symbol] = edge; pull_node(edge, EdgeAtom::Direction::IN); return true; } // attempt to get a value from the outgoing edges if (out_edges_ && *out_edges_it_ != out_edges_->end()) { auto edge = *(*out_edges_it_)++; // when expanding in EdgeAtom::Direction::BOTH directions // we should do only one expansion for cycles, and it was // already done in the block above if (self_.common_.direction == EdgeAtom::Direction::BOTH && edge.IsCycle()) continue; #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge.To(), self_.view_, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { continue; } #endif frame[self_.common_.edge_symbol] = edge; pull_node(edge, EdgeAtom::Direction::OUT); return true; } // If we are here, either the edges have not been initialized, // or they have been exhausted. Attempt to initialize the edges. if (!InitEdges(frame, context)) return false; // we have re-initialized the edges, continue with the loop } } void Expand::ExpandCursor::Shutdown() { input_cursor_->Shutdown(); } void Expand::ExpandCursor::Reset() { input_cursor_->Reset(); in_edges_ = std::nullopt; in_edges_it_ = std::nullopt; out_edges_ = std::nullopt; out_edges_it_ = std::nullopt; } ExpansionInfo Expand::ExpandCursor::GetExpansionInfo(Frame &frame) { TypedValue &vertex_value = frame[self_.input_symbol_]; if (vertex_value.IsNull()) { return ExpansionInfo{}; } ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); auto &vertex = vertex_value.ValueVertex(); auto direction = self_.common_.direction; if (!self_.common_.existing_node) { return ExpansionInfo{.input_node = vertex, .direction = direction}; } TypedValue &existing_node = frame[self_.common_.node_symbol]; if (existing_node.IsNull()) { return ExpansionInfo{.input_node = vertex, .direction = direction}; } ExpectType(self_.common_.node_symbol, existing_node, TypedValue::Type::Vertex); auto &existing_vertex = existing_node.ValueVertex(); // -1 and -1 -> normal expansion // -1 and expanded -> can't happen // expanded and -1 -> reverse // expanded and expanded -> see if can reverse if ((prev_input_degree_ == -1 && prev_existing_degree_ == -1) || prev_input_degree_ < prev_existing_degree_) { return ExpansionInfo{.input_node = vertex, .direction = direction, .existing_node = existing_vertex}; } auto new_direction = direction; switch (new_direction) { case EdgeAtom::Direction::IN: new_direction = EdgeAtom::Direction::OUT; break; case EdgeAtom::Direction::OUT: new_direction = EdgeAtom::Direction::IN; break; default: new_direction = EdgeAtom::Direction::BOTH; break; } return ExpansionInfo{ .input_node = existing_vertex, .direction = new_direction, .existing_node = vertex, .reversed = true}; } bool Expand::ExpandCursor::InitEdges(Frame &frame, ExecutionContext &context) { // Input Vertex could be null if it is created by a failed optional match. In // those cases we skip that input pull and continue with the next. while (true) { if (!input_cursor_->Pull(frame, context)) return false; expansion_info_ = GetExpansionInfo(frame); if (!expansion_info_.input_node) { continue; } auto vertex = *expansion_info_.input_node; auto direction = expansion_info_.direction; int64_t num_expanded_first = -1; if (direction == EdgeAtom::Direction::IN || direction == EdgeAtom::Direction::BOTH) { if (self_.common_.existing_node) { if (expansion_info_.existing_node) { auto existing_node = *expansion_info_.existing_node; auto edges_result = UnwrapEdgesResult(vertex.InEdges(self_.view_, self_.common_.edge_types, existing_node)); in_edges_.emplace(edges_result.edges); num_expanded_first = edges_result.expanded_count; } } else { auto edges_result = UnwrapEdgesResult(vertex.InEdges(self_.view_, self_.common_.edge_types)); in_edges_.emplace(edges_result.edges); num_expanded_first = edges_result.expanded_count; } if (in_edges_) { in_edges_it_.emplace(in_edges_->begin()); } } int64_t num_expanded_second = -1; if (direction == EdgeAtom::Direction::OUT || direction == EdgeAtom::Direction::BOTH) { if (self_.common_.existing_node) { if (expansion_info_.existing_node) { auto existing_node = *expansion_info_.existing_node; auto edges_result = UnwrapEdgesResult(vertex.OutEdges(self_.view_, self_.common_.edge_types, existing_node)); out_edges_.emplace(edges_result.edges); num_expanded_second = edges_result.expanded_count; } } else { auto edges_result = UnwrapEdgesResult(vertex.OutEdges(self_.view_, self_.common_.edge_types)); out_edges_.emplace(edges_result.edges); num_expanded_second = edges_result.expanded_count; } if (out_edges_) { out_edges_it_.emplace(out_edges_->begin()); } } if (!expansion_info_.existing_node) { return true; } num_expanded_first = num_expanded_first == -1 ? 0 : num_expanded_first; num_expanded_second = num_expanded_second == -1 ? 0 : num_expanded_second; int64_t total_expanded_edges = num_expanded_first + num_expanded_second; if (!expansion_info_.reversed) { prev_input_degree_ = total_expanded_edges; } else { prev_existing_degree_ = total_expanded_edges; } return true; } } ExpandVariable::ExpandVariable(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Symbol node_symbol, Symbol edge_symbol, EdgeAtom::Type type, EdgeAtom::Direction direction, const std::vector<storage::EdgeTypeId> &edge_types, bool is_reverse, Expression *lower_bound, Expression *upper_bound, bool existing_node, ExpansionLambda filter_lambda, std::optional<ExpansionLambda> weight_lambda, std::optional<Symbol> total_weight) : input_(input ? input : std::make_shared<Once>()), input_symbol_(std::move(input_symbol)), common_{node_symbol, edge_symbol, direction, edge_types, existing_node}, type_(type), is_reverse_(is_reverse), lower_bound_(lower_bound), upper_bound_(upper_bound), filter_lambda_(std::move(filter_lambda)), weight_lambda_(std::move(weight_lambda)), total_weight_(std::move(total_weight)) { DMG_ASSERT(type_ == EdgeAtom::Type::DEPTH_FIRST || type_ == EdgeAtom::Type::BREADTH_FIRST || type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH || type_ == EdgeAtom::Type::ALL_SHORTEST_PATHS, "ExpandVariable can only be used with breadth first, depth first, " "weighted shortest path or all shortest paths type"); DMG_ASSERT(!(type_ == EdgeAtom::Type::BREADTH_FIRST && is_reverse), "Breadth first expansion can't be reversed"); } ACCEPT_WITH_INPUT(ExpandVariable) std::vector<Symbol> ExpandVariable::ModifiedSymbols(const SymbolTable &table) const { auto symbols = input_->ModifiedSymbols(table); symbols.emplace_back(common_.node_symbol); symbols.emplace_back(common_.edge_symbol); return symbols; } namespace { /** * Helper function that returns an iterable over * <EdgeAtom::Direction, EdgeAccessor> pairs * for the given params. * * @param vertex - The vertex to expand from. * @param direction - Expansion direction. All directions (IN, OUT, BOTH) * are supported. * @param memory - Used to allocate the result. * @return See above. */ auto ExpandFromVertex(const VertexAccessor &vertex, EdgeAtom::Direction direction, const std::vector<storage::EdgeTypeId> &edge_types, utils::MemoryResource *memory, DbAccessor *db_accessor) { // wraps an EdgeAccessor into a pair <accessor, direction> auto wrapper = [](EdgeAtom::Direction direction, auto &&edges) { return iter::imap([direction](const auto &edge) { return std::make_pair(edge, direction); }, std::forward<decltype(edges)>(edges)); }; storage::View view = storage::View::OLD; utils::pmr::vector<decltype(wrapper(direction, vertex.InEdges(view, edge_types).GetValue().edges))> chain_elements( memory); if (direction != EdgeAtom::Direction::OUT) { auto edges = UnwrapEdgesResult(vertex.InEdges(view, edge_types)).edges; if (edges.begin() != edges.end()) { chain_elements.emplace_back(wrapper(EdgeAtom::Direction::IN, std::move(edges))); } } if (direction != EdgeAtom::Direction::IN) { auto edges = UnwrapEdgesResult(vertex.OutEdges(view, edge_types)).edges; if (edges.begin() != edges.end()) { chain_elements.emplace_back(wrapper(EdgeAtom::Direction::OUT, std::move(edges))); } } // TODO: Investigate whether itertools perform heap allocation? return iter::chain.from_iterable(std::move(chain_elements)); } } // namespace class ExpandVariableCursor : public Cursor { public: ExpandVariableCursor(const ExpandVariable &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)), edges_(mem), edges_it_(mem) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP_BY_REF(self_); ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); while (true) { if (Expand(frame, context)) return true; if (PullInput(frame, context)) { // if lower bound is zero we also yield empty paths if (lower_bound_ == 0) { auto &start_vertex = frame[self_.input_symbol_].ValueVertex(); if (!self_.common_.existing_node) { frame[self_.common_.node_symbol] = start_vertex; return true; } if (CheckExistingNode(start_vertex, self_.common_.node_symbol, frame)) { return true; } } // if lower bound is not zero, we just continue, the next // loop iteration will attempt to expand and we're good } else return false; // else continue with the loop, try to expand again // because we succesfully pulled from the input } } void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); edges_.clear(); edges_it_.clear(); } private: const ExpandVariable &self_; const UniqueCursorPtr input_cursor_; // bounds. in the cursor they are not optional but set to // default values if missing in the ExpandVariable operator // initialize to arbitrary values, they should only be used // after a successful pull from the input int64_t upper_bound_{-1}; int64_t lower_bound_{-1}; // a stack of edge iterables corresponding to the level/depth of // the expansion currently being Pulled using ExpandEdges = decltype(ExpandFromVertex(std::declval<VertexAccessor>(), EdgeAtom::Direction::IN, self_.common_.edge_types, utils::NewDeleteResource(), std::declval<DbAccessor *>())); utils::pmr::vector<ExpandEdges> edges_; // an iterator indicating the position in the corresponding edges_ element utils::pmr::vector<decltype(edges_.begin()->begin())> edges_it_; /** * Helper function that Pulls from the input vertex and * makes iteration over it's edges possible. * * @return If the Pull succeeded. If not, this VariableExpandCursor * is exhausted. */ bool PullInput(Frame &frame, ExecutionContext &context) { // Input Vertex could be null if it is created by a failed optional match. // In those cases we skip that input pull and continue with the next. while (true) { AbortCheck(context); if (!input_cursor_->Pull(frame, context)) return false; TypedValue &vertex_value = frame[self_.input_symbol_]; // Null check due to possible failed optional match. if (vertex_value.IsNull()) continue; ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); auto &vertex = vertex_value.ValueVertex(); // Evaluate the upper and lower bounds. ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); auto calc_bound = [&evaluator](auto &bound) { auto value = EvaluateInt(&evaluator, bound, "Variable expansion bound"); if (value < 0) throw QueryRuntimeException("Variable expansion bound must be a non-negative integer."); return value; }; lower_bound_ = self_.lower_bound_ ? calc_bound(self_.lower_bound_) : 1; upper_bound_ = self_.upper_bound_ ? calc_bound(self_.upper_bound_) : std::numeric_limits<int64_t>::max(); if (upper_bound_ > 0) { auto *memory = edges_.get_allocator().GetMemoryResource(); edges_.emplace_back( ExpandFromVertex(vertex, self_.common_.direction, self_.common_.edge_types, memory, context.db_accessor)); edges_it_.emplace_back(edges_.back().begin()); } if (self_.filter_lambda_.accumulated_path_symbol) { // Add initial vertex of path to the accumulated path frame[self_.filter_lambda_.accumulated_path_symbol.value()] = Path(vertex); } // reset the frame value to an empty edge list auto *pull_memory = context.evaluation_context.memory; frame[self_.common_.edge_symbol] = TypedValue::TVector(pull_memory); return true; } } // Helper function for appending an edge to the list on the frame. void AppendEdge(const EdgeAccessor &new_edge, utils::pmr::vector<TypedValue> *edges_on_frame) { // We are placing an edge on the frame. It is possible that there already // exists an edge on the frame for this level. If so first remove it. DMG_ASSERT(edges_.size() > 0, "Edges are empty"); if (self_.is_reverse_) { // TODO: This is innefficient, we should look into replacing // vector with something else for TypedValue::List. size_t diff = edges_on_frame->size() - std::min(edges_on_frame->size(), edges_.size() - 1U); if (diff > 0U) edges_on_frame->erase(edges_on_frame->begin(), edges_on_frame->begin() + diff); edges_on_frame->emplace(edges_on_frame->begin(), new_edge); } else { edges_on_frame->resize(std::min(edges_on_frame->size(), edges_.size() - 1U)); edges_on_frame->emplace_back(new_edge); } } /** * Performs a single expansion for the current state of this * VariableExpansionCursor. * * @return True if the expansion was a success and this Cursor's * consumer can consume it. False if the expansion failed. In that * case no more expansions are available from the current input * vertex and another Pull from the input cursor should be performed. */ bool Expand(Frame &frame, ExecutionContext &context) { ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); // Some expansions might not be valid due to edge uniqueness and // existing_node criterions, so expand in a loop until either the input // vertex is exhausted or a valid variable-length expansion is available. while (true) { AbortCheck(context); // pop from the stack while there is stuff to pop and the current // level is exhausted while (!edges_.empty() && edges_it_.back() == edges_.back().end()) { edges_.pop_back(); edges_it_.pop_back(); } // check if we exhausted everything, if so return false if (edges_.empty()) return false; // we use this a lot auto &edges_on_frame = frame[self_.common_.edge_symbol].ValueList(); // it is possible that edges_on_frame does not contain as many // elements as edges_ due to edge-uniqueness (when a whole layer // gets exhausted but no edges are valid). for that reason only // pop from edges_on_frame if they contain enough elements if (self_.is_reverse_) { auto diff = edges_on_frame.size() - std::min(edges_on_frame.size(), edges_.size()); if (diff > 0) { edges_on_frame.erase(edges_on_frame.begin(), edges_on_frame.begin() + diff); } } else { edges_on_frame.resize(std::min(edges_on_frame.size(), edges_.size())); } // if we are here, we have a valid stack, // get the edge, increase the relevant iterator auto current_edge = *edges_it_.back()++; // Check edge-uniqueness. bool found_existing = std::any_of(edges_on_frame.begin(), edges_on_frame.end(), [¤t_edge](const TypedValue &edge) { return current_edge.first == edge.ValueEdge(); }); if (found_existing) continue; VertexAccessor current_vertex = current_edge.second == EdgeAtom::Direction::IN ? current_edge.first.From() : current_edge.first.To(); #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(current_edge.first, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(current_vertex, storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { continue; } #endif AppendEdge(current_edge.first, &edges_on_frame); if (!self_.common_.existing_node) { frame[self_.common_.node_symbol] = current_vertex; } // Skip expanding out of filtered expansion. frame[self_.filter_lambda_.inner_edge_symbol] = current_edge.first; frame[self_.filter_lambda_.inner_node_symbol] = current_vertex; if (self_.filter_lambda_.accumulated_path_symbol) { MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(), "Accumulated path must be path"); Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath(); // Shrink the accumulated path including current level if necessary while (accumulated_path.size() >= edges_on_frame.size()) { accumulated_path.Shrink(); } accumulated_path.Expand(current_edge.first); accumulated_path.Expand(current_vertex); } if (self_.filter_lambda_.expression && !EvaluateFilter(evaluator, self_.filter_lambda_.expression)) continue; // we are doing depth-first search, so place the current // edge's expansions onto the stack, if we should continue to expand if (upper_bound_ > static_cast<int64_t>(edges_.size())) { auto *memory = edges_.get_allocator().GetMemoryResource(); edges_.emplace_back(ExpandFromVertex(current_vertex, self_.common_.direction, self_.common_.edge_types, memory, context.db_accessor)); edges_it_.emplace_back(edges_.back().begin()); } if (self_.common_.existing_node && !CheckExistingNode(current_vertex, self_.common_.node_symbol, frame)) continue; // We only yield true if we satisfy the lower bound. if (static_cast<int64_t>(edges_on_frame.size()) >= lower_bound_) { return true; } } } }; class STShortestPathCursor : public query::plan::Cursor { public: STShortestPathCursor(const ExpandVariable &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_.input()->MakeCursor(mem)) { MG_ASSERT(self_.common_.existing_node, "s-t shortest path algorithm should only " "be used when `existing_node` flag is " "set!"); } bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("STShortestPath"); ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); while (input_cursor_->Pull(frame, context)) { const auto &source_tv = frame[self_.input_symbol_]; const auto &sink_tv = frame[self_.common_.node_symbol]; // It is possible that source or sink vertex is Null due to optional // matching. if (source_tv.IsNull() || sink_tv.IsNull()) continue; const auto &source = source_tv.ValueVertex(); const auto &sink = sink_tv.ValueVertex(); int64_t lower_bound = self_.lower_bound_ ? EvaluateInt(&evaluator, self_.lower_bound_, "Min depth in breadth-first expansion") : 1; int64_t upper_bound = self_.upper_bound_ ? EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in breadth-first expansion") : std::numeric_limits<int64_t>::max(); if (upper_bound < 1 || lower_bound > upper_bound) continue; if (FindPath(*context.db_accessor, source, sink, lower_bound, upper_bound, &frame, &evaluator, context)) { return true; } } return false; } void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); } private: const ExpandVariable &self_; UniqueCursorPtr input_cursor_; using VertexEdgeMapT = utils::pmr::unordered_map<VertexAccessor, std::optional<EdgeAccessor>>; void ReconstructPath(const VertexAccessor &midpoint, const VertexEdgeMapT &in_edge, const VertexEdgeMapT &out_edge, Frame *frame, utils::MemoryResource *pull_memory) { utils::pmr::vector<TypedValue> result(pull_memory); auto last_vertex = midpoint; while (true) { const auto &last_edge = in_edge.at(last_vertex); if (!last_edge) break; last_vertex = last_edge->From() == last_vertex ? last_edge->To() : last_edge->From(); result.emplace_back(*last_edge); } std::reverse(result.begin(), result.end()); last_vertex = midpoint; while (true) { const auto &last_edge = out_edge.at(last_vertex); if (!last_edge) break; last_vertex = last_edge->From() == last_vertex ? last_edge->To() : last_edge->From(); result.emplace_back(*last_edge); } frame->at(self_.common_.edge_symbol) = std::move(result); } bool ShouldExpand(const VertexAccessor &vertex, const EdgeAccessor &edge, Frame *frame, ExpressionEvaluator *evaluator) { if (!self_.filter_lambda_.expression) return true; frame->at(self_.filter_lambda_.inner_node_symbol) = vertex; frame->at(self_.filter_lambda_.inner_edge_symbol) = edge; TypedValue result = self_.filter_lambda_.expression->Accept(*evaluator); if (result.IsNull()) return false; if (result.IsBool()) return result.ValueBool(); throw QueryRuntimeException("Expansion condition must evaluate to boolean or null"); } bool FindPath(const DbAccessor &dba, const VertexAccessor &source, const VertexAccessor &sink, int64_t lower_bound, int64_t upper_bound, Frame *frame, ExpressionEvaluator *evaluator, const ExecutionContext &context) { using utils::Contains; if (source == sink) return false; // We expand from both directions, both from the source and the sink. // Expansions meet at the middle of the path if it exists. This should // perform better for real-world like graphs where the expansion front // grows exponentially, effectively reducing the exponent by half. auto *pull_memory = evaluator->GetMemoryResource(); // Holds vertices at the current level of expansion from the source // (sink). utils::pmr::vector<VertexAccessor> source_frontier(pull_memory); utils::pmr::vector<VertexAccessor> sink_frontier(pull_memory); // Holds vertices we can expand to from `source_frontier` // (`sink_frontier`). utils::pmr::vector<VertexAccessor> source_next(pull_memory); utils::pmr::vector<VertexAccessor> sink_next(pull_memory); // Maps each vertex we visited expanding from the source (sink) to the // edge used. Necessary for path reconstruction. VertexEdgeMapT in_edge(pull_memory); VertexEdgeMapT out_edge(pull_memory); size_t current_length = 0; source_frontier.emplace_back(source); in_edge[source] = std::nullopt; sink_frontier.emplace_back(sink); out_edge[sink] = std::nullopt; while (true) { AbortCheck(context); // Top-down step (expansion from the source). ++current_length; if (current_length > upper_bound) return false; for (const auto &vertex : source_frontier) { if (self_.common_.direction != EdgeAtom::Direction::IN) { auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : out_edges) { #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge.To(), storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { continue; } #endif if (ShouldExpand(edge.To(), edge, frame, evaluator) && !Contains(in_edge, edge.To())) { in_edge.emplace(edge.To(), edge); if (Contains(out_edge, edge.To())) { if (current_length >= lower_bound) { ReconstructPath(edge.To(), in_edge, out_edge, frame, pull_memory); return true; } else { return false; } } source_next.push_back(edge.To()); } } } if (self_.common_.direction != EdgeAtom::Direction::OUT) { auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : in_edges) { #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge.From(), storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { continue; } #endif if (ShouldExpand(edge.From(), edge, frame, evaluator) && !Contains(in_edge, edge.From())) { in_edge.emplace(edge.From(), edge); if (Contains(out_edge, edge.From())) { if (current_length >= lower_bound) { ReconstructPath(edge.From(), in_edge, out_edge, frame, pull_memory); return true; } else { return false; } } source_next.push_back(edge.From()); } } } } if (source_next.empty()) return false; source_frontier.clear(); std::swap(source_frontier, source_next); // Bottom-up step (expansion from the sink). ++current_length; if (current_length > upper_bound) return false; // When expanding from the sink we have to be careful which edge // endpoint we pass to `should_expand`, because everything is // reversed. for (const auto &vertex : sink_frontier) { if (self_.common_.direction != EdgeAtom::Direction::OUT) { auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : out_edges) { #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge.To(), storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { continue; } #endif if (ShouldExpand(vertex, edge, frame, evaluator) && !Contains(out_edge, edge.To())) { out_edge.emplace(edge.To(), edge); if (Contains(in_edge, edge.To())) { if (current_length >= lower_bound) { ReconstructPath(edge.To(), in_edge, out_edge, frame, pull_memory); return true; } else { return false; } } sink_next.push_back(edge.To()); } } } if (self_.common_.direction != EdgeAtom::Direction::IN) { auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : in_edges) { #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge.From(), storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { continue; } #endif if (ShouldExpand(vertex, edge, frame, evaluator) && !Contains(out_edge, edge.From())) { out_edge.emplace(edge.From(), edge); if (Contains(in_edge, edge.From())) { if (current_length >= lower_bound) { ReconstructPath(edge.From(), in_edge, out_edge, frame, pull_memory); return true; } else { return false; } } sink_next.push_back(edge.From()); } } } } if (sink_next.empty()) return false; sink_frontier.clear(); std::swap(sink_frontier, sink_next); } } }; class SingleSourceShortestPathCursor : public query::plan::Cursor { public: SingleSourceShortestPathCursor(const ExpandVariable &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_.input()->MakeCursor(mem)), processed_(mem), to_visit_next_(mem), to_visit_current_(mem) { MG_ASSERT(!self_.common_.existing_node, "Single source shortest path algorithm " "should not be used when `existing_node` " "flag is set, s-t shortest path algorithm " "should be used instead!"); } bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("SingleSourceShortestPath"); ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); // for the given (edge, vertex) pair checks if they satisfy the // "where" condition. if so, places them in the to_visit_ structure. auto expand_pair = [this, &evaluator, &frame, &context](EdgeAccessor edge, VertexAccessor vertex) -> bool { // if we already processed the given vertex it doesn't get expanded if (processed_.find(vertex) != processed_.end()) return false; #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(vertex, storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { return false; } #endif frame[self_.filter_lambda_.inner_edge_symbol] = edge; frame[self_.filter_lambda_.inner_node_symbol] = vertex; std::optional<Path> curr_acc_path = std::nullopt; if (self_.filter_lambda_.accumulated_path_symbol) { MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(), "Accumulated path must have Path type"); Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath(); accumulated_path.Expand(edge); accumulated_path.Expand(vertex); curr_acc_path = accumulated_path; } if (self_.filter_lambda_.expression) { TypedValue result = self_.filter_lambda_.expression->Accept(evaluator); switch (result.type()) { case TypedValue::Type::Null: return true; case TypedValue::Type::Bool: if (!result.ValueBool()) return true; break; default: throw QueryRuntimeException("Expansion condition must evaluate to boolean or null."); } } to_visit_next_.emplace_back(edge, vertex, std::move(curr_acc_path)); processed_.emplace(vertex, edge); return true; }; auto restore_frame_state_after_expansion = [this, &frame](bool was_expanded) { if (was_expanded && self_.filter_lambda_.accumulated_path_symbol) { frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath().Shrink(); } }; // populates the to_visit_next_ structure with expansions // from the given vertex. skips expansions that don't satisfy // the "where" condition. auto expand_from_vertex = [this, &expand_pair, &restore_frame_state_after_expansion](const auto &vertex) { if (self_.common_.direction != EdgeAtom::Direction::IN) { auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : out_edges) { bool was_expanded = expand_pair(edge, edge.To()); restore_frame_state_after_expansion(was_expanded); } } if (self_.common_.direction != EdgeAtom::Direction::OUT) { auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : in_edges) { bool was_expanded = expand_pair(edge, edge.From()); restore_frame_state_after_expansion(was_expanded); } } }; // do it all in a loop because we skip some elements while (true) { AbortCheck(context); // if we have nothing to visit on the current depth, switch to next if (to_visit_current_.empty()) to_visit_current_.swap(to_visit_next_); // if current is still empty, it means both are empty, so pull from // input if (to_visit_current_.empty()) { if (!input_cursor_->Pull(frame, context)) return false; to_visit_current_.clear(); to_visit_next_.clear(); processed_.clear(); const auto &vertex_value = frame[self_.input_symbol_]; // it is possible that the vertex is Null due to optional matching if (vertex_value.IsNull()) continue; lower_bound_ = self_.lower_bound_ ? EvaluateInt(&evaluator, self_.lower_bound_, "Min depth in breadth-first expansion") : 1; upper_bound_ = self_.upper_bound_ ? EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in breadth-first expansion") : std::numeric_limits<int64_t>::max(); if (upper_bound_ < 1 || lower_bound_ > upper_bound_) continue; const auto &vertex = vertex_value.ValueVertex(); processed_.emplace(vertex, std::nullopt); if (self_.filter_lambda_.accumulated_path_symbol) { // Add initial vertex of path to the accumulated path frame[self_.filter_lambda_.accumulated_path_symbol.value()] = Path(vertex); } expand_from_vertex(vertex); // go back to loop start and see if we expanded anything continue; } // take the next expansion from the queue auto [curr_edge, curr_vertex, curr_acc_path] = to_visit_current_.back(); to_visit_current_.pop_back(); // create the frame value for the edges auto *pull_memory = context.evaluation_context.memory; utils::pmr::vector<TypedValue> edge_list(pull_memory); edge_list.emplace_back(curr_edge); auto last_vertex = curr_vertex; while (true) { const EdgeAccessor &last_edge = edge_list.back().ValueEdge(); last_vertex = last_edge.From() == last_vertex ? last_edge.To() : last_edge.From(); // origin_vertex must be in processed const auto &previous_edge = processed_.find(last_vertex)->second; if (!previous_edge) break; edge_list.emplace_back(previous_edge.value()); } // expand only if what we've just expanded is less then max depth if (static_cast<int64_t>(edge_list.size()) < upper_bound_) { if (self_.filter_lambda_.accumulated_path_symbol) { MG_ASSERT(curr_acc_path.has_value(), "Expected non-null accumulated path"); frame[self_.filter_lambda_.accumulated_path_symbol.value()] = std::move(curr_acc_path.value()); } expand_from_vertex(curr_vertex); } if (static_cast<int64_t>(edge_list.size()) < lower_bound_) continue; frame[self_.common_.node_symbol] = curr_vertex; // place edges on the frame in the correct order std::reverse(edge_list.begin(), edge_list.end()); frame[self_.common_.edge_symbol] = std::move(edge_list); return true; } } void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); processed_.clear(); to_visit_next_.clear(); to_visit_current_.clear(); } private: const ExpandVariable &self_; const UniqueCursorPtr input_cursor_; // Depth bounds. Calculated on each pull from the input, the initial value // is irrelevant. int64_t lower_bound_{-1}; int64_t upper_bound_{-1}; // maps vertices to the edge they got expanded from. it is an optional // edge because the root does not get expanded from anything. // contains visited vertices as well as those scheduled to be visited. utils::pmr::unordered_map<VertexAccessor, std::optional<EdgeAccessor>> processed_; // edge, vertex we have yet to visit, for current and next depth and their accumulated paths utils::pmr::vector<std::tuple<EdgeAccessor, VertexAccessor, std::optional<Path>>> to_visit_next_; utils::pmr::vector<std::tuple<EdgeAccessor, VertexAccessor, std::optional<Path>>> to_visit_current_; }; namespace { void CheckWeightType(TypedValue current_weight, utils::MemoryResource *memory) { if (current_weight.IsNull()) { return; } if (!current_weight.IsNumeric() && !current_weight.IsDuration()) { throw QueryRuntimeException("Calculated weight must be numeric or a Duration, got {}.", current_weight.type()); } const auto is_valid_numeric = [&] { return current_weight.IsNumeric() && (current_weight >= TypedValue(0, memory)).ValueBool(); }; const auto is_valid_duration = [&] { return current_weight.IsDuration() && (current_weight >= TypedValue(utils::Duration(0), memory)).ValueBool(); }; if (!is_valid_numeric() && !is_valid_duration()) { throw QueryRuntimeException("Calculated weight must be non-negative!"); } } void ValidateWeightTypes(const TypedValue &lhs, const TypedValue &rhs) { if ((lhs.IsNumeric() && rhs.IsNumeric()) || (lhs.IsDuration() && rhs.IsDuration())) { return; } throw QueryRuntimeException(utils::MessageWithLink( "All weights should be of the same type, either numeric or a Duration. Please update the weight " "expression or the filter expression.", "https://memgr.ph/wsp")); } TypedValue CalculateNextWeight(const std::optional<memgraph::query::plan::ExpansionLambda> &weight_lambda, const TypedValue &total_weight, ExpressionEvaluator evaluator) { if (!weight_lambda) { return {}; } auto *memory = evaluator.GetMemoryResource(); TypedValue current_weight = weight_lambda->expression->Accept(evaluator); CheckWeightType(current_weight, memory); if (total_weight.IsNull()) { return current_weight; } ValidateWeightTypes(current_weight, total_weight); return TypedValue(current_weight, memory) + total_weight; } } // namespace class ExpandWeightedShortestPathCursor : public query::plan::Cursor { public: ExpandWeightedShortestPathCursor(const ExpandVariable &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_.input_->MakeCursor(mem)), total_cost_(mem), previous_(mem), yielded_vertices_(mem), pq_(mem) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("ExpandWeightedShortestPath"); ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); auto create_state = [this](const VertexAccessor &vertex, int64_t depth) { return std::make_pair(vertex, upper_bound_set_ ? depth : 0); }; // For the given (edge, vertex, weight, depth) tuple checks if they // satisfy the "where" condition. if so, places them in the priority // queue. auto expand_pair = [this, &evaluator, &frame, &create_state](const EdgeAccessor &edge, const VertexAccessor &vertex, const TypedValue &total_weight, int64_t depth) { frame[self_.weight_lambda_->inner_edge_symbol] = edge; frame[self_.weight_lambda_->inner_node_symbol] = vertex; TypedValue next_weight = CalculateNextWeight(self_.weight_lambda_, total_weight, evaluator); std::optional<Path> curr_acc_path = std::nullopt; if (self_.filter_lambda_.expression) { frame[self_.filter_lambda_.inner_edge_symbol] = edge; frame[self_.filter_lambda_.inner_node_symbol] = vertex; if (self_.filter_lambda_.accumulated_path_symbol) { MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(), "Accumulated path must be path"); Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath(); accumulated_path.Expand(edge); accumulated_path.Expand(vertex); curr_acc_path = accumulated_path; if (self_.filter_lambda_.accumulated_weight_symbol) { frame[self_.filter_lambda_.accumulated_weight_symbol.value()] = next_weight; } } if (!EvaluateFilter(evaluator, self_.filter_lambda_.expression)) return; } auto next_state = create_state(vertex, depth); auto found_it = total_cost_.find(next_state); if (found_it != total_cost_.end() && (found_it->second.IsNull() || (found_it->second <= next_weight).ValueBool())) return; pq_.emplace(next_weight, depth + 1, vertex, edge, curr_acc_path); }; auto restore_frame_state_after_expansion = [this, &frame]() { if (self_.filter_lambda_.accumulated_path_symbol) { frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath().Shrink(); } }; // Populates the priority queue structure with expansions // from the given vertex. skips expansions that don't satisfy // the "where" condition. auto expand_from_vertex = [this, &context, &expand_pair, &restore_frame_state_after_expansion]( const VertexAccessor &vertex, const TypedValue &weight, int64_t depth) { if (self_.common_.direction != EdgeAtom::Direction::IN) { auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : out_edges) { #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(edge.To(), storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { continue; } #endif expand_pair(edge, edge.To(), weight, depth); restore_frame_state_after_expansion(); } } if (self_.common_.direction != EdgeAtom::Direction::OUT) { auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : in_edges) { #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(edge.From(), storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { continue; } #endif expand_pair(edge, edge.From(), weight, depth); restore_frame_state_after_expansion(); } } }; while (true) { AbortCheck(context); if (pq_.empty()) { if (!input_cursor_->Pull(frame, context)) return false; const auto &vertex_value = frame[self_.input_symbol_]; if (vertex_value.IsNull()) continue; auto vertex = vertex_value.ValueVertex(); if (self_.common_.existing_node) { const auto &node = frame[self_.common_.node_symbol]; // Due to optional matching the existing node could be null. // Skip expansion for such nodes. if (node.IsNull()) continue; } std::optional<Path> curr_acc_path; if (self_.filter_lambda_.accumulated_path_symbol) { // Add initial vertex of path to the accumulated path curr_acc_path = Path(vertex); frame[self_.filter_lambda_.accumulated_path_symbol.value()] = curr_acc_path.value(); } if (self_.upper_bound_) { upper_bound_ = EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in weighted shortest path expansion"); upper_bound_set_ = true; } else { upper_bound_ = std::numeric_limits<int64_t>::max(); upper_bound_set_ = false; } if (upper_bound_ < 1) throw QueryRuntimeException( "Maximum depth in weighted shortest path expansion must be at " "least 1."); frame[self_.weight_lambda_->inner_edge_symbol] = TypedValue(); frame[self_.weight_lambda_->inner_node_symbol] = vertex; TypedValue current_weight = CalculateNextWeight(self_.weight_lambda_, /* total_weight */ TypedValue(), evaluator); // Clear existing data structures. previous_.clear(); total_cost_.clear(); yielded_vertices_.clear(); pq_.emplace(current_weight, 0, vertex, std::nullopt, curr_acc_path); // We are adding the starting vertex to the set of yielded vertices // because we don't want to yield paths that end with the starting // vertex. yielded_vertices_.insert(vertex); } while (!pq_.empty()) { AbortCheck(context); auto [current_weight, current_depth, current_vertex, current_edge, curr_acc_path] = pq_.top(); pq_.pop(); auto current_state = create_state(current_vertex, current_depth); // Check if the vertex has already been processed. if (total_cost_.find(current_state) != total_cost_.end()) { continue; } previous_.emplace(current_state, current_edge); total_cost_.emplace(current_state, current_weight); // Expand only if what we've just expanded is less than max depth. if (current_depth < upper_bound_) { if (self_.filter_lambda_.accumulated_path_symbol) { frame[self_.filter_lambda_.accumulated_path_symbol.value()] = std::move(curr_acc_path.value()); } expand_from_vertex(current_vertex, current_weight, current_depth); } // If we yielded a path for a vertex already, make the expansion but // don't return the path again. if (yielded_vertices_.find(current_vertex) != yielded_vertices_.end()) continue; // Reconstruct the path. auto last_vertex = current_vertex; auto last_depth = current_depth; auto *pull_memory = context.evaluation_context.memory; utils::pmr::vector<TypedValue> edge_list(pull_memory); while (true) { // Origin_vertex must be in previous. const auto &previous_edge = previous_.find(create_state(last_vertex, last_depth))->second; if (!previous_edge) break; last_vertex = previous_edge->From() == last_vertex ? previous_edge->To() : previous_edge->From(); last_depth--; edge_list.emplace_back(previous_edge.value()); } // Place destination node on the frame, handle existence flag. if (self_.common_.existing_node) { const auto &node = frame[self_.common_.node_symbol]; if ((node != TypedValue(current_vertex, pull_memory)).ValueBool()) { continue; } // Prevent expanding other paths, because we found the // shortest to existing node. ClearQueue(); } else { frame[self_.common_.node_symbol] = current_vertex; } if (!self_.is_reverse_) { // Place edges on the frame in the correct order. std::reverse(edge_list.begin(), edge_list.end()); } frame[self_.common_.edge_symbol] = std::move(edge_list); frame[self_.total_weight_.value()] = current_weight; yielded_vertices_.insert(current_vertex); return true; } } } void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); previous_.clear(); total_cost_.clear(); yielded_vertices_.clear(); ClearQueue(); } private: const ExpandVariable &self_; const UniqueCursorPtr input_cursor_; // Upper bound on the path length. int64_t upper_bound_{-1}; bool upper_bound_set_{false}; struct WspStateHash { size_t operator()(const std::pair<VertexAccessor, int64_t> &key) const { return utils::HashCombine<VertexAccessor, int64_t>{}(key.first, key.second); } }; // Maps vertices to weights they got in expansion. utils::pmr::unordered_map<std::pair<VertexAccessor, int64_t>, TypedValue, WspStateHash> total_cost_; // Maps vertices to edges used to reach them. utils::pmr::unordered_map<std::pair<VertexAccessor, int64_t>, std::optional<EdgeAccessor>, WspStateHash> previous_; // Keeps track of vertices for which we yielded a path already. utils::pmr::unordered_set<VertexAccessor> yielded_vertices_; // Priority queue comparator. Keep lowest weight on top of the queue. class PriorityQueueComparator { public: bool operator()( const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>, std::optional<Path>> &lhs, const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>, std::optional<Path>> &rhs) { const auto &lhs_weight = std::get<0>(lhs); const auto &rhs_weight = std::get<0>(rhs); // Null defines minimum value for all types if (lhs_weight.IsNull()) { return false; } if (rhs_weight.IsNull()) { return true; } ValidateWeightTypes(lhs_weight, rhs_weight); return (lhs_weight > rhs_weight).ValueBool(); } }; std::priority_queue<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>, std::optional<Path>>, utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>, std::optional<Path>>>, PriorityQueueComparator> pq_; void ClearQueue() { while (!pq_.empty()) pq_.pop(); } }; class ExpandAllShortestPathsCursor : public query::plan::Cursor { public: ExpandAllShortestPathsCursor(const ExpandVariable &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_.input_->MakeCursor(mem)), visited_cost_(mem), total_cost_(mem), next_edges_(mem), traversal_stack_(mem), pq_(mem) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("ExpandAllShortestPathsCursor"); ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); auto *memory = context.evaluation_context.memory; auto create_state = [this](const VertexAccessor &vertex, int64_t depth) { return std::make_pair(vertex, upper_bound_set_ ? depth : 0); }; // For the given (edge, direction, weight, depth) tuple checks if they // satisfy the "where" condition. if so, places them in the priority // queue. auto expand_vertex = [this, &evaluator, &frame](const EdgeAccessor &edge, const EdgeAtom::Direction direction, const TypedValue &total_weight, int64_t depth) { auto const &next_vertex = direction == EdgeAtom::Direction::IN ? edge.From() : edge.To(); // Evaluate current weight frame[self_.weight_lambda_->inner_edge_symbol] = edge; frame[self_.weight_lambda_->inner_node_symbol] = next_vertex; TypedValue next_weight = CalculateNextWeight(self_.weight_lambda_, total_weight, evaluator); // If filter expression exists, evaluate filter std::optional<Path> curr_acc_path = std::nullopt; if (self_.filter_lambda_.expression) { frame[self_.filter_lambda_.inner_edge_symbol] = edge; frame[self_.filter_lambda_.inner_node_symbol] = next_vertex; if (self_.filter_lambda_.accumulated_path_symbol) { MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(), "Accumulated path must be path"); Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath(); accumulated_path.Expand(edge); accumulated_path.Expand(next_vertex); curr_acc_path = accumulated_path; if (self_.filter_lambda_.accumulated_weight_symbol) { frame[self_.filter_lambda_.accumulated_weight_symbol.value()] = next_weight; } } if (!EvaluateFilter(evaluator, self_.filter_lambda_.expression)) return; } auto found_it = visited_cost_.find(next_vertex); // Check if the vertex has already been processed. if (found_it != visited_cost_.end()) { auto weight = found_it->second; if (weight.IsNull() || (next_weight <= weight).ValueBool()) { // Has been visited, but now found a shorter path visited_cost_[next_vertex] = next_weight; } else { // Continue and do not expand if current weight is larger return; } } else { visited_cost_[next_vertex] = next_weight; } DirectedEdge directed_edge = {edge, direction, next_weight}; pq_.emplace(next_weight, depth + 1, next_vertex, directed_edge, curr_acc_path); }; auto restore_frame_state_after_expansion = [this, &frame]() { if (self_.filter_lambda_.accumulated_path_symbol) { frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath().Shrink(); } }; // Populates the priority queue structure with expansions // from the given vertex. skips expansions that don't satisfy // the "where" condition. auto expand_from_vertex = [this, &expand_vertex, &context, &restore_frame_state_after_expansion]( const VertexAccessor &vertex, const TypedValue &weight, int64_t depth) { if (self_.common_.direction != EdgeAtom::Direction::IN) { auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : out_edges) { #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(edge.To(), storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { continue; } #endif expand_vertex(edge, EdgeAtom::Direction::OUT, weight, depth); restore_frame_state_after_expansion(); } } if (self_.common_.direction != EdgeAtom::Direction::OUT) { auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : in_edges) { #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(edge.From(), storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { continue; } #endif expand_vertex(edge, EdgeAtom::Direction::IN, weight, depth); restore_frame_state_after_expansion(); } } }; std::optional<VertexAccessor> start_vertex; auto create_path = [this, &frame, &memory]() { auto ¤t_level = traversal_stack_.back(); auto &edges_on_frame = frame[self_.common_.edge_symbol].ValueList(); // Clean out the current stack if (current_level.empty()) { if (!edges_on_frame.empty()) { if (!self_.is_reverse_) edges_on_frame.erase(edges_on_frame.end()); else edges_on_frame.erase(edges_on_frame.begin()); } traversal_stack_.pop_back(); return false; } auto [current_edge, current_edge_direction, current_weight] = current_level.back(); current_level.pop_back(); // Edges order depends on direction of expansion if (!self_.is_reverse_) edges_on_frame.emplace_back(current_edge); else edges_on_frame.emplace(edges_on_frame.begin(), current_edge); auto next_vertex = current_edge_direction == EdgeAtom::Direction::IN ? current_edge.From() : current_edge.To(); frame[self_.total_weight_.value()] = current_weight; if (next_edges_.find({next_vertex, traversal_stack_.size()}) != next_edges_.end()) { auto next_vertex_edges = next_edges_[{next_vertex, traversal_stack_.size()}]; traversal_stack_.emplace_back(std::move(next_vertex_edges)); } else { // Signal the end of iteration utils::pmr::list<DirectedEdge> empty(memory); traversal_stack_.emplace_back(std::move(empty)); } if ((current_weight > visited_cost_.at(next_vertex)).ValueBool()) return false; // Place destination node on the frame, handle existence flag if (self_.common_.existing_node) { const auto &node = frame[self_.common_.node_symbol]; ExpectType(self_.common_.node_symbol, node, TypedValue::Type::Vertex); if (node.ValueVertex() != next_vertex) return false; } else { frame[self_.common_.node_symbol] = next_vertex; } return true; }; auto create_DFS_traversal_tree = [this, &context, &memory, &frame, &create_state, &expand_from_vertex]() { while (!pq_.empty()) { AbortCheck(context); auto [current_weight, current_depth, current_vertex, directed_edge, acc_path] = pq_.top(); pq_.pop(); const auto &[current_edge, direction, weight] = directed_edge; auto current_state = create_state(current_vertex, current_depth); auto position = total_cost_.find(current_state); if (position != total_cost_.end()) { if ((position->second < current_weight).ValueBool()) continue; } else { total_cost_.emplace(current_state, current_weight); if (current_depth < upper_bound_) { if (self_.filter_lambda_.accumulated_path_symbol) { DMG_ASSERT(acc_path.has_value(), "Path must be already filled in AllShortestPath DFS traversals"); frame[self_.filter_lambda_.accumulated_path_symbol.value()] = std::move(acc_path.value()); } expand_from_vertex(current_vertex, current_weight, current_depth); } } // Searching for a previous vertex in the expansion auto prev_vertex = direction == EdgeAtom::Direction::IN ? current_edge.To() : current_edge.From(); // Update the parent if (next_edges_.find({prev_vertex, current_depth - 1}) == next_edges_.end()) { utils::pmr::list<DirectedEdge> empty(memory); next_edges_[{prev_vertex, current_depth - 1}] = std::move(empty); } next_edges_.at({prev_vertex, current_depth - 1}).emplace_back(directed_edge); } }; // upper_bound_set is used when storing visited edges, because with an upper bound we also consider suboptimal paths // if they are shorter in depth if (self_.upper_bound_) { upper_bound_ = EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in all shortest path expansion"); upper_bound_set_ = true; } else { upper_bound_ = std::numeric_limits<int64_t>::max(); upper_bound_set_ = false; } // Check if upper bound is valid if (upper_bound_ < 1) { throw QueryRuntimeException("Maximum depth in all shortest paths expansion must be at least 1."); } // On first Pull run, traversal stack and priority queue are empty, so we start a pulling stream // and create a DFS traversal tree (main part of algorithm). Then we return the first path // created from the DFS traversal tree (basically a DFS algorithm). // On each subsequent Pull run, paths are created from the traversal stack and returned. while (true) { // Check if there is an external error. AbortCheck(context); // The algorithm is run all at once by create_DFS_traversal_tree, after which we // traverse the tree iteratively by preserving the traversal state on stack. while (!traversal_stack_.empty()) { if (create_path()) return true; } // If priority queue is empty start new pulling stream. if (pq_.empty()) { // Finish if there is nothing to pull if (!input_cursor_->Pull(frame, context)) return false; const auto &vertex_value = frame[self_.input_symbol_]; if (vertex_value.IsNull()) continue; start_vertex = vertex_value.ValueVertex(); if (self_.common_.existing_node) { const auto &node = frame[self_.common_.node_symbol]; // Due to optional matching the existing node could be null. // Skip expansion for such nodes. if (node.IsNull()) continue; } // Clear existing data structures. visited_cost_.clear(); next_edges_.clear(); traversal_stack_.clear(); total_cost_.clear(); if (self_.filter_lambda_.accumulated_path_symbol) { // Add initial vertex of path to the accumulated path frame[self_.filter_lambda_.accumulated_path_symbol.value()] = Path(*start_vertex); } frame[self_.weight_lambda_->inner_edge_symbol] = TypedValue(); frame[self_.weight_lambda_->inner_node_symbol] = *start_vertex; TypedValue current_weight = CalculateNextWeight(self_.weight_lambda_, /* total_weight */ TypedValue(), evaluator); expand_from_vertex(*start_vertex, current_weight, 0); visited_cost_.emplace(*start_vertex, 0); frame[self_.common_.edge_symbol] = TypedValue::TVector(memory); } // Create a DFS traversal tree from the start node create_DFS_traversal_tree(); // DFS traversal tree is create, if (start_vertex && next_edges_.find({*start_vertex, 0}) != next_edges_.end()) { auto start_vertex_edges = next_edges_[{*start_vertex, 0}]; traversal_stack_.emplace_back(std::move(start_vertex_edges)); } } } void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); visited_cost_.clear(); next_edges_.clear(); traversal_stack_.clear(); total_cost_.clear(); ClearQueue(); } private: const ExpandVariable &self_; const UniqueCursorPtr input_cursor_; // Upper bound on the path length. int64_t upper_bound_{-1}; bool upper_bound_set_{false}; struct AspStateHash { size_t operator()(const std::pair<VertexAccessor, int64_t> &key) const { return utils::HashCombine<VertexAccessor, int64_t>{}(key.first, key.second); } }; using DirectedEdge = std::tuple<EdgeAccessor, EdgeAtom::Direction, TypedValue>; using NextEdgesState = std::pair<VertexAccessor, int64_t>; // Maps vertices to minimum weights they got in expansion. utils::pmr::unordered_map<VertexAccessor, TypedValue> visited_cost_; // Maps vertices to weights they got in expansion. utils::pmr::unordered_map<NextEdgesState, TypedValue, AspStateHash> total_cost_; // Maps the vertex with the potential expansion edge. utils::pmr::unordered_map<NextEdgesState, utils::pmr::list<DirectedEdge>, AspStateHash> next_edges_; // Stack indicating the traversal level. utils::pmr::list<utils::pmr::list<DirectedEdge>> traversal_stack_; // Priority queue comparator. Keep lowest weight on top of the queue. class PriorityQueueComparator { public: bool operator()(const std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge, std::optional<Path>> &lhs, const std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge, std::optional<Path>> &rhs) { const auto &lhs_weight = std::get<0>(lhs); const auto &rhs_weight = std::get<0>(rhs); // Null defines minimum value for all types if (lhs_weight.IsNull()) { return false; } if (rhs_weight.IsNull()) { return true; } ValidateWeightTypes(lhs_weight, rhs_weight); return (lhs_weight > rhs_weight).ValueBool(); } }; // Priority queue - core element of the algorithm. // Stores: {weight, depth, next vertex, edge and direction} std::priority_queue< std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge, std::optional<Path>>, utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge, std::optional<Path>>>, PriorityQueueComparator> pq_; void ClearQueue() { while (!pq_.empty()) pq_.pop(); } }; UniqueCursorPtr ExpandVariable::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::ExpandVariableOperator); switch (type_) { case EdgeAtom::Type::BREADTH_FIRST: if (common_.existing_node) { return MakeUniqueCursorPtr<STShortestPathCursor>(mem, *this, mem); } else { return MakeUniqueCursorPtr<SingleSourceShortestPathCursor>(mem, *this, mem); } case EdgeAtom::Type::DEPTH_FIRST: return MakeUniqueCursorPtr<ExpandVariableCursor>(mem, *this, mem); case EdgeAtom::Type::WEIGHTED_SHORTEST_PATH: return MakeUniqueCursorPtr<ExpandWeightedShortestPathCursor>(mem, *this, mem); case EdgeAtom::Type::ALL_SHORTEST_PATHS: return MakeUniqueCursorPtr<ExpandAllShortestPathsCursor>(mem, *this, mem); case EdgeAtom::Type::SINGLE: LOG_FATAL("ExpandVariable should not be planned for a single expansion!"); } } class ConstructNamedPathCursor : public Cursor { public: ConstructNamedPathCursor(ConstructNamedPath self, utils::MemoryResource *mem) : self_(std::move(self)), input_cursor_(self_.input()->MakeCursor(mem)) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("ConstructNamedPath"); if (!input_cursor_->Pull(frame, context)) return false; auto symbol_it = self_.path_elements_.begin(); DMG_ASSERT(symbol_it != self_.path_elements_.end(), "Named path must contain at least one node"); const auto &start_vertex = frame[*symbol_it++]; auto *pull_memory = context.evaluation_context.memory; // In an OPTIONAL MATCH everything could be Null. if (start_vertex.IsNull()) { frame[self_.path_symbol_] = TypedValue(pull_memory); return true; } DMG_ASSERT(start_vertex.IsVertex(), "First named path element must be a vertex"); query::Path path(start_vertex.ValueVertex(), pull_memory); // If the last path element symbol was for an edge list, then // the next symbol is a vertex and it should not append to the path // because // expansion already did it. bool last_was_edge_list = false; for (; symbol_it != self_.path_elements_.end(); symbol_it++) { const auto &expansion = frame[*symbol_it]; // We can have Null (OPTIONAL MATCH), a vertex, an edge, or an edge // list (variable expand or BFS). switch (expansion.type()) { case TypedValue::Type::Null: frame[self_.path_symbol_] = TypedValue(pull_memory); return true; case TypedValue::Type::Vertex: if (!last_was_edge_list) path.Expand(expansion.ValueVertex()); last_was_edge_list = false; break; case TypedValue::Type::Edge: path.Expand(expansion.ValueEdge()); break; case TypedValue::Type::List: { last_was_edge_list = true; // We need to expand all edges in the list and intermediary // vertices. const auto &edges = expansion.ValueList(); for (const auto &edge_value : edges) { const auto &edge = edge_value.ValueEdge(); const auto &from = edge.From(); if (path.vertices().back() == from) path.Expand(edge, edge.To()); else path.Expand(edge, from); } break; } default: LOG_FATAL("Unsupported type in named path construction"); break; } } frame[self_.path_symbol_] = path; return true; } void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); } private: const ConstructNamedPath self_; const UniqueCursorPtr input_cursor_; }; ACCEPT_WITH_INPUT(ConstructNamedPath) UniqueCursorPtr ConstructNamedPath::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::ConstructNamedPathOperator); return MakeUniqueCursorPtr<ConstructNamedPathCursor>(mem, *this, mem); } std::vector<Symbol> ConstructNamedPath::ModifiedSymbols(const SymbolTable &table) const { auto symbols = input_->ModifiedSymbols(table); symbols.emplace_back(path_symbol_); return symbols; } Filter::Filter(const std::shared_ptr<LogicalOperator> &input, const std::vector<std::shared_ptr<LogicalOperator>> &pattern_filters, Expression *expression) : input_(input ? input : std::make_shared<Once>()), pattern_filters_(pattern_filters), expression_(expression) {} Filter::Filter(const std::shared_ptr<LogicalOperator> &input, const std::vector<std::shared_ptr<LogicalOperator>> &pattern_filters, Expression *expression, Filters all_filters) : input_(input ? input : std::make_shared<Once>()), pattern_filters_(pattern_filters), expression_(expression), all_filters_(std::move(all_filters)) {} bool Filter::Accept(HierarchicalLogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { input_->Accept(visitor); for (const auto &pattern_filter : pattern_filters_) { pattern_filter->Accept(visitor); } } return visitor.PostVisit(*this); } UniqueCursorPtr Filter::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::FilterOperator); return MakeUniqueCursorPtr<FilterCursor>(mem, *this, mem); } std::vector<Symbol> Filter::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } static std::vector<UniqueCursorPtr> MakeCursorVector(const std::vector<std::shared_ptr<LogicalOperator>> &ops, utils::MemoryResource *mem) { std::vector<UniqueCursorPtr> cursors; cursors.reserve(ops.size()); if (!ops.empty()) { for (const auto &op : ops) { cursors.push_back(op->MakeCursor(mem)); } } return cursors; } Filter::FilterCursor::FilterCursor(const Filter &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_.input_->MakeCursor(mem)), pattern_filter_cursors_(MakeCursorVector(self_.pattern_filters_, mem)) {} bool Filter::FilterCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP_BY_REF(self_); // Like all filters, newly set values should not affect filtering of old // nodes and edges. ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD, context.frame_change_collector); while (input_cursor_->Pull(frame, context)) { for (const auto &pattern_filter_cursor : pattern_filter_cursors_) { pattern_filter_cursor->Pull(frame, context); } if (EvaluateFilter(evaluator, self_.expression_)) return true; } return false; } void Filter::FilterCursor::Shutdown() { input_cursor_->Shutdown(); } void Filter::FilterCursor::Reset() { input_cursor_->Reset(); } EvaluatePatternFilter::EvaluatePatternFilter(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol) : input_(input), output_symbol_(std::move(output_symbol)) {} ACCEPT_WITH_INPUT(EvaluatePatternFilter); UniqueCursorPtr EvaluatePatternFilter::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::EvaluatePatternFilterOperator); return MakeUniqueCursorPtr<EvaluatePatternFilterCursor>(mem, *this, mem); } EvaluatePatternFilter::EvaluatePatternFilterCursor::EvaluatePatternFilterCursor(const EvaluatePatternFilter &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {} std::vector<Symbol> EvaluatePatternFilter::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } bool EvaluatePatternFilter::EvaluatePatternFilterCursor::Pull(Frame &frame, ExecutionContext &context) { SCOPED_PROFILE_OP("EvaluatePatternFilter"); std::function<void(TypedValue *)> function = [&frame, self = this->self_, input_cursor = this->input_cursor_.get(), &context](TypedValue *return_value) { OOMExceptionEnabler oom_exception; input_cursor->Reset(); *return_value = TypedValue(input_cursor->Pull(frame, context), context.evaluation_context.memory); }; frame[self_.output_symbol_] = TypedValue(std::move(function)); return true; } void EvaluatePatternFilter::EvaluatePatternFilterCursor::Shutdown() { input_cursor_->Shutdown(); } void EvaluatePatternFilter::EvaluatePatternFilterCursor::Reset() { input_cursor_->Reset(); } Produce::Produce(const std::shared_ptr<LogicalOperator> &input, const std::vector<NamedExpression *> &named_expressions) : input_(input ? input : std::make_shared<Once>()), named_expressions_(named_expressions) {} ACCEPT_WITH_INPUT(Produce) UniqueCursorPtr Produce::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::ProduceOperator); return MakeUniqueCursorPtr<ProduceCursor>(mem, *this, mem); } std::vector<Symbol> Produce::OutputSymbols(const SymbolTable &symbol_table) const { std::vector<Symbol> symbols; for (const auto &named_expr : named_expressions_) { symbols.emplace_back(symbol_table.at(*named_expr)); } return symbols; } std::vector<Symbol> Produce::ModifiedSymbols(const SymbolTable &table) const { return OutputSymbols(table); } Produce::ProduceCursor::ProduceCursor(const Produce &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {} bool Produce::ProduceCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP_BY_REF(self_); if (input_cursor_->Pull(frame, context)) { // Produce should always yield the latest results. ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW, context.frame_change_collector); for (auto *named_expr : self_.named_expressions_) { if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(named_expr->name_)) { context.frame_change_collector->ResetTrackingValue(named_expr->name_); } named_expr->Accept(evaluator); } return true; } return false; } void Produce::ProduceCursor::Shutdown() { input_cursor_->Shutdown(); } void Produce::ProduceCursor::Reset() { input_cursor_->Reset(); } Delete::Delete(const std::shared_ptr<LogicalOperator> &input_, const std::vector<Expression *> &expressions, bool detach_) : input_(input_), expressions_(expressions), detach_(detach_) {} ACCEPT_WITH_INPUT(Delete) UniqueCursorPtr Delete::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::DeleteOperator); return MakeUniqueCursorPtr<DeleteCursor>(mem, *this, mem); } std::vector<Symbol> Delete::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } Delete::DeleteCursor::DeleteCursor(const Delete &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {} void Delete::DeleteCursor::UpdateDeleteBuffer(Frame &frame, ExecutionContext &context) { // Delete should get the latest information, this way it is also possible // to delete newly added nodes and edges. ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW); auto *pull_memory = context.evaluation_context.memory; // collect expressions results so edges can get deleted before vertices // this is necessary because an edge that gets deleted could block vertex // deletion utils::pmr::vector<TypedValue> expression_results(pull_memory); expression_results.reserve(self_.expressions_.size()); for (Expression *expression : self_.expressions_) { expression_results.emplace_back(expression->Accept(evaluator)); } auto vertex_auth_checker = [&context](const VertexAccessor &va) -> bool { #ifdef MG_ENTERPRISE return !(license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !context.auth_checker->Has(va, storage::View::NEW, query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)); #else return true; #endif }; auto edge_auth_checker = [&context](const EdgeAccessor &ea) -> bool { #ifdef MG_ENTERPRISE return !( license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(ea, query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE) && context.auth_checker->Has(ea.To(), storage::View::NEW, query::AuthQuery::FineGrainedPrivilege::UPDATE) && context.auth_checker->Has(ea.From(), storage::View::NEW, query::AuthQuery::FineGrainedPrivilege::UPDATE))); #else return true; #endif }; for (TypedValue &expression_result : expression_results) { AbortCheck(context); switch (expression_result.type()) { case TypedValue::Type::Vertex: { auto va = expression_result.ValueVertex(); if (vertex_auth_checker(va)) { buffer_.nodes.push_back(va); } else { throw QueryRuntimeException("Vertex not deleted due to not having enough permission!"); } break; } case TypedValue::Type::Edge: { auto ea = expression_result.ValueEdge(); if (edge_auth_checker(ea)) { buffer_.edges.push_back(ea); } else { throw QueryRuntimeException("Edge not deleted due to not having enough permission!"); } break; } case TypedValue::Type::Path: { auto path = expression_result.ValuePath(); #ifdef MG_ENTERPRISE auto edges_res = std::any_of(path.edges().cbegin(), path.edges().cend(), [&edge_auth_checker](const auto &ea) { return !edge_auth_checker(ea); }); auto vertices_res = std::any_of(path.vertices().cbegin(), path.vertices().cend(), [&vertex_auth_checker](const auto &va) { return !vertex_auth_checker(va); }); if (edges_res || vertices_res) { throw QueryRuntimeException( "Path not deleted due to not having enough permission on all edges and vertices on the path!"); } #endif buffer_.nodes.insert(buffer_.nodes.begin(), path.vertices().begin(), path.vertices().end()); buffer_.edges.insert(buffer_.edges.begin(), path.edges().begin(), path.edges().end()); } case TypedValue::Type::Null: break; default: throw QueryRuntimeException("Edges, vertices and paths can be deleted."); } } } bool Delete::DeleteCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("Delete"); if (delete_executed_) { return false; } if (input_cursor_->Pull(frame, context)) { UpdateDeleteBuffer(frame, context); return true; } auto &dba = *context.db_accessor; auto res = dba.DetachDelete(std::move(buffer_.nodes), std::move(buffer_.edges), self_.detach_); if (res.HasError()) { switch (res.GetError()) { case storage::Error::SERIALIZATION_ERROR: throw TransactionSerializationException(); case storage::Error::VERTEX_HAS_EDGES: throw RemoveAttachedVertexException(); case storage::Error::DELETED_OBJECT: case storage::Error::PROPERTIES_DISABLED: case storage::Error::NONEXISTENT_OBJECT: throw QueryRuntimeException("Unexpected error when deleting a node."); } } if (*res) { context.execution_stats[ExecutionStats::Key::DELETED_NODES] += static_cast<int64_t>((*res)->first.size()); context.execution_stats[ExecutionStats::Key::DELETED_EDGES] += static_cast<int64_t>((*res)->second.size()); } // Update deleted objects for triggers if (context.trigger_context_collector && *res) { for (const auto &node : (*res)->first) { context.trigger_context_collector->RegisterDeletedObject(node); } if (context.trigger_context_collector->ShouldRegisterDeletedObject<query::EdgeAccessor>()) { for (const auto &edge : (*res)->second) { context.trigger_context_collector->RegisterDeletedObject(edge); } } } delete_executed_ = true; return false; } void Delete::DeleteCursor::Shutdown() { input_cursor_->Shutdown(); } void Delete::DeleteCursor::Reset() { input_cursor_->Reset(); delete_executed_ = false; } SetProperty::SetProperty(const std::shared_ptr<LogicalOperator> &input, storage::PropertyId property, PropertyLookup *lhs, Expression *rhs) : input_(input), property_(property), lhs_(lhs), rhs_(rhs) {} ACCEPT_WITH_INPUT(SetProperty) UniqueCursorPtr SetProperty::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::SetPropertyOperator); return MakeUniqueCursorPtr<SetPropertyCursor>(mem, *this, mem); } std::vector<Symbol> SetProperty::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } SetProperty::SetPropertyCursor::SetPropertyCursor(const SetProperty &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} bool SetProperty::SetPropertyCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("SetProperty"); if (!input_cursor_->Pull(frame, context)) return false; // Set, just like Create needs to see the latest changes. ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW); TypedValue lhs = self_.lhs_->expression_->Accept(evaluator); TypedValue rhs = self_.rhs_->Accept(evaluator); switch (lhs.type()) { case TypedValue::Type::Vertex: { #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !context.auth_checker->Has(lhs.ValueVertex(), storage::View::NEW, memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) { throw QueryRuntimeException("Vertex property not set due to not having enough permission!"); } #endif auto old_value = PropsSetChecked(&lhs.ValueVertex(), self_.property_, rhs); context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1; if (context.trigger_context_collector) { // rhs cannot be moved because it was created with the allocator that is only valid during current pull context.trigger_context_collector->RegisterSetObjectProperty(lhs.ValueVertex(), self_.property_, TypedValue{std::move(old_value)}, TypedValue{rhs}); } break; } case TypedValue::Type::Edge: { #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !context.auth_checker->Has(lhs.ValueEdge(), memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) { throw QueryRuntimeException("Edge property not set due to not having enough permission!"); } #endif auto old_value = PropsSetChecked(&lhs.ValueEdge(), self_.property_, rhs); context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1; if (context.trigger_context_collector) { // rhs cannot be moved because it was created with the allocator that is only valid // during current pull context.trigger_context_collector->RegisterSetObjectProperty(lhs.ValueEdge(), self_.property_, TypedValue{std::move(old_value)}, TypedValue{rhs}); } break; } case TypedValue::Type::Null: // Skip setting properties on Null (can occur in optional match). break; case TypedValue::Type::Map: // Semantically modifying a map makes sense, but it's not supported due // to all the copying we do (when PropertyValue -> TypedValue and in // ExpressionEvaluator). So even though we set a map property here, that // is never visible to the user and it's not stored. // TODO: fix above described bug default: throw QueryRuntimeException("Properties can only be set on edges and vertices."); } return true; } void SetProperty::SetPropertyCursor::Shutdown() { input_cursor_->Shutdown(); } void SetProperty::SetPropertyCursor::Reset() { input_cursor_->Reset(); } SetProperties::SetProperties(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Expression *rhs, Op op) : input_(input), input_symbol_(std::move(input_symbol)), rhs_(rhs), op_(op) {} ACCEPT_WITH_INPUT(SetProperties) UniqueCursorPtr SetProperties::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::SetPropertiesOperator); return MakeUniqueCursorPtr<SetPropertiesCursor>(mem, *this, mem); } std::vector<Symbol> SetProperties::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } SetProperties::SetPropertiesCursor::SetPropertiesCursor(const SetProperties &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} namespace { template <typename T> concept AccessorWithProperties = requires(T value, storage::PropertyId property_id, storage::PropertyValue property_value, std::map<storage::PropertyId, storage::PropertyValue> properties) { { value.ClearProperties() } -> std::same_as<storage::Result<std::map<storage::PropertyId, storage::PropertyValue>>>; {value.SetProperty(property_id, property_value)}; {value.UpdateProperties(properties)}; }; /// Helper function that sets the given values on either a Vertex or an Edge. /// /// @tparam TRecordAccessor Either RecordAccessor<Vertex> or /// RecordAccessor<Edge> template <AccessorWithProperties TRecordAccessor> void SetPropertiesOnRecord(TRecordAccessor *record, const TypedValue &rhs, SetProperties::Op op, ExecutionContext *context, std::unordered_map<std::string, storage::PropertyId> &cached_name_id) { using PropertiesMap = std::map<storage::PropertyId, storage::PropertyValue>; std::optional<PropertiesMap> old_values; const bool should_register_change = context->trigger_context_collector && context->trigger_context_collector->ShouldRegisterObjectPropertyChange<TRecordAccessor>(); if (op == SetProperties::Op::REPLACE) { auto maybe_value = record->ClearProperties(); if (maybe_value.HasError()) { switch (maybe_value.GetError()) { case storage::Error::DELETED_OBJECT: throw QueryRuntimeException("Trying to set properties on a deleted graph element."); case storage::Error::SERIALIZATION_ERROR: throw TransactionSerializationException(); case storage::Error::PROPERTIES_DISABLED: throw QueryRuntimeException("Can't set property because properties on edges are disabled."); case storage::Error::VERTEX_HAS_EDGES: case storage::Error::NONEXISTENT_OBJECT: throw QueryRuntimeException("Unexpected error when setting properties."); } } if (should_register_change) { old_values.emplace(std::move(*maybe_value)); } } auto get_props = [](const auto &record) { auto maybe_props = record.Properties(storage::View::NEW); if (maybe_props.HasError()) { switch (maybe_props.GetError()) { case storage::Error::DELETED_OBJECT: throw QueryRuntimeException("Trying to get properties from a deleted object."); case storage::Error::NONEXISTENT_OBJECT: throw query::QueryRuntimeException("Trying to get properties from an object that doesn't exist."); case storage::Error::SERIALIZATION_ERROR: case storage::Error::VERTEX_HAS_EDGES: case storage::Error::PROPERTIES_DISABLED: throw QueryRuntimeException("Unexpected error when getting properties."); } } return *maybe_props; }; auto register_set_property = [&](auto &&returned_old_value, auto key, auto &&new_value) { auto old_value = [&]() -> storage::PropertyValue { if (!old_values) { return std::forward<decltype(returned_old_value)>(returned_old_value); } if (auto it = old_values->find(key); it != old_values->end()) { return std::move(it->second); } return {}; }(); context->trigger_context_collector->RegisterSetObjectProperty( *record, key, TypedValue(std::move(old_value)), TypedValue(std::forward<decltype(new_value)>(new_value))); }; auto update_props = [&, record](PropertiesMap &new_properties) { auto updated_properties = UpdatePropertiesChecked(record, new_properties); // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) context->execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += new_properties.size(); if (should_register_change) { for (const auto &[id, old_value, new_value] : updated_properties) { register_set_property(std::move(old_value), id, std::move(new_value)); } } }; switch (rhs.type()) { case TypedValue::Type::Edge: { PropertiesMap new_properties = get_props(rhs.ValueEdge()); update_props(new_properties); break; } case TypedValue::Type::Vertex: { PropertiesMap new_properties = get_props(rhs.ValueVertex()); update_props(new_properties); break; } case TypedValue::Type::Map: { PropertiesMap new_properties; for (const auto &[string_key, value] : rhs.ValueMap()) { storage::PropertyId property_id; if (auto it = cached_name_id.find(std::string(string_key)); it != cached_name_id.end()) [[likely]] { property_id = it->second; } else { property_id = context->db_accessor->NameToProperty(string_key); cached_name_id.emplace(string_key, property_id); } new_properties.emplace(property_id, value); } update_props(new_properties); break; } default: throw QueryRuntimeException( "Right-hand side in SET expression must be a node, an edge or a " "map."); } if (should_register_change && old_values) { // register removed properties for (auto &[property_id, property_value] : *old_values) { context->trigger_context_collector->RegisterRemovedObjectProperty(*record, property_id, TypedValue(std::move(property_value))); } } } } // namespace bool SetProperties::SetPropertiesCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("SetProperties"); if (!input_cursor_->Pull(frame, context)) return false; TypedValue &lhs = frame[self_.input_symbol_]; // Set, just like Create needs to see the latest changes. ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW); TypedValue rhs = self_.rhs_->Accept(evaluator); switch (lhs.type()) { case TypedValue::Type::Vertex: #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !context.auth_checker->Has(lhs.ValueVertex(), storage::View::NEW, memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) { throw QueryRuntimeException("Vertex properties not set due to not having enough permission!"); } #endif SetPropertiesOnRecord(&lhs.ValueVertex(), rhs, self_.op_, &context, cached_name_id_); break; case TypedValue::Type::Edge: #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !context.auth_checker->Has(lhs.ValueEdge(), memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) { throw QueryRuntimeException("Edge properties not set due to not having enough permission!"); } #endif SetPropertiesOnRecord(&lhs.ValueEdge(), rhs, self_.op_, &context, cached_name_id_); break; case TypedValue::Type::Null: // Skip setting properties on Null (can occur in optional match). break; default: throw QueryRuntimeException("Properties can only be set on edges and vertices."); } return true; } void SetProperties::SetPropertiesCursor::Shutdown() { input_cursor_->Shutdown(); } void SetProperties::SetPropertiesCursor::Reset() { input_cursor_->Reset(); } SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, const std::vector<std::variant<storage::LabelId, query::Expression *>> &labels) : input_(input), input_symbol_(std::move(input_symbol)), labels_(labels) {} SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, const std::vector<storage::LabelId> &labels) : input_(input), input_symbol_(std::move(input_symbol)) { labels_.reserve(labels.size()); for (const auto &label : labels) { labels_.emplace_back(label); } } ACCEPT_WITH_INPUT(SetLabels) UniqueCursorPtr SetLabels::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::SetLabelsOperator); return MakeUniqueCursorPtr<SetLabelsCursor>(mem, *this, mem); } std::vector<Symbol> SetLabels::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } SetLabels::SetLabelsCursor::SetLabelsCursor(const SetLabels &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} bool SetLabels::SetLabelsCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("SetLabels"); ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW); std::vector<storage::LabelId> labels; for (const auto &label : self_.labels_) { if (std::holds_alternative<storage::LabelId>(label)) { labels.push_back(std::get<storage::LabelId>(label)); } else { labels.push_back( context.db_accessor->NameToLabel(std::get<query::Expression *>(label)->Accept(evaluator).ValueString())); } } #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !context.auth_checker->Has(labels, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)) { throw QueryRuntimeException("Couldn't set label due to not having enough permission!"); } #endif if (!input_cursor_->Pull(frame, context)) return false; TypedValue &vertex_value = frame[self_.input_symbol_]; // Skip setting labels on Null (can occur in optional match). if (vertex_value.IsNull()) return true; ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); auto &vertex = vertex_value.ValueVertex(); #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !context.auth_checker->Has(vertex, storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) { throw QueryRuntimeException("Couldn't set label due to not having enough permission!"); } #endif for (auto label : labels) { auto maybe_value = vertex.AddLabel(label); if (maybe_value.HasError()) { switch (maybe_value.GetError()) { case storage::Error::SERIALIZATION_ERROR: throw TransactionSerializationException(); case storage::Error::DELETED_OBJECT: throw QueryRuntimeException("Trying to set a label on a deleted node."); case storage::Error::VERTEX_HAS_EDGES: case storage::Error::PROPERTIES_DISABLED: case storage::Error::NONEXISTENT_OBJECT: throw QueryRuntimeException("Unexpected error when setting a label."); } } if (context.trigger_context_collector && *maybe_value) { context.trigger_context_collector->RegisterSetVertexLabel(vertex, label); } } return true; } void SetLabels::SetLabelsCursor::Shutdown() { input_cursor_->Shutdown(); } void SetLabels::SetLabelsCursor::Reset() { input_cursor_->Reset(); } RemoveProperty::RemoveProperty(const std::shared_ptr<LogicalOperator> &input, storage::PropertyId property, PropertyLookup *lhs) : input_(input), property_(property), lhs_(lhs) {} ACCEPT_WITH_INPUT(RemoveProperty) UniqueCursorPtr RemoveProperty::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::RemovePropertyOperator); return MakeUniqueCursorPtr<RemovePropertyCursor>(mem, *this, mem); } std::vector<Symbol> RemoveProperty::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } RemoveProperty::RemovePropertyCursor::RemovePropertyCursor(const RemoveProperty &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} bool RemoveProperty::RemovePropertyCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("RemoveProperty"); if (!input_cursor_->Pull(frame, context)) return false; // Remove, just like Delete needs to see the latest changes. ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW); TypedValue lhs = self_.lhs_->expression_->Accept(evaluator); auto remove_prop = [property = self_.property_, &context](auto *record) { auto maybe_old_value = record->RemoveProperty(property); if (maybe_old_value.HasError()) { switch (maybe_old_value.GetError()) { case storage::Error::DELETED_OBJECT: throw QueryRuntimeException("Trying to remove a property on a deleted graph element."); case storage::Error::SERIALIZATION_ERROR: throw TransactionSerializationException(); case storage::Error::PROPERTIES_DISABLED: throw QueryRuntimeException( "Can't remove property because properties on edges are " "disabled."); case storage::Error::VERTEX_HAS_EDGES: case storage::Error::NONEXISTENT_OBJECT: throw QueryRuntimeException("Unexpected error when removing property."); } } if (context.trigger_context_collector) { context.trigger_context_collector->RegisterRemovedObjectProperty(*record, property, TypedValue(std::move(*maybe_old_value))); } }; switch (lhs.type()) { case TypedValue::Type::Vertex: #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !context.auth_checker->Has(lhs.ValueVertex(), storage::View::NEW, memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) { throw QueryRuntimeException("Vertex property not removed due to not having enough permission!"); } #endif remove_prop(&lhs.ValueVertex()); break; case TypedValue::Type::Edge: #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !context.auth_checker->Has(lhs.ValueEdge(), memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) { throw QueryRuntimeException("Edge property not removed due to not having enough permission!"); } #endif remove_prop(&lhs.ValueEdge()); break; case TypedValue::Type::Null: // Skip removing properties on Null (can occur in optional match). break; default: throw QueryRuntimeException("Properties can only be removed from vertices and edges."); } return true; } void RemoveProperty::RemovePropertyCursor::Shutdown() { input_cursor_->Shutdown(); } void RemoveProperty::RemovePropertyCursor::Reset() { input_cursor_->Reset(); } RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, const std::vector<std::variant<storage::LabelId, query::Expression *>> &labels) : input_(input), input_symbol_(std::move(input_symbol)), labels_(labels) {} RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, const std::vector<storage::LabelId> &labels) : input_(input), input_symbol_(std::move(input_symbol)) { labels_.reserve(labels.size()); for (const auto &label : labels) { labels_.push_back(label); } } ACCEPT_WITH_INPUT(RemoveLabels) UniqueCursorPtr RemoveLabels::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::RemoveLabelsOperator); return MakeUniqueCursorPtr<RemoveLabelsCursor>(mem, *this, mem); } std::vector<Symbol> RemoveLabels::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } RemoveLabels::RemoveLabelsCursor::RemoveLabelsCursor(const RemoveLabels &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("RemoveLabels"); ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW); std::vector<storage::LabelId> labels; for (const auto &label : self_.labels_) { if (std::holds_alternative<storage::LabelId>(label)) { labels.push_back(std::get<storage::LabelId>(label)); } else { labels.push_back( context.db_accessor->NameToLabel(std::get<query::Expression *>(label)->Accept(evaluator).ValueString())); } } #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !context.auth_checker->Has(labels, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)) { throw QueryRuntimeException("Couldn't remove label due to not having enough permission!"); } #endif if (!input_cursor_->Pull(frame, context)) return false; TypedValue &vertex_value = frame[self_.input_symbol_]; // Skip removing labels on Null (can occur in optional match). if (vertex_value.IsNull()) return true; ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); auto &vertex = vertex_value.ValueVertex(); #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !context.auth_checker->Has(vertex, storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) { throw QueryRuntimeException("Couldn't remove label due to not having enough permission!"); } #endif for (auto label : labels) { auto maybe_value = vertex.RemoveLabel(label); if (maybe_value.HasError()) { switch (maybe_value.GetError()) { case storage::Error::SERIALIZATION_ERROR: throw TransactionSerializationException(); case storage::Error::DELETED_OBJECT: throw QueryRuntimeException("Trying to remove labels from a deleted node."); case storage::Error::VERTEX_HAS_EDGES: case storage::Error::PROPERTIES_DISABLED: case storage::Error::NONEXISTENT_OBJECT: throw QueryRuntimeException("Unexpected error when removing labels from a node."); } } context.execution_stats[ExecutionStats::Key::DELETED_LABELS] += 1; if (context.trigger_context_collector && *maybe_value) { context.trigger_context_collector->RegisterRemovedVertexLabel(vertex, label); } } return true; } void RemoveLabels::RemoveLabelsCursor::Shutdown() { input_cursor_->Shutdown(); } void RemoveLabels::RemoveLabelsCursor::Reset() { input_cursor_->Reset(); } EdgeUniquenessFilter::EdgeUniquenessFilter(const std::shared_ptr<LogicalOperator> &input, Symbol expand_symbol, const std::vector<Symbol> &previous_symbols) : input_(input), expand_symbol_(std::move(expand_symbol)), previous_symbols_(previous_symbols) {} ACCEPT_WITH_INPUT(EdgeUniquenessFilter) UniqueCursorPtr EdgeUniquenessFilter::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::EdgeUniquenessFilterOperator); return MakeUniqueCursorPtr<EdgeUniquenessFilterCursor>(mem, *this, mem); } std::vector<Symbol> EdgeUniquenessFilter::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } EdgeUniquenessFilter::EdgeUniquenessFilterCursor::EdgeUniquenessFilterCursor(const EdgeUniquenessFilter &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)) {} namespace { /** * Returns true if: * - a and b are either edge or edge-list values, and there * is at least one matching edge in the two values */ bool ContainsSameEdge(const TypedValue &a, const TypedValue &b) { auto compare_to_list = [](const TypedValue &list, const TypedValue &other) { for (const TypedValue &list_elem : list.ValueList()) if (ContainsSameEdge(list_elem, other)) return true; return false; }; if (a.type() == TypedValue::Type::List) return compare_to_list(a, b); if (b.type() == TypedValue::Type::List) return compare_to_list(b, a); return a.ValueEdge() == b.ValueEdge(); } } // namespace bool EdgeUniquenessFilter::EdgeUniquenessFilterCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("EdgeUniquenessFilter"); auto expansion_ok = [&]() { const auto &expand_value = frame[self_.expand_symbol_]; for (const auto &previous_symbol : self_.previous_symbols_) { const auto &previous_value = frame[previous_symbol]; // This shouldn't raise a TypedValueException, because the planner // makes sure these are all of the expected type. In case they are not // an error should be raised long before this code is executed. if (ContainsSameEdge(previous_value, expand_value)) return false; } return true; }; while (input_cursor_->Pull(frame, context)) if (expansion_ok()) return true; return false; } void EdgeUniquenessFilter::EdgeUniquenessFilterCursor::Shutdown() { input_cursor_->Shutdown(); } void EdgeUniquenessFilter::EdgeUniquenessFilterCursor::Reset() { input_cursor_->Reset(); } EmptyResult::EmptyResult(const std::shared_ptr<LogicalOperator> &input) : input_(input ? input : std::make_shared<Once>()) {} ACCEPT_WITH_INPUT(EmptyResult) std::vector<Symbol> EmptyResult::OutputSymbols(const SymbolTable &) const { // NOLINT(hicpp-named-parameter) return {}; } std::vector<Symbol> EmptyResult::ModifiedSymbols(const SymbolTable &) const { // NOLINT(hicpp-named-parameter) return {}; } class EmptyResultCursor : public Cursor { public: EmptyResultCursor(const EmptyResult &self, utils::MemoryResource *mem) : input_cursor_(self.input_->MakeCursor(mem)) {} bool Pull(Frame &frame, ExecutionContext &context) override { SCOPED_PROFILE_OP("EmptyResult"); if (!pulled_all_input_) { while (input_cursor_->Pull(frame, context)) { AbortCheck(context); } pulled_all_input_ = true; } return false; } void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); pulled_all_input_ = false; } private: const UniqueCursorPtr input_cursor_; bool pulled_all_input_{false}; }; UniqueCursorPtr EmptyResult::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::EmptyResultOperator); return MakeUniqueCursorPtr<EmptyResultCursor>(mem, *this, mem); } Accumulate::Accumulate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Symbol> &symbols, bool advance_command) : input_(input), symbols_(symbols), advance_command_(advance_command) {} ACCEPT_WITH_INPUT(Accumulate) std::vector<Symbol> Accumulate::ModifiedSymbols(const SymbolTable &) const { return symbols_; } class AccumulateCursor : public Cursor { public: AccumulateCursor(const Accumulate &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)), cache_(mem) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("Accumulate"); auto &dba = *context.db_accessor; // cache all the input if (!pulled_all_input_) { while (input_cursor_->Pull(frame, context)) { utils::pmr::vector<TypedValue> row(cache_.get_allocator().GetMemoryResource()); row.reserve(self_.symbols_.size()); for (const Symbol &symbol : self_.symbols_) row.emplace_back(frame[symbol]); cache_.emplace_back(std::move(row)); } pulled_all_input_ = true; cache_it_ = cache_.begin(); if (self_.advance_command_) dba.AdvanceCommand(); } AbortCheck(context); if (cache_it_ == cache_.end()) return false; auto row_it = (cache_it_++)->begin(); for (const Symbol &symbol : self_.symbols_) { if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) { context.frame_change_collector->ResetTrackingValue(symbol.name()); } frame[symbol] = *row_it++; } return true; } void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); cache_.clear(); cache_it_ = cache_.begin(); pulled_all_input_ = false; } private: const Accumulate &self_; const UniqueCursorPtr input_cursor_; utils::pmr::deque<utils::pmr::vector<TypedValue>> cache_; decltype(cache_.begin()) cache_it_ = cache_.begin(); bool pulled_all_input_{false}; }; UniqueCursorPtr Accumulate::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::AccumulateOperator); return MakeUniqueCursorPtr<AccumulateCursor>(mem, *this, mem); } Aggregate::Aggregate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Aggregate::Element> &aggregations, const std::vector<Expression *> &group_by, const std::vector<Symbol> &remember) : input_(input ? input : std::make_shared<Once>()), aggregations_(aggregations), group_by_(group_by), remember_(remember) {} ACCEPT_WITH_INPUT(Aggregate) std::vector<Symbol> Aggregate::ModifiedSymbols(const SymbolTable &) const { auto symbols = remember_; for (const auto &elem : aggregations_) symbols.push_back(elem.output_sym); return symbols; } namespace { /** Returns the default TypedValue for an Aggregation element. * This value is valid both for returning when where are no inputs * to the aggregation op, and for initializing an aggregation result * when there are */ TypedValue DefaultAggregationOpValue(const Aggregate::Element &element, utils::MemoryResource *memory) { switch (element.op) { case Aggregation::Op::MIN: case Aggregation::Op::MAX: case Aggregation::Op::AVG: return TypedValue(memory); case Aggregation::Op::COUNT: case Aggregation::Op::SUM: return TypedValue(0, memory); case Aggregation::Op::COLLECT_LIST: return TypedValue(TypedValue::TVector(memory)); case Aggregation::Op::COLLECT_MAP: return TypedValue(TypedValue::TMap(memory)); case Aggregation::Op::PROJECT: return TypedValue(query::Graph(memory)); } } } // namespace class AggregateCursor : public Cursor { public: AggregateCursor(const Aggregate &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_.input_->MakeCursor(mem)), aggregation_(mem), reused_group_by_(self.group_by_.size(), mem) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP_BY_REF(self_); if (!pulled_all_input_) { if (!ProcessAll(&frame, &context) && !self_.group_by_.empty()) return false; pulled_all_input_ = true; aggregation_it_ = aggregation_.begin(); if (aggregation_.empty()) { auto *pull_memory = context.evaluation_context.memory; // place default aggregation values on the frame for (const auto &elem : self_.aggregations_) { frame[elem.output_sym] = DefaultAggregationOpValue(elem, pull_memory); if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(elem.output_sym.name())) { context.frame_change_collector->ResetTrackingValue(elem.output_sym.name()); } } // place null as remember values on the frame for (const Symbol &remember_sym : self_.remember_) { frame[remember_sym] = TypedValue(pull_memory); if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(remember_sym.name())) { context.frame_change_collector->ResetTrackingValue(remember_sym.name()); } } return true; } } if (aggregation_it_ == aggregation_.end()) return false; // place aggregation values on the frame auto aggregation_values_it = aggregation_it_->second.values_.begin(); for (const auto &aggregation_elem : self_.aggregations_) frame[aggregation_elem.output_sym] = *aggregation_values_it++; // place remember values on the frame auto remember_values_it = aggregation_it_->second.remember_.begin(); for (const Symbol &remember_sym : self_.remember_) frame[remember_sym] = *remember_values_it++; aggregation_it_++; return true; } void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); aggregation_.clear(); aggregation_it_ = aggregation_.begin(); pulled_all_input_ = false; } private: // Data structure for a single aggregation cache. // Does NOT include the group-by values since those are a key in the // aggregation map. The vectors in an AggregationValue contain one element for // each aggregation in this LogicalOp. struct AggregationValue { explicit AggregationValue(utils::MemoryResource *mem) : counts_(mem), values_(mem), remember_(mem), unique_values_(mem) {} // how many input rows have been aggregated in respective values_ element so // far // TODO: The counting value type should be changed to an unsigned type once // TypedValue can support signed integer values larger than 64bits so that // precision isn't lost. utils::pmr::vector<int64_t> counts_; // aggregated values. Initially Null (until at least one input row with a // valid value gets processed) utils::pmr::vector<TypedValue> values_; // remember values. utils::pmr::vector<TypedValue> remember_; using TSet = utils::pmr::unordered_set<TypedValue, TypedValue::Hash, TypedValue::BoolEqual>; utils::pmr::vector<TSet> unique_values_; }; const Aggregate &self_; const UniqueCursorPtr input_cursor_; // storage for aggregated data // map key is the vector of group-by values // map value is an AggregationValue struct utils::pmr::unordered_map<utils::pmr::vector<TypedValue>, AggregationValue, // use FNV collection hashing specialized for a // vector of TypedValues utils::FnvCollection<utils::pmr::vector<TypedValue>, TypedValue, TypedValue::Hash>, // custom equality TypedValueVectorEqual> aggregation_; // this is a for object reuse, to avoid re-allocating this buffer utils::pmr::vector<TypedValue> reused_group_by_; // iterator over the accumulated cache decltype(aggregation_.begin()) aggregation_it_ = aggregation_.begin(); // this LogicalOp pulls all from the input on it's first pull // this switch tracks if this has been performed bool pulled_all_input_{false}; /** * Pulls from the input operator until exhausted and aggregates the * results. If the input operator is not provided, a single call * to ProcessOne is issued. * * Accumulation automatically groups the results so that `aggregation_` * cache cardinality depends on number of * aggregation results, and not on the number of inputs. */ bool ProcessAll(Frame *frame, ExecutionContext *context) { ExpressionEvaluator evaluator(frame, context->symbol_table, context->evaluation_context, context->db_accessor, storage::View::NEW); bool pulled = false; while (input_cursor_->Pull(*frame, *context)) { ProcessOne(*frame, &evaluator); pulled = true; } if (!pulled) return false; // post processing for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) { switch (self_.aggregations_[pos].op) { case Aggregation::Op::AVG: { // calculate AVG aggregations (so far they have only been summed) for (auto &kv : aggregation_) { AggregationValue &agg_value = kv.second; auto count = agg_value.counts_[pos]; auto *pull_memory = context->evaluation_context.memory; if (count > 0) { agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory); } } break; } case Aggregation::Op::COUNT: { // Copy counts to be the value for (auto &kv : aggregation_) { AggregationValue &agg_value = kv.second; agg_value.values_[pos] = agg_value.counts_[pos]; } break; } case Aggregation::Op::MIN: case Aggregation::Op::MAX: case Aggregation::Op::SUM: case Aggregation::Op::COLLECT_LIST: case Aggregation::Op::COLLECT_MAP: case Aggregation::Op::PROJECT: break; } } return true; } /** * Performs a single accumulation. */ void ProcessOne(const Frame &frame, ExpressionEvaluator *evaluator) { // Preallocated group_by, since most of the time the aggregation key won't be unique reused_group_by_.clear(); evaluator->ResetPropertyLookupCache(); for (Expression *expression : self_.group_by_) { reused_group_by_.emplace_back(expression->Accept(*evaluator)); } auto *mem = aggregation_.get_allocator().GetMemoryResource(); auto res = aggregation_.try_emplace(reused_group_by_, mem); auto &agg_value = res.first->second; if (res.second /*was newly inserted*/) EnsureInitialized(frame, &agg_value); Update(evaluator, &agg_value); } /** Ensures the new AggregationValue has been initialized. This means * that the value vectors are filled with an appropriate number of Nulls, * counts are set to 0 and remember values are remembered. */ void EnsureInitialized(const Frame &frame, AggregateCursor::AggregationValue *agg_value) const { if (!agg_value->values_.empty()) return; const auto num_of_aggregations = self_.aggregations_.size(); agg_value->values_.reserve(num_of_aggregations); agg_value->unique_values_.reserve(num_of_aggregations); auto *mem = agg_value->values_.get_allocator().GetMemoryResource(); for (const auto &agg_elem : self_.aggregations_) { agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem)); agg_value->unique_values_.emplace_back(AggregationValue::TSet(mem)); } agg_value->counts_.resize(num_of_aggregations, 0); agg_value->remember_.reserve(self_.remember_.size()); for (const Symbol &remember_sym : self_.remember_) { agg_value->remember_.push_back(frame[remember_sym]); } } /** Updates the given AggregationValue with new data. Assumes that * the AggregationValue has been initialized */ void Update(ExpressionEvaluator *evaluator, AggregateCursor::AggregationValue *agg_value) { DMG_ASSERT(self_.aggregations_.size() == agg_value->values_.size(), "Expected as much AggregationValue.values_ as there are " "aggregations."); DMG_ASSERT(self_.aggregations_.size() == agg_value->counts_.size(), "Expected as much AggregationValue.counts_ as there are " "aggregations."); auto count_it = agg_value->counts_.begin(); auto value_it = agg_value->values_.begin(); auto unique_values_it = agg_value->unique_values_.begin(); auto agg_elem_it = self_.aggregations_.begin(); const auto counts_end = agg_value->counts_.end(); for (; count_it != counts_end; ++count_it, ++value_it, ++unique_values_it, ++agg_elem_it) { // COUNT(*) is the only case where input expression is optional // handle it here auto input_expr_ptr = agg_elem_it->value; if (!input_expr_ptr) { *count_it += 1; // value is deferred to post-processing continue; } TypedValue input_value = input_expr_ptr->Accept(*evaluator); // Aggregations skip Null input values. if (input_value.IsNull()) continue; const auto &agg_op = agg_elem_it->op; if (agg_elem_it->distinct) { auto insert_result = unique_values_it->insert(input_value); if (!insert_result.second) { continue; } } *count_it += 1; if (*count_it == 1) { // first value, nothing to aggregate. check type, set and continue. switch (agg_op) { case Aggregation::Op::MIN: case Aggregation::Op::MAX: EnsureOkForMinMax(input_value); *value_it = std::move(input_value); break; case Aggregation::Op::SUM: case Aggregation::Op::AVG: EnsureOkForAvgSum(input_value); *value_it = std::move(input_value); break; case Aggregation::Op::COUNT: // value is deferred to post-processing break; case Aggregation::Op::COLLECT_LIST: value_it->ValueList().push_back(std::move(input_value)); break; case Aggregation::Op::PROJECT: { EnsureOkForProject(input_value); value_it->ValueGraph().Expand(input_value.ValuePath()); break; } case Aggregation::Op::COLLECT_MAP: auto key = agg_elem_it->key->Accept(*evaluator); if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string."); value_it->ValueMap().emplace(key.ValueString(), std::move(input_value)); break; } continue; } // aggregation of existing values switch (agg_op) { case Aggregation::Op::COUNT: // value is deferred to post-processing break; case Aggregation::Op::MIN: { EnsureOkForMinMax(input_value); try { TypedValue comparison_result = input_value < *value_it; // since we skip nulls we either have a valid comparison, or // an exception was just thrown above // safe to assume a bool TypedValue if (comparison_result.ValueBool()) *value_it = std::move(input_value); } catch (const TypedValueException &) { throw QueryRuntimeException("Unable to get MIN of '{}' and '{}'.", input_value.type(), value_it->type()); } break; } case Aggregation::Op::MAX: { // all comments as for Op::Min EnsureOkForMinMax(input_value); try { TypedValue comparison_result = input_value > *value_it; if (comparison_result.ValueBool()) *value_it = std::move(input_value); } catch (const TypedValueException &) { throw QueryRuntimeException("Unable to get MAX of '{}' and '{}'.", input_value.type(), value_it->type()); } break; } case Aggregation::Op::AVG: // for averaging we sum first and divide by count once all // the input has been processed case Aggregation::Op::SUM: EnsureOkForAvgSum(input_value); *value_it = *value_it + input_value; break; case Aggregation::Op::COLLECT_LIST: value_it->ValueList().push_back(std::move(input_value)); break; case Aggregation::Op::PROJECT: { EnsureOkForProject(input_value); value_it->ValueGraph().Expand(input_value.ValuePath()); break; } case Aggregation::Op::COLLECT_MAP: auto key = agg_elem_it->key->Accept(*evaluator); if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string."); value_it->ValueMap().emplace(key.ValueString(), std::move(input_value)); break; } // end switch over Aggregation::Op enum } // end loop over all aggregations } /** Checks if the given TypedValue is legal in MIN and MAX. If not * an appropriate exception is thrown. */ void EnsureOkForMinMax(const TypedValue &value) const { switch (value.type()) { case TypedValue::Type::Bool: case TypedValue::Type::Int: case TypedValue::Type::Double: case TypedValue::Type::String: return; default: throw QueryRuntimeException( "Only boolean, numeric and string values are allowed in " "MIN and MAX aggregations."); } } /** Checks if the given TypedValue is legal in AVG and SUM. If not * an appropriate exception is thrown. */ void EnsureOkForAvgSum(const TypedValue &value) const { switch (value.type()) { case TypedValue::Type::Int: case TypedValue::Type::Double: return; default: throw QueryRuntimeException("Only numeric values allowed in SUM and AVG aggregations."); } } /** Checks if the given TypedValue is legal in PROJECT and PROJECT_TRANSITIVE. If not * an appropriate exception is thrown. */ // NOLINTNEXTLINE(readability-convert-member-functions-to-static) void EnsureOkForProject(const TypedValue &value) const { switch (value.type()) { case TypedValue::Type::Path: return; default: throw QueryRuntimeException("Only path values allowed in PROJECT aggregation."); } } }; UniqueCursorPtr Aggregate::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::AggregateOperator); return MakeUniqueCursorPtr<AggregateCursor>(mem, *this, mem); } Skip::Skip(const std::shared_ptr<LogicalOperator> &input, Expression *expression) : input_(input), expression_(expression) {} ACCEPT_WITH_INPUT(Skip) UniqueCursorPtr Skip::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::SkipOperator); return MakeUniqueCursorPtr<SkipCursor>(mem, *this, mem); } std::vector<Symbol> Skip::OutputSymbols(const SymbolTable &symbol_table) const { // Propagate this to potential Produce. return input_->OutputSymbols(symbol_table); } std::vector<Symbol> Skip::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } Skip::SkipCursor::SkipCursor(const Skip &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {} bool Skip::SkipCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("Skip"); while (input_cursor_->Pull(frame, context)) { if (to_skip_ == -1) { // First successful pull from the input, evaluate the skip expression. // The skip expression doesn't contain identifiers so graph view // parameter is not important. ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); TypedValue to_skip = self_.expression_->Accept(evaluator); if (to_skip.type() != TypedValue::Type::Int) throw QueryRuntimeException("Number of elements to skip must be an integer."); to_skip_ = to_skip.ValueInt(); if (to_skip_ < 0) throw QueryRuntimeException("Number of elements to skip must be non-negative."); } if (skipped_++ < to_skip_) continue; return true; } return false; } void Skip::SkipCursor::Shutdown() { input_cursor_->Shutdown(); } void Skip::SkipCursor::Reset() { input_cursor_->Reset(); to_skip_ = -1; skipped_ = 0; } Limit::Limit(const std::shared_ptr<LogicalOperator> &input, Expression *expression) : input_(input), expression_(expression) {} ACCEPT_WITH_INPUT(Limit) UniqueCursorPtr Limit::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::LimitOperator); return MakeUniqueCursorPtr<LimitCursor>(mem, *this, mem); } std::vector<Symbol> Limit::OutputSymbols(const SymbolTable &symbol_table) const { // Propagate this to potential Produce. return input_->OutputSymbols(symbol_table); } std::vector<Symbol> Limit::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } Limit::LimitCursor::LimitCursor(const Limit &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {} bool Limit::LimitCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("Limit"); // We need to evaluate the limit expression before the first input Pull // because it might be 0 and thereby we shouldn't Pull from input at all. // We can do this before Pulling from the input because the limit expression // is not allowed to contain any identifiers. if (limit_ == -1) { // Limit expression doesn't contain identifiers so graph view is not // important. ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); TypedValue limit = self_.expression_->Accept(evaluator); if (limit.type() != TypedValue::Type::Int) throw QueryRuntimeException("Limit on number of returned elements must be an integer."); limit_ = limit.ValueInt(); if (limit_ < 0) throw QueryRuntimeException("Limit on number of returned elements must be non-negative."); } // check we have not exceeded the limit before pulling if (pulled_++ >= limit_) return false; return input_cursor_->Pull(frame, context); } void Limit::LimitCursor::Shutdown() { input_cursor_->Shutdown(); } void Limit::LimitCursor::Reset() { input_cursor_->Reset(); limit_ = -1; pulled_ = 0; } OrderBy::OrderBy(const std::shared_ptr<LogicalOperator> &input, const std::vector<SortItem> &order_by, const std::vector<Symbol> &output_symbols) : input_(input), output_symbols_(output_symbols) { // split the order_by vector into two vectors of orderings and expressions std::vector<Ordering> ordering; ordering.reserve(order_by.size()); order_by_.reserve(order_by.size()); for (const auto &ordering_expression_pair : order_by) { ordering.emplace_back(ordering_expression_pair.ordering); order_by_.emplace_back(ordering_expression_pair.expression); } compare_ = TypedValueVectorCompare(ordering); } ACCEPT_WITH_INPUT(OrderBy) std::vector<Symbol> OrderBy::OutputSymbols(const SymbolTable &symbol_table) const { // Propagate this to potential Produce. return input_->OutputSymbols(symbol_table); } std::vector<Symbol> OrderBy::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } class OrderByCursor : public Cursor { public: OrderByCursor(const OrderBy &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_.input_->MakeCursor(mem)), cache_(mem) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP_BY_REF(self_); if (!did_pull_all_) { ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); auto *mem = cache_.get_allocator().GetMemoryResource(); while (input_cursor_->Pull(frame, context)) { // collect the order_by elements utils::pmr::vector<TypedValue> order_by(mem); order_by.reserve(self_.order_by_.size()); for (auto expression_ptr : self_.order_by_) { order_by.emplace_back(expression_ptr->Accept(evaluator)); } // collect the output elements utils::pmr::vector<TypedValue> output(mem); output.reserve(self_.output_symbols_.size()); for (const Symbol &output_sym : self_.output_symbols_) output.emplace_back(frame[output_sym]); cache_.push_back(Element{std::move(order_by), std::move(output)}); } std::sort(cache_.begin(), cache_.end(), [this](const auto &pair1, const auto &pair2) { return self_.compare_(pair1.order_by, pair2.order_by); }); did_pull_all_ = true; cache_it_ = cache_.begin(); } if (cache_it_ == cache_.end()) return false; AbortCheck(context); // place the output values on the frame DMG_ASSERT(self_.output_symbols_.size() == cache_it_->remember.size(), "Number of values does not match the number of output symbols " "in OrderBy"); auto output_sym_it = self_.output_symbols_.begin(); for (const TypedValue &output : cache_it_->remember) { if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(output_sym_it->name())) { context.frame_change_collector->ResetTrackingValue(output_sym_it->name()); } frame[*output_sym_it++] = output; } cache_it_++; return true; } void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); did_pull_all_ = false; cache_.clear(); cache_it_ = cache_.begin(); } private: struct Element { utils::pmr::vector<TypedValue> order_by; utils::pmr::vector<TypedValue> remember; }; const OrderBy &self_; const UniqueCursorPtr input_cursor_; bool did_pull_all_{false}; // a cache of elements pulled from the input // the cache is filled and sorted (only on first elem) on first Pull utils::pmr::vector<Element> cache_; // iterator over the cache_, maintains state between Pulls decltype(cache_.begin()) cache_it_ = cache_.begin(); }; UniqueCursorPtr OrderBy::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::OrderByOperator); return MakeUniqueCursorPtr<OrderByCursor>(mem, *this, mem); } Merge::Merge(const std::shared_ptr<LogicalOperator> &input, const std::shared_ptr<LogicalOperator> &merge_match, const std::shared_ptr<LogicalOperator> &merge_create) : input_(input ? input : std::make_shared<Once>()), merge_match_(merge_match), merge_create_(merge_create) {} bool Merge::Accept(HierarchicalLogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { input_->Accept(visitor) && merge_match_->Accept(visitor) && merge_create_->Accept(visitor); } return visitor.PostVisit(*this); } UniqueCursorPtr Merge::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::MergeOperator); return MakeUniqueCursorPtr<MergeCursor>(mem, *this, mem); } std::vector<Symbol> Merge::ModifiedSymbols(const SymbolTable &table) const { auto symbols = input_->ModifiedSymbols(table); // Match and create branches should have the same symbols, so just take one // of them. auto my_symbols = merge_match_->OutputSymbols(table); symbols.insert(symbols.end(), my_symbols.begin(), my_symbols.end()); return symbols; } Merge::MergeCursor::MergeCursor(const Merge &self, utils::MemoryResource *mem) : input_cursor_(self.input_->MakeCursor(mem)), merge_match_cursor_(self.merge_match_->MakeCursor(mem)), merge_create_cursor_(self.merge_create_->MakeCursor(mem)) {} bool Merge::MergeCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("Merge"); while (true) { if (pull_input_) { if (input_cursor_->Pull(frame, context)) { // after a successful input from the input // reset merge_match (it's expand iterators maintain state) // and merge_create (could have a Once at the beginning) merge_match_cursor_->Reset(); merge_create_cursor_->Reset(); } else // input is exhausted, we're done return false; } // pull from the merge_match cursor if (merge_match_cursor_->Pull(frame, context)) { // if successful, next Pull from this should not pull_input_ pull_input_ = false; return true; } else { // failed to Pull from the merge_match cursor if (pull_input_) { // if we have just now pulled from the input // and failed to pull from merge_match, we should create return merge_create_cursor_->Pull(frame, context); } // We have exhausted merge_match_cursor_ after 1 or more successful // Pulls. Attempt next input_cursor_ pull pull_input_ = true; continue; } } } void Merge::MergeCursor::Shutdown() { input_cursor_->Shutdown(); merge_match_cursor_->Shutdown(); merge_create_cursor_->Shutdown(); } void Merge::MergeCursor::Reset() { input_cursor_->Reset(); merge_match_cursor_->Reset(); merge_create_cursor_->Reset(); pull_input_ = true; } Optional::Optional(const std::shared_ptr<LogicalOperator> &input, const std::shared_ptr<LogicalOperator> &optional, const std::vector<Symbol> &optional_symbols) : input_(input ? input : std::make_shared<Once>()), optional_(optional), optional_symbols_(optional_symbols) {} bool Optional::Accept(HierarchicalLogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { input_->Accept(visitor) && optional_->Accept(visitor); } return visitor.PostVisit(*this); } UniqueCursorPtr Optional::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::OptionalOperator); return MakeUniqueCursorPtr<OptionalCursor>(mem, *this, mem); } std::vector<Symbol> Optional::ModifiedSymbols(const SymbolTable &table) const { auto symbols = input_->ModifiedSymbols(table); auto my_symbols = optional_->ModifiedSymbols(table); symbols.insert(symbols.end(), my_symbols.begin(), my_symbols.end()); return symbols; } Optional::OptionalCursor::OptionalCursor(const Optional &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)), optional_cursor_(self.optional_->MakeCursor(mem)) {} bool Optional::OptionalCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("Optional"); while (true) { if (pull_input_) { if (input_cursor_->Pull(frame, context)) { // after a successful input from the input // reset optional_ (it's expand iterators maintain state) optional_cursor_->Reset(); } else // input is exhausted, we're done return false; } // pull from the optional_ cursor if (optional_cursor_->Pull(frame, context)) { // if successful, next Pull from this should not pull_input_ pull_input_ = false; return true; } else { // failed to Pull from the merge_match cursor if (pull_input_) { // if we have just now pulled from the input // and failed to pull from optional_ so set the // optional symbols to Null, ensure next time the // input gets pulled and return true for (const Symbol &sym : self_.optional_symbols_) frame[sym] = TypedValue(context.evaluation_context.memory); pull_input_ = true; return true; } // we have exhausted optional_cursor_ after 1 or more successful Pulls // attempt next input_cursor_ pull pull_input_ = true; continue; } } } void Optional::OptionalCursor::Shutdown() { input_cursor_->Shutdown(); optional_cursor_->Shutdown(); } void Optional::OptionalCursor::Reset() { input_cursor_->Reset(); optional_cursor_->Reset(); pull_input_ = true; } Unwind::Unwind(const std::shared_ptr<LogicalOperator> &input, Expression *input_expression, Symbol output_symbol) : input_(input ? input : std::make_shared<Once>()), input_expression_(input_expression), output_symbol_(std::move(output_symbol)) {} ACCEPT_WITH_INPUT(Unwind) std::vector<Symbol> Unwind::ModifiedSymbols(const SymbolTable &table) const { auto symbols = input_->ModifiedSymbols(table); symbols.emplace_back(output_symbol_); return symbols; } class UnwindCursor : public Cursor { public: UnwindCursor(const Unwind &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)), input_value_(mem) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("Unwind"); while (true) { AbortCheck(context); // if we reached the end of our list of values // pull from the input if (input_value_it_ == input_value_.end()) { if (!input_cursor_->Pull(frame, context)) return false; // successful pull from input, initialize value and iterator ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); TypedValue input_value = self_.input_expression_->Accept(evaluator); if (input_value.type() != TypedValue::Type::List) throw QueryRuntimeException("Argument of UNWIND must be a list, but '{}' was provided.", input_value.type()); // Copy the evaluted input_value_list to our vector. input_value_ = input_value.ValueList(); input_value_it_ = input_value_.begin(); } // if we reached the end of our list of values goto back to top if (input_value_it_ == input_value_.end()) continue; frame[self_.output_symbol_] = *input_value_it_++; if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(self_.output_symbol_.name_)) { context.frame_change_collector->ResetTrackingValue(self_.output_symbol_.name_); } return true; } } void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); input_value_.clear(); input_value_it_ = input_value_.end(); } private: const Unwind &self_; const UniqueCursorPtr input_cursor_; // typed values we are unwinding and yielding utils::pmr::vector<TypedValue> input_value_; // current position in input_value_ decltype(input_value_)::iterator input_value_it_ = input_value_.end(); }; UniqueCursorPtr Unwind::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::UnwindOperator); return MakeUniqueCursorPtr<UnwindCursor>(mem, *this, mem); } class DistinctCursor : public Cursor { public: DistinctCursor(const Distinct &self, utils::MemoryResource *mem) : self_(self), input_cursor_(self.input_->MakeCursor(mem)), seen_rows_(mem) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("Distinct"); while (true) { if (!input_cursor_->Pull(frame, context)) return false; utils::pmr::vector<TypedValue> row(seen_rows_.get_allocator().GetMemoryResource()); row.reserve(self_.value_symbols_.size()); for (const auto &symbol : self_.value_symbols_) { row.emplace_back(frame.at(symbol)); } if (seen_rows_.insert(std::move(row)).second) { return true; } } } void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { input_cursor_->Reset(); seen_rows_.clear(); } private: const Distinct &self_; const UniqueCursorPtr input_cursor_; // a set of already seen rows utils::pmr::unordered_set<utils::pmr::vector<TypedValue>, // use FNV collection hashing specialized for a // vector of TypedValue utils::FnvCollection<utils::pmr::vector<TypedValue>, TypedValue, TypedValue::Hash>, TypedValueVectorEqual> seen_rows_; }; Distinct::Distinct(const std::shared_ptr<LogicalOperator> &input, const std::vector<Symbol> &value_symbols) : input_(input ? input : std::make_shared<Once>()), value_symbols_(value_symbols) {} ACCEPT_WITH_INPUT(Distinct) UniqueCursorPtr Distinct::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::DistinctOperator); return MakeUniqueCursorPtr<DistinctCursor>(mem, *this, mem); } std::vector<Symbol> Distinct::OutputSymbols(const SymbolTable &symbol_table) const { // Propagate this to potential Produce. return input_->OutputSymbols(symbol_table); } std::vector<Symbol> Distinct::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); } Union::Union(const std::shared_ptr<LogicalOperator> &left_op, const std::shared_ptr<LogicalOperator> &right_op, const std::vector<Symbol> &union_symbols, const std::vector<Symbol> &left_symbols, const std::vector<Symbol> &right_symbols) : left_op_(left_op), right_op_(right_op), union_symbols_(union_symbols), left_symbols_(left_symbols), right_symbols_(right_symbols) {} UniqueCursorPtr Union::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::UnionOperator); return MakeUniqueCursorPtr<Union::UnionCursor>(mem, *this, mem); } bool Union::Accept(HierarchicalLogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { if (left_op_->Accept(visitor)) { right_op_->Accept(visitor); } } return visitor.PostVisit(*this); } std::vector<Symbol> Union::OutputSymbols(const SymbolTable &) const { return union_symbols_; } std::vector<Symbol> Union::ModifiedSymbols(const SymbolTable &) const { return union_symbols_; } WITHOUT_SINGLE_INPUT(Union); Union::UnionCursor::UnionCursor(const Union &self, utils::MemoryResource *mem) : self_(self), left_cursor_(self.left_op_->MakeCursor(mem)), right_cursor_(self.right_op_->MakeCursor(mem)) {} bool Union::UnionCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP_BY_REF(self_); utils::pmr::unordered_map<std::string, TypedValue> results(context.evaluation_context.memory); if (left_cursor_->Pull(frame, context)) { // collect values from the left child for (const auto &output_symbol : self_.left_symbols_) { results[output_symbol.name()] = frame[output_symbol]; if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(output_symbol.name())) { context.frame_change_collector->ResetTrackingValue(output_symbol.name()); } } } else if (right_cursor_->Pull(frame, context)) { // collect values from the right child for (const auto &output_symbol : self_.right_symbols_) { results[output_symbol.name()] = frame[output_symbol]; if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(output_symbol.name())) { context.frame_change_collector->ResetTrackingValue(output_symbol.name()); } } } else { return false; } // put collected values on frame under union symbols for (const auto &symbol : self_.union_symbols_) { frame[symbol] = results[symbol.name()]; if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) { context.frame_change_collector->ResetTrackingValue(symbol.name()); } } return true; } void Union::UnionCursor::Shutdown() { left_cursor_->Shutdown(); right_cursor_->Shutdown(); } void Union::UnionCursor::Reset() { left_cursor_->Reset(); right_cursor_->Reset(); } std::vector<Symbol> Cartesian::ModifiedSymbols(const SymbolTable &table) const { auto symbols = left_op_->ModifiedSymbols(table); auto right = right_op_->ModifiedSymbols(table); symbols.insert(symbols.end(), right.begin(), right.end()); return symbols; } bool Cartesian::Accept(HierarchicalLogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { left_op_->Accept(visitor) && right_op_->Accept(visitor); } return visitor.PostVisit(*this); } WITHOUT_SINGLE_INPUT(Cartesian); namespace { class CartesianCursor : public Cursor { public: CartesianCursor(const Cartesian &self, utils::MemoryResource *mem) : self_(self), left_op_frames_(mem), right_op_frame_(mem), left_op_cursor_(self.left_op_->MakeCursor(mem)), right_op_cursor_(self_.right_op_->MakeCursor(mem)) { MG_ASSERT(left_op_cursor_ != nullptr, "CartesianCursor: Missing left operator cursor."); MG_ASSERT(right_op_cursor_ != nullptr, "CartesianCursor: Missing right operator cursor."); } bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP_BY_REF(self_); if (!cartesian_pull_initialized_) { // Pull all left_op frames. while (left_op_cursor_->Pull(frame, context)) { left_op_frames_.emplace_back(frame.elems().begin(), frame.elems().end()); } // We're setting the iterator to 'end' here so it pulls the right // cursor. left_op_frames_it_ = left_op_frames_.end(); cartesian_pull_initialized_ = true; } // If left operator yielded zero results there is no cartesian product. if (left_op_frames_.empty()) { return false; } auto restore_frame = [&frame, &context](const auto &symbols, const auto &restore_from) { for (const auto &symbol : symbols) { frame[symbol] = restore_from[symbol.position()]; if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) { context.frame_change_collector->ResetTrackingValue(symbol.name()); } } }; if (left_op_frames_it_ == left_op_frames_.end()) { // Advance right_op_cursor_. if (!right_op_cursor_->Pull(frame, context)) return false; right_op_frame_.assign(frame.elems().begin(), frame.elems().end()); left_op_frames_it_ = left_op_frames_.begin(); } else { // Make sure right_op_cursor last pulled results are on frame. restore_frame(self_.right_symbols_, right_op_frame_); } AbortCheck(context); restore_frame(self_.left_symbols_, *left_op_frames_it_); left_op_frames_it_++; return true; } void Shutdown() override { left_op_cursor_->Shutdown(); right_op_cursor_->Shutdown(); } void Reset() override { left_op_cursor_->Reset(); right_op_cursor_->Reset(); right_op_frame_.clear(); left_op_frames_.clear(); left_op_frames_it_ = left_op_frames_.end(); cartesian_pull_initialized_ = false; } private: const Cartesian &self_; utils::pmr::vector<utils::pmr::vector<TypedValue>> left_op_frames_; utils::pmr::vector<TypedValue> right_op_frame_; const UniqueCursorPtr left_op_cursor_; const UniqueCursorPtr right_op_cursor_; utils::pmr::vector<utils::pmr::vector<TypedValue>>::iterator left_op_frames_it_; bool cartesian_pull_initialized_{false}; }; } // namespace UniqueCursorPtr Cartesian::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::CartesianOperator); return MakeUniqueCursorPtr<CartesianCursor>(mem, *this, mem); } OutputTable::OutputTable(std::vector<Symbol> output_symbols, std::vector<std::vector<TypedValue>> rows) : output_symbols_(std::move(output_symbols)), callback_([rows](Frame *, ExecutionContext *) { return rows; }) {} OutputTable::OutputTable(std::vector<Symbol> output_symbols, std::function<std::vector<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback) : output_symbols_(std::move(output_symbols)), callback_(std::move(callback)) {} WITHOUT_SINGLE_INPUT(OutputTable); class OutputTableCursor : public Cursor { public: explicit OutputTableCursor(const OutputTable &self) : self_(self) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; if (!pulled_) { rows_ = self_.callback_(&frame, &context); for (const auto &row : rows_) { MG_ASSERT(row.size() == self_.output_symbols_.size(), "Wrong number of columns in row!"); } pulled_ = true; } if (current_row_ < rows_.size()) { for (size_t i = 0; i < self_.output_symbols_.size(); ++i) { frame[self_.output_symbols_[i]] = rows_[current_row_][i]; if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(self_.output_symbols_[i].name())) { context.frame_change_collector->ResetTrackingValue(self_.output_symbols_[i].name()); } } current_row_++; return true; } return false; } void Reset() override { pulled_ = false; current_row_ = 0; rows_.clear(); } void Shutdown() override {} private: const OutputTable &self_; size_t current_row_{0}; std::vector<std::vector<TypedValue>> rows_; bool pulled_{false}; }; UniqueCursorPtr OutputTable::MakeCursor(utils::MemoryResource *mem) const { return MakeUniqueCursorPtr<OutputTableCursor>(mem, *this); } OutputTableStream::OutputTableStream( std::vector<Symbol> output_symbols, std::function<std::optional<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback) : output_symbols_(std::move(output_symbols)), callback_(std::move(callback)) {} WITHOUT_SINGLE_INPUT(OutputTableStream); class OutputTableStreamCursor : public Cursor { public: explicit OutputTableStreamCursor(const OutputTableStream *self) : self_(self) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; const auto row = self_->callback_(&frame, &context); if (row) { MG_ASSERT(row->size() == self_->output_symbols_.size(), "Wrong number of columns in row!"); for (size_t i = 0; i < self_->output_symbols_.size(); ++i) { frame[self_->output_symbols_[i]] = row->at(i); if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(self_->output_symbols_[i].name())) { context.frame_change_collector->ResetTrackingValue(self_->output_symbols_[i].name()); } } return true; } return false; } // TODO(tsabolcec): Come up with better approach for handling `Reset()`. // One possibility is to implement a custom closure utility class with // `Reset()` method. void Reset() override { throw utils::NotYetImplemented("OutputTableStreamCursor::Reset"); } void Shutdown() override {} private: const OutputTableStream *self_; }; UniqueCursorPtr OutputTableStream::MakeCursor(utils::MemoryResource *mem) const { return MakeUniqueCursorPtr<OutputTableStreamCursor>(mem, this); } CallProcedure::CallProcedure(std::shared_ptr<LogicalOperator> input, std::string name, std::vector<Expression *> args, std::vector<std::string> fields, std::vector<Symbol> symbols, Expression *memory_limit, size_t memory_scale, bool is_write, int64_t procedure_id, bool void_procedure) : input_(input ? input : std::make_shared<Once>()), procedure_name_(std::move(name)), arguments_(std::move(args)), result_fields_(std::move(fields)), result_symbols_(std::move(symbols)), memory_limit_(memory_limit), memory_scale_(memory_scale), is_write_(is_write), procedure_id_(procedure_id), void_procedure_(void_procedure) {} ACCEPT_WITH_INPUT(CallProcedure); std::vector<Symbol> CallProcedure::OutputSymbols(const SymbolTable &) const { return result_symbols_; } std::vector<Symbol> CallProcedure::ModifiedSymbols(const SymbolTable &table) const { auto symbols = input_->ModifiedSymbols(table); symbols.insert(symbols.end(), result_symbols_.begin(), result_symbols_.end()); return symbols; } void CallProcedure::IncrementCounter(const std::string &procedure_name) { procedure_counters_.WithLock([&](auto &counters) { ++counters[procedure_name]; }); } std::unordered_map<std::string, int64_t> CallProcedure::GetAndResetCounters() { auto counters = procedure_counters_.Lock(); auto ret = std::move(*counters); counters->clear(); return ret; } namespace { void CallCustomProcedure(const std::string_view fully_qualified_procedure_name, const mgp_proc &proc, const std::vector<Expression *> &args, mgp_graph &graph, ExpressionEvaluator *evaluator, utils::MemoryResource *memory, std::optional<size_t> memory_limit, mgp_result *result, int64_t procedure_id, uint64_t transaction_id, const bool call_initializer = false) { static_assert(std::uses_allocator_v<mgp_value, utils::Allocator<mgp_value>>, "Expected mgp_value to use custom allocator and makes STL " "containers aware of that"); // Build and type check procedure arguments. mgp_list proc_args(memory); std::vector<TypedValue> args_list; args_list.reserve(args.size()); for (auto *expression : args) { args_list.emplace_back(expression->Accept(*evaluator)); } std::optional<query::Graph> subgraph; std::optional<query::SubgraphDbAccessor> db_acc; if (!args_list.empty() && args_list.front().type() == TypedValue::Type::Graph) { auto subgraph_value = args_list.front().ValueGraph(); subgraph = query::Graph(std::move(subgraph_value), subgraph_value.GetMemoryResource()); args_list.erase(args_list.begin()); db_acc = query::SubgraphDbAccessor(*std::get<query::DbAccessor *>(graph.impl), &*subgraph); graph.impl = &*db_acc; } procedure::ConstructArguments(args_list, proc, fully_qualified_procedure_name, proc_args, graph); if (call_initializer) { MG_ASSERT(proc.initializer); mgp_memory initializer_memory{memory}; proc.initializer.value()(&proc_args, &graph, &initializer_memory); } if (memory_limit) { SPDLOG_INFO("Running '{}' with memory limit of {}", fully_qualified_procedure_name, utils::GetReadableSize(*memory_limit)); // Only allocations which can leak memory are // our own mgp object allocations. Jemalloc can track // memory correctly, but some memory may not be released // immediately, so we want to give user info on leak still // considering our allocations utils::MemoryTrackingResource memory_tracking_resource{memory, *memory_limit}; // if we are already tracking, no harm no faul // if we are not tracking, we need to start now, with unlimited memory // for query, but limited for procedure // check if transaction is tracked currently, so we // can disable tracking on that arena if it is not // once we are done with procedure tracking bool is_transaction_tracked = memgraph::memory::IsTransactionTracked(transaction_id); if (!is_transaction_tracked) { // start tracking with unlimited limit on query // which is same as not being tracked at all memgraph::memory::TryStartTrackingOnTransaction(transaction_id, memgraph::memory::UNLIMITED_MEMORY); } memgraph::memory::StartTrackingCurrentThreadTransaction(transaction_id); // due to mgp_batch_read_proc and mgp_batch_write_proc // we can return to execution without exhausting whole // memory. Here we need to update tracking memgraph::memory::CreateOrContinueProcedureTracking(transaction_id, procedure_id, *memory_limit); mgp_memory proc_memory{&memory_tracking_resource}; MG_ASSERT(result->signature == &proc.results); utils::OnScopeExit on_scope_exit{[transaction_id = transaction_id]() { memgraph::memory::StopTrackingCurrentThreadTransaction(transaction_id); memgraph::memory::PauseProcedureTracking(transaction_id); }}; // TODO: What about cross library boundary exceptions? OMG C++?! proc.cb(&proc_args, &graph, result, &proc_memory); auto leaked_bytes = memory_tracking_resource.GetAllocatedBytes(); if (leaked_bytes > 0U) { spdlog::warn("Query procedure '{}' leaked {} *tracked* bytes", fully_qualified_procedure_name, leaked_bytes); } } else { // TODO: Add a tracking MemoryResource without limits, so that we report // memory leaks in procedure. mgp_memory proc_memory{memory}; MG_ASSERT(result->signature == &proc.results); // TODO: What about cross library boundary exceptions? OMG C++?! proc.cb(&proc_args, &graph, result, &proc_memory); } } } // namespace class CallProcedureCursor : public Cursor { const CallProcedure *self_; UniqueCursorPtr input_cursor_; mgp_result *result_; decltype(result_->rows.end()) result_row_it_{result_->rows.end()}; size_t result_signature_size_{0}; bool stream_exhausted{true}; bool call_initializer{false}; std::optional<std::function<void()>> cleanup_{std::nullopt}; public: CallProcedureCursor(const CallProcedure *self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_->input_->MakeCursor(mem)), // result_ needs to live throughout multiple Pull evaluations, until all // rows are produced. We don't use the memory dedicated for QueryExecution (and Frame), // but memory dedicated for procedure to wipe result_ and everything allocated in procedure all at once. result_(utils::Allocator<mgp_result>(self_->memory_resource) .new_object<mgp_result>(nullptr, self_->memory_resource)) { MG_ASSERT(self_->result_fields_.size() == self_->result_symbols_.size(), "Incorrectly constructed CallProcedure"); } bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP_BY_REF(*self_); AbortCheck(context); auto skip_rows_with_deleted_values = [this]() { while (result_row_it_ != result_->rows.end() && result_row_it_->has_deleted_values) { ++result_row_it_; } }; // We need to fetch new procedure results after pulling from input. // TODO: Look into openCypher's distinction between procedures returning an // empty result set vs procedures which return `void`. We currently don't // have procedures registering what they return. // This `while` loop will skip over empty results. while (result_row_it_ == result_->rows.end()) { // It might be a good idea to resolve the procedure name once, at the // start. Unfortunately, this could deadlock if we tried to invoke a // procedure from a module (read lock) and reload a module (write lock) // inside the same execution thread. Also, our RWLock is set up so that // it's not possible for a single thread to request multiple read locks. // Builtin module registration in query/procedure/module.cpp depends on // this locking scheme. const auto &maybe_found = procedure::FindProcedure(procedure::gModuleRegistry, self_->procedure_name_, context.evaluation_context.memory); if (!maybe_found) { throw QueryRuntimeException("There is no procedure named '{}'.", self_->procedure_name_); } const auto &[module, proc] = *maybe_found; if (proc->info.is_write != self_->is_write_) { auto get_proc_type_str = [](bool is_write) { return is_write ? "write" : "read"; }; throw QueryRuntimeException("The procedure named '{}' was a {} procedure, but changed to be a {} procedure.", self_->procedure_name_, get_proc_type_str(self_->is_write_), get_proc_type_str(proc->info.is_write)); } if (!proc->info.is_batched) { stream_exhausted = true; } if (stream_exhausted) { if (!input_cursor_->Pull(frame, context)) { if (proc->cleanup) { proc->cleanup.value()(); } return false; } stream_exhausted = false; if (proc->initializer) { call_initializer = true; MG_ASSERT(proc->cleanup); proc->cleanup.value()(); } } if (!cleanup_ && proc->cleanup) [[unlikely]] { cleanup_.emplace(*proc->cleanup); } // Unpluging memory without calling destruct on each object since everything was allocated with this memory // resource self_->monotonic_memory.Release(); result_ = utils::Allocator<mgp_result>(self_->memory_resource).new_object<mgp_result>(nullptr, self_->memory_resource); const auto graph_view = proc->info.is_write ? storage::View::NEW : storage::View::OLD; ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, graph_view); result_->signature = &proc->results; result_->is_transactional = storage::IsTransactional(context.db_accessor->GetStorageMode()); // Use special memory as invoking procedure is complex // TODO: This will probably need to be changed when we add support for // generator like procedures which yield a new result on new query calls. auto *memory = self_->memory_resource; auto memory_limit = EvaluateMemoryLimit(evaluator, self_->memory_limit_, self_->memory_scale_); auto graph = mgp_graph::WritableGraph(*context.db_accessor, graph_view, context); const auto transaction_id = context.db_accessor->GetTransactionId(); MG_ASSERT(transaction_id.has_value()); CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit, result_, self_->procedure_id_, transaction_id.value(), call_initializer); if (call_initializer) call_initializer = false; // Reset result_.signature to nullptr, because outside of this scope we // will no longer hold a lock on the `module`. If someone were to reload // it, the pointer would be invalid. result_signature_size_ = result_->signature->size(); result_->signature = nullptr; if (result_->error_msg) { memgraph::utils::MemoryTracker::OutOfMemoryExceptionBlocker blocker; throw QueryRuntimeException("{}: {}", self_->procedure_name_, *result_->error_msg); } result_row_it_ = result_->rows.begin(); if (!result_->is_transactional) { skip_rows_with_deleted_values(); } stream_exhausted = result_row_it_ == result_->rows.end(); } auto &values = result_row_it_->values; // Check that the row has all fields as required by the result signature. // C API guarantees that it's impossible to set fields which are not part of // the result record, but it does not gurantee that some may be missing. See // `mgp_result_record_insert`. if (values.size() != result_signature_size_) { throw QueryRuntimeException( "Procedure '{}' did not yield all fields as required by its " "signature.", self_->procedure_name_); } for (size_t i = 0; i < self_->result_fields_.size(); ++i) { std::string_view field_name(self_->result_fields_[i]); auto result_it = values.find(field_name); if (result_it == values.end()) { throw QueryRuntimeException("Procedure '{}' did not yield a record with '{}' field.", self_->procedure_name_, field_name); } frame[self_->result_symbols_[i]] = std::move(result_it->second); if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(self_->result_symbols_[i].name())) { context.frame_change_collector->ResetTrackingValue(self_->result_symbols_[i].name()); } } ++result_row_it_; if (!result_->is_transactional) { skip_rows_with_deleted_values(); } return true; } void Reset() override { self_->monotonic_memory.Release(); result_ = utils::Allocator<mgp_result>(self_->memory_resource).new_object<mgp_result>(nullptr, self_->memory_resource); if (cleanup_) { cleanup_.value()(); } } void Shutdown() override { self_->monotonic_memory.Release(); if (cleanup_) { cleanup_.value()(); } } }; class CallValidateProcedureCursor : public Cursor { const CallProcedure *self_; UniqueCursorPtr input_cursor_; public: CallValidateProcedureCursor(const CallProcedure *self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_->input_->MakeCursor(mem)) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("CallValidateProcedureCursor"); AbortCheck(context); if (!input_cursor_->Pull(frame, context)) { return false; } ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW); const auto args = self_->arguments_; MG_ASSERT(args.size() == 3U); const auto predicate = args[0]->Accept(evaluator); const bool predicate_val = predicate.ValueBool(); if (predicate_val) [[unlikely]] { const auto &message = args[1]->Accept(evaluator); const auto &message_args = args[2]->Accept(evaluator); using TString = std::remove_cvref_t<decltype(message.ValueString())>; using TElement = std::remove_cvref_t<decltype(message_args.ValueList()[0])>; utils::JStringFormatter<TString, TElement> formatter; try { const auto &msg = formatter.FormatString(message.ValueString(), message_args.ValueList()); throw QueryRuntimeException(msg); } catch (const utils::JStringFormatException &e) { throw QueryRuntimeException(e.what()); } } return true; } void Reset() override { input_cursor_->Reset(); } void Shutdown() override {} }; UniqueCursorPtr CallProcedure::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::CallProcedureOperator); CallProcedure::IncrementCounter(procedure_name_); if (void_procedure_) { // Currently we do not support Call procedures that do not return // anything. This cursor is way too specific, but it provides a workaround // to ensure GraphQL compatibility until we start supporting truly void // procedures. return MakeUniqueCursorPtr<CallValidateProcedureCursor>(mem, this, mem); } return MakeUniqueCursorPtr<CallProcedureCursor>(mem, this, mem); } LoadCsv::LoadCsv(std::shared_ptr<LogicalOperator> input, Expression *file, bool with_header, bool ignore_bad, Expression *delimiter, Expression *quote, Expression *nullif, Symbol row_var) : input_(input ? input : (std::make_shared<Once>())), file_(file), with_header_(with_header), ignore_bad_(ignore_bad), delimiter_(delimiter), quote_(quote), nullif_(nullif), row_var_(std::move(row_var)) { MG_ASSERT(file_, "Something went wrong - '{}' member file_ shouldn't be a nullptr", __func__); } ACCEPT_WITH_INPUT(LoadCsv) class LoadCsvCursor; std::vector<Symbol> LoadCsv::OutputSymbols(const SymbolTable &sym_table) const { return {row_var_}; }; std::vector<Symbol> LoadCsv::ModifiedSymbols(const SymbolTable &sym_table) const { auto symbols = input_->ModifiedSymbols(sym_table); symbols.push_back(row_var_); return symbols; }; namespace { // copy-pasted from interpreter.cpp TypedValue EvaluateOptionalExpression(Expression *expression, ExpressionEvaluator *eval) { return expression ? expression->Accept(*eval) : TypedValue(); } auto ToOptionalString(ExpressionEvaluator *evaluator, Expression *expression) -> std::optional<utils::pmr::string> { const auto evaluated_expr = EvaluateOptionalExpression(expression, evaluator); if (evaluated_expr.IsString()) { return utils::pmr::string(evaluated_expr.ValueString(), utils::NewDeleteResource()); } return std::nullopt; }; TypedValue CsvRowToTypedList(csv::Reader::Row &row, std::optional<utils::pmr::string> &nullif) { auto *mem = row.get_allocator().GetMemoryResource(); auto typed_columns = utils::pmr::vector<TypedValue>(mem); typed_columns.reserve(row.size()); for (auto &column : row) { if (!nullif.has_value() || column != nullif.value()) { typed_columns.emplace_back(std::move(column)); } else { typed_columns.emplace_back(); } } return {std::move(typed_columns), mem}; } TypedValue CsvRowToTypedMap(csv::Reader::Row &row, csv::Reader::Header header, std::optional<utils::pmr::string> &nullif) { // a valid row has the same number of elements as the header auto *mem = row.get_allocator().GetMemoryResource(); utils::pmr::map<utils::pmr::string, TypedValue> m(mem); for (auto i = 0; i < row.size(); ++i) { if (!nullif.has_value() || row[i] != nullif.value()) { m.emplace(std::move(header[i]), std::move(row[i])); } else { m.emplace(std::piecewise_construct, std::forward_as_tuple(std::move(header[i])), std::forward_as_tuple()); } } return {std::move(m), mem}; } } // namespace class LoadCsvCursor : public Cursor { const LoadCsv *self_; const UniqueCursorPtr input_cursor_; bool did_pull_; std::optional<csv::Reader> reader_{}; std::optional<utils::pmr::string> nullif_; public: LoadCsvCursor(const LoadCsv *self, utils::MemoryResource *mem) : self_(self), input_cursor_(self_->input_->MakeCursor(mem)), did_pull_{false} {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP_BY_REF(*self_); AbortCheck(context); // ToDo(the-joksim): // - this is an ungodly hack because the pipeline of creating a plan // doesn't allow evaluating the expressions contained in self_->file_, // self_->delimiter_, and self_->quote_ earlier (say, in the interpreter.cpp) // without massacring the code even worse than I did here if (UNLIKELY(!reader_)) { reader_ = MakeReader(&context.evaluation_context); nullif_ = ParseNullif(&context.evaluation_context); } if (input_cursor_->Pull(frame, context)) { if (did_pull_) { throw QueryRuntimeException( "LOAD CSV can be executed only once, please check if the cardinality of the operator before LOAD CSV " "is " "1"); } did_pull_ = true; } auto row = reader_->GetNextRow(context.evaluation_context.memory); if (!row) { return false; } if (!reader_->HasHeader()) { frame[self_->row_var_] = CsvRowToTypedList(*row, nullif_); } else { frame[self_->row_var_] = CsvRowToTypedMap(*row, csv::Reader::Header(reader_->GetHeader(), context.evaluation_context.memory), nullif_); } if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(self_->row_var_.name())) { context.frame_change_collector->ResetTrackingValue(self_->row_var_.name()); } return true; } void Reset() override { input_cursor_->Reset(); } void Shutdown() override { input_cursor_->Shutdown(); } private: csv::Reader MakeReader(EvaluationContext *eval_context) { Frame frame(0); SymbolTable symbol_table; DbAccessor *dba = nullptr; auto evaluator = ExpressionEvaluator(&frame, symbol_table, *eval_context, dba, storage::View::OLD); auto maybe_file = ToOptionalString(&evaluator, self_->file_); auto maybe_delim = ToOptionalString(&evaluator, self_->delimiter_); auto maybe_quote = ToOptionalString(&evaluator, self_->quote_); // No need to check if maybe_file is std::nullopt, as the parser makes sure // we can't get a nullptr for the 'file_' member in the LoadCsv clause. // Note that the reader has to be given its own memory resource, as it // persists between pulls, so it can't use the evalutation context memory // resource. return csv::Reader( csv::CsvSource::Create(*maybe_file), csv::Reader::Config(self_->with_header_, self_->ignore_bad_, std::move(maybe_delim), std::move(maybe_quote)), utils::NewDeleteResource()); } std::optional<utils::pmr::string> ParseNullif(EvaluationContext *eval_context) { Frame frame(0); SymbolTable symbol_table; DbAccessor *dba = nullptr; auto evaluator = ExpressionEvaluator(&frame, symbol_table, *eval_context, dba, storage::View::OLD); return ToOptionalString(&evaluator, self_->nullif_); } }; UniqueCursorPtr LoadCsv::MakeCursor(utils::MemoryResource *mem) const { return MakeUniqueCursorPtr<LoadCsvCursor>(mem, this, mem); }; class ForeachCursor : public Cursor { public: explicit ForeachCursor(const Foreach &foreach, utils::MemoryResource *mem) : loop_variable_symbol_(foreach.loop_variable_symbol_), input_(foreach.input_->MakeCursor(mem)), updates_(foreach.update_clauses_->MakeCursor(mem)), expression(foreach.expression_) {} bool Pull(Frame &frame, ExecutionContext &context) override { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP(op_name_); if (!input_->Pull(frame, context)) { return false; } ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW); TypedValue expr_result = expression->Accept(evaluator); if (expr_result.IsNull()) { return true; } if (!expr_result.IsList()) { throw QueryRuntimeException("FOREACH expression must resolve to a list, but got '{}'.", expr_result.type()); } const auto &cache_ = expr_result.ValueList(); for (const auto &index : cache_) { frame[loop_variable_symbol_] = index; while (updates_->Pull(frame, context)) { } ResetUpdates(); } return true; } void Shutdown() override { input_->Shutdown(); } void ResetUpdates() { updates_->Reset(); } void Reset() override { input_->Reset(); ResetUpdates(); } private: const Symbol loop_variable_symbol_; const UniqueCursorPtr input_; const UniqueCursorPtr updates_; Expression *expression; const char *op_name_{"Foreach"}; }; Foreach::Foreach(std::shared_ptr<LogicalOperator> input, std::shared_ptr<LogicalOperator> updates, Expression *expr, Symbol loop_variable_symbol) : input_(input ? std::move(input) : std::make_shared<Once>()), update_clauses_(std::move(updates)), expression_(expr), loop_variable_symbol_(std::move(loop_variable_symbol)) {} UniqueCursorPtr Foreach::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::ForeachOperator); return MakeUniqueCursorPtr<ForeachCursor>(mem, *this, mem); } std::vector<Symbol> Foreach::ModifiedSymbols(const SymbolTable &table) const { auto symbols = input_->ModifiedSymbols(table); symbols.emplace_back(loop_variable_symbol_); return symbols; } bool Foreach::Accept(HierarchicalLogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { input_->Accept(visitor); update_clauses_->Accept(visitor); } return visitor.PostVisit(*this); } Apply::Apply(const std::shared_ptr<LogicalOperator> input, const std::shared_ptr<LogicalOperator> subquery, bool subquery_has_return) : input_(input ? input : std::make_shared<Once>()), subquery_(subquery), subquery_has_return_(subquery_has_return) {} bool Apply::Accept(HierarchicalLogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { input_->Accept(visitor) && subquery_->Accept(visitor); } return visitor.PostVisit(*this); } UniqueCursorPtr Apply::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::ApplyOperator); return MakeUniqueCursorPtr<ApplyCursor>(mem, *this, mem); } Apply::ApplyCursor::ApplyCursor(const Apply &self, utils::MemoryResource *mem) : self_(self), input_(self.input_->MakeCursor(mem)), subquery_(self.subquery_->MakeCursor(mem)), subquery_has_return_(self.subquery_has_return_) {} std::vector<Symbol> Apply::ModifiedSymbols(const SymbolTable &table) const { // Since Apply is the Cartesian product, modified symbols are combined from // both execution branches. auto symbols = input_->ModifiedSymbols(table); auto subquery_symbols = subquery_->ModifiedSymbols(table); symbols.insert(symbols.end(), subquery_symbols.begin(), subquery_symbols.end()); return symbols; } bool Apply::ApplyCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("Apply"); while (true) { if (pull_input_ && !input_->Pull(frame, context)) { return false; }; if (subquery_->Pull(frame, context)) { // if successful, next Pull from this should not pull_input_ pull_input_ = false; return true; } // failed to pull from subquery cursor // skip that row pull_input_ = true; subquery_->Reset(); // don't skip row if no rows are returned from subquery, return input_ rows if (!subquery_has_return_) return true; } } void Apply::ApplyCursor::Shutdown() { input_->Shutdown(); subquery_->Shutdown(); } void Apply::ApplyCursor::Reset() { input_->Reset(); subquery_->Reset(); pull_input_ = true; } IndexedJoin::IndexedJoin(const std::shared_ptr<LogicalOperator> main_branch, const std::shared_ptr<LogicalOperator> sub_branch) : main_branch_(main_branch ? main_branch : std::make_shared<Once>()), sub_branch_(sub_branch) {} WITHOUT_SINGLE_INPUT(IndexedJoin); bool IndexedJoin::Accept(HierarchicalLogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { main_branch_->Accept(visitor) && sub_branch_->Accept(visitor); } return visitor.PostVisit(*this); } UniqueCursorPtr IndexedJoin::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::IndexedJoinOperator); return MakeUniqueCursorPtr<IndexedJoinCursor>(mem, *this, mem); } IndexedJoin::IndexedJoinCursor::IndexedJoinCursor(const IndexedJoin &self, utils::MemoryResource *mem) : self_(self), main_branch_(self.main_branch_->MakeCursor(mem)), sub_branch_(self.sub_branch_->MakeCursor(mem)) {} std::vector<Symbol> IndexedJoin::ModifiedSymbols(const SymbolTable &table) const { // Since Apply is the Cartesian product, modified symbols are combined from // both execution branches. auto symbols = main_branch_->ModifiedSymbols(table); auto sub_branch_symbols = sub_branch_->ModifiedSymbols(table); symbols.insert(symbols.end(), sub_branch_symbols.begin(), sub_branch_symbols.end()); return symbols; } bool IndexedJoin::IndexedJoinCursor::Pull(Frame &frame, ExecutionContext &context) { SCOPED_PROFILE_OP("IndexedJoin"); while (true) { if (pull_input_ && !main_branch_->Pull(frame, context)) { return false; }; if (sub_branch_->Pull(frame, context)) { // if successful, next Pull from this should not pull_input_ pull_input_ = false; return true; } // failed to pull from subquery cursor // skip that row pull_input_ = true; sub_branch_->Reset(); } } void IndexedJoin::IndexedJoinCursor::Shutdown() { main_branch_->Shutdown(); sub_branch_->Shutdown(); } void IndexedJoin::IndexedJoinCursor::Reset() { main_branch_->Reset(); sub_branch_->Reset(); pull_input_ = true; } std::vector<Symbol> HashJoin::ModifiedSymbols(const SymbolTable &table) const { auto symbols = left_op_->ModifiedSymbols(table); auto right = right_op_->ModifiedSymbols(table); symbols.insert(symbols.end(), right.begin(), right.end()); return symbols; } bool HashJoin::Accept(HierarchicalLogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { left_op_->Accept(visitor) && right_op_->Accept(visitor); } return visitor.PostVisit(*this); } WITHOUT_SINGLE_INPUT(HashJoin); namespace { class HashJoinCursor : public Cursor { public: HashJoinCursor(const HashJoin &self, utils::MemoryResource *mem) : self_(self), left_op_cursor_(self.left_op_->MakeCursor(mem)), right_op_cursor_(self_.right_op_->MakeCursor(mem)), hashtable_(mem), right_op_frame_(mem) { MG_ASSERT(left_op_cursor_ != nullptr, "HashJoinCursor: Missing left operator cursor."); MG_ASSERT(right_op_cursor_ != nullptr, "HashJoinCursor: Missing right operator cursor."); } bool Pull(Frame &frame, ExecutionContext &context) override { SCOPED_PROFILE_OP("HashJoin"); if (!hash_join_initialized_) { InitializeHashJoin(frame, context); hash_join_initialized_ = true; } // If left_op yielded zero results, there is no cartesian product. if (hashtable_.empty()) { return false; } auto restore_frame = [&frame, &context](const auto &symbols, const auto &restore_from) { for (const auto &symbol : symbols) { frame[symbol] = restore_from[symbol.position()]; if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) { context.frame_change_collector->ResetTrackingValue(symbol.name()); } } }; if (!common_value_found_) { // Pull from the right_op until there’s a mergeable frame while (true) { auto pulled = right_op_cursor_->Pull(frame, context); if (!pulled) return false; // Check if the join value from the pulled frame is shared with any left frames ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); auto right_value = self_.hash_join_condition_->expression2_->Accept(evaluator); if (hashtable_.contains(right_value)) { // If so, finish pulling for now and proceed to joining the pulled frame right_op_frame_.assign(frame.elems().begin(), frame.elems().end()); common_value_found_ = true; common_value = right_value; left_op_frame_it_ = hashtable_[common_value].begin(); break; } } } else { // Restore the right frame ahead of restoring the left frame restore_frame(self_.right_symbols_, right_op_frame_); } restore_frame(self_.left_symbols_, *left_op_frame_it_); left_op_frame_it_++; // When all left frames with the common value have been joined, move on to pulling and joining the next right // frame if (common_value_found_ && left_op_frame_it_ == hashtable_[common_value].end()) { common_value_found_ = false; } return true; } void Shutdown() override { left_op_cursor_->Shutdown(); right_op_cursor_->Shutdown(); } void Reset() override { left_op_cursor_->Reset(); right_op_cursor_->Reset(); hashtable_.clear(); right_op_frame_.clear(); left_op_frame_it_ = {}; hash_join_initialized_ = false; common_value_found_ = false; } private: void InitializeHashJoin(Frame &frame, ExecutionContext &context) { // Pull all left_op_ frames while (left_op_cursor_->Pull(frame, context)) { ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::OLD); auto left_value = self_.hash_join_condition_->expression1_->Accept(evaluator); if (left_value.type() != TypedValue::Type::Null) { hashtable_[left_value].emplace_back(frame.elems().begin(), frame.elems().end()); } } } const HashJoin &self_; const UniqueCursorPtr left_op_cursor_; const UniqueCursorPtr right_op_cursor_; utils::pmr::unordered_map<TypedValue, utils::pmr::vector<utils::pmr::vector<TypedValue>>, TypedValue::Hash, TypedValue::BoolEqual> hashtable_; utils::pmr::vector<TypedValue> right_op_frame_; utils::pmr::vector<utils::pmr::vector<TypedValue>>::iterator left_op_frame_it_; bool hash_join_initialized_{false}; bool common_value_found_{false}; TypedValue common_value; }; } // namespace UniqueCursorPtr HashJoin::MakeCursor(utils::MemoryResource *mem) const { memgraph::metrics::IncrementCounter(memgraph::metrics::HashJoinOperator); return MakeUniqueCursorPtr<HashJoinCursor>(mem, *this, mem); } } // namespace memgraph::query::plan