Add text search: * named property search * all-property search * regex search * aggregation over search results Text search works with: * non-parallel transactions * durability (WAL files and snapshots) * multitenancy
5795 lines
228 KiB
C++
5795 lines
228 KiB
C++
// Copyright 2024 Memgraph Ltd.
|
||
//
|
||
// Use of this software is governed by the Business Source License
|
||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||
// License, and you may not use this file except in compliance with the Business Source License.
|
||
//
|
||
// As of the Change Date specified in that file, in accordance with
|
||
// the Business Source License, use of this software will be governed
|
||
// by the Apache License, Version 2.0, included in the file
|
||
// licenses/APL.txt.
|
||
|
||
#include "query/plan/operator.hpp"
|
||
|
||
#include <algorithm>
|
||
#include <cctype>
|
||
#include <cstdint>
|
||
#include <limits>
|
||
#include <optional>
|
||
#include <queue>
|
||
#include <random>
|
||
#include <string>
|
||
#include <tuple>
|
||
#include <type_traits>
|
||
#include <unordered_map>
|
||
#include <unordered_set>
|
||
#include <utility>
|
||
|
||
#include <cppitertools/chain.hpp>
|
||
#include <cppitertools/imap.hpp>
|
||
#include "memory/query_memory_control.hpp"
|
||
#include "query/common.hpp"
|
||
#include "spdlog/spdlog.h"
|
||
|
||
#include "csv/parsing.hpp"
|
||
#include "flags/experimental.hpp"
|
||
#include "license/license.hpp"
|
||
#include "query/auth_checker.hpp"
|
||
#include "query/context.hpp"
|
||
#include "query/db_accessor.hpp"
|
||
#include "query/exceptions.hpp"
|
||
#include "query/frontend/ast/ast.hpp"
|
||
#include "query/frontend/semantic/symbol_table.hpp"
|
||
#include "query/graph.hpp"
|
||
#include "query/interpret/eval.hpp"
|
||
#include "query/path.hpp"
|
||
#include "query/plan/scoped_profile.hpp"
|
||
#include "query/procedure/cypher_types.hpp"
|
||
#include "query/procedure/mg_procedure_impl.hpp"
|
||
#include "query/procedure/module.hpp"
|
||
#include "query/typed_value.hpp"
|
||
#include "range/v3/all.hpp"
|
||
#include "storage/v2/id_types.hpp"
|
||
#include "storage/v2/property_value.hpp"
|
||
#include "storage/v2/view.hpp"
|
||
#include "utils/algorithm.hpp"
|
||
#include "utils/event_counter.hpp"
|
||
#include "utils/exceptions.hpp"
|
||
#include "utils/fnv.hpp"
|
||
#include "utils/java_string_formatter.hpp"
|
||
#include "utils/likely.hpp"
|
||
#include "utils/logging.hpp"
|
||
#include "utils/memory.hpp"
|
||
#include "utils/memory_tracker.hpp"
|
||
#include "utils/message.hpp"
|
||
#include "utils/on_scope_exit.hpp"
|
||
#include "utils/pmr/deque.hpp"
|
||
#include "utils/pmr/list.hpp"
|
||
#include "utils/pmr/unordered_map.hpp"
|
||
#include "utils/pmr/unordered_set.hpp"
|
||
#include "utils/pmr/vector.hpp"
|
||
#include "utils/readable_size.hpp"
|
||
#include "utils/string.hpp"
|
||
#include "utils/tag.hpp"
|
||
#include "utils/temporal.hpp"
|
||
#include "utils/typeinfo.hpp"
|
||
|
||
// macro for the default implementation of LogicalOperator::Accept
|
||
// that accepts the visitor and visits it's input_ operator
|
||
// NOLINTNEXTLINE
|
||
#define ACCEPT_WITH_INPUT(class_name) \
|
||
bool class_name::Accept(HierarchicalLogicalOperatorVisitor &visitor) { \
|
||
if (visitor.PreVisit(*this)) { \
|
||
if (input_ == nullptr) { \
|
||
throw QueryRuntimeException( \
|
||
"The query couldn't be executed due to the unexpected null value in " #class_name \
|
||
" operator. To learn more about operators visit https://memgr.ph/query-operators!"); \
|
||
} \
|
||
input_->Accept(visitor); \
|
||
} \
|
||
return visitor.PostVisit(*this); \
|
||
}
|
||
|
||
#define WITHOUT_SINGLE_INPUT(class_name) \
|
||
bool class_name::HasSingleInput() const { return false; } \
|
||
std::shared_ptr<LogicalOperator> class_name::input() const { \
|
||
LOG_FATAL("Operator " #class_name " has no single input!"); \
|
||
} \
|
||
void class_name::set_input(std::shared_ptr<LogicalOperator>) { \
|
||
LOG_FATAL("Operator " #class_name " has no single input!"); \
|
||
}
|
||
|
||
namespace memgraph::metrics {
|
||
extern const Event OnceOperator;
|
||
extern const Event CreateNodeOperator;
|
||
extern const Event CreateExpandOperator;
|
||
extern const Event ScanAllOperator;
|
||
extern const Event ScanAllByLabelOperator;
|
||
extern const Event ScanAllByLabelPropertyRangeOperator;
|
||
extern const Event ScanAllByLabelPropertyValueOperator;
|
||
extern const Event ScanAllByLabelPropertyOperator;
|
||
extern const Event ScanAllByIdOperator;
|
||
extern const Event ScanAllByEdgeTypeOperator;
|
||
extern const Event ExpandOperator;
|
||
extern const Event ExpandVariableOperator;
|
||
extern const Event ConstructNamedPathOperator;
|
||
extern const Event FilterOperator;
|
||
extern const Event ProduceOperator;
|
||
extern const Event DeleteOperator;
|
||
extern const Event SetPropertyOperator;
|
||
extern const Event SetPropertiesOperator;
|
||
extern const Event SetLabelsOperator;
|
||
extern const Event RemovePropertyOperator;
|
||
extern const Event RemoveLabelsOperator;
|
||
extern const Event EdgeUniquenessFilterOperator;
|
||
extern const Event AccumulateOperator;
|
||
extern const Event AggregateOperator;
|
||
extern const Event SkipOperator;
|
||
extern const Event LimitOperator;
|
||
extern const Event OrderByOperator;
|
||
extern const Event MergeOperator;
|
||
extern const Event OptionalOperator;
|
||
extern const Event UnwindOperator;
|
||
extern const Event DistinctOperator;
|
||
extern const Event UnionOperator;
|
||
extern const Event CartesianOperator;
|
||
extern const Event CallProcedureOperator;
|
||
extern const Event ForeachOperator;
|
||
extern const Event EmptyResultOperator;
|
||
extern const Event EvaluatePatternFilterOperator;
|
||
extern const Event ApplyOperator;
|
||
extern const Event IndexedJoinOperator;
|
||
extern const Event HashJoinOperator;
|
||
} // namespace memgraph::metrics
|
||
|
||
namespace memgraph::query::plan {
|
||
|
||
using OOMExceptionEnabler = utils::MemoryTracker::OutOfMemoryExceptionEnabler;
|
||
|
||
namespace {
|
||
|
||
// Custom equality function for a vector of typed values.
|
||
// Used in unordered_maps in Aggregate and Distinct operators.
|
||
struct TypedValueVectorEqual {
|
||
template <class TAllocator>
|
||
bool operator()(const std::vector<TypedValue, TAllocator> &left,
|
||
const std::vector<TypedValue, TAllocator> &right) const {
|
||
MG_ASSERT(left.size() == right.size(),
|
||
"TypedValueVector comparison should only be done over vectors "
|
||
"of the same size");
|
||
return std::equal(left.begin(), left.end(), right.begin(), TypedValue::BoolEqual{});
|
||
}
|
||
};
|
||
|
||
// Returns boolean result of evaluating filter expression. Null is treated as
|
||
// false. Other non boolean values raise a QueryRuntimeException.
|
||
bool EvaluateFilter(ExpressionEvaluator &evaluator, Expression *filter) {
|
||
TypedValue result = filter->Accept(evaluator);
|
||
// Null is treated like false.
|
||
if (result.IsNull()) return false;
|
||
if (result.type() != TypedValue::Type::Bool)
|
||
throw QueryRuntimeException("Filter expression must evaluate to bool or null, got {}.", result.type());
|
||
return result.ValueBool();
|
||
}
|
||
|
||
template <typename T>
|
||
uint64_t ComputeProfilingKey(const T *obj) {
|
||
static_assert(sizeof(T *) == sizeof(uint64_t));
|
||
return reinterpret_cast<uint64_t>(obj);
|
||
}
|
||
|
||
inline void AbortCheck(ExecutionContext const &context) {
|
||
if (auto const reason = MustAbort(context); reason != AbortReason::NO_ABORT) throw HintedAbortError(reason);
|
||
}
|
||
|
||
std::vector<storage::LabelId> EvaluateLabels(const std::vector<StorageLabelType> &labels,
|
||
ExpressionEvaluator &evaluator, DbAccessor *dba) {
|
||
std::vector<storage::LabelId> result;
|
||
result.reserve(labels.size());
|
||
for (const auto &label : labels) {
|
||
if (const auto *label_atom = std::get_if<storage::LabelId>(&label)) {
|
||
result.emplace_back(*label_atom);
|
||
} else {
|
||
result.emplace_back(dba->NameToLabel(std::get<Expression *>(label)->Accept(evaluator).ValueString()));
|
||
}
|
||
}
|
||
return result;
|
||
}
|
||
|
||
} // namespace
|
||
|
||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||
#define SCOPED_PROFILE_OP(name) ScopedProfile profile{ComputeProfilingKey(this), name, &context};
|
||
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
|
||
#define SCOPED_PROFILE_OP_BY_REF(ref) ScopedProfile profile{ComputeProfilingKey(this), ref, &context};
|
||
|
||
bool Once::OnceCursor::Pull(Frame &, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("Once");
|
||
|
||
if (!did_pull_) {
|
||
did_pull_ = true;
|
||
return true;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
UniqueCursorPtr Once::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::OnceOperator);
|
||
|
||
return MakeUniqueCursorPtr<OnceCursor>(mem);
|
||
}
|
||
|
||
WITHOUT_SINGLE_INPUT(Once);
|
||
|
||
void Once::OnceCursor::Shutdown() {}
|
||
|
||
void Once::OnceCursor::Reset() { did_pull_ = false; }
|
||
|
||
CreateNode::CreateNode(const std::shared_ptr<LogicalOperator> &input, NodeCreationInfo node_info)
|
||
: input_(input ? input : std::make_shared<Once>()), node_info_(std::move(node_info)) {}
|
||
|
||
// Creates a vertex on this GraphDb. Returns a reference to vertex placed on the
|
||
// frame.
|
||
VertexAccessor &CreateLocalVertex(const NodeCreationInfo &node_info, Frame *frame, ExecutionContext &context,
|
||
std::vector<storage::LabelId> &labels, ExpressionEvaluator &evaluator) {
|
||
auto &dba = *context.db_accessor;
|
||
auto new_node = dba.InsertVertex();
|
||
context.execution_stats[ExecutionStats::Key::CREATED_NODES] += 1;
|
||
for (const auto &label : labels) {
|
||
auto maybe_error = std::invoke([&] { return new_node.AddLabel(label); });
|
||
if (maybe_error.HasError()) {
|
||
switch (maybe_error.GetError()) {
|
||
case storage::Error::SERIALIZATION_ERROR:
|
||
throw TransactionSerializationException();
|
||
case storage::Error::DELETED_OBJECT:
|
||
throw QueryRuntimeException("Trying to set a label on a deleted node.");
|
||
case storage::Error::VERTEX_HAS_EDGES:
|
||
case storage::Error::PROPERTIES_DISABLED:
|
||
case storage::Error::NONEXISTENT_OBJECT:
|
||
throw QueryRuntimeException("Unexpected error when setting a label.");
|
||
}
|
||
}
|
||
context.execution_stats[ExecutionStats::Key::CREATED_LABELS] += 1;
|
||
}
|
||
// TODO: PropsSetChecked allocates a PropertyValue, make it use context.memory
|
||
// when we update PropertyValue with custom allocator.
|
||
std::map<storage::PropertyId, storage::PropertyValue> properties;
|
||
if (const auto *node_info_properties = std::get_if<PropertiesMapList>(&node_info.properties)) {
|
||
for (const auto &[key, value_expression] : *node_info_properties) {
|
||
properties.emplace(key, value_expression->Accept(evaluator));
|
||
}
|
||
} else {
|
||
auto property_map = evaluator.Visit(*std::get<ParameterLookup *>(node_info.properties));
|
||
for (const auto &[key, value] : property_map.ValueMap()) {
|
||
properties.emplace(dba.NameToProperty(key), value);
|
||
}
|
||
}
|
||
MultiPropsInitChecked(&new_node, properties);
|
||
|
||
if (flags::AreExperimentsEnabled(flags::Experiments::TEXT_SEARCH)) {
|
||
context.db_accessor->TextIndexAddVertex(new_node);
|
||
}
|
||
|
||
(*frame)[node_info.symbol] = new_node;
|
||
return (*frame)[node_info.symbol].ValueVertex();
|
||
}
|
||
|
||
ACCEPT_WITH_INPUT(CreateNode)
|
||
|
||
UniqueCursorPtr CreateNode::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::CreateNodeOperator);
|
||
|
||
return MakeUniqueCursorPtr<CreateNodeCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> CreateNode::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
symbols.emplace_back(node_info_.symbol);
|
||
return symbols;
|
||
}
|
||
|
||
CreateNode::CreateNodeCursor::CreateNodeCursor(const CreateNode &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)) {}
|
||
|
||
bool CreateNode::CreateNodeCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("CreateNode");
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::NEW);
|
||
|
||
if (input_cursor_->Pull(frame, context)) {
|
||
// we have to resolve the labels before we can check for permissions
|
||
auto labels = EvaluateLabels(self_.node_info_.labels, evaluator, context.db_accessor);
|
||
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!context.auth_checker->Has(labels, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)) {
|
||
throw QueryRuntimeException("Vertex not created due to not having enough permission!");
|
||
}
|
||
#endif
|
||
|
||
auto created_vertex = CreateLocalVertex(self_.node_info_, &frame, context, labels, evaluator);
|
||
if (context.trigger_context_collector) {
|
||
context.trigger_context_collector->RegisterCreatedObject(created_vertex);
|
||
}
|
||
return true;
|
||
}
|
||
|
||
return false;
|
||
}
|
||
|
||
void CreateNode::CreateNodeCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void CreateNode::CreateNodeCursor::Reset() { input_cursor_->Reset(); }
|
||
|
||
CreateExpand::CreateExpand(NodeCreationInfo node_info, EdgeCreationInfo edge_info,
|
||
const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, bool existing_node)
|
||
: node_info_(std::move(node_info)),
|
||
edge_info_(std::move(edge_info)),
|
||
input_(input ? input : std::make_shared<Once>()),
|
||
input_symbol_(std::move(input_symbol)),
|
||
existing_node_(existing_node) {}
|
||
|
||
ACCEPT_WITH_INPUT(CreateExpand)
|
||
|
||
UniqueCursorPtr CreateExpand::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::CreateExpandOperator);
|
||
|
||
return MakeUniqueCursorPtr<CreateExpandCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> CreateExpand::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
symbols.emplace_back(node_info_.symbol);
|
||
symbols.emplace_back(edge_info_.symbol);
|
||
return symbols;
|
||
}
|
||
|
||
CreateExpand::CreateExpandCursor::CreateExpandCursor(const CreateExpand &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)) {}
|
||
|
||
namespace {
|
||
|
||
EdgeAccessor CreateEdge(const EdgeCreationInfo &edge_info, DbAccessor *dba, VertexAccessor *from, VertexAccessor *to,
|
||
Frame *frame, ExpressionEvaluator *evaluator) {
|
||
auto maybe_edge = dba->InsertEdge(from, to, edge_info.edge_type);
|
||
if (maybe_edge.HasValue()) {
|
||
auto &edge = *maybe_edge;
|
||
std::map<storage::PropertyId, storage::PropertyValue> properties;
|
||
if (const auto *edge_info_properties = std::get_if<PropertiesMapList>(&edge_info.properties)) {
|
||
for (const auto &[key, value_expression] : *edge_info_properties) {
|
||
properties.emplace(key, value_expression->Accept(*evaluator));
|
||
}
|
||
} else {
|
||
auto property_map = evaluator->Visit(*std::get<ParameterLookup *>(edge_info.properties));
|
||
for (const auto &[key, value] : property_map.ValueMap()) {
|
||
properties.emplace(dba->NameToProperty(key), value);
|
||
}
|
||
}
|
||
if (!properties.empty()) MultiPropsInitChecked(&edge, properties);
|
||
|
||
(*frame)[edge_info.symbol] = edge;
|
||
} else {
|
||
switch (maybe_edge.GetError()) {
|
||
case storage::Error::SERIALIZATION_ERROR:
|
||
throw TransactionSerializationException();
|
||
case storage::Error::DELETED_OBJECT:
|
||
throw QueryRuntimeException("Trying to create an edge on a deleted node.");
|
||
case storage::Error::VERTEX_HAS_EDGES:
|
||
case storage::Error::PROPERTIES_DISABLED:
|
||
case storage::Error::NONEXISTENT_OBJECT:
|
||
throw QueryRuntimeException("Unexpected error when creating an edge.");
|
||
}
|
||
}
|
||
|
||
return *maybe_edge;
|
||
}
|
||
|
||
} // namespace
|
||
|
||
bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(self_);
|
||
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::NEW);
|
||
auto labels = EvaluateLabels(self_.node_info_.labels, evaluator, context.db_accessor);
|
||
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast()) {
|
||
const auto fine_grained_permission = self_.existing_node_
|
||
? memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE
|
||
|
||
: memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE;
|
||
|
||
if (context.auth_checker &&
|
||
!(context.auth_checker->Has(self_.edge_info_.edge_type,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE) &&
|
||
context.auth_checker->Has(labels, fine_grained_permission))) {
|
||
throw QueryRuntimeException("Edge not created due to not having enough permission!");
|
||
}
|
||
}
|
||
#endif
|
||
// get the origin vertex
|
||
TypedValue &vertex_value = frame[self_.input_symbol_];
|
||
ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex);
|
||
auto &v1 = vertex_value.ValueVertex();
|
||
|
||
// get the destination vertex (possibly an existing node)
|
||
auto &v2 = OtherVertex(frame, context, labels, evaluator);
|
||
|
||
// create an edge between the two nodes
|
||
auto *dba = context.db_accessor;
|
||
|
||
auto created_edge = [&] {
|
||
switch (self_.edge_info_.direction) {
|
||
case EdgeAtom::Direction::IN:
|
||
return CreateEdge(self_.edge_info_, dba, &v2, &v1, &frame, &evaluator);
|
||
case EdgeAtom::Direction::OUT:
|
||
// in the case of an undirected CreateExpand we choose an arbitrary
|
||
// direction. this is used in the MERGE clause
|
||
// it is not allowed in the CREATE clause, and the semantic
|
||
// checker needs to ensure it doesn't reach this point
|
||
case EdgeAtom::Direction::BOTH:
|
||
return CreateEdge(self_.edge_info_, dba, &v1, &v2, &frame, &evaluator);
|
||
}
|
||
}();
|
||
|
||
context.execution_stats[ExecutionStats::Key::CREATED_EDGES] += 1;
|
||
if (context.trigger_context_collector) {
|
||
context.trigger_context_collector->RegisterCreatedObject(created_edge);
|
||
}
|
||
|
||
return true;
|
||
}
|
||
|
||
void CreateExpand::CreateExpandCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void CreateExpand::CreateExpandCursor::Reset() { input_cursor_->Reset(); }
|
||
|
||
VertexAccessor &CreateExpand::CreateExpandCursor::OtherVertex(Frame &frame, ExecutionContext &context,
|
||
std::vector<storage::LabelId> &labels,
|
||
ExpressionEvaluator &evaluator) {
|
||
if (self_.existing_node_) {
|
||
TypedValue &dest_node_value = frame[self_.node_info_.symbol];
|
||
ExpectType(self_.node_info_.symbol, dest_node_value, TypedValue::Type::Vertex);
|
||
return dest_node_value.ValueVertex();
|
||
} else {
|
||
auto &created_vertex = CreateLocalVertex(self_.node_info_, &frame, context, labels, evaluator);
|
||
if (context.trigger_context_collector) {
|
||
context.trigger_context_collector->RegisterCreatedObject(created_vertex);
|
||
}
|
||
return created_vertex;
|
||
}
|
||
}
|
||
|
||
template <class TVerticesFun>
|
||
class ScanAllCursor : public Cursor {
|
||
public:
|
||
explicit ScanAllCursor(const ScanAll &self, Symbol output_symbol, UniqueCursorPtr input_cursor, storage::View view,
|
||
TVerticesFun get_vertices, const char *op_name)
|
||
: self_(self),
|
||
output_symbol_(std::move(output_symbol)),
|
||
input_cursor_(std::move(input_cursor)),
|
||
view_(view),
|
||
get_vertices_(std::move(get_vertices)),
|
||
op_name_(op_name) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(self_);
|
||
|
||
AbortCheck(context);
|
||
|
||
while (!vertices_ || vertices_it_.value() == vertices_end_it_.value()) {
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
// We need a getter function, because in case of exhausting a lazy
|
||
// iterable, we cannot simply reset it by calling begin().
|
||
auto next_vertices = get_vertices_(frame, context);
|
||
if (!next_vertices) continue;
|
||
// Since vertices iterator isn't nothrow_move_assignable, we have to use
|
||
// the roundabout assignment + emplace, instead of simple:
|
||
// vertices _ = get_vertices_(frame, context);
|
||
vertices_.emplace(std::move(next_vertices.value()));
|
||
vertices_it_.emplace(vertices_.value().begin());
|
||
vertices_end_it_.emplace(vertices_.value().end());
|
||
}
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !FindNextVertex(context)) {
|
||
return false;
|
||
}
|
||
#endif
|
||
|
||
frame[output_symbol_] = *vertices_it_.value();
|
||
++vertices_it_.value();
|
||
return true;
|
||
}
|
||
|
||
#ifdef MG_ENTERPRISE
|
||
bool FindNextVertex(const ExecutionContext &context) {
|
||
while (vertices_it_.value() != vertices_end_it_.value()) {
|
||
if (context.auth_checker->Has(*vertices_it_.value(), view_,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ)) {
|
||
return true;
|
||
}
|
||
++vertices_it_.value();
|
||
}
|
||
return false;
|
||
}
|
||
#endif
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override {
|
||
input_cursor_->Reset();
|
||
vertices_ = std::nullopt;
|
||
vertices_it_ = std::nullopt;
|
||
vertices_end_it_ = std::nullopt;
|
||
}
|
||
|
||
private:
|
||
const ScanAll &self_;
|
||
const Symbol output_symbol_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
storage::View view_;
|
||
TVerticesFun get_vertices_;
|
||
std::optional<typename std::result_of<TVerticesFun(Frame &, ExecutionContext &)>::type::value_type> vertices_;
|
||
std::optional<decltype(vertices_.value().begin())> vertices_it_;
|
||
std::optional<decltype(vertices_.value().end())> vertices_end_it_;
|
||
const char *op_name_;
|
||
};
|
||
|
||
template <typename TEdgesFun>
|
||
class ScanAllByEdgeTypeCursor : public Cursor {
|
||
public:
|
||
explicit ScanAllByEdgeTypeCursor(const ScanAllByEdgeType &self, Symbol output_symbol, UniqueCursorPtr input_cursor,
|
||
storage::View view, TEdgesFun get_edges, const char *op_name)
|
||
: self_(self),
|
||
output_symbol_(std::move(output_symbol)),
|
||
input_cursor_(std::move(input_cursor)),
|
||
view_(view),
|
||
get_edges_(std::move(get_edges)),
|
||
op_name_(op_name) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(self_);
|
||
|
||
AbortCheck(context);
|
||
|
||
while (!vertices_ || vertices_it_.value() == vertices_end_it_.value()) {
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
auto next_vertices = get_edges_(frame, context);
|
||
if (!next_vertices) continue;
|
||
|
||
vertices_.emplace(std::move(next_vertices.value()));
|
||
vertices_it_.emplace(vertices_.value().begin());
|
||
vertices_end_it_.emplace(vertices_.value().end());
|
||
}
|
||
|
||
frame[output_symbol_] = *vertices_it_.value();
|
||
++vertices_it_.value();
|
||
return true;
|
||
}
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override {
|
||
input_cursor_->Reset();
|
||
vertices_ = std::nullopt;
|
||
vertices_it_ = std::nullopt;
|
||
vertices_end_it_ = std::nullopt;
|
||
}
|
||
|
||
private:
|
||
const ScanAllByEdgeType &self_;
|
||
const Symbol output_symbol_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
storage::View view_;
|
||
TEdgesFun get_edges_;
|
||
std::optional<typename std::result_of<TEdgesFun(Frame &, ExecutionContext &)>::type::value_type> vertices_;
|
||
std::optional<decltype(vertices_.value().begin())> vertices_it_;
|
||
std::optional<decltype(vertices_.value().end())> vertices_end_it_;
|
||
const char *op_name_;
|
||
};
|
||
|
||
ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::View view)
|
||
: input_(input ? input : std::make_shared<Once>()), output_symbol_(std::move(output_symbol)), view_(view) {}
|
||
|
||
ACCEPT_WITH_INPUT(ScanAll)
|
||
|
||
UniqueCursorPtr ScanAll::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllOperator);
|
||
|
||
auto vertices = [this](Frame &, ExecutionContext &context) {
|
||
auto *db = context.db_accessor;
|
||
return std::make_optional(db->Vertices(view_));
|
||
};
|
||
return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, *this, output_symbol_, input_->MakeCursor(mem),
|
||
view_, std::move(vertices), "ScanAll");
|
||
}
|
||
|
||
std::vector<Symbol> ScanAll::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
symbols.emplace_back(output_symbol_);
|
||
return symbols;
|
||
}
|
||
|
||
ScanAllByLabel::ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol,
|
||
storage::LabelId label, storage::View view)
|
||
: ScanAll(input, output_symbol, view), label_(label) {}
|
||
|
||
ACCEPT_WITH_INPUT(ScanAllByLabel)
|
||
|
||
UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByLabelOperator);
|
||
|
||
auto vertices = [this](Frame &, ExecutionContext &context) {
|
||
auto *db = context.db_accessor;
|
||
return std::make_optional(db->Vertices(view_, label_));
|
||
};
|
||
return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, *this, output_symbol_, input_->MakeCursor(mem),
|
||
view_, std::move(vertices), "ScanAllByLabel");
|
||
}
|
||
|
||
ScanAllByEdgeType::ScanAllByEdgeType(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol,
|
||
storage::EdgeTypeId edge_type, storage::View view)
|
||
: input_(input ? input : std::make_shared<Once>()),
|
||
output_symbol_(std::move(output_symbol)),
|
||
view_(view),
|
||
edge_type_(edge_type) {}
|
||
|
||
ACCEPT_WITH_INPUT(ScanAllByEdgeType)
|
||
|
||
UniqueCursorPtr ScanAllByEdgeType::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByEdgeTypeOperator);
|
||
|
||
auto edges = [this](Frame &, ExecutionContext &context) {
|
||
auto *db = context.db_accessor;
|
||
return std::make_optional(db->Edges(view_, edge_type_));
|
||
};
|
||
|
||
return MakeUniqueCursorPtr<ScanAllByEdgeTypeCursor<decltype(edges)>>(
|
||
mem, *this, output_symbol_, input_->MakeCursor(mem), view_, std::move(edges), "ScanAllByEdgeType");
|
||
}
|
||
|
||
std::vector<Symbol> ScanAllByEdgeType::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
symbols.emplace_back(output_symbol_);
|
||
return symbols;
|
||
}
|
||
|
||
// TODO(buda): Implement ScanAllByLabelProperty operator to iterate over
|
||
// vertices that have the label and some value for the given property.
|
||
|
||
ScanAllByLabelPropertyRange::ScanAllByLabelPropertyRange(const std::shared_ptr<LogicalOperator> &input,
|
||
Symbol output_symbol, storage::LabelId label,
|
||
storage::PropertyId property, std::string property_name,
|
||
std::optional<Bound> lower_bound,
|
||
std::optional<Bound> upper_bound, storage::View view)
|
||
: ScanAll(input, output_symbol, view),
|
||
label_(label),
|
||
property_(property),
|
||
property_name_(std::move(property_name)),
|
||
lower_bound_(lower_bound),
|
||
upper_bound_(upper_bound) {
|
||
MG_ASSERT(lower_bound_ || upper_bound_, "Only one bound can be left out");
|
||
}
|
||
|
||
ACCEPT_WITH_INPUT(ScanAllByLabelPropertyRange)
|
||
|
||
UniqueCursorPtr ScanAllByLabelPropertyRange::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByLabelPropertyRangeOperator);
|
||
|
||
auto vertices = [this](Frame &frame, ExecutionContext &context)
|
||
-> std::optional<decltype(context.db_accessor->Vertices(view_, label_, property_, std::nullopt, std::nullopt))> {
|
||
auto *db = context.db_accessor;
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_);
|
||
auto convert = [&evaluator](const auto &bound) -> std::optional<utils::Bound<storage::PropertyValue>> {
|
||
if (!bound) return std::nullopt;
|
||
const auto &value = bound->value()->Accept(evaluator);
|
||
try {
|
||
const auto &property_value = storage::PropertyValue(value);
|
||
switch (property_value.type()) {
|
||
case storage::PropertyValue::Type::Bool:
|
||
case storage::PropertyValue::Type::List:
|
||
case storage::PropertyValue::Type::Map:
|
||
// Prevent indexed lookup with something that would fail if we did
|
||
// the original filter with `operator<`. Note, for some reason,
|
||
// Cypher does not support comparing boolean values.
|
||
throw QueryRuntimeException("Invalid type {} for '<'.", value.type());
|
||
case storage::PropertyValue::Type::Null:
|
||
case storage::PropertyValue::Type::Int:
|
||
case storage::PropertyValue::Type::Double:
|
||
case storage::PropertyValue::Type::String:
|
||
case storage::PropertyValue::Type::TemporalData:
|
||
// These are all fine, there's also Point, Date and Time data types
|
||
// which were added to Cypher, but we don't have support for those
|
||
// yet.
|
||
return std::make_optional(utils::Bound<storage::PropertyValue>(property_value, bound->type()));
|
||
}
|
||
} catch (const TypedValueException &) {
|
||
throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type());
|
||
}
|
||
};
|
||
auto maybe_lower = convert(lower_bound_);
|
||
auto maybe_upper = convert(upper_bound_);
|
||
// If any bound is null, then the comparison would result in nulls. This
|
||
// is treated as not satisfying the filter, so return no vertices.
|
||
if (maybe_lower && maybe_lower->value().IsNull()) return std::nullopt;
|
||
if (maybe_upper && maybe_upper->value().IsNull()) return std::nullopt;
|
||
return std::make_optional(db->Vertices(view_, label_, property_, maybe_lower, maybe_upper));
|
||
};
|
||
return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(
|
||
mem, *this, output_symbol_, input_->MakeCursor(mem), view_, std::move(vertices), "ScanAllByLabelPropertyRange");
|
||
}
|
||
|
||
ScanAllByLabelPropertyValue::ScanAllByLabelPropertyValue(const std::shared_ptr<LogicalOperator> &input,
|
||
Symbol output_symbol, storage::LabelId label,
|
||
storage::PropertyId property, std::string property_name,
|
||
Expression *expression, storage::View view)
|
||
: ScanAll(input, output_symbol, view),
|
||
label_(label),
|
||
property_(property),
|
||
property_name_(std::move(property_name)),
|
||
expression_(expression) {
|
||
DMG_ASSERT(expression, "Expression is not optional.");
|
||
}
|
||
|
||
ACCEPT_WITH_INPUT(ScanAllByLabelPropertyValue)
|
||
|
||
UniqueCursorPtr ScanAllByLabelPropertyValue::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByLabelPropertyValueOperator);
|
||
|
||
auto vertices = [this](Frame &frame, ExecutionContext &context)
|
||
-> std::optional<decltype(context.db_accessor->Vertices(view_, label_, property_, storage::PropertyValue()))> {
|
||
auto *db = context.db_accessor;
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_);
|
||
auto value = expression_->Accept(evaluator);
|
||
if (value.IsNull()) return std::nullopt;
|
||
if (!value.IsPropertyValue()) {
|
||
throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type());
|
||
}
|
||
return std::make_optional(db->Vertices(view_, label_, property_, storage::PropertyValue(value)));
|
||
};
|
||
return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(
|
||
mem, *this, output_symbol_, input_->MakeCursor(mem), view_, std::move(vertices), "ScanAllByLabelPropertyValue");
|
||
}
|
||
|
||
ScanAllByLabelProperty::ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol,
|
||
storage::LabelId label, storage::PropertyId property,
|
||
std::string property_name, storage::View view)
|
||
: ScanAll(input, output_symbol, view),
|
||
label_(label),
|
||
property_(property),
|
||
property_name_(std::move(property_name)) {}
|
||
|
||
ACCEPT_WITH_INPUT(ScanAllByLabelProperty)
|
||
|
||
UniqueCursorPtr ScanAllByLabelProperty::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByLabelPropertyOperator);
|
||
|
||
auto vertices = [this](Frame &frame, ExecutionContext &context) {
|
||
auto *db = context.db_accessor;
|
||
return std::make_optional(db->Vertices(view_, label_, property_));
|
||
};
|
||
return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, *this, output_symbol_, input_->MakeCursor(mem),
|
||
view_, std::move(vertices), "ScanAllByLabelProperty");
|
||
}
|
||
|
||
ScanAllById::ScanAllById(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, Expression *expression,
|
||
storage::View view)
|
||
: ScanAll(input, output_symbol, view), expression_(expression) {
|
||
MG_ASSERT(expression);
|
||
}
|
||
|
||
ACCEPT_WITH_INPUT(ScanAllById)
|
||
|
||
UniqueCursorPtr ScanAllById::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ScanAllByIdOperator);
|
||
|
||
auto vertices = [this](Frame &frame, ExecutionContext &context) -> std::optional<std::vector<VertexAccessor>> {
|
||
auto *db = context.db_accessor;
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_);
|
||
auto value = expression_->Accept(evaluator);
|
||
if (!value.IsNumeric()) return std::nullopt;
|
||
int64_t id = value.IsInt() ? value.ValueInt() : value.ValueDouble();
|
||
if (value.IsDouble() && id != value.ValueDouble()) return std::nullopt;
|
||
auto maybe_vertex = db->FindVertex(storage::Gid::FromInt(id), view_);
|
||
if (!maybe_vertex) return std::nullopt;
|
||
return std::vector<VertexAccessor>{*maybe_vertex};
|
||
};
|
||
return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, *this, output_symbol_, input_->MakeCursor(mem),
|
||
view_, std::move(vertices), "ScanAllById");
|
||
}
|
||
|
||
namespace {
|
||
bool CheckExistingNode(const VertexAccessor &new_node, const Symbol &existing_node_sym, Frame &frame) {
|
||
const TypedValue &existing_node = frame[existing_node_sym];
|
||
if (existing_node.IsNull()) return false;
|
||
ExpectType(existing_node_sym, existing_node, TypedValue::Type::Vertex);
|
||
return existing_node.ValueVertex() == new_node;
|
||
}
|
||
|
||
template <class TEdgesResult>
|
||
auto UnwrapEdgesResult(storage::Result<TEdgesResult> &&result) {
|
||
if (result.HasError()) {
|
||
switch (result.GetError()) {
|
||
case storage::Error::DELETED_OBJECT:
|
||
throw QueryRuntimeException("Trying to get relationships of a deleted node.");
|
||
case storage::Error::NONEXISTENT_OBJECT:
|
||
throw query::QueryRuntimeException("Trying to get relationships from a node that doesn't exist.");
|
||
case storage::Error::VERTEX_HAS_EDGES:
|
||
case storage::Error::SERIALIZATION_ERROR:
|
||
case storage::Error::PROPERTIES_DISABLED:
|
||
throw QueryRuntimeException("Unexpected error when accessing relationships.");
|
||
}
|
||
}
|
||
return std::move(*result);
|
||
}
|
||
|
||
} // namespace
|
||
|
||
Expand::Expand(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Symbol node_symbol,
|
||
Symbol edge_symbol, EdgeAtom::Direction direction, const std::vector<storage::EdgeTypeId> &edge_types,
|
||
bool existing_node, storage::View view)
|
||
: input_(input ? input : std::make_shared<Once>()),
|
||
input_symbol_(std::move(input_symbol)),
|
||
common_{node_symbol, edge_symbol, direction, edge_types, existing_node},
|
||
view_(view) {}
|
||
|
||
ACCEPT_WITH_INPUT(Expand)
|
||
|
||
UniqueCursorPtr Expand::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ExpandOperator);
|
||
|
||
return MakeUniqueCursorPtr<ExpandCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> Expand::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
symbols.emplace_back(common_.node_symbol);
|
||
symbols.emplace_back(common_.edge_symbol);
|
||
return symbols;
|
||
}
|
||
|
||
Expand::ExpandCursor::ExpandCursor(const Expand &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)) {}
|
||
|
||
Expand::ExpandCursor::ExpandCursor(const Expand &self, int64_t input_degree, int64_t existing_node_degree,
|
||
utils::MemoryResource *mem)
|
||
: self_(self),
|
||
input_cursor_(self.input_->MakeCursor(mem)),
|
||
prev_input_degree_(input_degree),
|
||
prev_existing_degree_(existing_node_degree) {}
|
||
|
||
bool Expand::ExpandCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(self_);
|
||
|
||
// A helper function for expanding a node from an edge.
|
||
auto pull_node = [this, &frame]<EdgeAtom::Direction direction>(const EdgeAccessor &new_edge,
|
||
utils::tag_value<direction>) {
|
||
if (self_.common_.existing_node) return;
|
||
if constexpr (direction == EdgeAtom::Direction::IN) {
|
||
frame[self_.common_.node_symbol] = new_edge.From();
|
||
} else if constexpr (direction == EdgeAtom::Direction::OUT) {
|
||
frame[self_.common_.node_symbol] = new_edge.To();
|
||
} else {
|
||
LOG_FATAL("Must indicate exact expansion direction here");
|
||
}
|
||
};
|
||
|
||
while (true) {
|
||
AbortCheck(context);
|
||
// attempt to get a value from the incoming edges
|
||
if (in_edges_ && *in_edges_it_ != in_edges_->end()) {
|
||
auto edge = *(*in_edges_it_)++;
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) &&
|
||
context.auth_checker->Has(edge.From(), self_.view_,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) {
|
||
continue;
|
||
}
|
||
#endif
|
||
|
||
frame[self_.common_.edge_symbol] = edge;
|
||
pull_node(edge, utils::tag_v<EdgeAtom::Direction::IN>);
|
||
return true;
|
||
}
|
||
|
||
// attempt to get a value from the outgoing edges
|
||
if (out_edges_ && *out_edges_it_ != out_edges_->end()) {
|
||
auto edge = *(*out_edges_it_)++;
|
||
// when expanding in EdgeAtom::Direction::BOTH directions
|
||
// we should do only one expansion for cycles, and it was
|
||
// already done in the block above
|
||
if (self_.common_.direction == EdgeAtom::Direction::BOTH && edge.IsCycle()) continue;
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) &&
|
||
context.auth_checker->Has(edge.To(), self_.view_,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) {
|
||
continue;
|
||
}
|
||
#endif
|
||
frame[self_.common_.edge_symbol] = edge;
|
||
pull_node(edge, utils::tag_v<EdgeAtom::Direction::OUT>);
|
||
return true;
|
||
}
|
||
|
||
// If we are here, either the edges have not been initialized,
|
||
// or they have been exhausted. Attempt to initialize the edges.
|
||
if (!InitEdges(frame, context)) return false;
|
||
|
||
// we have re-initialized the edges, continue with the loop
|
||
}
|
||
}
|
||
|
||
void Expand::ExpandCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void Expand::ExpandCursor::Reset() {
|
||
input_cursor_->Reset();
|
||
in_edges_ = std::nullopt;
|
||
in_edges_it_ = std::nullopt;
|
||
out_edges_ = std::nullopt;
|
||
out_edges_it_ = std::nullopt;
|
||
}
|
||
|
||
ExpansionInfo Expand::ExpandCursor::GetExpansionInfo(Frame &frame) {
|
||
TypedValue &vertex_value = frame[self_.input_symbol_];
|
||
|
||
if (vertex_value.IsNull()) {
|
||
return ExpansionInfo{};
|
||
}
|
||
|
||
ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex);
|
||
auto &vertex = vertex_value.ValueVertex();
|
||
|
||
auto direction = self_.common_.direction;
|
||
if (!self_.common_.existing_node) {
|
||
return ExpansionInfo{.input_node = vertex, .direction = direction};
|
||
}
|
||
|
||
TypedValue &existing_node = frame[self_.common_.node_symbol];
|
||
|
||
if (existing_node.IsNull()) {
|
||
return ExpansionInfo{.input_node = vertex, .direction = direction};
|
||
}
|
||
|
||
ExpectType(self_.common_.node_symbol, existing_node, TypedValue::Type::Vertex);
|
||
|
||
auto &existing_vertex = existing_node.ValueVertex();
|
||
|
||
// -1 and -1 -> normal expansion
|
||
// -1 and expanded -> can't happen
|
||
// expanded and -1 -> reverse
|
||
// expanded and expanded -> see if can reverse
|
||
if ((prev_input_degree_ == -1 && prev_existing_degree_ == -1) || prev_input_degree_ < prev_existing_degree_) {
|
||
return ExpansionInfo{.input_node = vertex, .direction = direction, .existing_node = existing_vertex};
|
||
}
|
||
|
||
auto new_direction = direction;
|
||
switch (new_direction) {
|
||
case EdgeAtom::Direction::IN:
|
||
new_direction = EdgeAtom::Direction::OUT;
|
||
break;
|
||
case EdgeAtom::Direction::OUT:
|
||
new_direction = EdgeAtom::Direction::IN;
|
||
break;
|
||
default:
|
||
new_direction = EdgeAtom::Direction::BOTH;
|
||
break;
|
||
}
|
||
|
||
return ExpansionInfo{
|
||
.input_node = existing_vertex, .direction = new_direction, .existing_node = vertex, .reversed = true};
|
||
}
|
||
|
||
bool Expand::ExpandCursor::InitEdges(Frame &frame, ExecutionContext &context) {
|
||
// Input Vertex could be null if it is created by a failed optional match. In
|
||
// those cases we skip that input pull and continue with the next.
|
||
while (true) {
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
|
||
expansion_info_ = GetExpansionInfo(frame);
|
||
|
||
if (!expansion_info_.input_node) {
|
||
continue;
|
||
}
|
||
|
||
auto vertex = *expansion_info_.input_node;
|
||
auto direction = expansion_info_.direction;
|
||
|
||
int64_t num_expanded_first = -1;
|
||
if (direction == EdgeAtom::Direction::IN || direction == EdgeAtom::Direction::BOTH) {
|
||
if (self_.common_.existing_node) {
|
||
if (expansion_info_.existing_node) {
|
||
auto existing_node = *expansion_info_.existing_node;
|
||
|
||
auto edges_result = UnwrapEdgesResult(vertex.InEdges(self_.view_, self_.common_.edge_types, existing_node));
|
||
in_edges_.emplace(std::move(edges_result.edges));
|
||
num_expanded_first = edges_result.expanded_count;
|
||
}
|
||
} else {
|
||
auto edges_result = UnwrapEdgesResult(vertex.InEdges(self_.view_, self_.common_.edge_types));
|
||
in_edges_.emplace(std::move(edges_result.edges));
|
||
num_expanded_first = edges_result.expanded_count;
|
||
}
|
||
if (in_edges_) {
|
||
in_edges_it_.emplace(in_edges_->begin());
|
||
}
|
||
}
|
||
|
||
int64_t num_expanded_second = -1;
|
||
if (direction == EdgeAtom::Direction::OUT || direction == EdgeAtom::Direction::BOTH) {
|
||
if (self_.common_.existing_node) {
|
||
if (expansion_info_.existing_node) {
|
||
auto existing_node = *expansion_info_.existing_node;
|
||
auto edges_result = UnwrapEdgesResult(vertex.OutEdges(self_.view_, self_.common_.edge_types, existing_node));
|
||
out_edges_.emplace(std::move(edges_result.edges));
|
||
num_expanded_second = edges_result.expanded_count;
|
||
}
|
||
} else {
|
||
auto edges_result = UnwrapEdgesResult(vertex.OutEdges(self_.view_, self_.common_.edge_types));
|
||
out_edges_.emplace(std::move(edges_result.edges));
|
||
num_expanded_second = edges_result.expanded_count;
|
||
}
|
||
if (out_edges_) {
|
||
out_edges_it_.emplace(out_edges_->begin());
|
||
}
|
||
}
|
||
|
||
if (!expansion_info_.existing_node) {
|
||
return true;
|
||
}
|
||
|
||
num_expanded_first = num_expanded_first == -1 ? 0 : num_expanded_first;
|
||
num_expanded_second = num_expanded_second == -1 ? 0 : num_expanded_second;
|
||
int64_t total_expanded_edges = num_expanded_first + num_expanded_second;
|
||
|
||
if (!expansion_info_.reversed) {
|
||
prev_input_degree_ = total_expanded_edges;
|
||
} else {
|
||
prev_existing_degree_ = total_expanded_edges;
|
||
}
|
||
|
||
return true;
|
||
}
|
||
}
|
||
|
||
ExpandVariable::ExpandVariable(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Symbol node_symbol,
|
||
Symbol edge_symbol, EdgeAtom::Type type, EdgeAtom::Direction direction,
|
||
const std::vector<storage::EdgeTypeId> &edge_types, bool is_reverse,
|
||
Expression *lower_bound, Expression *upper_bound, bool existing_node,
|
||
ExpansionLambda filter_lambda, std::optional<ExpansionLambda> weight_lambda,
|
||
std::optional<Symbol> total_weight)
|
||
: input_(input ? input : std::make_shared<Once>()),
|
||
input_symbol_(std::move(input_symbol)),
|
||
common_{node_symbol, edge_symbol, direction, edge_types, existing_node},
|
||
type_(type),
|
||
is_reverse_(is_reverse),
|
||
lower_bound_(lower_bound),
|
||
upper_bound_(upper_bound),
|
||
filter_lambda_(std::move(filter_lambda)),
|
||
weight_lambda_(std::move(weight_lambda)),
|
||
total_weight_(std::move(total_weight)) {
|
||
DMG_ASSERT(type_ == EdgeAtom::Type::DEPTH_FIRST || type_ == EdgeAtom::Type::BREADTH_FIRST ||
|
||
type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH || type_ == EdgeAtom::Type::ALL_SHORTEST_PATHS,
|
||
"ExpandVariable can only be used with breadth first, depth first, "
|
||
"weighted shortest path or all shortest paths type");
|
||
DMG_ASSERT(!(type_ == EdgeAtom::Type::BREADTH_FIRST && is_reverse), "Breadth first expansion can't be reversed");
|
||
}
|
||
|
||
ACCEPT_WITH_INPUT(ExpandVariable)
|
||
|
||
std::vector<Symbol> ExpandVariable::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
symbols.emplace_back(common_.node_symbol);
|
||
symbols.emplace_back(common_.edge_symbol);
|
||
return symbols;
|
||
}
|
||
|
||
namespace {
|
||
|
||
/**
|
||
* Helper function that returns an iterable over
|
||
* <EdgeAtom::Direction, EdgeAccessor> pairs
|
||
* for the given params.
|
||
*
|
||
* @param vertex - The vertex to expand from.
|
||
* @param direction - Expansion direction. All directions (IN, OUT, BOTH)
|
||
* are supported.
|
||
* @param memory - Used to allocate the result.
|
||
* @return See above.
|
||
*/
|
||
auto ExpandFromVertex(const VertexAccessor &vertex, EdgeAtom::Direction direction,
|
||
const std::vector<storage::EdgeTypeId> &edge_types, utils::MemoryResource *memory,
|
||
DbAccessor *db_accessor) {
|
||
// wraps an EdgeAccessor into a pair <accessor, direction>
|
||
auto wrapper = [](EdgeAtom::Direction direction, auto &&edges) {
|
||
return iter::imap([direction](const auto &edge) { return std::make_pair(edge, direction); },
|
||
std::forward<decltype(edges)>(edges));
|
||
};
|
||
|
||
storage::View view = storage::View::OLD;
|
||
utils::pmr::vector<decltype(wrapper(direction, vertex.InEdges(view, edge_types).GetValue().edges))> chain_elements(
|
||
memory);
|
||
|
||
if (direction != EdgeAtom::Direction::OUT) {
|
||
auto edges = UnwrapEdgesResult(vertex.InEdges(view, edge_types)).edges;
|
||
if (!edges.empty()) {
|
||
chain_elements.emplace_back(wrapper(EdgeAtom::Direction::IN, std::move(edges)));
|
||
}
|
||
}
|
||
|
||
if (direction != EdgeAtom::Direction::IN) {
|
||
auto edges = UnwrapEdgesResult(vertex.OutEdges(view, edge_types)).edges;
|
||
if (!edges.empty()) {
|
||
chain_elements.emplace_back(wrapper(EdgeAtom::Direction::OUT, std::move(edges)));
|
||
}
|
||
}
|
||
|
||
// TODO: Investigate whether itertools perform heap allocation?
|
||
return iter::chain.from_iterable(std::move(chain_elements));
|
||
}
|
||
|
||
} // namespace
|
||
|
||
class ExpandVariableCursor : public Cursor {
|
||
public:
|
||
ExpandVariableCursor(const ExpandVariable &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)), edges_(mem), edges_it_(mem) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(self_);
|
||
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
while (true) {
|
||
if (Expand(frame, context)) return true;
|
||
|
||
if (PullInput(frame, context)) {
|
||
// if lower bound is zero we also yield empty paths
|
||
if (lower_bound_ == 0) {
|
||
auto &start_vertex = frame[self_.input_symbol_].ValueVertex();
|
||
if (!self_.common_.existing_node) {
|
||
frame[self_.common_.node_symbol] = start_vertex;
|
||
return true;
|
||
}
|
||
if (CheckExistingNode(start_vertex, self_.common_.node_symbol, frame)) {
|
||
return true;
|
||
}
|
||
}
|
||
// if lower bound is not zero, we just continue, the next
|
||
// loop iteration will attempt to expand and we're good
|
||
} else
|
||
return false;
|
||
// else continue with the loop, try to expand again
|
||
// because we succesfully pulled from the input
|
||
}
|
||
}
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override {
|
||
input_cursor_->Reset();
|
||
edges_.clear();
|
||
edges_it_.clear();
|
||
}
|
||
|
||
private:
|
||
const ExpandVariable &self_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
// bounds. in the cursor they are not optional but set to
|
||
// default values if missing in the ExpandVariable operator
|
||
// initialize to arbitrary values, they should only be used
|
||
// after a successful pull from the input
|
||
int64_t upper_bound_{-1};
|
||
int64_t lower_bound_{-1};
|
||
|
||
// a stack of edge iterables corresponding to the level/depth of
|
||
// the expansion currently being Pulled
|
||
using ExpandEdges =
|
||
decltype(ExpandFromVertex(std::declval<VertexAccessor>(), EdgeAtom::Direction::IN, self_.common_.edge_types,
|
||
utils::NewDeleteResource(), std::declval<DbAccessor *>()));
|
||
|
||
utils::pmr::vector<ExpandEdges> edges_;
|
||
// an iterator indicating the position in the corresponding edges_ element
|
||
utils::pmr::vector<decltype(edges_.begin()->begin())> edges_it_;
|
||
|
||
/**
|
||
* Helper function that Pulls from the input vertex and
|
||
* makes iteration over it's edges possible.
|
||
*
|
||
* @return If the Pull succeeded. If not, this VariableExpandCursor
|
||
* is exhausted.
|
||
*/
|
||
bool PullInput(Frame &frame, ExecutionContext &context) {
|
||
// Input Vertex could be null if it is created by a failed optional match.
|
||
// In those cases we skip that input pull and continue with the next.
|
||
while (true) {
|
||
AbortCheck(context);
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
TypedValue &vertex_value = frame[self_.input_symbol_];
|
||
|
||
// Null check due to possible failed optional match.
|
||
if (vertex_value.IsNull()) continue;
|
||
|
||
ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex);
|
||
auto &vertex = vertex_value.ValueVertex();
|
||
|
||
// Evaluate the upper and lower bounds.
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
auto calc_bound = [&evaluator](auto &bound) {
|
||
auto value = EvaluateInt(&evaluator, bound, "Variable expansion bound");
|
||
if (value < 0) throw QueryRuntimeException("Variable expansion bound must be a non-negative integer.");
|
||
return value;
|
||
};
|
||
|
||
lower_bound_ = self_.lower_bound_ ? calc_bound(self_.lower_bound_) : 1;
|
||
upper_bound_ = self_.upper_bound_ ? calc_bound(self_.upper_bound_) : std::numeric_limits<int64_t>::max();
|
||
|
||
if (upper_bound_ > 0) {
|
||
auto *memory = edges_.get_allocator().GetMemoryResource();
|
||
edges_.emplace_back(
|
||
ExpandFromVertex(vertex, self_.common_.direction, self_.common_.edge_types, memory, context.db_accessor));
|
||
edges_it_.emplace_back(edges_.back().begin());
|
||
}
|
||
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
// Add initial vertex of path to the accumulated path
|
||
frame[self_.filter_lambda_.accumulated_path_symbol.value()] = Path(vertex);
|
||
}
|
||
|
||
// reset the frame value to an empty edge list
|
||
if (frame[self_.common_.edge_symbol].IsList()) {
|
||
// Preserve the list capacity if possible
|
||
frame[self_.common_.edge_symbol].ValueList().clear();
|
||
} else {
|
||
auto *pull_memory = context.evaluation_context.memory;
|
||
frame[self_.common_.edge_symbol] = TypedValue::TVector(pull_memory);
|
||
}
|
||
|
||
return true;
|
||
}
|
||
}
|
||
|
||
// Helper function for appending an edge to the list on the frame.
|
||
void AppendEdge(const EdgeAccessor &new_edge, utils::pmr::vector<TypedValue> *edges_on_frame) {
|
||
// We are placing an edge on the frame. It is possible that there already
|
||
// exists an edge on the frame for this level. If so first remove it.
|
||
DMG_ASSERT(edges_.size() > 0, "Edges are empty");
|
||
if (self_.is_reverse_) {
|
||
// TODO: This is innefficient, we should look into replacing
|
||
// vector with something else for TypedValue::List.
|
||
size_t diff = edges_on_frame->size() - std::min(edges_on_frame->size(), edges_.size() - 1U);
|
||
if (diff > 0U) edges_on_frame->erase(edges_on_frame->begin(), edges_on_frame->begin() + diff);
|
||
edges_on_frame->emplace(edges_on_frame->begin(), new_edge);
|
||
} else {
|
||
edges_on_frame->resize(std::min(edges_on_frame->size(), edges_.size() - 1U));
|
||
edges_on_frame->emplace_back(new_edge);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Performs a single expansion for the current state of this
|
||
* VariableExpansionCursor.
|
||
*
|
||
* @return True if the expansion was a success and this Cursor's
|
||
* consumer can consume it. False if the expansion failed. In that
|
||
* case no more expansions are available from the current input
|
||
* vertex and another Pull from the input cursor should be performed.
|
||
*/
|
||
bool Expand(Frame &frame, ExecutionContext &context) {
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
// Some expansions might not be valid due to edge uniqueness and
|
||
// existing_node criterions, so expand in a loop until either the input
|
||
// vertex is exhausted or a valid variable-length expansion is available.
|
||
while (true) {
|
||
AbortCheck(context);
|
||
// pop from the stack while there is stuff to pop and the current
|
||
// level is exhausted
|
||
while (!edges_.empty() && edges_it_.back() == edges_.back().end()) {
|
||
edges_.pop_back();
|
||
edges_it_.pop_back();
|
||
}
|
||
|
||
// check if we exhausted everything, if so return false
|
||
if (edges_.empty()) return false;
|
||
|
||
// we use this a lot
|
||
auto &edges_on_frame = frame[self_.common_.edge_symbol].ValueList();
|
||
|
||
// it is possible that edges_on_frame does not contain as many
|
||
// elements as edges_ due to edge-uniqueness (when a whole layer
|
||
// gets exhausted but no edges are valid). for that reason only
|
||
// pop from edges_on_frame if they contain enough elements
|
||
if (self_.is_reverse_) {
|
||
auto diff = edges_on_frame.size() - std::min(edges_on_frame.size(), edges_.size());
|
||
if (diff > 0) {
|
||
edges_on_frame.erase(edges_on_frame.begin(), edges_on_frame.begin() + diff);
|
||
}
|
||
} else {
|
||
edges_on_frame.resize(std::min(edges_on_frame.size(), edges_.size()));
|
||
}
|
||
|
||
// if we are here, we have a valid stack,
|
||
// get the edge, increase the relevant iterator
|
||
auto current_edge = *edges_it_.back()++;
|
||
// Check edge-uniqueness.
|
||
bool found_existing =
|
||
std::any_of(edges_on_frame.begin(), edges_on_frame.end(),
|
||
[¤t_edge](const TypedValue &edge) { return current_edge.first == edge.ValueEdge(); });
|
||
if (found_existing) continue;
|
||
|
||
VertexAccessor current_vertex =
|
||
current_edge.second == EdgeAtom::Direction::IN ? current_edge.first.From() : current_edge.first.To();
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(current_edge.first, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) &&
|
||
context.auth_checker->Has(current_vertex, storage::View::OLD,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) {
|
||
continue;
|
||
}
|
||
#endif
|
||
AppendEdge(current_edge.first, &edges_on_frame);
|
||
|
||
if (!self_.common_.existing_node) {
|
||
frame[self_.common_.node_symbol] = current_vertex;
|
||
}
|
||
|
||
// Skip expanding out of filtered expansion.
|
||
frame[self_.filter_lambda_.inner_edge_symbol] = current_edge.first;
|
||
frame[self_.filter_lambda_.inner_node_symbol] = current_vertex;
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(),
|
||
"Accumulated path must be path");
|
||
Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath();
|
||
// Shrink the accumulated path including current level if necessary
|
||
while (accumulated_path.size() >= edges_on_frame.size()) {
|
||
accumulated_path.Shrink();
|
||
}
|
||
accumulated_path.Expand(current_edge.first);
|
||
accumulated_path.Expand(current_vertex);
|
||
}
|
||
if (self_.filter_lambda_.expression && !EvaluateFilter(evaluator, self_.filter_lambda_.expression)) continue;
|
||
|
||
// we are doing depth-first search, so place the current
|
||
// edge's expansions onto the stack, if we should continue to expand
|
||
if (upper_bound_ > static_cast<int64_t>(edges_.size())) {
|
||
auto *memory = edges_.get_allocator().GetMemoryResource();
|
||
edges_.emplace_back(ExpandFromVertex(current_vertex, self_.common_.direction, self_.common_.edge_types, memory,
|
||
context.db_accessor));
|
||
edges_it_.emplace_back(edges_.back().begin());
|
||
}
|
||
|
||
if (self_.common_.existing_node && !CheckExistingNode(current_vertex, self_.common_.node_symbol, frame)) continue;
|
||
|
||
// We only yield true if we satisfy the lower bound.
|
||
if (static_cast<int64_t>(edges_on_frame.size()) >= lower_bound_) {
|
||
return true;
|
||
}
|
||
}
|
||
}
|
||
};
|
||
|
||
class STShortestPathCursor : public query::plan::Cursor {
|
||
public:
|
||
STShortestPathCursor(const ExpandVariable &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self_.input()->MakeCursor(mem)) {
|
||
MG_ASSERT(self_.common_.existing_node,
|
||
"s-t shortest path algorithm should only "
|
||
"be used when `existing_node` flag is "
|
||
"set!");
|
||
}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("STShortestPath");
|
||
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
while (input_cursor_->Pull(frame, context)) {
|
||
const auto &source_tv = frame[self_.input_symbol_];
|
||
const auto &sink_tv = frame[self_.common_.node_symbol];
|
||
|
||
// It is possible that source or sink vertex is Null due to optional
|
||
// matching.
|
||
if (source_tv.IsNull() || sink_tv.IsNull()) continue;
|
||
|
||
const auto &source = source_tv.ValueVertex();
|
||
const auto &sink = sink_tv.ValueVertex();
|
||
|
||
int64_t lower_bound =
|
||
self_.lower_bound_ ? EvaluateInt(&evaluator, self_.lower_bound_, "Min depth in breadth-first expansion") : 1;
|
||
int64_t upper_bound = self_.upper_bound_
|
||
? EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in breadth-first expansion")
|
||
: std::numeric_limits<int64_t>::max();
|
||
|
||
if (upper_bound < 1 || lower_bound > upper_bound) continue;
|
||
|
||
if (FindPath(*context.db_accessor, source, sink, lower_bound, upper_bound, &frame, &evaluator, context)) {
|
||
return true;
|
||
}
|
||
}
|
||
return false;
|
||
}
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override { input_cursor_->Reset(); }
|
||
|
||
private:
|
||
const ExpandVariable &self_;
|
||
UniqueCursorPtr input_cursor_;
|
||
|
||
using VertexEdgeMapT = utils::pmr::unordered_map<VertexAccessor, std::optional<EdgeAccessor>>;
|
||
|
||
void ReconstructPath(const VertexAccessor &midpoint, const VertexEdgeMapT &in_edge, const VertexEdgeMapT &out_edge,
|
||
Frame *frame, utils::MemoryResource *pull_memory) {
|
||
utils::pmr::vector<TypedValue> result(pull_memory);
|
||
auto last_vertex = midpoint;
|
||
while (true) {
|
||
const auto &last_edge = in_edge.at(last_vertex);
|
||
if (!last_edge) break;
|
||
last_vertex = last_edge->From() == last_vertex ? last_edge->To() : last_edge->From();
|
||
result.emplace_back(*last_edge);
|
||
}
|
||
std::reverse(result.begin(), result.end());
|
||
last_vertex = midpoint;
|
||
while (true) {
|
||
const auto &last_edge = out_edge.at(last_vertex);
|
||
if (!last_edge) break;
|
||
last_vertex = last_edge->From() == last_vertex ? last_edge->To() : last_edge->From();
|
||
result.emplace_back(*last_edge);
|
||
}
|
||
frame->at(self_.common_.edge_symbol) = std::move(result);
|
||
}
|
||
|
||
bool ShouldExpand(const VertexAccessor &vertex, const EdgeAccessor &edge, Frame *frame,
|
||
ExpressionEvaluator *evaluator) {
|
||
if (!self_.filter_lambda_.expression) return true;
|
||
|
||
frame->at(self_.filter_lambda_.inner_node_symbol) = vertex;
|
||
frame->at(self_.filter_lambda_.inner_edge_symbol) = edge;
|
||
|
||
TypedValue result = self_.filter_lambda_.expression->Accept(*evaluator);
|
||
if (result.IsNull()) return false;
|
||
if (result.IsBool()) return result.ValueBool();
|
||
|
||
throw QueryRuntimeException("Expansion condition must evaluate to boolean or null");
|
||
}
|
||
|
||
bool FindPath(const DbAccessor &dba, const VertexAccessor &source, const VertexAccessor &sink, int64_t lower_bound,
|
||
int64_t upper_bound, Frame *frame, ExpressionEvaluator *evaluator, const ExecutionContext &context) {
|
||
using utils::Contains;
|
||
|
||
if (source == sink) return false;
|
||
|
||
// We expand from both directions, both from the source and the sink.
|
||
// Expansions meet at the middle of the path if it exists. This should
|
||
// perform better for real-world like graphs where the expansion front
|
||
// grows exponentially, effectively reducing the exponent by half.
|
||
|
||
auto *pull_memory = evaluator->GetMemoryResource();
|
||
// Holds vertices at the current level of expansion from the source
|
||
// (sink).
|
||
utils::pmr::vector<VertexAccessor> source_frontier(pull_memory);
|
||
utils::pmr::vector<VertexAccessor> sink_frontier(pull_memory);
|
||
|
||
// Holds vertices we can expand to from `source_frontier`
|
||
// (`sink_frontier`).
|
||
utils::pmr::vector<VertexAccessor> source_next(pull_memory);
|
||
utils::pmr::vector<VertexAccessor> sink_next(pull_memory);
|
||
|
||
// Maps each vertex we visited expanding from the source (sink) to the
|
||
// edge used. Necessary for path reconstruction.
|
||
VertexEdgeMapT in_edge(pull_memory);
|
||
VertexEdgeMapT out_edge(pull_memory);
|
||
|
||
size_t current_length = 0;
|
||
|
||
source_frontier.emplace_back(source);
|
||
in_edge[source] = std::nullopt;
|
||
sink_frontier.emplace_back(sink);
|
||
out_edge[sink] = std::nullopt;
|
||
|
||
while (true) {
|
||
AbortCheck(context);
|
||
// Top-down step (expansion from the source).
|
||
++current_length;
|
||
if (current_length > upper_bound) return false;
|
||
|
||
for (const auto &vertex : source_frontier) {
|
||
if (self_.common_.direction != EdgeAtom::Direction::IN) {
|
||
auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges;
|
||
for (const auto &edge : out_edges) {
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) &&
|
||
context.auth_checker->Has(edge.To(), storage::View::OLD,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) {
|
||
continue;
|
||
}
|
||
#endif
|
||
|
||
if (ShouldExpand(edge.To(), edge, frame, evaluator) && !Contains(in_edge, edge.To())) {
|
||
in_edge.emplace(edge.To(), edge);
|
||
if (Contains(out_edge, edge.To())) {
|
||
if (current_length >= lower_bound) {
|
||
ReconstructPath(edge.To(), in_edge, out_edge, frame, pull_memory);
|
||
return true;
|
||
} else {
|
||
return false;
|
||
}
|
||
}
|
||
source_next.push_back(edge.To());
|
||
}
|
||
}
|
||
}
|
||
if (self_.common_.direction != EdgeAtom::Direction::OUT) {
|
||
auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges;
|
||
for (const auto &edge : in_edges) {
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) &&
|
||
context.auth_checker->Has(edge.From(), storage::View::OLD,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) {
|
||
continue;
|
||
}
|
||
#endif
|
||
|
||
if (ShouldExpand(edge.From(), edge, frame, evaluator) && !Contains(in_edge, edge.From())) {
|
||
in_edge.emplace(edge.From(), edge);
|
||
if (Contains(out_edge, edge.From())) {
|
||
if (current_length >= lower_bound) {
|
||
ReconstructPath(edge.From(), in_edge, out_edge, frame, pull_memory);
|
||
return true;
|
||
} else {
|
||
return false;
|
||
}
|
||
}
|
||
source_next.push_back(edge.From());
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if (source_next.empty()) return false;
|
||
source_frontier.clear();
|
||
std::swap(source_frontier, source_next);
|
||
|
||
// Bottom-up step (expansion from the sink).
|
||
++current_length;
|
||
if (current_length > upper_bound) return false;
|
||
|
||
// When expanding from the sink we have to be careful which edge
|
||
// endpoint we pass to `should_expand`, because everything is
|
||
// reversed.
|
||
for (const auto &vertex : sink_frontier) {
|
||
if (self_.common_.direction != EdgeAtom::Direction::OUT) {
|
||
auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges;
|
||
for (const auto &edge : out_edges) {
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) &&
|
||
context.auth_checker->Has(edge.To(), storage::View::OLD,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) {
|
||
continue;
|
||
}
|
||
#endif
|
||
if (ShouldExpand(vertex, edge, frame, evaluator) && !Contains(out_edge, edge.To())) {
|
||
out_edge.emplace(edge.To(), edge);
|
||
if (Contains(in_edge, edge.To())) {
|
||
if (current_length >= lower_bound) {
|
||
ReconstructPath(edge.To(), in_edge, out_edge, frame, pull_memory);
|
||
return true;
|
||
} else {
|
||
return false;
|
||
}
|
||
}
|
||
sink_next.push_back(edge.To());
|
||
}
|
||
}
|
||
}
|
||
if (self_.common_.direction != EdgeAtom::Direction::IN) {
|
||
auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges;
|
||
for (const auto &edge : in_edges) {
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) &&
|
||
context.auth_checker->Has(edge.From(), storage::View::OLD,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) {
|
||
continue;
|
||
}
|
||
#endif
|
||
if (ShouldExpand(vertex, edge, frame, evaluator) && !Contains(out_edge, edge.From())) {
|
||
out_edge.emplace(edge.From(), edge);
|
||
if (Contains(in_edge, edge.From())) {
|
||
if (current_length >= lower_bound) {
|
||
ReconstructPath(edge.From(), in_edge, out_edge, frame, pull_memory);
|
||
return true;
|
||
} else {
|
||
return false;
|
||
}
|
||
}
|
||
sink_next.push_back(edge.From());
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if (sink_next.empty()) return false;
|
||
sink_frontier.clear();
|
||
std::swap(sink_frontier, sink_next);
|
||
}
|
||
}
|
||
};
|
||
|
||
class SingleSourceShortestPathCursor : public query::plan::Cursor {
|
||
public:
|
||
SingleSourceShortestPathCursor(const ExpandVariable &self, utils::MemoryResource *mem)
|
||
: self_(self),
|
||
input_cursor_(self_.input()->MakeCursor(mem)),
|
||
processed_(mem),
|
||
to_visit_next_(mem),
|
||
to_visit_current_(mem) {
|
||
MG_ASSERT(!self_.common_.existing_node,
|
||
"Single source shortest path algorithm "
|
||
"should not be used when `existing_node` "
|
||
"flag is set, s-t shortest path algorithm "
|
||
"should be used instead!");
|
||
}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("SingleSourceShortestPath");
|
||
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
|
||
// for the given (edge, vertex) pair checks if they satisfy the
|
||
// "where" condition. if so, places them in the to_visit_ structure.
|
||
auto expand_pair = [this, &evaluator, &frame, &context](EdgeAccessor edge, VertexAccessor vertex) -> bool {
|
||
// if we already processed the given vertex it doesn't get expanded
|
||
if (processed_.find(vertex) != processed_.end()) return false;
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(vertex, storage::View::OLD,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ) &&
|
||
context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) {
|
||
return false;
|
||
}
|
||
#endif
|
||
frame[self_.filter_lambda_.inner_edge_symbol] = edge;
|
||
frame[self_.filter_lambda_.inner_node_symbol] = vertex;
|
||
std::optional<Path> curr_acc_path = std::nullopt;
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(),
|
||
"Accumulated path must have Path type");
|
||
Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath();
|
||
accumulated_path.Expand(edge);
|
||
accumulated_path.Expand(vertex);
|
||
curr_acc_path = accumulated_path;
|
||
}
|
||
|
||
if (self_.filter_lambda_.expression) {
|
||
TypedValue result = self_.filter_lambda_.expression->Accept(evaluator);
|
||
switch (result.type()) {
|
||
case TypedValue::Type::Null:
|
||
return true;
|
||
case TypedValue::Type::Bool:
|
||
if (!result.ValueBool()) return true;
|
||
break;
|
||
default:
|
||
throw QueryRuntimeException("Expansion condition must evaluate to boolean or null.");
|
||
}
|
||
}
|
||
to_visit_next_.emplace_back(edge, vertex, std::move(curr_acc_path));
|
||
processed_.emplace(vertex, edge);
|
||
return true;
|
||
};
|
||
|
||
auto restore_frame_state_after_expansion = [this, &frame](bool was_expanded) {
|
||
if (was_expanded && self_.filter_lambda_.accumulated_path_symbol) {
|
||
frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath().Shrink();
|
||
}
|
||
};
|
||
|
||
// populates the to_visit_next_ structure with expansions
|
||
// from the given vertex. skips expansions that don't satisfy
|
||
// the "where" condition.
|
||
auto expand_from_vertex = [this, &expand_pair, &restore_frame_state_after_expansion](const auto &vertex) {
|
||
if (self_.common_.direction != EdgeAtom::Direction::IN) {
|
||
auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges;
|
||
for (const auto &edge : out_edges) {
|
||
bool was_expanded = expand_pair(edge, edge.To());
|
||
restore_frame_state_after_expansion(was_expanded);
|
||
}
|
||
}
|
||
if (self_.common_.direction != EdgeAtom::Direction::OUT) {
|
||
auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges;
|
||
for (const auto &edge : in_edges) {
|
||
bool was_expanded = expand_pair(edge, edge.From());
|
||
restore_frame_state_after_expansion(was_expanded);
|
||
}
|
||
}
|
||
};
|
||
|
||
// do it all in a loop because we skip some elements
|
||
while (true) {
|
||
AbortCheck(context);
|
||
// if we have nothing to visit on the current depth, switch to next
|
||
if (to_visit_current_.empty()) to_visit_current_.swap(to_visit_next_);
|
||
|
||
// if current is still empty, it means both are empty, so pull from
|
||
// input
|
||
if (to_visit_current_.empty()) {
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
|
||
to_visit_current_.clear();
|
||
to_visit_next_.clear();
|
||
processed_.clear();
|
||
|
||
const auto &vertex_value = frame[self_.input_symbol_];
|
||
// it is possible that the vertex is Null due to optional matching
|
||
if (vertex_value.IsNull()) continue;
|
||
lower_bound_ = self_.lower_bound_
|
||
? EvaluateInt(&evaluator, self_.lower_bound_, "Min depth in breadth-first expansion")
|
||
: 1;
|
||
upper_bound_ = self_.upper_bound_
|
||
? EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in breadth-first expansion")
|
||
: std::numeric_limits<int64_t>::max();
|
||
|
||
if (upper_bound_ < 1 || lower_bound_ > upper_bound_) continue;
|
||
|
||
const auto &vertex = vertex_value.ValueVertex();
|
||
processed_.emplace(vertex, std::nullopt);
|
||
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
// Add initial vertex of path to the accumulated path
|
||
frame[self_.filter_lambda_.accumulated_path_symbol.value()] = Path(vertex);
|
||
}
|
||
|
||
expand_from_vertex(vertex);
|
||
|
||
// go back to loop start and see if we expanded anything
|
||
continue;
|
||
}
|
||
|
||
// take the next expansion from the queue
|
||
auto [curr_edge, curr_vertex, curr_acc_path] = to_visit_current_.back();
|
||
to_visit_current_.pop_back();
|
||
|
||
// create the frame value for the edges
|
||
auto *pull_memory = context.evaluation_context.memory;
|
||
utils::pmr::vector<TypedValue> edge_list(pull_memory);
|
||
edge_list.emplace_back(curr_edge);
|
||
auto last_vertex = curr_vertex;
|
||
while (true) {
|
||
const EdgeAccessor &last_edge = edge_list.back().ValueEdge();
|
||
last_vertex = last_edge.From() == last_vertex ? last_edge.To() : last_edge.From();
|
||
// origin_vertex must be in processed
|
||
const auto &previous_edge = processed_.find(last_vertex)->second;
|
||
if (!previous_edge) break;
|
||
|
||
edge_list.emplace_back(previous_edge.value());
|
||
}
|
||
|
||
// expand only if what we've just expanded is less then max depth
|
||
if (static_cast<int64_t>(edge_list.size()) < upper_bound_) {
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
MG_ASSERT(curr_acc_path.has_value(), "Expected non-null accumulated path");
|
||
frame[self_.filter_lambda_.accumulated_path_symbol.value()] = std::move(curr_acc_path.value());
|
||
}
|
||
expand_from_vertex(curr_vertex);
|
||
}
|
||
|
||
if (static_cast<int64_t>(edge_list.size()) < lower_bound_) continue;
|
||
|
||
frame[self_.common_.node_symbol] = curr_vertex;
|
||
|
||
// place edges on the frame in the correct order
|
||
std::reverse(edge_list.begin(), edge_list.end());
|
||
frame[self_.common_.edge_symbol] = std::move(edge_list);
|
||
|
||
return true;
|
||
}
|
||
}
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override {
|
||
input_cursor_->Reset();
|
||
processed_.clear();
|
||
to_visit_next_.clear();
|
||
to_visit_current_.clear();
|
||
}
|
||
|
||
private:
|
||
const ExpandVariable &self_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
|
||
// Depth bounds. Calculated on each pull from the input, the initial value
|
||
// is irrelevant.
|
||
int64_t lower_bound_{-1};
|
||
int64_t upper_bound_{-1};
|
||
|
||
// maps vertices to the edge they got expanded from. it is an optional
|
||
// edge because the root does not get expanded from anything.
|
||
// contains visited vertices as well as those scheduled to be visited.
|
||
utils::pmr::unordered_map<VertexAccessor, std::optional<EdgeAccessor>> processed_;
|
||
// edge, vertex we have yet to visit, for current and next depth and their accumulated paths
|
||
utils::pmr::vector<std::tuple<EdgeAccessor, VertexAccessor, std::optional<Path>>> to_visit_next_;
|
||
utils::pmr::vector<std::tuple<EdgeAccessor, VertexAccessor, std::optional<Path>>> to_visit_current_;
|
||
};
|
||
|
||
namespace {
|
||
|
||
void CheckWeightType(TypedValue current_weight, utils::MemoryResource *memory) {
|
||
if (current_weight.IsNull()) {
|
||
return;
|
||
}
|
||
|
||
if (!current_weight.IsNumeric() && !current_weight.IsDuration()) {
|
||
throw QueryRuntimeException("Calculated weight must be numeric or a Duration, got {}.", current_weight.type());
|
||
}
|
||
|
||
const auto is_valid_numeric = [&] {
|
||
return current_weight.IsNumeric() && (current_weight >= TypedValue(0, memory)).ValueBool();
|
||
};
|
||
|
||
const auto is_valid_duration = [&] {
|
||
return current_weight.IsDuration() && (current_weight >= TypedValue(utils::Duration(0), memory)).ValueBool();
|
||
};
|
||
|
||
if (!is_valid_numeric() && !is_valid_duration()) {
|
||
throw QueryRuntimeException("Calculated weight must be non-negative!");
|
||
}
|
||
}
|
||
|
||
void ValidateWeightTypes(const TypedValue &lhs, const TypedValue &rhs) {
|
||
if ((lhs.IsNumeric() && rhs.IsNumeric()) || (lhs.IsDuration() && rhs.IsDuration())) {
|
||
return;
|
||
}
|
||
throw QueryRuntimeException(utils::MessageWithLink(
|
||
"All weights should be of the same type, either numeric or a Duration. Please update the weight "
|
||
"expression or the filter expression.",
|
||
"https://memgr.ph/wsp"));
|
||
}
|
||
|
||
TypedValue CalculateNextWeight(const std::optional<memgraph::query::plan::ExpansionLambda> &weight_lambda,
|
||
const TypedValue &total_weight, ExpressionEvaluator evaluator) {
|
||
if (!weight_lambda) {
|
||
return {};
|
||
}
|
||
auto *memory = evaluator.GetMemoryResource();
|
||
TypedValue current_weight = weight_lambda->expression->Accept(evaluator);
|
||
CheckWeightType(current_weight, memory);
|
||
|
||
if (total_weight.IsNull()) {
|
||
return current_weight;
|
||
}
|
||
|
||
ValidateWeightTypes(current_weight, total_weight);
|
||
|
||
return TypedValue(current_weight, memory) + total_weight;
|
||
}
|
||
|
||
} // namespace
|
||
|
||
class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
|
||
public:
|
||
ExpandWeightedShortestPathCursor(const ExpandVariable &self, utils::MemoryResource *mem)
|
||
: self_(self),
|
||
input_cursor_(self_.input_->MakeCursor(mem)),
|
||
total_cost_(mem),
|
||
previous_(mem),
|
||
yielded_vertices_(mem),
|
||
pq_(mem) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("ExpandWeightedShortestPath");
|
||
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
|
||
auto create_state = [this](const VertexAccessor &vertex, int64_t depth) {
|
||
return std::make_pair(vertex, upper_bound_set_ ? depth : 0);
|
||
};
|
||
|
||
// For the given (edge, vertex, weight, depth) tuple checks if they
|
||
// satisfy the "where" condition. if so, places them in the priority
|
||
// queue.
|
||
auto expand_pair = [this, &evaluator, &frame, &create_state](const EdgeAccessor &edge, const VertexAccessor &vertex,
|
||
const TypedValue &total_weight, int64_t depth) {
|
||
frame[self_.weight_lambda_->inner_edge_symbol] = edge;
|
||
frame[self_.weight_lambda_->inner_node_symbol] = vertex;
|
||
TypedValue next_weight = CalculateNextWeight(self_.weight_lambda_, total_weight, evaluator);
|
||
|
||
std::optional<Path> curr_acc_path = std::nullopt;
|
||
if (self_.filter_lambda_.expression) {
|
||
frame[self_.filter_lambda_.inner_edge_symbol] = edge;
|
||
frame[self_.filter_lambda_.inner_node_symbol] = vertex;
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(),
|
||
"Accumulated path must be path");
|
||
Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath();
|
||
accumulated_path.Expand(edge);
|
||
accumulated_path.Expand(vertex);
|
||
curr_acc_path = accumulated_path;
|
||
|
||
if (self_.filter_lambda_.accumulated_weight_symbol) {
|
||
frame[self_.filter_lambda_.accumulated_weight_symbol.value()] = next_weight;
|
||
}
|
||
}
|
||
|
||
if (!EvaluateFilter(evaluator, self_.filter_lambda_.expression)) return;
|
||
}
|
||
|
||
auto next_state = create_state(vertex, depth);
|
||
|
||
auto found_it = total_cost_.find(next_state);
|
||
if (found_it != total_cost_.end() && (found_it->second.IsNull() || (found_it->second <= next_weight).ValueBool()))
|
||
return;
|
||
|
||
pq_.emplace(next_weight, depth + 1, vertex, edge, curr_acc_path);
|
||
};
|
||
|
||
auto restore_frame_state_after_expansion = [this, &frame]() {
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath().Shrink();
|
||
}
|
||
};
|
||
|
||
// Populates the priority queue structure with expansions
|
||
// from the given vertex. skips expansions that don't satisfy
|
||
// the "where" condition.
|
||
auto expand_from_vertex = [this, &context, &expand_pair, &restore_frame_state_after_expansion](
|
||
const VertexAccessor &vertex, const TypedValue &weight, int64_t depth) {
|
||
if (self_.common_.direction != EdgeAtom::Direction::IN) {
|
||
auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges;
|
||
for (const auto &edge : out_edges) {
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(edge.To(), storage::View::OLD,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ) &&
|
||
context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) {
|
||
continue;
|
||
}
|
||
#endif
|
||
expand_pair(edge, edge.To(), weight, depth);
|
||
restore_frame_state_after_expansion();
|
||
}
|
||
}
|
||
if (self_.common_.direction != EdgeAtom::Direction::OUT) {
|
||
auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges;
|
||
for (const auto &edge : in_edges) {
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(edge.From(), storage::View::OLD,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ) &&
|
||
context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) {
|
||
continue;
|
||
}
|
||
#endif
|
||
expand_pair(edge, edge.From(), weight, depth);
|
||
restore_frame_state_after_expansion();
|
||
}
|
||
}
|
||
};
|
||
|
||
while (true) {
|
||
AbortCheck(context);
|
||
if (pq_.empty()) {
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
const auto &vertex_value = frame[self_.input_symbol_];
|
||
if (vertex_value.IsNull()) continue;
|
||
auto vertex = vertex_value.ValueVertex();
|
||
if (self_.common_.existing_node) {
|
||
const auto &node = frame[self_.common_.node_symbol];
|
||
// Due to optional matching the existing node could be null.
|
||
// Skip expansion for such nodes.
|
||
if (node.IsNull()) continue;
|
||
}
|
||
|
||
std::optional<Path> curr_acc_path;
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
// Add initial vertex of path to the accumulated path
|
||
curr_acc_path = Path(vertex);
|
||
frame[self_.filter_lambda_.accumulated_path_symbol.value()] = curr_acc_path.value();
|
||
}
|
||
if (self_.upper_bound_) {
|
||
upper_bound_ = EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in weighted shortest path expansion");
|
||
upper_bound_set_ = true;
|
||
} else {
|
||
upper_bound_ = std::numeric_limits<int64_t>::max();
|
||
upper_bound_set_ = false;
|
||
}
|
||
if (upper_bound_ < 1)
|
||
throw QueryRuntimeException(
|
||
"Maximum depth in weighted shortest path expansion must be at "
|
||
"least 1.");
|
||
|
||
frame[self_.weight_lambda_->inner_edge_symbol] = TypedValue();
|
||
frame[self_.weight_lambda_->inner_node_symbol] = vertex;
|
||
TypedValue current_weight =
|
||
CalculateNextWeight(self_.weight_lambda_, /* total_weight */ TypedValue(), evaluator);
|
||
|
||
// Clear existing data structures.
|
||
previous_.clear();
|
||
total_cost_.clear();
|
||
yielded_vertices_.clear();
|
||
|
||
pq_.emplace(current_weight, 0, vertex, std::nullopt, curr_acc_path);
|
||
// We are adding the starting vertex to the set of yielded vertices
|
||
// because we don't want to yield paths that end with the starting
|
||
// vertex.
|
||
yielded_vertices_.insert(vertex);
|
||
}
|
||
|
||
while (!pq_.empty()) {
|
||
AbortCheck(context);
|
||
auto [current_weight, current_depth, current_vertex, current_edge, curr_acc_path] = pq_.top();
|
||
pq_.pop();
|
||
|
||
auto current_state = create_state(current_vertex, current_depth);
|
||
|
||
// Check if the vertex has already been processed.
|
||
if (total_cost_.find(current_state) != total_cost_.end()) {
|
||
continue;
|
||
}
|
||
previous_.emplace(current_state, current_edge);
|
||
total_cost_.emplace(current_state, current_weight);
|
||
|
||
// Expand only if what we've just expanded is less than max depth.
|
||
if (current_depth < upper_bound_) {
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
frame[self_.filter_lambda_.accumulated_path_symbol.value()] = std::move(curr_acc_path.value());
|
||
}
|
||
expand_from_vertex(current_vertex, current_weight, current_depth);
|
||
}
|
||
|
||
// If we yielded a path for a vertex already, make the expansion but
|
||
// don't return the path again.
|
||
if (yielded_vertices_.find(current_vertex) != yielded_vertices_.end()) continue;
|
||
|
||
// Reconstruct the path.
|
||
auto last_vertex = current_vertex;
|
||
auto last_depth = current_depth;
|
||
auto *pull_memory = context.evaluation_context.memory;
|
||
utils::pmr::vector<TypedValue> edge_list(pull_memory);
|
||
while (true) {
|
||
// Origin_vertex must be in previous.
|
||
const auto &previous_edge = previous_.find(create_state(last_vertex, last_depth))->second;
|
||
if (!previous_edge) break;
|
||
last_vertex = previous_edge->From() == last_vertex ? previous_edge->To() : previous_edge->From();
|
||
last_depth--;
|
||
edge_list.emplace_back(previous_edge.value());
|
||
}
|
||
|
||
// Place destination node on the frame, handle existence flag.
|
||
if (self_.common_.existing_node) {
|
||
const auto &node = frame[self_.common_.node_symbol];
|
||
if ((node != TypedValue(current_vertex, pull_memory)).ValueBool()) {
|
||
continue;
|
||
}
|
||
// Prevent expanding other paths, because we found the
|
||
// shortest to existing node.
|
||
ClearQueue();
|
||
} else {
|
||
frame[self_.common_.node_symbol] = current_vertex;
|
||
}
|
||
|
||
if (!self_.is_reverse_) {
|
||
// Place edges on the frame in the correct order.
|
||
std::reverse(edge_list.begin(), edge_list.end());
|
||
}
|
||
frame[self_.common_.edge_symbol] = std::move(edge_list);
|
||
frame[self_.total_weight_.value()] = current_weight;
|
||
yielded_vertices_.insert(current_vertex);
|
||
return true;
|
||
}
|
||
}
|
||
}
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override {
|
||
input_cursor_->Reset();
|
||
previous_.clear();
|
||
total_cost_.clear();
|
||
yielded_vertices_.clear();
|
||
ClearQueue();
|
||
}
|
||
|
||
private:
|
||
const ExpandVariable &self_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
|
||
// Upper bound on the path length.
|
||
int64_t upper_bound_{-1};
|
||
bool upper_bound_set_{false};
|
||
|
||
struct WspStateHash {
|
||
size_t operator()(const std::pair<VertexAccessor, int64_t> &key) const {
|
||
return utils::HashCombine<VertexAccessor, int64_t>{}(key.first, key.second);
|
||
}
|
||
};
|
||
|
||
// Maps vertices to weights they got in expansion.
|
||
utils::pmr::unordered_map<std::pair<VertexAccessor, int64_t>, TypedValue, WspStateHash> total_cost_;
|
||
|
||
// Maps vertices to edges used to reach them.
|
||
utils::pmr::unordered_map<std::pair<VertexAccessor, int64_t>, std::optional<EdgeAccessor>, WspStateHash> previous_;
|
||
|
||
// Keeps track of vertices for which we yielded a path already.
|
||
utils::pmr::unordered_set<VertexAccessor> yielded_vertices_;
|
||
|
||
// Priority queue comparator. Keep lowest weight on top of the queue.
|
||
class PriorityQueueComparator {
|
||
public:
|
||
bool operator()(
|
||
const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>, std::optional<Path>> &lhs,
|
||
const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>, std::optional<Path>> &rhs) {
|
||
const auto &lhs_weight = std::get<0>(lhs);
|
||
const auto &rhs_weight = std::get<0>(rhs);
|
||
// Null defines minimum value for all types
|
||
if (lhs_weight.IsNull()) {
|
||
return false;
|
||
}
|
||
|
||
if (rhs_weight.IsNull()) {
|
||
return true;
|
||
}
|
||
|
||
ValidateWeightTypes(lhs_weight, rhs_weight);
|
||
return (lhs_weight > rhs_weight).ValueBool();
|
||
}
|
||
};
|
||
|
||
std::priority_queue<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>, std::optional<Path>>,
|
||
utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>,
|
||
std::optional<Path>>>,
|
||
PriorityQueueComparator>
|
||
pq_;
|
||
|
||
void ClearQueue() {
|
||
while (!pq_.empty()) pq_.pop();
|
||
}
|
||
};
|
||
|
||
class ExpandAllShortestPathsCursor : public query::plan::Cursor {
|
||
public:
|
||
ExpandAllShortestPathsCursor(const ExpandVariable &self, utils::MemoryResource *mem)
|
||
: self_(self),
|
||
input_cursor_(self_.input_->MakeCursor(mem)),
|
||
visited_cost_(mem),
|
||
total_cost_(mem),
|
||
next_edges_(mem),
|
||
traversal_stack_(mem),
|
||
pq_(mem) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("ExpandAllShortestPathsCursor");
|
||
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
|
||
auto *memory = context.evaluation_context.memory;
|
||
|
||
auto create_state = [this](const VertexAccessor &vertex, int64_t depth) {
|
||
return std::make_pair(vertex, upper_bound_set_ ? depth : 0);
|
||
};
|
||
|
||
// For the given (edge, direction, weight, depth) tuple checks if they
|
||
// satisfy the "where" condition. if so, places them in the priority
|
||
// queue.
|
||
auto expand_vertex = [this, &evaluator, &frame](const EdgeAccessor &edge, const EdgeAtom::Direction direction,
|
||
const TypedValue &total_weight, int64_t depth) {
|
||
auto const &next_vertex = direction == EdgeAtom::Direction::IN ? edge.From() : edge.To();
|
||
|
||
// Evaluate current weight
|
||
frame[self_.weight_lambda_->inner_edge_symbol] = edge;
|
||
frame[self_.weight_lambda_->inner_node_symbol] = next_vertex;
|
||
TypedValue next_weight = CalculateNextWeight(self_.weight_lambda_, total_weight, evaluator);
|
||
|
||
// If filter expression exists, evaluate filter
|
||
std::optional<Path> curr_acc_path = std::nullopt;
|
||
if (self_.filter_lambda_.expression) {
|
||
frame[self_.filter_lambda_.inner_edge_symbol] = edge;
|
||
frame[self_.filter_lambda_.inner_node_symbol] = next_vertex;
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(),
|
||
"Accumulated path must be path");
|
||
Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath();
|
||
accumulated_path.Expand(edge);
|
||
accumulated_path.Expand(next_vertex);
|
||
curr_acc_path = accumulated_path;
|
||
|
||
if (self_.filter_lambda_.accumulated_weight_symbol) {
|
||
frame[self_.filter_lambda_.accumulated_weight_symbol.value()] = next_weight;
|
||
}
|
||
}
|
||
|
||
if (!EvaluateFilter(evaluator, self_.filter_lambda_.expression)) return;
|
||
}
|
||
|
||
auto found_it = visited_cost_.find(next_vertex);
|
||
// Check if the vertex has already been processed.
|
||
if (found_it != visited_cost_.end()) {
|
||
auto weight = found_it->second;
|
||
|
||
if (weight.IsNull() || (next_weight <= weight).ValueBool()) {
|
||
// Has been visited, but now found a shorter path
|
||
visited_cost_[next_vertex] = next_weight;
|
||
} else {
|
||
// Continue and do not expand if current weight is larger
|
||
return;
|
||
}
|
||
} else {
|
||
visited_cost_[next_vertex] = next_weight;
|
||
}
|
||
|
||
DirectedEdge directed_edge = {edge, direction, next_weight};
|
||
pq_.emplace(next_weight, depth + 1, next_vertex, directed_edge, curr_acc_path);
|
||
};
|
||
|
||
auto restore_frame_state_after_expansion = [this, &frame]() {
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath().Shrink();
|
||
}
|
||
};
|
||
|
||
// Populates the priority queue structure with expansions
|
||
// from the given vertex. skips expansions that don't satisfy
|
||
// the "where" condition.
|
||
auto expand_from_vertex = [this, &expand_vertex, &context, &restore_frame_state_after_expansion](
|
||
const VertexAccessor &vertex, const TypedValue &weight, int64_t depth) {
|
||
if (self_.common_.direction != EdgeAtom::Direction::IN) {
|
||
auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges;
|
||
for (const auto &edge : out_edges) {
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(edge.To(), storage::View::OLD,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ) &&
|
||
context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) {
|
||
continue;
|
||
}
|
||
#endif
|
||
expand_vertex(edge, EdgeAtom::Direction::OUT, weight, depth);
|
||
restore_frame_state_after_expansion();
|
||
}
|
||
}
|
||
if (self_.common_.direction != EdgeAtom::Direction::OUT) {
|
||
auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges;
|
||
for (const auto &edge : in_edges) {
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(edge.From(), storage::View::OLD,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::READ) &&
|
||
context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) {
|
||
continue;
|
||
}
|
||
#endif
|
||
expand_vertex(edge, EdgeAtom::Direction::IN, weight, depth);
|
||
restore_frame_state_after_expansion();
|
||
}
|
||
}
|
||
};
|
||
|
||
std::optional<VertexAccessor> start_vertex;
|
||
|
||
auto create_path = [this, &frame, &memory]() {
|
||
auto ¤t_level = traversal_stack_.back();
|
||
auto &edges_on_frame = frame[self_.common_.edge_symbol].ValueList();
|
||
|
||
// Clean out the current stack
|
||
if (current_level.empty()) {
|
||
if (!edges_on_frame.empty()) {
|
||
if (!self_.is_reverse_)
|
||
edges_on_frame.erase(edges_on_frame.end());
|
||
else
|
||
edges_on_frame.erase(edges_on_frame.begin());
|
||
}
|
||
traversal_stack_.pop_back();
|
||
return false;
|
||
}
|
||
|
||
auto [current_edge, current_edge_direction, current_weight] = current_level.back();
|
||
current_level.pop_back();
|
||
|
||
// Edges order depends on direction of expansion
|
||
if (!self_.is_reverse_)
|
||
edges_on_frame.emplace_back(current_edge);
|
||
else
|
||
edges_on_frame.emplace(edges_on_frame.begin(), current_edge);
|
||
|
||
auto next_vertex = current_edge_direction == EdgeAtom::Direction::IN ? current_edge.From() : current_edge.To();
|
||
frame[self_.total_weight_.value()] = current_weight;
|
||
|
||
if (next_edges_.find({next_vertex, traversal_stack_.size()}) != next_edges_.end()) {
|
||
auto next_vertex_edges = next_edges_[{next_vertex, traversal_stack_.size()}];
|
||
traversal_stack_.emplace_back(std::move(next_vertex_edges));
|
||
} else {
|
||
// Signal the end of iteration
|
||
utils::pmr::list<DirectedEdge> empty(memory);
|
||
traversal_stack_.emplace_back(std::move(empty));
|
||
}
|
||
|
||
if ((current_weight > visited_cost_.at(next_vertex)).ValueBool()) return false;
|
||
|
||
// Place destination node on the frame, handle existence flag
|
||
if (self_.common_.existing_node) {
|
||
const auto &node = frame[self_.common_.node_symbol];
|
||
ExpectType(self_.common_.node_symbol, node, TypedValue::Type::Vertex);
|
||
if (node.ValueVertex() != next_vertex) return false;
|
||
} else {
|
||
frame[self_.common_.node_symbol] = next_vertex;
|
||
}
|
||
return true;
|
||
};
|
||
|
||
auto create_DFS_traversal_tree = [this, &context, &memory, &frame, &create_state, &expand_from_vertex]() {
|
||
while (!pq_.empty()) {
|
||
AbortCheck(context);
|
||
|
||
auto [current_weight, current_depth, current_vertex, directed_edge, acc_path] = pq_.top();
|
||
pq_.pop();
|
||
|
||
const auto &[current_edge, direction, weight] = directed_edge;
|
||
auto current_state = create_state(current_vertex, current_depth);
|
||
|
||
auto position = total_cost_.find(current_state);
|
||
if (position != total_cost_.end()) {
|
||
if ((position->second < current_weight).ValueBool()) continue;
|
||
} else {
|
||
total_cost_.emplace(current_state, current_weight);
|
||
if (current_depth < upper_bound_) {
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
DMG_ASSERT(acc_path.has_value(), "Path must be already filled in AllShortestPath DFS traversals");
|
||
frame[self_.filter_lambda_.accumulated_path_symbol.value()] = std::move(acc_path.value());
|
||
}
|
||
expand_from_vertex(current_vertex, current_weight, current_depth);
|
||
}
|
||
}
|
||
|
||
// Searching for a previous vertex in the expansion
|
||
auto prev_vertex = direction == EdgeAtom::Direction::IN ? current_edge.To() : current_edge.From();
|
||
|
||
// Update the parent
|
||
if (next_edges_.find({prev_vertex, current_depth - 1}) == next_edges_.end()) {
|
||
utils::pmr::list<DirectedEdge> empty(memory);
|
||
next_edges_[{prev_vertex, current_depth - 1}] = std::move(empty);
|
||
}
|
||
next_edges_.at({prev_vertex, current_depth - 1}).emplace_back(directed_edge);
|
||
}
|
||
};
|
||
|
||
// upper_bound_set is used when storing visited edges, because with an upper bound we also consider suboptimal paths
|
||
// if they are shorter in depth
|
||
if (self_.upper_bound_) {
|
||
upper_bound_ = EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in all shortest path expansion");
|
||
upper_bound_set_ = true;
|
||
} else {
|
||
upper_bound_ = std::numeric_limits<int64_t>::max();
|
||
upper_bound_set_ = false;
|
||
}
|
||
|
||
// Check if upper bound is valid
|
||
if (upper_bound_ < 1) {
|
||
throw QueryRuntimeException("Maximum depth in all shortest paths expansion must be at least 1.");
|
||
}
|
||
|
||
// On first Pull run, traversal stack and priority queue are empty, so we start a pulling stream
|
||
// and create a DFS traversal tree (main part of algorithm). Then we return the first path
|
||
// created from the DFS traversal tree (basically a DFS algorithm).
|
||
// On each subsequent Pull run, paths are created from the traversal stack and returned.
|
||
while (true) {
|
||
// Check if there is an external error.
|
||
AbortCheck(context);
|
||
|
||
// The algorithm is run all at once by create_DFS_traversal_tree, after which we
|
||
// traverse the tree iteratively by preserving the traversal state on stack.
|
||
while (!traversal_stack_.empty()) {
|
||
if (create_path()) return true;
|
||
}
|
||
|
||
// If priority queue is empty start new pulling stream.
|
||
if (pq_.empty()) {
|
||
// Finish if there is nothing to pull
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
|
||
const auto &vertex_value = frame[self_.input_symbol_];
|
||
if (vertex_value.IsNull()) continue;
|
||
|
||
start_vertex = vertex_value.ValueVertex();
|
||
if (self_.common_.existing_node) {
|
||
const auto &node = frame[self_.common_.node_symbol];
|
||
// Due to optional matching the existing node could be null.
|
||
// Skip expansion for such nodes.
|
||
if (node.IsNull()) continue;
|
||
}
|
||
|
||
// Clear existing data structures.
|
||
visited_cost_.clear();
|
||
next_edges_.clear();
|
||
traversal_stack_.clear();
|
||
total_cost_.clear();
|
||
|
||
if (self_.filter_lambda_.accumulated_path_symbol) {
|
||
// Add initial vertex of path to the accumulated path
|
||
frame[self_.filter_lambda_.accumulated_path_symbol.value()] = Path(*start_vertex);
|
||
}
|
||
|
||
frame[self_.weight_lambda_->inner_edge_symbol] = TypedValue();
|
||
frame[self_.weight_lambda_->inner_node_symbol] = *start_vertex;
|
||
TypedValue current_weight =
|
||
CalculateNextWeight(self_.weight_lambda_, /* total_weight */ TypedValue(), evaluator);
|
||
|
||
expand_from_vertex(*start_vertex, current_weight, 0);
|
||
visited_cost_.emplace(*start_vertex, 0);
|
||
frame[self_.common_.edge_symbol] = TypedValue::TVector(memory);
|
||
}
|
||
|
||
// Create a DFS traversal tree from the start node
|
||
create_DFS_traversal_tree();
|
||
|
||
// DFS traversal tree is create,
|
||
if (start_vertex && next_edges_.find({*start_vertex, 0}) != next_edges_.end()) {
|
||
auto start_vertex_edges = next_edges_[{*start_vertex, 0}];
|
||
traversal_stack_.emplace_back(std::move(start_vertex_edges));
|
||
}
|
||
}
|
||
}
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override {
|
||
input_cursor_->Reset();
|
||
visited_cost_.clear();
|
||
next_edges_.clear();
|
||
traversal_stack_.clear();
|
||
total_cost_.clear();
|
||
ClearQueue();
|
||
}
|
||
|
||
private:
|
||
const ExpandVariable &self_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
|
||
// Upper bound on the path length.
|
||
int64_t upper_bound_{-1};
|
||
bool upper_bound_set_{false};
|
||
|
||
struct AspStateHash {
|
||
size_t operator()(const std::pair<VertexAccessor, int64_t> &key) const {
|
||
return utils::HashCombine<VertexAccessor, int64_t>{}(key.first, key.second);
|
||
}
|
||
};
|
||
|
||
using DirectedEdge = std::tuple<EdgeAccessor, EdgeAtom::Direction, TypedValue>;
|
||
using NextEdgesState = std::pair<VertexAccessor, int64_t>;
|
||
// Maps vertices to minimum weights they got in expansion.
|
||
utils::pmr::unordered_map<VertexAccessor, TypedValue> visited_cost_;
|
||
// Maps vertices to weights they got in expansion.
|
||
utils::pmr::unordered_map<NextEdgesState, TypedValue, AspStateHash> total_cost_;
|
||
// Maps the vertex with the potential expansion edge.
|
||
utils::pmr::unordered_map<NextEdgesState, utils::pmr::list<DirectedEdge>, AspStateHash> next_edges_;
|
||
// Stack indicating the traversal level.
|
||
utils::pmr::list<utils::pmr::list<DirectedEdge>> traversal_stack_;
|
||
|
||
// Priority queue comparator. Keep lowest weight on top of the queue.
|
||
class PriorityQueueComparator {
|
||
public:
|
||
bool operator()(const std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge, std::optional<Path>> &lhs,
|
||
const std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge, std::optional<Path>> &rhs) {
|
||
const auto &lhs_weight = std::get<0>(lhs);
|
||
const auto &rhs_weight = std::get<0>(rhs);
|
||
// Null defines minimum value for all types
|
||
if (lhs_weight.IsNull()) {
|
||
return false;
|
||
}
|
||
|
||
if (rhs_weight.IsNull()) {
|
||
return true;
|
||
}
|
||
|
||
ValidateWeightTypes(lhs_weight, rhs_weight);
|
||
return (lhs_weight > rhs_weight).ValueBool();
|
||
}
|
||
};
|
||
|
||
// Priority queue - core element of the algorithm.
|
||
// Stores: {weight, depth, next vertex, edge and direction}
|
||
std::priority_queue<
|
||
std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge, std::optional<Path>>,
|
||
utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, DirectedEdge, std::optional<Path>>>,
|
||
PriorityQueueComparator>
|
||
pq_;
|
||
|
||
void ClearQueue() {
|
||
while (!pq_.empty()) pq_.pop();
|
||
}
|
||
};
|
||
|
||
UniqueCursorPtr ExpandVariable::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ExpandVariableOperator);
|
||
|
||
switch (type_) {
|
||
case EdgeAtom::Type::BREADTH_FIRST:
|
||
if (common_.existing_node) {
|
||
return MakeUniqueCursorPtr<STShortestPathCursor>(mem, *this, mem);
|
||
} else {
|
||
return MakeUniqueCursorPtr<SingleSourceShortestPathCursor>(mem, *this, mem);
|
||
}
|
||
case EdgeAtom::Type::DEPTH_FIRST:
|
||
return MakeUniqueCursorPtr<ExpandVariableCursor>(mem, *this, mem);
|
||
case EdgeAtom::Type::WEIGHTED_SHORTEST_PATH:
|
||
return MakeUniqueCursorPtr<ExpandWeightedShortestPathCursor>(mem, *this, mem);
|
||
case EdgeAtom::Type::ALL_SHORTEST_PATHS:
|
||
return MakeUniqueCursorPtr<ExpandAllShortestPathsCursor>(mem, *this, mem);
|
||
case EdgeAtom::Type::SINGLE:
|
||
LOG_FATAL("ExpandVariable should not be planned for a single expansion!");
|
||
}
|
||
}
|
||
|
||
class ConstructNamedPathCursor : public Cursor {
|
||
public:
|
||
ConstructNamedPathCursor(ConstructNamedPath self, utils::MemoryResource *mem)
|
||
: self_(std::move(self)), input_cursor_(self_.input()->MakeCursor(mem)) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("ConstructNamedPath");
|
||
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
|
||
auto symbol_it = self_.path_elements_.begin();
|
||
DMG_ASSERT(symbol_it != self_.path_elements_.end(), "Named path must contain at least one node");
|
||
|
||
const auto &start_vertex = frame[*symbol_it++];
|
||
auto *pull_memory = context.evaluation_context.memory;
|
||
// In an OPTIONAL MATCH everything could be Null.
|
||
if (start_vertex.IsNull()) {
|
||
frame[self_.path_symbol_] = TypedValue(pull_memory);
|
||
return true;
|
||
}
|
||
|
||
DMG_ASSERT(start_vertex.IsVertex(), "First named path element must be a vertex");
|
||
query::Path path(start_vertex.ValueVertex(), pull_memory);
|
||
|
||
// If the last path element symbol was for an edge list, then
|
||
// the next symbol is a vertex and it should not append to the path
|
||
// because
|
||
// expansion already did it.
|
||
bool last_was_edge_list = false;
|
||
|
||
for (; symbol_it != self_.path_elements_.end(); symbol_it++) {
|
||
const auto &expansion = frame[*symbol_it];
|
||
// We can have Null (OPTIONAL MATCH), a vertex, an edge, or an edge
|
||
// list (variable expand or BFS).
|
||
switch (expansion.type()) {
|
||
case TypedValue::Type::Null:
|
||
frame[self_.path_symbol_] = TypedValue(pull_memory);
|
||
return true;
|
||
case TypedValue::Type::Vertex:
|
||
if (!last_was_edge_list) path.Expand(expansion.ValueVertex());
|
||
last_was_edge_list = false;
|
||
break;
|
||
case TypedValue::Type::Edge:
|
||
path.Expand(expansion.ValueEdge());
|
||
break;
|
||
case TypedValue::Type::List: {
|
||
last_was_edge_list = true;
|
||
// We need to expand all edges in the list and intermediary
|
||
// vertices.
|
||
const auto &edges = expansion.ValueList();
|
||
for (const auto &edge_value : edges) {
|
||
const auto &edge = edge_value.ValueEdge();
|
||
const auto &from = edge.From();
|
||
if (path.vertices().back() == from)
|
||
path.Expand(edge, edge.To());
|
||
else
|
||
path.Expand(edge, from);
|
||
}
|
||
break;
|
||
}
|
||
default:
|
||
LOG_FATAL("Unsupported type in named path construction");
|
||
|
||
break;
|
||
}
|
||
}
|
||
|
||
frame[self_.path_symbol_] = path;
|
||
return true;
|
||
}
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override { input_cursor_->Reset(); }
|
||
|
||
private:
|
||
const ConstructNamedPath self_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
};
|
||
|
||
ACCEPT_WITH_INPUT(ConstructNamedPath)
|
||
|
||
UniqueCursorPtr ConstructNamedPath::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ConstructNamedPathOperator);
|
||
|
||
return MakeUniqueCursorPtr<ConstructNamedPathCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> ConstructNamedPath::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
symbols.emplace_back(path_symbol_);
|
||
return symbols;
|
||
}
|
||
|
||
Filter::Filter(const std::shared_ptr<LogicalOperator> &input,
|
||
const std::vector<std::shared_ptr<LogicalOperator>> &pattern_filters, Expression *expression)
|
||
: input_(input ? input : std::make_shared<Once>()), pattern_filters_(pattern_filters), expression_(expression) {}
|
||
|
||
Filter::Filter(const std::shared_ptr<LogicalOperator> &input,
|
||
const std::vector<std::shared_ptr<LogicalOperator>> &pattern_filters, Expression *expression,
|
||
Filters all_filters)
|
||
: input_(input ? input : std::make_shared<Once>()),
|
||
pattern_filters_(pattern_filters),
|
||
expression_(expression),
|
||
all_filters_(std::move(all_filters)) {}
|
||
|
||
bool Filter::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||
if (visitor.PreVisit(*this)) {
|
||
input_->Accept(visitor);
|
||
for (const auto &pattern_filter : pattern_filters_) {
|
||
pattern_filter->Accept(visitor);
|
||
}
|
||
}
|
||
return visitor.PostVisit(*this);
|
||
}
|
||
|
||
UniqueCursorPtr Filter::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::FilterOperator);
|
||
|
||
return MakeUniqueCursorPtr<FilterCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> Filter::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); }
|
||
|
||
static std::vector<UniqueCursorPtr> MakeCursorVector(const std::vector<std::shared_ptr<LogicalOperator>> &ops,
|
||
utils::MemoryResource *mem) {
|
||
std::vector<UniqueCursorPtr> cursors;
|
||
cursors.reserve(ops.size());
|
||
|
||
if (!ops.empty()) {
|
||
for (const auto &op : ops) {
|
||
cursors.push_back(op->MakeCursor(mem));
|
||
}
|
||
}
|
||
|
||
return cursors;
|
||
}
|
||
|
||
Filter::FilterCursor::FilterCursor(const Filter &self, utils::MemoryResource *mem)
|
||
: self_(self),
|
||
input_cursor_(self_.input_->MakeCursor(mem)),
|
||
pattern_filter_cursors_(MakeCursorVector(self_.pattern_filters_, mem)) {}
|
||
|
||
bool Filter::FilterCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(self_);
|
||
|
||
// Like all filters, newly set values should not affect filtering of old
|
||
// nodes and edges.
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD, context.frame_change_collector);
|
||
while (input_cursor_->Pull(frame, context)) {
|
||
for (const auto &pattern_filter_cursor : pattern_filter_cursors_) {
|
||
pattern_filter_cursor->Pull(frame, context);
|
||
}
|
||
if (EvaluateFilter(evaluator, self_.expression_)) return true;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
void Filter::FilterCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void Filter::FilterCursor::Reset() { input_cursor_->Reset(); }
|
||
|
||
EvaluatePatternFilter::EvaluatePatternFilter(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol)
|
||
: input_(input), output_symbol_(std::move(output_symbol)) {}
|
||
|
||
ACCEPT_WITH_INPUT(EvaluatePatternFilter);
|
||
|
||
UniqueCursorPtr EvaluatePatternFilter::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::EvaluatePatternFilterOperator);
|
||
|
||
return MakeUniqueCursorPtr<EvaluatePatternFilterCursor>(mem, *this, mem);
|
||
}
|
||
|
||
EvaluatePatternFilter::EvaluatePatternFilterCursor::EvaluatePatternFilterCursor(const EvaluatePatternFilter &self,
|
||
utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {}
|
||
|
||
std::vector<Symbol> EvaluatePatternFilter::ModifiedSymbols(const SymbolTable &table) const {
|
||
return input_->ModifiedSymbols(table);
|
||
}
|
||
|
||
bool EvaluatePatternFilter::EvaluatePatternFilterCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
SCOPED_PROFILE_OP("EvaluatePatternFilter");
|
||
std::function<void(TypedValue *)> function = [&frame, self = this->self_, input_cursor = this->input_cursor_.get(),
|
||
&context](TypedValue *return_value) {
|
||
OOMExceptionEnabler oom_exception;
|
||
input_cursor->Reset();
|
||
|
||
*return_value = TypedValue(input_cursor->Pull(frame, context), context.evaluation_context.memory);
|
||
};
|
||
|
||
frame[self_.output_symbol_] = TypedValue(std::move(function));
|
||
return true;
|
||
}
|
||
|
||
void EvaluatePatternFilter::EvaluatePatternFilterCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void EvaluatePatternFilter::EvaluatePatternFilterCursor::Reset() { input_cursor_->Reset(); }
|
||
|
||
Produce::Produce(const std::shared_ptr<LogicalOperator> &input, const std::vector<NamedExpression *> &named_expressions)
|
||
: input_(input ? input : std::make_shared<Once>()), named_expressions_(named_expressions) {}
|
||
|
||
ACCEPT_WITH_INPUT(Produce)
|
||
|
||
UniqueCursorPtr Produce::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ProduceOperator);
|
||
|
||
return MakeUniqueCursorPtr<ProduceCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> Produce::OutputSymbols(const SymbolTable &symbol_table) const {
|
||
std::vector<Symbol> symbols;
|
||
for (const auto &named_expr : named_expressions_) {
|
||
symbols.emplace_back(symbol_table.at(*named_expr));
|
||
}
|
||
return symbols;
|
||
}
|
||
|
||
std::vector<Symbol> Produce::ModifiedSymbols(const SymbolTable &table) const { return OutputSymbols(table); }
|
||
|
||
Produce::ProduceCursor::ProduceCursor(const Produce &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {}
|
||
|
||
bool Produce::ProduceCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(self_);
|
||
|
||
if (input_cursor_->Pull(frame, context)) {
|
||
// Produce should always yield the latest results.
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::NEW, context.frame_change_collector);
|
||
for (auto *named_expr : self_.named_expressions_) {
|
||
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(named_expr->name_)) {
|
||
context.frame_change_collector->ResetTrackingValue(named_expr->name_);
|
||
}
|
||
named_expr->Accept(evaluator);
|
||
}
|
||
return true;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
void Produce::ProduceCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void Produce::ProduceCursor::Reset() { input_cursor_->Reset(); }
|
||
|
||
Delete::Delete(const std::shared_ptr<LogicalOperator> &input_, const std::vector<Expression *> &expressions,
|
||
bool detach_)
|
||
: input_(input_), expressions_(expressions), detach_(detach_) {}
|
||
|
||
ACCEPT_WITH_INPUT(Delete)
|
||
|
||
UniqueCursorPtr Delete::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::DeleteOperator);
|
||
|
||
return MakeUniqueCursorPtr<DeleteCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> Delete::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); }
|
||
|
||
Delete::DeleteCursor::DeleteCursor(const Delete &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {}
|
||
|
||
void Delete::DeleteCursor::UpdateDeleteBuffer(Frame &frame, ExecutionContext &context) {
|
||
// Delete should get the latest information, this way it is also possible
|
||
// to delete newly added nodes and edges.
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::NEW);
|
||
|
||
auto *pull_memory = context.evaluation_context.memory;
|
||
// collect expressions results so edges can get deleted before vertices
|
||
// this is necessary because an edge that gets deleted could block vertex
|
||
// deletion
|
||
utils::pmr::vector<TypedValue> expression_results(pull_memory);
|
||
expression_results.reserve(self_.expressions_.size());
|
||
for (Expression *expression : self_.expressions_) {
|
||
expression_results.emplace_back(expression->Accept(evaluator));
|
||
}
|
||
|
||
auto vertex_auth_checker = [&context](const VertexAccessor &va) -> bool {
|
||
#ifdef MG_ENTERPRISE
|
||
return !(license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!context.auth_checker->Has(va, storage::View::NEW, query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE));
|
||
#else
|
||
return true;
|
||
#endif
|
||
};
|
||
|
||
auto edge_auth_checker = [&context](const EdgeAccessor &ea) -> bool {
|
||
#ifdef MG_ENTERPRISE
|
||
return !(
|
||
license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!(context.auth_checker->Has(ea, query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE) &&
|
||
context.auth_checker->Has(ea.To(), storage::View::NEW, query::AuthQuery::FineGrainedPrivilege::UPDATE) &&
|
||
context.auth_checker->Has(ea.From(), storage::View::NEW, query::AuthQuery::FineGrainedPrivilege::UPDATE)));
|
||
#else
|
||
return true;
|
||
#endif
|
||
};
|
||
|
||
for (TypedValue &expression_result : expression_results) {
|
||
AbortCheck(context);
|
||
switch (expression_result.type()) {
|
||
case TypedValue::Type::Vertex: {
|
||
auto va = expression_result.ValueVertex();
|
||
if (vertex_auth_checker(va)) {
|
||
buffer_.nodes.push_back(va);
|
||
} else {
|
||
throw QueryRuntimeException("Vertex not deleted due to not having enough permission!");
|
||
}
|
||
break;
|
||
}
|
||
case TypedValue::Type::Edge: {
|
||
auto ea = expression_result.ValueEdge();
|
||
if (edge_auth_checker(ea)) {
|
||
buffer_.edges.push_back(ea);
|
||
} else {
|
||
throw QueryRuntimeException("Edge not deleted due to not having enough permission!");
|
||
}
|
||
break;
|
||
}
|
||
case TypedValue::Type::Path: {
|
||
auto path = expression_result.ValuePath();
|
||
#ifdef MG_ENTERPRISE
|
||
auto edges_res = std::any_of(path.edges().cbegin(), path.edges().cend(),
|
||
[&edge_auth_checker](const auto &ea) { return !edge_auth_checker(ea); });
|
||
auto vertices_res = std::any_of(path.vertices().cbegin(), path.vertices().cend(),
|
||
[&vertex_auth_checker](const auto &va) { return !vertex_auth_checker(va); });
|
||
|
||
if (edges_res || vertices_res) {
|
||
throw QueryRuntimeException(
|
||
"Path not deleted due to not having enough permission on all edges and vertices on the path!");
|
||
}
|
||
#endif
|
||
buffer_.nodes.insert(buffer_.nodes.begin(), path.vertices().begin(), path.vertices().end());
|
||
buffer_.edges.insert(buffer_.edges.begin(), path.edges().begin(), path.edges().end());
|
||
}
|
||
case TypedValue::Type::Null:
|
||
break;
|
||
default:
|
||
throw QueryRuntimeException("Edges, vertices and paths can be deleted.");
|
||
}
|
||
}
|
||
}
|
||
|
||
bool Delete::DeleteCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("Delete");
|
||
|
||
if (delete_executed_) {
|
||
return false;
|
||
}
|
||
|
||
if (input_cursor_->Pull(frame, context)) {
|
||
UpdateDeleteBuffer(frame, context);
|
||
return true;
|
||
}
|
||
|
||
auto &dba = *context.db_accessor;
|
||
auto res = dba.DetachDelete(std::move(buffer_.nodes), std::move(buffer_.edges), self_.detach_);
|
||
if (res.HasError()) {
|
||
switch (res.GetError()) {
|
||
case storage::Error::SERIALIZATION_ERROR:
|
||
throw TransactionSerializationException();
|
||
case storage::Error::VERTEX_HAS_EDGES:
|
||
throw RemoveAttachedVertexException();
|
||
case storage::Error::DELETED_OBJECT:
|
||
case storage::Error::PROPERTIES_DISABLED:
|
||
case storage::Error::NONEXISTENT_OBJECT:
|
||
throw QueryRuntimeException("Unexpected error when deleting a node.");
|
||
}
|
||
}
|
||
|
||
if (*res) {
|
||
context.execution_stats[ExecutionStats::Key::DELETED_NODES] += static_cast<int64_t>((*res)->first.size());
|
||
context.execution_stats[ExecutionStats::Key::DELETED_EDGES] += static_cast<int64_t>((*res)->second.size());
|
||
}
|
||
|
||
// Update deleted objects for triggers
|
||
if (context.trigger_context_collector && *res) {
|
||
for (const auto &node : (*res)->first) {
|
||
context.trigger_context_collector->RegisterDeletedObject(node);
|
||
}
|
||
|
||
if (context.trigger_context_collector->ShouldRegisterDeletedObject<query::EdgeAccessor>()) {
|
||
for (const auto &edge : (*res)->second) {
|
||
context.trigger_context_collector->RegisterDeletedObject(edge);
|
||
}
|
||
}
|
||
}
|
||
|
||
delete_executed_ = true;
|
||
|
||
return false;
|
||
}
|
||
|
||
void Delete::DeleteCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void Delete::DeleteCursor::Reset() {
|
||
input_cursor_->Reset();
|
||
delete_executed_ = false;
|
||
}
|
||
|
||
SetProperty::SetProperty(const std::shared_ptr<LogicalOperator> &input, storage::PropertyId property,
|
||
PropertyLookup *lhs, Expression *rhs)
|
||
: input_(input), property_(property), lhs_(lhs), rhs_(rhs) {}
|
||
|
||
ACCEPT_WITH_INPUT(SetProperty)
|
||
|
||
UniqueCursorPtr SetProperty::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::SetPropertyOperator);
|
||
|
||
return MakeUniqueCursorPtr<SetPropertyCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> SetProperty::ModifiedSymbols(const SymbolTable &table) const {
|
||
return input_->ModifiedSymbols(table);
|
||
}
|
||
|
||
SetProperty::SetPropertyCursor::SetPropertyCursor(const SetProperty &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)) {}
|
||
|
||
bool SetProperty::SetPropertyCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("SetProperty");
|
||
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
|
||
// Set, just like Create needs to see the latest changes.
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::NEW);
|
||
TypedValue lhs = self_.lhs_->expression_->Accept(evaluator);
|
||
TypedValue rhs = self_.rhs_->Accept(evaluator);
|
||
|
||
switch (lhs.type()) {
|
||
case TypedValue::Type::Vertex: {
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!context.auth_checker->Has(lhs.ValueVertex(), storage::View::NEW,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) {
|
||
throw QueryRuntimeException("Vertex property not set due to not having enough permission!");
|
||
}
|
||
#endif
|
||
auto old_value = PropsSetChecked(&lhs.ValueVertex(), self_.property_, rhs);
|
||
context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1;
|
||
if (context.trigger_context_collector) {
|
||
// rhs cannot be moved because it was created with the allocator that is only valid during current pull
|
||
context.trigger_context_collector->RegisterSetObjectProperty(lhs.ValueVertex(), self_.property_,
|
||
TypedValue{std::move(old_value)}, TypedValue{rhs});
|
||
}
|
||
if (flags::AreExperimentsEnabled(flags::Experiments::TEXT_SEARCH)) {
|
||
context.db_accessor->TextIndexUpdateVertex(lhs.ValueVertex());
|
||
}
|
||
break;
|
||
}
|
||
case TypedValue::Type::Edge: {
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!context.auth_checker->Has(lhs.ValueEdge(), memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) {
|
||
throw QueryRuntimeException("Edge property not set due to not having enough permission!");
|
||
}
|
||
#endif
|
||
auto old_value = PropsSetChecked(&lhs.ValueEdge(), self_.property_, rhs);
|
||
context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1;
|
||
if (context.trigger_context_collector) {
|
||
// rhs cannot be moved because it was created with the allocator that is only valid
|
||
// during current pull
|
||
context.trigger_context_collector->RegisterSetObjectProperty(lhs.ValueEdge(), self_.property_,
|
||
TypedValue{std::move(old_value)}, TypedValue{rhs});
|
||
}
|
||
break;
|
||
}
|
||
case TypedValue::Type::Null:
|
||
// Skip setting properties on Null (can occur in optional match).
|
||
break;
|
||
case TypedValue::Type::Map:
|
||
// Semantically modifying a map makes sense, but it's not supported due
|
||
// to all the copying we do (when PropertyValue -> TypedValue and in
|
||
// ExpressionEvaluator). So even though we set a map property here, that
|
||
// is never visible to the user and it's not stored.
|
||
// TODO: fix above described bug
|
||
default:
|
||
throw QueryRuntimeException("Properties can only be set on edges and vertices.");
|
||
}
|
||
return true;
|
||
}
|
||
|
||
void SetProperty::SetPropertyCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void SetProperty::SetPropertyCursor::Reset() { input_cursor_->Reset(); }
|
||
|
||
SetProperties::SetProperties(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Expression *rhs, Op op)
|
||
: input_(input), input_symbol_(std::move(input_symbol)), rhs_(rhs), op_(op) {}
|
||
|
||
ACCEPT_WITH_INPUT(SetProperties)
|
||
|
||
UniqueCursorPtr SetProperties::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::SetPropertiesOperator);
|
||
|
||
return MakeUniqueCursorPtr<SetPropertiesCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> SetProperties::ModifiedSymbols(const SymbolTable &table) const {
|
||
return input_->ModifiedSymbols(table);
|
||
}
|
||
|
||
SetProperties::SetPropertiesCursor::SetPropertiesCursor(const SetProperties &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)) {}
|
||
|
||
namespace {
|
||
|
||
template <typename T>
|
||
concept AccessorWithProperties = requires(T value, storage::PropertyId property_id,
|
||
storage::PropertyValue property_value,
|
||
std::map<storage::PropertyId, storage::PropertyValue> properties) {
|
||
{ value.ClearProperties() } -> std::same_as<storage::Result<std::map<storage::PropertyId, storage::PropertyValue>>>;
|
||
{value.SetProperty(property_id, property_value)};
|
||
{value.UpdateProperties(properties)};
|
||
};
|
||
|
||
/// Helper function that sets the given values on either a Vertex or an Edge.
|
||
///
|
||
/// @tparam TRecordAccessor Either RecordAccessor<Vertex> or
|
||
/// RecordAccessor<Edge>
|
||
template <AccessorWithProperties TRecordAccessor>
|
||
void SetPropertiesOnRecord(TRecordAccessor *record, const TypedValue &rhs, SetProperties::Op op,
|
||
ExecutionContext *context,
|
||
std::unordered_map<std::string, storage::PropertyId> &cached_name_id) {
|
||
using PropertiesMap = std::map<storage::PropertyId, storage::PropertyValue>;
|
||
std::optional<PropertiesMap> old_values;
|
||
const bool should_register_change =
|
||
context->trigger_context_collector &&
|
||
context->trigger_context_collector->ShouldRegisterObjectPropertyChange<TRecordAccessor>();
|
||
if (op == SetProperties::Op::REPLACE) {
|
||
auto maybe_value = record->ClearProperties();
|
||
if (maybe_value.HasError()) {
|
||
switch (maybe_value.GetError()) {
|
||
case storage::Error::DELETED_OBJECT:
|
||
throw QueryRuntimeException("Trying to set properties on a deleted graph element.");
|
||
case storage::Error::SERIALIZATION_ERROR:
|
||
throw TransactionSerializationException();
|
||
case storage::Error::PROPERTIES_DISABLED:
|
||
throw QueryRuntimeException("Can't set property because properties on edges are disabled.");
|
||
case storage::Error::VERTEX_HAS_EDGES:
|
||
case storage::Error::NONEXISTENT_OBJECT:
|
||
throw QueryRuntimeException("Unexpected error when setting properties.");
|
||
}
|
||
}
|
||
|
||
if (should_register_change) {
|
||
old_values.emplace(std::move(*maybe_value));
|
||
}
|
||
}
|
||
|
||
auto get_props = [](const auto &record) {
|
||
auto maybe_props = record.Properties(storage::View::NEW);
|
||
if (maybe_props.HasError()) {
|
||
switch (maybe_props.GetError()) {
|
||
case storage::Error::DELETED_OBJECT:
|
||
throw QueryRuntimeException("Trying to get properties from a deleted object.");
|
||
case storage::Error::NONEXISTENT_OBJECT:
|
||
throw query::QueryRuntimeException("Trying to get properties from an object that doesn't exist.");
|
||
case storage::Error::SERIALIZATION_ERROR:
|
||
case storage::Error::VERTEX_HAS_EDGES:
|
||
case storage::Error::PROPERTIES_DISABLED:
|
||
throw QueryRuntimeException("Unexpected error when getting properties.");
|
||
}
|
||
}
|
||
return *maybe_props;
|
||
};
|
||
|
||
auto register_set_property = [&](auto &&returned_old_value, auto key, auto &&new_value) {
|
||
auto old_value = [&]() -> storage::PropertyValue {
|
||
if (!old_values) {
|
||
return std::forward<decltype(returned_old_value)>(returned_old_value);
|
||
}
|
||
|
||
if (auto it = old_values->find(key); it != old_values->end()) {
|
||
return std::move(it->second);
|
||
}
|
||
|
||
return {};
|
||
}();
|
||
|
||
context->trigger_context_collector->RegisterSetObjectProperty(
|
||
*record, key, TypedValue(std::move(old_value)), TypedValue(std::forward<decltype(new_value)>(new_value)));
|
||
};
|
||
|
||
auto update_props = [&, record](PropertiesMap &new_properties) {
|
||
auto updated_properties = UpdatePropertiesChecked(record, new_properties);
|
||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||
context->execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += new_properties.size();
|
||
|
||
if (should_register_change) {
|
||
for (const auto &[id, old_value, new_value] : updated_properties) {
|
||
register_set_property(std::move(old_value), id, std::move(new_value));
|
||
}
|
||
}
|
||
};
|
||
|
||
switch (rhs.type()) {
|
||
case TypedValue::Type::Edge: {
|
||
PropertiesMap new_properties = get_props(rhs.ValueEdge());
|
||
update_props(new_properties);
|
||
break;
|
||
}
|
||
case TypedValue::Type::Vertex: {
|
||
PropertiesMap new_properties = get_props(rhs.ValueVertex());
|
||
update_props(new_properties);
|
||
if (flags::AreExperimentsEnabled(flags::Experiments::TEXT_SEARCH)) {
|
||
context->db_accessor->TextIndexUpdateVertex(rhs.ValueVertex());
|
||
}
|
||
break;
|
||
}
|
||
case TypedValue::Type::Map: {
|
||
PropertiesMap new_properties;
|
||
for (const auto &[string_key, value] : rhs.ValueMap()) {
|
||
storage::PropertyId property_id;
|
||
if (auto it = cached_name_id.find(std::string(string_key)); it != cached_name_id.end()) [[likely]] {
|
||
property_id = it->second;
|
||
} else {
|
||
property_id = context->db_accessor->NameToProperty(string_key);
|
||
cached_name_id.emplace(string_key, property_id);
|
||
}
|
||
new_properties.emplace(property_id, value);
|
||
}
|
||
update_props(new_properties);
|
||
break;
|
||
}
|
||
default:
|
||
throw QueryRuntimeException(
|
||
"Right-hand side in SET expression must be a node, an edge or a "
|
||
"map.");
|
||
}
|
||
|
||
if (should_register_change && old_values) {
|
||
// register removed properties
|
||
for (auto &[property_id, property_value] : *old_values) {
|
||
context->trigger_context_collector->RegisterRemovedObjectProperty(*record, property_id,
|
||
TypedValue(std::move(property_value)));
|
||
}
|
||
}
|
||
}
|
||
|
||
} // namespace
|
||
|
||
bool SetProperties::SetPropertiesCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("SetProperties");
|
||
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
|
||
TypedValue &lhs = frame[self_.input_symbol_];
|
||
|
||
// Set, just like Create needs to see the latest changes.
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::NEW);
|
||
TypedValue rhs = self_.rhs_->Accept(evaluator);
|
||
|
||
switch (lhs.type()) {
|
||
case TypedValue::Type::Vertex:
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!context.auth_checker->Has(lhs.ValueVertex(), storage::View::NEW,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) {
|
||
throw QueryRuntimeException("Vertex properties not set due to not having enough permission!");
|
||
}
|
||
#endif
|
||
SetPropertiesOnRecord(&lhs.ValueVertex(), rhs, self_.op_, &context, cached_name_id_);
|
||
if (flags::AreExperimentsEnabled(flags::Experiments::TEXT_SEARCH)) {
|
||
context.db_accessor->TextIndexUpdateVertex(lhs.ValueVertex());
|
||
}
|
||
break;
|
||
case TypedValue::Type::Edge:
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!context.auth_checker->Has(lhs.ValueEdge(), memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) {
|
||
throw QueryRuntimeException("Edge properties not set due to not having enough permission!");
|
||
}
|
||
#endif
|
||
SetPropertiesOnRecord(&lhs.ValueEdge(), rhs, self_.op_, &context, cached_name_id_);
|
||
break;
|
||
case TypedValue::Type::Null:
|
||
// Skip setting properties on Null (can occur in optional match).
|
||
break;
|
||
default:
|
||
throw QueryRuntimeException("Properties can only be set on edges and vertices.");
|
||
}
|
||
return true;
|
||
}
|
||
|
||
void SetProperties::SetPropertiesCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void SetProperties::SetPropertiesCursor::Reset() { input_cursor_->Reset(); }
|
||
|
||
SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
|
||
std::vector<StorageLabelType> labels)
|
||
: input_(input), input_symbol_(std::move(input_symbol)), labels_(std::move(labels)) {}
|
||
|
||
ACCEPT_WITH_INPUT(SetLabels)
|
||
|
||
UniqueCursorPtr SetLabels::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::SetLabelsOperator);
|
||
|
||
return MakeUniqueCursorPtr<SetLabelsCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> SetLabels::ModifiedSymbols(const SymbolTable &table) const {
|
||
return input_->ModifiedSymbols(table);
|
||
}
|
||
|
||
SetLabels::SetLabelsCursor::SetLabelsCursor(const SetLabels &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)) {}
|
||
|
||
bool SetLabels::SetLabelsCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("SetLabels");
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::NEW);
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
auto labels = EvaluateLabels(self_.labels_, evaluator, context.db_accessor);
|
||
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!context.auth_checker->Has(labels, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)) {
|
||
throw QueryRuntimeException("Couldn't set label due to not having enough permission!");
|
||
}
|
||
#endif
|
||
|
||
TypedValue &vertex_value = frame[self_.input_symbol_];
|
||
// Skip setting labels on Null (can occur in optional match).
|
||
if (vertex_value.IsNull()) return true;
|
||
ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex);
|
||
auto &vertex = vertex_value.ValueVertex();
|
||
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!context.auth_checker->Has(vertex, storage::View::OLD,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) {
|
||
throw QueryRuntimeException("Couldn't set label due to not having enough permission!");
|
||
}
|
||
#endif
|
||
|
||
for (auto label : labels) {
|
||
auto maybe_value = vertex.AddLabel(label);
|
||
if (maybe_value.HasError()) {
|
||
switch (maybe_value.GetError()) {
|
||
case storage::Error::SERIALIZATION_ERROR:
|
||
throw TransactionSerializationException();
|
||
case storage::Error::DELETED_OBJECT:
|
||
throw QueryRuntimeException("Trying to set a label on a deleted node.");
|
||
case storage::Error::VERTEX_HAS_EDGES:
|
||
case storage::Error::PROPERTIES_DISABLED:
|
||
case storage::Error::NONEXISTENT_OBJECT:
|
||
throw QueryRuntimeException("Unexpected error when setting a label.");
|
||
}
|
||
}
|
||
|
||
if (context.trigger_context_collector && *maybe_value) {
|
||
context.trigger_context_collector->RegisterSetVertexLabel(vertex, label);
|
||
}
|
||
}
|
||
|
||
if (flags::AreExperimentsEnabled(flags::Experiments::TEXT_SEARCH)) {
|
||
context.db_accessor->TextIndexUpdateVertex(vertex);
|
||
}
|
||
|
||
return true;
|
||
}
|
||
|
||
void SetLabels::SetLabelsCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void SetLabels::SetLabelsCursor::Reset() { input_cursor_->Reset(); }
|
||
|
||
RemoveProperty::RemoveProperty(const std::shared_ptr<LogicalOperator> &input, storage::PropertyId property,
|
||
PropertyLookup *lhs)
|
||
: input_(input), property_(property), lhs_(lhs) {}
|
||
|
||
ACCEPT_WITH_INPUT(RemoveProperty)
|
||
|
||
UniqueCursorPtr RemoveProperty::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::RemovePropertyOperator);
|
||
|
||
return MakeUniqueCursorPtr<RemovePropertyCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> RemoveProperty::ModifiedSymbols(const SymbolTable &table) const {
|
||
return input_->ModifiedSymbols(table);
|
||
}
|
||
|
||
RemoveProperty::RemovePropertyCursor::RemovePropertyCursor(const RemoveProperty &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)) {}
|
||
|
||
bool RemoveProperty::RemovePropertyCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("RemoveProperty");
|
||
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
|
||
// Remove, just like Delete needs to see the latest changes.
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::NEW);
|
||
TypedValue lhs = self_.lhs_->expression_->Accept(evaluator);
|
||
|
||
auto remove_prop = [property = self_.property_, &context](auto *record) {
|
||
auto maybe_old_value = record->RemoveProperty(property);
|
||
if (maybe_old_value.HasError()) {
|
||
switch (maybe_old_value.GetError()) {
|
||
case storage::Error::DELETED_OBJECT:
|
||
throw QueryRuntimeException("Trying to remove a property on a deleted graph element.");
|
||
case storage::Error::SERIALIZATION_ERROR:
|
||
throw TransactionSerializationException();
|
||
case storage::Error::PROPERTIES_DISABLED:
|
||
throw QueryRuntimeException(
|
||
"Can't remove property because properties on edges are "
|
||
"disabled.");
|
||
case storage::Error::VERTEX_HAS_EDGES:
|
||
case storage::Error::NONEXISTENT_OBJECT:
|
||
throw QueryRuntimeException("Unexpected error when removing property.");
|
||
}
|
||
}
|
||
|
||
if (context.trigger_context_collector) {
|
||
context.trigger_context_collector->RegisterRemovedObjectProperty(*record, property,
|
||
TypedValue(std::move(*maybe_old_value)));
|
||
}
|
||
};
|
||
|
||
switch (lhs.type()) {
|
||
case TypedValue::Type::Vertex:
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!context.auth_checker->Has(lhs.ValueVertex(), storage::View::NEW,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) {
|
||
throw QueryRuntimeException("Vertex property not removed due to not having enough permission!");
|
||
}
|
||
#endif
|
||
remove_prop(&lhs.ValueVertex());
|
||
if (flags::AreExperimentsEnabled(flags::Experiments::TEXT_SEARCH)) {
|
||
context.db_accessor->TextIndexUpdateVertex(lhs.ValueVertex());
|
||
}
|
||
break;
|
||
case TypedValue::Type::Edge:
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!context.auth_checker->Has(lhs.ValueEdge(), memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) {
|
||
throw QueryRuntimeException("Edge property not removed due to not having enough permission!");
|
||
}
|
||
#endif
|
||
remove_prop(&lhs.ValueEdge());
|
||
break;
|
||
case TypedValue::Type::Null:
|
||
// Skip removing properties on Null (can occur in optional match).
|
||
break;
|
||
default:
|
||
throw QueryRuntimeException("Properties can only be removed from vertices and edges.");
|
||
}
|
||
return true;
|
||
}
|
||
|
||
void RemoveProperty::RemovePropertyCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void RemoveProperty::RemovePropertyCursor::Reset() { input_cursor_->Reset(); }
|
||
|
||
RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
|
||
std::vector<StorageLabelType> labels)
|
||
: input_(input), input_symbol_(std::move(input_symbol)), labels_(std::move(labels)) {}
|
||
|
||
ACCEPT_WITH_INPUT(RemoveLabels)
|
||
|
||
UniqueCursorPtr RemoveLabels::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::RemoveLabelsOperator);
|
||
|
||
return MakeUniqueCursorPtr<RemoveLabelsCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> RemoveLabels::ModifiedSymbols(const SymbolTable &table) const {
|
||
return input_->ModifiedSymbols(table);
|
||
}
|
||
|
||
RemoveLabels::RemoveLabelsCursor::RemoveLabelsCursor(const RemoveLabels &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)) {}
|
||
|
||
bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("RemoveLabels");
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::NEW);
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
auto labels = EvaluateLabels(self_.labels_, evaluator, context.db_accessor);
|
||
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!context.auth_checker->Has(labels, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)) {
|
||
throw QueryRuntimeException("Couldn't remove label due to not having enough permission!");
|
||
}
|
||
#endif
|
||
|
||
TypedValue &vertex_value = frame[self_.input_symbol_];
|
||
// Skip removing labels on Null (can occur in optional match).
|
||
if (vertex_value.IsNull()) return true;
|
||
ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex);
|
||
auto &vertex = vertex_value.ValueVertex();
|
||
|
||
#ifdef MG_ENTERPRISE
|
||
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
|
||
!context.auth_checker->Has(vertex, storage::View::OLD,
|
||
memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE)) {
|
||
throw QueryRuntimeException("Couldn't remove label due to not having enough permission!");
|
||
}
|
||
#endif
|
||
|
||
for (auto label : labels) {
|
||
auto maybe_value = vertex.RemoveLabel(label);
|
||
if (maybe_value.HasError()) {
|
||
switch (maybe_value.GetError()) {
|
||
case storage::Error::SERIALIZATION_ERROR:
|
||
throw TransactionSerializationException();
|
||
case storage::Error::DELETED_OBJECT:
|
||
throw QueryRuntimeException("Trying to remove labels from a deleted node.");
|
||
case storage::Error::VERTEX_HAS_EDGES:
|
||
case storage::Error::PROPERTIES_DISABLED:
|
||
case storage::Error::NONEXISTENT_OBJECT:
|
||
throw QueryRuntimeException("Unexpected error when removing labels from a node.");
|
||
}
|
||
}
|
||
|
||
context.execution_stats[ExecutionStats::Key::DELETED_LABELS] += 1;
|
||
if (context.trigger_context_collector && *maybe_value) {
|
||
context.trigger_context_collector->RegisterRemovedVertexLabel(vertex, label);
|
||
}
|
||
}
|
||
|
||
if (flags::AreExperimentsEnabled(flags::Experiments::TEXT_SEARCH)) {
|
||
context.db_accessor->TextIndexUpdateVertex(vertex, EvaluateLabels(self_.labels_, evaluator, context.db_accessor));
|
||
}
|
||
|
||
return true;
|
||
}
|
||
|
||
void RemoveLabels::RemoveLabelsCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void RemoveLabels::RemoveLabelsCursor::Reset() { input_cursor_->Reset(); }
|
||
|
||
EdgeUniquenessFilter::EdgeUniquenessFilter(const std::shared_ptr<LogicalOperator> &input, Symbol expand_symbol,
|
||
const std::vector<Symbol> &previous_symbols)
|
||
: input_(input), expand_symbol_(std::move(expand_symbol)), previous_symbols_(previous_symbols) {}
|
||
|
||
ACCEPT_WITH_INPUT(EdgeUniquenessFilter)
|
||
|
||
UniqueCursorPtr EdgeUniquenessFilter::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::EdgeUniquenessFilterOperator);
|
||
|
||
return MakeUniqueCursorPtr<EdgeUniquenessFilterCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> EdgeUniquenessFilter::ModifiedSymbols(const SymbolTable &table) const {
|
||
return input_->ModifiedSymbols(table);
|
||
}
|
||
|
||
EdgeUniquenessFilter::EdgeUniquenessFilterCursor::EdgeUniquenessFilterCursor(const EdgeUniquenessFilter &self,
|
||
utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)) {}
|
||
|
||
namespace {
|
||
/**
|
||
* Returns true if:
|
||
* - a and b are either edge or edge-list values, and there
|
||
* is at least one matching edge in the two values
|
||
*/
|
||
bool ContainsSameEdge(const TypedValue &a, const TypedValue &b) {
|
||
auto compare_to_list = [](const TypedValue &list, const TypedValue &other) {
|
||
for (const TypedValue &list_elem : list.ValueList())
|
||
if (ContainsSameEdge(list_elem, other)) return true;
|
||
return false;
|
||
};
|
||
|
||
if (a.type() == TypedValue::Type::List) return compare_to_list(a, b);
|
||
if (b.type() == TypedValue::Type::List) return compare_to_list(b, a);
|
||
|
||
return a.ValueEdge() == b.ValueEdge();
|
||
}
|
||
} // namespace
|
||
|
||
bool EdgeUniquenessFilter::EdgeUniquenessFilterCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("EdgeUniquenessFilter");
|
||
|
||
auto expansion_ok = [&]() {
|
||
const auto &expand_value = frame[self_.expand_symbol_];
|
||
for (const auto &previous_symbol : self_.previous_symbols_) {
|
||
const auto &previous_value = frame[previous_symbol];
|
||
// This shouldn't raise a TypedValueException, because the planner
|
||
// makes sure these are all of the expected type. In case they are not
|
||
// an error should be raised long before this code is executed.
|
||
if (ContainsSameEdge(previous_value, expand_value)) return false;
|
||
}
|
||
return true;
|
||
};
|
||
|
||
while (input_cursor_->Pull(frame, context))
|
||
if (expansion_ok()) return true;
|
||
return false;
|
||
}
|
||
|
||
void EdgeUniquenessFilter::EdgeUniquenessFilterCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void EdgeUniquenessFilter::EdgeUniquenessFilterCursor::Reset() { input_cursor_->Reset(); }
|
||
|
||
EmptyResult::EmptyResult(const std::shared_ptr<LogicalOperator> &input)
|
||
: input_(input ? input : std::make_shared<Once>()) {}
|
||
|
||
ACCEPT_WITH_INPUT(EmptyResult)
|
||
|
||
std::vector<Symbol> EmptyResult::OutputSymbols(const SymbolTable &) const { // NOLINT(hicpp-named-parameter)
|
||
return {};
|
||
}
|
||
|
||
std::vector<Symbol> EmptyResult::ModifiedSymbols(const SymbolTable &) const { // NOLINT(hicpp-named-parameter)
|
||
return {};
|
||
}
|
||
|
||
class EmptyResultCursor : public Cursor {
|
||
public:
|
||
EmptyResultCursor(const EmptyResult &self, utils::MemoryResource *mem)
|
||
: input_cursor_(self.input_->MakeCursor(mem)) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
SCOPED_PROFILE_OP("EmptyResult");
|
||
|
||
if (!pulled_all_input_) {
|
||
while (input_cursor_->Pull(frame, context)) {
|
||
AbortCheck(context);
|
||
}
|
||
pulled_all_input_ = true;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override {
|
||
input_cursor_->Reset();
|
||
pulled_all_input_ = false;
|
||
}
|
||
|
||
private:
|
||
const UniqueCursorPtr input_cursor_;
|
||
bool pulled_all_input_{false};
|
||
};
|
||
|
||
UniqueCursorPtr EmptyResult::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::EmptyResultOperator);
|
||
|
||
return MakeUniqueCursorPtr<EmptyResultCursor>(mem, *this, mem);
|
||
}
|
||
|
||
Accumulate::Accumulate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Symbol> &symbols,
|
||
bool advance_command)
|
||
: input_(input), symbols_(symbols), advance_command_(advance_command) {}
|
||
|
||
ACCEPT_WITH_INPUT(Accumulate)
|
||
|
||
std::vector<Symbol> Accumulate::ModifiedSymbols(const SymbolTable &) const { return symbols_; }
|
||
|
||
class AccumulateCursor : public Cursor {
|
||
public:
|
||
AccumulateCursor(const Accumulate &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)), cache_(mem) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("Accumulate");
|
||
|
||
auto &dba = *context.db_accessor;
|
||
// cache all the input
|
||
if (!pulled_all_input_) {
|
||
while (input_cursor_->Pull(frame, context)) {
|
||
utils::pmr::vector<TypedValue> row(cache_.get_allocator().GetMemoryResource());
|
||
row.reserve(self_.symbols_.size());
|
||
for (const Symbol &symbol : self_.symbols_) row.emplace_back(frame[symbol]);
|
||
cache_.emplace_back(std::move(row));
|
||
}
|
||
pulled_all_input_ = true;
|
||
cache_it_ = cache_.begin();
|
||
|
||
if (self_.advance_command_) dba.AdvanceCommand();
|
||
}
|
||
|
||
AbortCheck(context);
|
||
if (cache_it_ == cache_.end()) return false;
|
||
auto row_it = (cache_it_++)->begin();
|
||
for (const Symbol &symbol : self_.symbols_) {
|
||
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) {
|
||
context.frame_change_collector->ResetTrackingValue(symbol.name());
|
||
}
|
||
frame[symbol] = *row_it++;
|
||
}
|
||
return true;
|
||
}
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override {
|
||
input_cursor_->Reset();
|
||
cache_.clear();
|
||
cache_it_ = cache_.begin();
|
||
pulled_all_input_ = false;
|
||
}
|
||
|
||
private:
|
||
const Accumulate &self_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
utils::pmr::deque<utils::pmr::vector<TypedValue>> cache_;
|
||
decltype(cache_.begin()) cache_it_ = cache_.begin();
|
||
bool pulled_all_input_{false};
|
||
};
|
||
|
||
UniqueCursorPtr Accumulate::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::AccumulateOperator);
|
||
|
||
return MakeUniqueCursorPtr<AccumulateCursor>(mem, *this, mem);
|
||
}
|
||
|
||
Aggregate::Aggregate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Aggregate::Element> &aggregations,
|
||
const std::vector<Expression *> &group_by, const std::vector<Symbol> &remember)
|
||
: input_(input ? input : std::make_shared<Once>()),
|
||
aggregations_(aggregations),
|
||
group_by_(group_by),
|
||
remember_(remember) {}
|
||
|
||
ACCEPT_WITH_INPUT(Aggregate)
|
||
|
||
std::vector<Symbol> Aggregate::ModifiedSymbols(const SymbolTable &) const {
|
||
auto symbols = remember_;
|
||
for (const auto &elem : aggregations_) symbols.push_back(elem.output_sym);
|
||
return symbols;
|
||
}
|
||
|
||
namespace {
|
||
/** Returns the default TypedValue for an Aggregation element.
|
||
* This value is valid both for returning when where are no inputs
|
||
* to the aggregation op, and for initializing an aggregation result
|
||
* when there are */
|
||
TypedValue DefaultAggregationOpValue(const Aggregate::Element &element, utils::MemoryResource *memory) {
|
||
switch (element.op) {
|
||
case Aggregation::Op::MIN:
|
||
case Aggregation::Op::MAX:
|
||
case Aggregation::Op::AVG:
|
||
return TypedValue(memory);
|
||
case Aggregation::Op::COUNT:
|
||
case Aggregation::Op::SUM:
|
||
return TypedValue(0, memory);
|
||
case Aggregation::Op::COLLECT_LIST:
|
||
return TypedValue(TypedValue::TVector(memory));
|
||
case Aggregation::Op::COLLECT_MAP:
|
||
return TypedValue(TypedValue::TMap(memory));
|
||
case Aggregation::Op::PROJECT:
|
||
return TypedValue(query::Graph(memory));
|
||
}
|
||
}
|
||
} // namespace
|
||
|
||
class AggregateCursor : public Cursor {
|
||
public:
|
||
AggregateCursor(const Aggregate &self, utils::MemoryResource *mem)
|
||
: self_(self),
|
||
input_cursor_(self_.input_->MakeCursor(mem)),
|
||
aggregation_(mem),
|
||
reused_group_by_(self.group_by_.size(), mem) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(self_);
|
||
|
||
if (!pulled_all_input_) {
|
||
if (!ProcessAll(&frame, &context) && !self_.group_by_.empty()) return false;
|
||
pulled_all_input_ = true;
|
||
aggregation_it_ = aggregation_.begin();
|
||
|
||
if (aggregation_.empty()) {
|
||
auto *pull_memory = context.evaluation_context.memory;
|
||
// place default aggregation values on the frame
|
||
for (const auto &elem : self_.aggregations_) {
|
||
frame[elem.output_sym] = DefaultAggregationOpValue(elem, pull_memory);
|
||
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(elem.output_sym.name())) {
|
||
context.frame_change_collector->ResetTrackingValue(elem.output_sym.name());
|
||
}
|
||
}
|
||
|
||
// place null as remember values on the frame
|
||
for (const Symbol &remember_sym : self_.remember_) {
|
||
frame[remember_sym] = TypedValue(pull_memory);
|
||
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(remember_sym.name())) {
|
||
context.frame_change_collector->ResetTrackingValue(remember_sym.name());
|
||
}
|
||
}
|
||
return true;
|
||
}
|
||
}
|
||
if (aggregation_it_ == aggregation_.end()) return false;
|
||
|
||
// place aggregation values on the frame
|
||
auto aggregation_values_it = aggregation_it_->second.values_.begin();
|
||
for (const auto &aggregation_elem : self_.aggregations_)
|
||
frame[aggregation_elem.output_sym] = *aggregation_values_it++;
|
||
|
||
// place remember values on the frame
|
||
auto remember_values_it = aggregation_it_->second.remember_.begin();
|
||
for (const Symbol &remember_sym : self_.remember_) frame[remember_sym] = *remember_values_it++;
|
||
|
||
aggregation_it_++;
|
||
return true;
|
||
}
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override {
|
||
input_cursor_->Reset();
|
||
aggregation_.clear();
|
||
aggregation_it_ = aggregation_.begin();
|
||
pulled_all_input_ = false;
|
||
}
|
||
|
||
private:
|
||
// Data structure for a single aggregation cache.
|
||
// Does NOT include the group-by values since those are a key in the
|
||
// aggregation map. The vectors in an AggregationValue contain one element for
|
||
// each aggregation in this LogicalOp.
|
||
struct AggregationValue {
|
||
explicit AggregationValue(utils::MemoryResource *mem)
|
||
: counts_(mem), values_(mem), remember_(mem), unique_values_(mem) {}
|
||
|
||
// how many input rows have been aggregated in respective values_ element so
|
||
// far
|
||
// TODO: The counting value type should be changed to an unsigned type once
|
||
// TypedValue can support signed integer values larger than 64bits so that
|
||
// precision isn't lost.
|
||
utils::pmr::vector<int64_t> counts_;
|
||
// aggregated values. Initially Null (until at least one input row with a
|
||
// valid value gets processed)
|
||
utils::pmr::vector<TypedValue> values_;
|
||
// remember values.
|
||
utils::pmr::vector<TypedValue> remember_;
|
||
|
||
using TSet = utils::pmr::unordered_set<TypedValue, TypedValue::Hash, TypedValue::BoolEqual>;
|
||
|
||
utils::pmr::vector<TSet> unique_values_;
|
||
};
|
||
|
||
const Aggregate &self_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
// storage for aggregated data
|
||
// map key is the vector of group-by values
|
||
// map value is an AggregationValue struct
|
||
utils::pmr::unordered_map<utils::pmr::vector<TypedValue>, AggregationValue,
|
||
// use FNV collection hashing specialized for a
|
||
// vector of TypedValues
|
||
utils::FnvCollection<utils::pmr::vector<TypedValue>, TypedValue, TypedValue::Hash>,
|
||
// custom equality
|
||
TypedValueVectorEqual>
|
||
aggregation_;
|
||
// this is a for object reuse, to avoid re-allocating this buffer
|
||
utils::pmr::vector<TypedValue> reused_group_by_;
|
||
// iterator over the accumulated cache
|
||
decltype(aggregation_.begin()) aggregation_it_ = aggregation_.begin();
|
||
// this LogicalOp pulls all from the input on it's first pull
|
||
// this switch tracks if this has been performed
|
||
bool pulled_all_input_{false};
|
||
|
||
/**
|
||
* Pulls from the input operator until exhausted and aggregates the
|
||
* results. If the input operator is not provided, a single call
|
||
* to ProcessOne is issued.
|
||
*
|
||
* Accumulation automatically groups the results so that `aggregation_`
|
||
* cache cardinality depends on number of
|
||
* aggregation results, and not on the number of inputs.
|
||
*/
|
||
bool ProcessAll(Frame *frame, ExecutionContext *context) {
|
||
ExpressionEvaluator evaluator(frame, context->symbol_table, context->evaluation_context, context->db_accessor,
|
||
storage::View::NEW);
|
||
|
||
bool pulled = false;
|
||
while (input_cursor_->Pull(*frame, *context)) {
|
||
ProcessOne(*frame, &evaluator);
|
||
pulled = true;
|
||
}
|
||
if (!pulled) return false;
|
||
|
||
// post processing
|
||
for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) {
|
||
switch (self_.aggregations_[pos].op) {
|
||
case Aggregation::Op::AVG: {
|
||
// calculate AVG aggregations (so far they have only been summed)
|
||
for (auto &kv : aggregation_) {
|
||
AggregationValue &agg_value = kv.second;
|
||
auto count = agg_value.counts_[pos];
|
||
auto *pull_memory = context->evaluation_context.memory;
|
||
if (count > 0) {
|
||
agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory);
|
||
}
|
||
}
|
||
break;
|
||
}
|
||
case Aggregation::Op::COUNT: {
|
||
// Copy counts to be the value
|
||
for (auto &kv : aggregation_) {
|
||
AggregationValue &agg_value = kv.second;
|
||
agg_value.values_[pos] = agg_value.counts_[pos];
|
||
}
|
||
break;
|
||
}
|
||
case Aggregation::Op::MIN:
|
||
case Aggregation::Op::MAX:
|
||
case Aggregation::Op::SUM:
|
||
case Aggregation::Op::COLLECT_LIST:
|
||
case Aggregation::Op::COLLECT_MAP:
|
||
case Aggregation::Op::PROJECT:
|
||
break;
|
||
}
|
||
}
|
||
return true;
|
||
}
|
||
|
||
/**
|
||
* Performs a single accumulation.
|
||
*/
|
||
void ProcessOne(const Frame &frame, ExpressionEvaluator *evaluator) {
|
||
// Preallocated group_by, since most of the time the aggregation key won't be unique
|
||
reused_group_by_.clear();
|
||
evaluator->ResetPropertyLookupCache();
|
||
|
||
for (Expression *expression : self_.group_by_) {
|
||
reused_group_by_.emplace_back(expression->Accept(*evaluator));
|
||
}
|
||
auto *mem = aggregation_.get_allocator().GetMemoryResource();
|
||
auto res = aggregation_.try_emplace(reused_group_by_, mem);
|
||
auto &agg_value = res.first->second;
|
||
if (res.second /*was newly inserted*/) EnsureInitialized(frame, &agg_value);
|
||
Update(evaluator, &agg_value);
|
||
}
|
||
|
||
/** Ensures the new AggregationValue has been initialized. This means
|
||
* that the value vectors are filled with an appropriate number of Nulls,
|
||
* counts are set to 0 and remember values are remembered.
|
||
*/
|
||
void EnsureInitialized(const Frame &frame, AggregateCursor::AggregationValue *agg_value) const {
|
||
if (!agg_value->values_.empty()) return;
|
||
|
||
const auto num_of_aggregations = self_.aggregations_.size();
|
||
agg_value->values_.reserve(num_of_aggregations);
|
||
agg_value->unique_values_.reserve(num_of_aggregations);
|
||
|
||
auto *mem = agg_value->values_.get_allocator().GetMemoryResource();
|
||
for (const auto &agg_elem : self_.aggregations_) {
|
||
agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem));
|
||
agg_value->unique_values_.emplace_back(AggregationValue::TSet(mem));
|
||
}
|
||
agg_value->counts_.resize(num_of_aggregations, 0);
|
||
|
||
agg_value->remember_.reserve(self_.remember_.size());
|
||
for (const Symbol &remember_sym : self_.remember_) {
|
||
agg_value->remember_.push_back(frame[remember_sym]);
|
||
}
|
||
}
|
||
|
||
/** Updates the given AggregationValue with new data. Assumes that
|
||
* the AggregationValue has been initialized */
|
||
void Update(ExpressionEvaluator *evaluator, AggregateCursor::AggregationValue *agg_value) {
|
||
DMG_ASSERT(self_.aggregations_.size() == agg_value->values_.size(),
|
||
"Expected as much AggregationValue.values_ as there are "
|
||
"aggregations.");
|
||
DMG_ASSERT(self_.aggregations_.size() == agg_value->counts_.size(),
|
||
"Expected as much AggregationValue.counts_ as there are "
|
||
"aggregations.");
|
||
|
||
auto count_it = agg_value->counts_.begin();
|
||
auto value_it = agg_value->values_.begin();
|
||
auto unique_values_it = agg_value->unique_values_.begin();
|
||
auto agg_elem_it = self_.aggregations_.begin();
|
||
const auto counts_end = agg_value->counts_.end();
|
||
for (; count_it != counts_end; ++count_it, ++value_it, ++unique_values_it, ++agg_elem_it) {
|
||
// COUNT(*) is the only case where input expression is optional
|
||
// handle it here
|
||
auto input_expr_ptr = agg_elem_it->value;
|
||
if (!input_expr_ptr) {
|
||
*count_it += 1;
|
||
// value is deferred to post-processing
|
||
continue;
|
||
}
|
||
|
||
TypedValue input_value = input_expr_ptr->Accept(*evaluator);
|
||
|
||
// Aggregations skip Null input values.
|
||
if (input_value.IsNull()) continue;
|
||
const auto &agg_op = agg_elem_it->op;
|
||
if (agg_elem_it->distinct) {
|
||
auto insert_result = unique_values_it->insert(input_value);
|
||
if (!insert_result.second) {
|
||
continue;
|
||
}
|
||
}
|
||
*count_it += 1;
|
||
if (*count_it == 1) {
|
||
// first value, nothing to aggregate. check type, set and continue.
|
||
switch (agg_op) {
|
||
case Aggregation::Op::MIN:
|
||
case Aggregation::Op::MAX:
|
||
EnsureOkForMinMax(input_value);
|
||
*value_it = std::move(input_value);
|
||
break;
|
||
case Aggregation::Op::SUM:
|
||
case Aggregation::Op::AVG:
|
||
EnsureOkForAvgSum(input_value);
|
||
*value_it = std::move(input_value);
|
||
break;
|
||
case Aggregation::Op::COUNT:
|
||
// value is deferred to post-processing
|
||
break;
|
||
case Aggregation::Op::COLLECT_LIST:
|
||
value_it->ValueList().push_back(std::move(input_value));
|
||
break;
|
||
case Aggregation::Op::PROJECT: {
|
||
EnsureOkForProject(input_value);
|
||
value_it->ValueGraph().Expand(input_value.ValuePath());
|
||
break;
|
||
}
|
||
case Aggregation::Op::COLLECT_MAP:
|
||
auto key = agg_elem_it->key->Accept(*evaluator);
|
||
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
|
||
value_it->ValueMap().emplace(key.ValueString(), std::move(input_value));
|
||
break;
|
||
}
|
||
continue;
|
||
}
|
||
|
||
// aggregation of existing values
|
||
switch (agg_op) {
|
||
case Aggregation::Op::COUNT:
|
||
// value is deferred to post-processing
|
||
break;
|
||
case Aggregation::Op::MIN: {
|
||
EnsureOkForMinMax(input_value);
|
||
try {
|
||
TypedValue comparison_result = input_value < *value_it;
|
||
// since we skip nulls we either have a valid comparison, or
|
||
// an exception was just thrown above
|
||
// safe to assume a bool TypedValue
|
||
if (comparison_result.ValueBool()) *value_it = std::move(input_value);
|
||
} catch (const TypedValueException &) {
|
||
throw QueryRuntimeException("Unable to get MIN of '{}' and '{}'.", input_value.type(), value_it->type());
|
||
}
|
||
break;
|
||
}
|
||
case Aggregation::Op::MAX: {
|
||
// all comments as for Op::Min
|
||
EnsureOkForMinMax(input_value);
|
||
try {
|
||
TypedValue comparison_result = input_value > *value_it;
|
||
if (comparison_result.ValueBool()) *value_it = std::move(input_value);
|
||
} catch (const TypedValueException &) {
|
||
throw QueryRuntimeException("Unable to get MAX of '{}' and '{}'.", input_value.type(), value_it->type());
|
||
}
|
||
break;
|
||
}
|
||
case Aggregation::Op::AVG:
|
||
// for averaging we sum first and divide by count once all
|
||
// the input has been processed
|
||
case Aggregation::Op::SUM:
|
||
EnsureOkForAvgSum(input_value);
|
||
*value_it = *value_it + input_value;
|
||
break;
|
||
case Aggregation::Op::COLLECT_LIST:
|
||
value_it->ValueList().push_back(std::move(input_value));
|
||
break;
|
||
case Aggregation::Op::PROJECT: {
|
||
EnsureOkForProject(input_value);
|
||
value_it->ValueGraph().Expand(input_value.ValuePath());
|
||
break;
|
||
}
|
||
case Aggregation::Op::COLLECT_MAP:
|
||
auto key = agg_elem_it->key->Accept(*evaluator);
|
||
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
|
||
value_it->ValueMap().emplace(key.ValueString(), std::move(input_value));
|
||
break;
|
||
} // end switch over Aggregation::Op enum
|
||
} // end loop over all aggregations
|
||
}
|
||
|
||
/** Checks if the given TypedValue is legal in MIN and MAX. If not
|
||
* an appropriate exception is thrown. */
|
||
void EnsureOkForMinMax(const TypedValue &value) const {
|
||
switch (value.type()) {
|
||
case TypedValue::Type::Bool:
|
||
case TypedValue::Type::Int:
|
||
case TypedValue::Type::Double:
|
||
case TypedValue::Type::String:
|
||
return;
|
||
default:
|
||
throw QueryRuntimeException(
|
||
"Only boolean, numeric and string values are allowed in "
|
||
"MIN and MAX aggregations.");
|
||
}
|
||
}
|
||
|
||
/** Checks if the given TypedValue is legal in AVG and SUM. If not
|
||
* an appropriate exception is thrown. */
|
||
void EnsureOkForAvgSum(const TypedValue &value) const {
|
||
switch (value.type()) {
|
||
case TypedValue::Type::Int:
|
||
case TypedValue::Type::Double:
|
||
return;
|
||
default:
|
||
throw QueryRuntimeException("Only numeric values allowed in SUM and AVG aggregations.");
|
||
}
|
||
}
|
||
|
||
/** Checks if the given TypedValue is legal in PROJECT and PROJECT_TRANSITIVE. If not
|
||
* an appropriate exception is thrown. */
|
||
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
|
||
void EnsureOkForProject(const TypedValue &value) const {
|
||
switch (value.type()) {
|
||
case TypedValue::Type::Path:
|
||
return;
|
||
default:
|
||
throw QueryRuntimeException("Only path values allowed in PROJECT aggregation.");
|
||
}
|
||
}
|
||
};
|
||
|
||
UniqueCursorPtr Aggregate::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::AggregateOperator);
|
||
|
||
return MakeUniqueCursorPtr<AggregateCursor>(mem, *this, mem);
|
||
}
|
||
|
||
Skip::Skip(const std::shared_ptr<LogicalOperator> &input, Expression *expression)
|
||
: input_(input), expression_(expression) {}
|
||
|
||
ACCEPT_WITH_INPUT(Skip)
|
||
|
||
UniqueCursorPtr Skip::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::SkipOperator);
|
||
|
||
return MakeUniqueCursorPtr<SkipCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> Skip::OutputSymbols(const SymbolTable &symbol_table) const {
|
||
// Propagate this to potential Produce.
|
||
return input_->OutputSymbols(symbol_table);
|
||
}
|
||
|
||
std::vector<Symbol> Skip::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); }
|
||
|
||
Skip::SkipCursor::SkipCursor(const Skip &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {}
|
||
|
||
bool Skip::SkipCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("Skip");
|
||
|
||
while (input_cursor_->Pull(frame, context)) {
|
||
if (to_skip_ == -1) {
|
||
// First successful pull from the input, evaluate the skip expression.
|
||
// The skip expression doesn't contain identifiers so graph view
|
||
// parameter is not important.
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
TypedValue to_skip = self_.expression_->Accept(evaluator);
|
||
if (to_skip.type() != TypedValue::Type::Int)
|
||
throw QueryRuntimeException("Number of elements to skip must be an integer.");
|
||
|
||
to_skip_ = to_skip.ValueInt();
|
||
if (to_skip_ < 0) throw QueryRuntimeException("Number of elements to skip must be non-negative.");
|
||
}
|
||
|
||
if (skipped_++ < to_skip_) continue;
|
||
return true;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
void Skip::SkipCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void Skip::SkipCursor::Reset() {
|
||
input_cursor_->Reset();
|
||
to_skip_ = -1;
|
||
skipped_ = 0;
|
||
}
|
||
|
||
Limit::Limit(const std::shared_ptr<LogicalOperator> &input, Expression *expression)
|
||
: input_(input), expression_(expression) {}
|
||
|
||
ACCEPT_WITH_INPUT(Limit)
|
||
|
||
UniqueCursorPtr Limit::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::LimitOperator);
|
||
|
||
return MakeUniqueCursorPtr<LimitCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> Limit::OutputSymbols(const SymbolTable &symbol_table) const {
|
||
// Propagate this to potential Produce.
|
||
return input_->OutputSymbols(symbol_table);
|
||
}
|
||
|
||
std::vector<Symbol> Limit::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); }
|
||
|
||
Limit::LimitCursor::LimitCursor(const Limit &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self_.input_->MakeCursor(mem)) {}
|
||
|
||
bool Limit::LimitCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("Limit");
|
||
|
||
// We need to evaluate the limit expression before the first input Pull
|
||
// because it might be 0 and thereby we shouldn't Pull from input at all.
|
||
// We can do this before Pulling from the input because the limit expression
|
||
// is not allowed to contain any identifiers.
|
||
if (limit_ == -1) {
|
||
// Limit expression doesn't contain identifiers so graph view is not
|
||
// important.
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
TypedValue limit = self_.expression_->Accept(evaluator);
|
||
if (limit.type() != TypedValue::Type::Int)
|
||
throw QueryRuntimeException("Limit on number of returned elements must be an integer.");
|
||
|
||
limit_ = limit.ValueInt();
|
||
if (limit_ < 0) throw QueryRuntimeException("Limit on number of returned elements must be non-negative.");
|
||
}
|
||
|
||
// check we have not exceeded the limit before pulling
|
||
if (pulled_++ >= limit_) return false;
|
||
|
||
return input_cursor_->Pull(frame, context);
|
||
}
|
||
|
||
void Limit::LimitCursor::Shutdown() { input_cursor_->Shutdown(); }
|
||
|
||
void Limit::LimitCursor::Reset() {
|
||
input_cursor_->Reset();
|
||
limit_ = -1;
|
||
pulled_ = 0;
|
||
}
|
||
|
||
OrderBy::OrderBy(const std::shared_ptr<LogicalOperator> &input, const std::vector<SortItem> &order_by,
|
||
const std::vector<Symbol> &output_symbols)
|
||
: input_(input), output_symbols_(output_symbols) {
|
||
// split the order_by vector into two vectors of orderings and expressions
|
||
std::vector<OrderedTypedValueCompare> ordering;
|
||
ordering.reserve(order_by.size());
|
||
order_by_.reserve(order_by.size());
|
||
for (const auto &ordering_expression_pair : order_by) {
|
||
ordering.emplace_back(ordering_expression_pair.ordering);
|
||
order_by_.emplace_back(ordering_expression_pair.expression);
|
||
}
|
||
compare_ = TypedValueVectorCompare(std::move(ordering));
|
||
}
|
||
|
||
ACCEPT_WITH_INPUT(OrderBy)
|
||
|
||
std::vector<Symbol> OrderBy::OutputSymbols(const SymbolTable &symbol_table) const {
|
||
// Propagate this to potential Produce.
|
||
return input_->OutputSymbols(symbol_table);
|
||
}
|
||
|
||
std::vector<Symbol> OrderBy::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); }
|
||
|
||
class OrderByCursor : public Cursor {
|
||
public:
|
||
OrderByCursor(const OrderBy &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self_.input_->MakeCursor(mem)), cache_(mem) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(self_);
|
||
|
||
if (!did_pull_all_) [[unlikely]] {
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
auto *pull_mem = context.evaluation_context.memory;
|
||
auto *query_mem = cache_.get_allocator().GetMemoryResource();
|
||
|
||
utils::pmr::vector<utils::pmr::vector<TypedValue>> order_by(pull_mem); // Not cached, pull memory
|
||
utils::pmr::vector<utils::pmr::vector<TypedValue>> output(query_mem); // Cached, query memory
|
||
|
||
while (input_cursor_->Pull(frame, context)) {
|
||
// collect the order_by elements
|
||
utils::pmr::vector<TypedValue> order_by_elem(pull_mem);
|
||
order_by_elem.reserve(self_.order_by_.size());
|
||
for (auto const &expression_ptr : self_.order_by_) {
|
||
order_by_elem.emplace_back(expression_ptr->Accept(evaluator));
|
||
}
|
||
order_by.emplace_back(std::move(order_by_elem));
|
||
|
||
// collect the output elements
|
||
utils::pmr::vector<TypedValue> output_elem(query_mem);
|
||
output_elem.reserve(self_.output_symbols_.size());
|
||
for (const Symbol &output_sym : self_.output_symbols_) {
|
||
output_elem.emplace_back(frame[output_sym]);
|
||
}
|
||
output.emplace_back(std::move(output_elem));
|
||
}
|
||
|
||
// sorting with range zip
|
||
// we compare on just the projection of the 1st range (order_by)
|
||
// this will also permute the 2nd range (output)
|
||
ranges::sort(
|
||
ranges::views::zip(order_by, output), self_.compare_.lex_cmp(),
|
||
[](auto const &value) -> auto const & { return std::get<0>(value); });
|
||
|
||
// no longer need the order_by terms
|
||
order_by.clear();
|
||
cache_ = std::move(output);
|
||
|
||
did_pull_all_ = true;
|
||
cache_it_ = cache_.begin();
|
||
}
|
||
|
||
if (cache_it_ == cache_.end()) return false;
|
||
|
||
AbortCheck(context);
|
||
|
||
// place the output values on the frame
|
||
DMG_ASSERT(self_.output_symbols_.size() == cache_it_->size(),
|
||
"Number of values does not match the number of output symbols "
|
||
"in OrderBy");
|
||
auto output_sym_it = self_.output_symbols_.begin();
|
||
for (TypedValue &output : *cache_it_) {
|
||
if (context.frame_change_collector) {
|
||
context.frame_change_collector->ResetTrackingValue(output_sym_it->name());
|
||
}
|
||
frame[*output_sym_it++] = std::move(output);
|
||
}
|
||
cache_it_++;
|
||
return true;
|
||
}
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override {
|
||
input_cursor_->Reset();
|
||
did_pull_all_ = false;
|
||
cache_.clear();
|
||
cache_it_ = cache_.begin();
|
||
}
|
||
|
||
private:
|
||
const OrderBy &self_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
bool did_pull_all_{false};
|
||
// a cache of elements pulled from the input
|
||
// the cache is filled and sorted on first Pull
|
||
utils::pmr::vector<utils::pmr::vector<TypedValue>> cache_;
|
||
// iterator over the cache_, maintains state between Pulls
|
||
decltype(cache_.begin()) cache_it_ = cache_.begin();
|
||
};
|
||
|
||
UniqueCursorPtr OrderBy::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::OrderByOperator);
|
||
|
||
return MakeUniqueCursorPtr<OrderByCursor>(mem, *this, mem);
|
||
}
|
||
|
||
Merge::Merge(const std::shared_ptr<LogicalOperator> &input, const std::shared_ptr<LogicalOperator> &merge_match,
|
||
const std::shared_ptr<LogicalOperator> &merge_create)
|
||
: input_(input ? input : std::make_shared<Once>()), merge_match_(merge_match), merge_create_(merge_create) {}
|
||
|
||
bool Merge::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||
if (visitor.PreVisit(*this)) {
|
||
input_->Accept(visitor) && merge_match_->Accept(visitor) && merge_create_->Accept(visitor);
|
||
}
|
||
return visitor.PostVisit(*this);
|
||
}
|
||
|
||
UniqueCursorPtr Merge::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::MergeOperator);
|
||
|
||
return MakeUniqueCursorPtr<MergeCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> Merge::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
// Match and create branches should have the same symbols, so just take one
|
||
// of them.
|
||
auto my_symbols = merge_match_->OutputSymbols(table);
|
||
symbols.insert(symbols.end(), my_symbols.begin(), my_symbols.end());
|
||
return symbols;
|
||
}
|
||
|
||
Merge::MergeCursor::MergeCursor(const Merge &self, utils::MemoryResource *mem)
|
||
: input_cursor_(self.input_->MakeCursor(mem)),
|
||
merge_match_cursor_(self.merge_match_->MakeCursor(mem)),
|
||
merge_create_cursor_(self.merge_create_->MakeCursor(mem)) {}
|
||
|
||
bool Merge::MergeCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("Merge");
|
||
|
||
while (true) {
|
||
if (pull_input_) {
|
||
if (input_cursor_->Pull(frame, context)) {
|
||
// after a successful input from the input
|
||
// reset merge_match (it's expand iterators maintain state)
|
||
// and merge_create (could have a Once at the beginning)
|
||
merge_match_cursor_->Reset();
|
||
merge_create_cursor_->Reset();
|
||
} else
|
||
// input is exhausted, we're done
|
||
return false;
|
||
}
|
||
|
||
// pull from the merge_match cursor
|
||
if (merge_match_cursor_->Pull(frame, context)) {
|
||
// if successful, next Pull from this should not pull_input_
|
||
pull_input_ = false;
|
||
return true;
|
||
} else {
|
||
// failed to Pull from the merge_match cursor
|
||
if (pull_input_) {
|
||
// if we have just now pulled from the input
|
||
// and failed to pull from merge_match, we should create
|
||
return merge_create_cursor_->Pull(frame, context);
|
||
}
|
||
// We have exhausted merge_match_cursor_ after 1 or more successful
|
||
// Pulls. Attempt next input_cursor_ pull
|
||
pull_input_ = true;
|
||
continue;
|
||
}
|
||
}
|
||
}
|
||
|
||
void Merge::MergeCursor::Shutdown() {
|
||
input_cursor_->Shutdown();
|
||
merge_match_cursor_->Shutdown();
|
||
merge_create_cursor_->Shutdown();
|
||
}
|
||
|
||
void Merge::MergeCursor::Reset() {
|
||
input_cursor_->Reset();
|
||
merge_match_cursor_->Reset();
|
||
merge_create_cursor_->Reset();
|
||
pull_input_ = true;
|
||
}
|
||
|
||
Optional::Optional(const std::shared_ptr<LogicalOperator> &input, const std::shared_ptr<LogicalOperator> &optional,
|
||
const std::vector<Symbol> &optional_symbols)
|
||
: input_(input ? input : std::make_shared<Once>()), optional_(optional), optional_symbols_(optional_symbols) {}
|
||
|
||
bool Optional::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||
if (visitor.PreVisit(*this)) {
|
||
input_->Accept(visitor) && optional_->Accept(visitor);
|
||
}
|
||
return visitor.PostVisit(*this);
|
||
}
|
||
|
||
UniqueCursorPtr Optional::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::OptionalOperator);
|
||
|
||
return MakeUniqueCursorPtr<OptionalCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> Optional::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
auto my_symbols = optional_->ModifiedSymbols(table);
|
||
symbols.insert(symbols.end(), my_symbols.begin(), my_symbols.end());
|
||
return symbols;
|
||
}
|
||
|
||
Optional::OptionalCursor::OptionalCursor(const Optional &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)), optional_cursor_(self.optional_->MakeCursor(mem)) {}
|
||
|
||
bool Optional::OptionalCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("Optional");
|
||
|
||
while (true) {
|
||
if (pull_input_) {
|
||
if (input_cursor_->Pull(frame, context)) {
|
||
// after a successful input from the input
|
||
// reset optional_ (it's expand iterators maintain state)
|
||
optional_cursor_->Reset();
|
||
} else
|
||
// input is exhausted, we're done
|
||
return false;
|
||
}
|
||
|
||
// pull from the optional_ cursor
|
||
if (optional_cursor_->Pull(frame, context)) {
|
||
// if successful, next Pull from this should not pull_input_
|
||
pull_input_ = false;
|
||
return true;
|
||
} else {
|
||
// failed to Pull from the merge_match cursor
|
||
if (pull_input_) {
|
||
// if we have just now pulled from the input
|
||
// and failed to pull from optional_ so set the
|
||
// optional symbols to Null, ensure next time the
|
||
// input gets pulled and return true
|
||
for (const Symbol &sym : self_.optional_symbols_) frame[sym] = TypedValue(context.evaluation_context.memory);
|
||
pull_input_ = true;
|
||
return true;
|
||
}
|
||
// we have exhausted optional_cursor_ after 1 or more successful Pulls
|
||
// attempt next input_cursor_ pull
|
||
pull_input_ = true;
|
||
continue;
|
||
}
|
||
}
|
||
}
|
||
|
||
void Optional::OptionalCursor::Shutdown() {
|
||
input_cursor_->Shutdown();
|
||
optional_cursor_->Shutdown();
|
||
}
|
||
|
||
void Optional::OptionalCursor::Reset() {
|
||
input_cursor_->Reset();
|
||
optional_cursor_->Reset();
|
||
pull_input_ = true;
|
||
}
|
||
|
||
Unwind::Unwind(const std::shared_ptr<LogicalOperator> &input, Expression *input_expression, Symbol output_symbol)
|
||
: input_(input ? input : std::make_shared<Once>()),
|
||
input_expression_(input_expression),
|
||
output_symbol_(std::move(output_symbol)) {}
|
||
|
||
ACCEPT_WITH_INPUT(Unwind)
|
||
|
||
std::vector<Symbol> Unwind::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
symbols.emplace_back(output_symbol_);
|
||
return symbols;
|
||
}
|
||
|
||
class UnwindCursor : public Cursor {
|
||
public:
|
||
UnwindCursor(const Unwind &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)), input_value_(mem) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("Unwind");
|
||
while (true) {
|
||
AbortCheck(context);
|
||
// if we reached the end of our list of values
|
||
// pull from the input
|
||
if (input_value_it_ == input_value_.end()) {
|
||
if (!input_cursor_->Pull(frame, context)) return false;
|
||
|
||
// successful pull from input, initialize value and iterator
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
TypedValue input_value = self_.input_expression_->Accept(evaluator);
|
||
if (input_value.type() != TypedValue::Type::List)
|
||
throw QueryRuntimeException("Argument of UNWIND must be a list, but '{}' was provided.", input_value.type());
|
||
// Move the evaluted input_value_list to our vector.
|
||
input_value_ = std::move(input_value.ValueList());
|
||
input_value_it_ = input_value_.begin();
|
||
}
|
||
|
||
// if we reached the end of our list of values goto back to top
|
||
if (input_value_it_ == input_value_.end()) continue;
|
||
|
||
frame[self_.output_symbol_] = std::move(*input_value_it_++);
|
||
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(self_.output_symbol_.name_)) {
|
||
context.frame_change_collector->ResetTrackingValue(self_.output_symbol_.name_);
|
||
}
|
||
return true;
|
||
}
|
||
}
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override {
|
||
input_cursor_->Reset();
|
||
input_value_.clear();
|
||
input_value_it_ = input_value_.end();
|
||
}
|
||
|
||
private:
|
||
const Unwind &self_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
// typed values we are unwinding and yielding
|
||
utils::pmr::vector<TypedValue> input_value_;
|
||
// current position in input_value_
|
||
decltype(input_value_)::iterator input_value_it_ = input_value_.end();
|
||
};
|
||
|
||
UniqueCursorPtr Unwind::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::UnwindOperator);
|
||
|
||
return MakeUniqueCursorPtr<UnwindCursor>(mem, *this, mem);
|
||
}
|
||
|
||
class DistinctCursor : public Cursor {
|
||
public:
|
||
DistinctCursor(const Distinct &self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self.input_->MakeCursor(mem)), seen_rows_(mem) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("Distinct");
|
||
|
||
while (true) {
|
||
if (!input_cursor_->Pull(frame, context)) {
|
||
// Nothing left to pull, we can dispose of seen_rows now
|
||
seen_rows_.clear();
|
||
return false;
|
||
}
|
||
|
||
utils::pmr::vector<TypedValue> row(seen_rows_.get_allocator().GetMemoryResource());
|
||
row.reserve(self_.value_symbols_.size());
|
||
|
||
for (const auto &symbol : self_.value_symbols_) {
|
||
row.emplace_back(frame.at(symbol));
|
||
}
|
||
|
||
if (seen_rows_.insert(std::move(row)).second) {
|
||
return true;
|
||
}
|
||
}
|
||
}
|
||
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
void Reset() override {
|
||
input_cursor_->Reset();
|
||
seen_rows_.clear();
|
||
}
|
||
|
||
private:
|
||
const Distinct &self_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
// a set of already seen rows
|
||
utils::pmr::unordered_set<utils::pmr::vector<TypedValue>,
|
||
// use FNV collection hashing specialized for a
|
||
// vector of TypedValue
|
||
utils::FnvCollection<utils::pmr::vector<TypedValue>, TypedValue, TypedValue::Hash>,
|
||
TypedValueVectorEqual>
|
||
seen_rows_;
|
||
};
|
||
|
||
Distinct::Distinct(const std::shared_ptr<LogicalOperator> &input, const std::vector<Symbol> &value_symbols)
|
||
: input_(input ? input : std::make_shared<Once>()), value_symbols_(value_symbols) {}
|
||
|
||
ACCEPT_WITH_INPUT(Distinct)
|
||
|
||
UniqueCursorPtr Distinct::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::DistinctOperator);
|
||
|
||
return MakeUniqueCursorPtr<DistinctCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> Distinct::OutputSymbols(const SymbolTable &symbol_table) const {
|
||
// Propagate this to potential Produce.
|
||
return input_->OutputSymbols(symbol_table);
|
||
}
|
||
|
||
std::vector<Symbol> Distinct::ModifiedSymbols(const SymbolTable &table) const { return input_->ModifiedSymbols(table); }
|
||
|
||
Union::Union(const std::shared_ptr<LogicalOperator> &left_op, const std::shared_ptr<LogicalOperator> &right_op,
|
||
const std::vector<Symbol> &union_symbols, const std::vector<Symbol> &left_symbols,
|
||
const std::vector<Symbol> &right_symbols)
|
||
: left_op_(left_op),
|
||
right_op_(right_op),
|
||
union_symbols_(union_symbols),
|
||
left_symbols_(left_symbols),
|
||
right_symbols_(right_symbols) {}
|
||
|
||
UniqueCursorPtr Union::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::UnionOperator);
|
||
|
||
return MakeUniqueCursorPtr<Union::UnionCursor>(mem, *this, mem);
|
||
}
|
||
|
||
bool Union::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||
if (visitor.PreVisit(*this)) {
|
||
if (left_op_->Accept(visitor)) {
|
||
right_op_->Accept(visitor);
|
||
}
|
||
}
|
||
return visitor.PostVisit(*this);
|
||
}
|
||
|
||
std::vector<Symbol> Union::OutputSymbols(const SymbolTable &) const { return union_symbols_; }
|
||
|
||
std::vector<Symbol> Union::ModifiedSymbols(const SymbolTable &) const { return union_symbols_; }
|
||
|
||
WITHOUT_SINGLE_INPUT(Union);
|
||
|
||
Union::UnionCursor::UnionCursor(const Union &self, utils::MemoryResource *mem)
|
||
: self_(self), left_cursor_(self.left_op_->MakeCursor(mem)), right_cursor_(self.right_op_->MakeCursor(mem)) {}
|
||
|
||
bool Union::UnionCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(self_);
|
||
|
||
utils::pmr::unordered_map<std::string, TypedValue> results(context.evaluation_context.memory);
|
||
if (left_cursor_->Pull(frame, context)) {
|
||
// collect values from the left child
|
||
for (const auto &output_symbol : self_.left_symbols_) {
|
||
results[output_symbol.name()] = frame[output_symbol];
|
||
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(output_symbol.name())) {
|
||
context.frame_change_collector->ResetTrackingValue(output_symbol.name());
|
||
}
|
||
}
|
||
} else if (right_cursor_->Pull(frame, context)) {
|
||
// collect values from the right child
|
||
for (const auto &output_symbol : self_.right_symbols_) {
|
||
results[output_symbol.name()] = frame[output_symbol];
|
||
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(output_symbol.name())) {
|
||
context.frame_change_collector->ResetTrackingValue(output_symbol.name());
|
||
}
|
||
}
|
||
} else {
|
||
return false;
|
||
}
|
||
|
||
// put collected values on frame under union symbols
|
||
for (const auto &symbol : self_.union_symbols_) {
|
||
frame[symbol] = results[symbol.name()];
|
||
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) {
|
||
context.frame_change_collector->ResetTrackingValue(symbol.name());
|
||
}
|
||
}
|
||
return true;
|
||
}
|
||
|
||
void Union::UnionCursor::Shutdown() {
|
||
left_cursor_->Shutdown();
|
||
right_cursor_->Shutdown();
|
||
}
|
||
|
||
void Union::UnionCursor::Reset() {
|
||
left_cursor_->Reset();
|
||
right_cursor_->Reset();
|
||
}
|
||
|
||
std::vector<Symbol> Cartesian::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = left_op_->ModifiedSymbols(table);
|
||
auto right = right_op_->ModifiedSymbols(table);
|
||
symbols.insert(symbols.end(), right.begin(), right.end());
|
||
return symbols;
|
||
}
|
||
|
||
bool Cartesian::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||
if (visitor.PreVisit(*this)) {
|
||
left_op_->Accept(visitor) && right_op_->Accept(visitor);
|
||
}
|
||
return visitor.PostVisit(*this);
|
||
}
|
||
|
||
WITHOUT_SINGLE_INPUT(Cartesian);
|
||
|
||
namespace {
|
||
|
||
class CartesianCursor : public Cursor {
|
||
public:
|
||
CartesianCursor(const Cartesian &self, utils::MemoryResource *mem)
|
||
: self_(self),
|
||
left_op_frames_(mem),
|
||
right_op_frame_(mem),
|
||
left_op_cursor_(self.left_op_->MakeCursor(mem)),
|
||
right_op_cursor_(self_.right_op_->MakeCursor(mem)) {
|
||
MG_ASSERT(left_op_cursor_ != nullptr, "CartesianCursor: Missing left operator cursor.");
|
||
MG_ASSERT(right_op_cursor_ != nullptr, "CartesianCursor: Missing right operator cursor.");
|
||
}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(self_);
|
||
|
||
if (!cartesian_pull_initialized_) {
|
||
// Pull all left_op frames.
|
||
while (left_op_cursor_->Pull(frame, context)) {
|
||
left_op_frames_.emplace_back(frame.elems().begin(), frame.elems().end());
|
||
}
|
||
|
||
// We're setting the iterator to 'end' here so it pulls the right
|
||
// cursor.
|
||
left_op_frames_it_ = left_op_frames_.end();
|
||
cartesian_pull_initialized_ = true;
|
||
}
|
||
|
||
// If left operator yielded zero results there is no cartesian product.
|
||
if (left_op_frames_.empty()) {
|
||
return false;
|
||
}
|
||
|
||
auto restore_frame = [&frame, &context](const auto &symbols, const auto &restore_from) {
|
||
for (const auto &symbol : symbols) {
|
||
frame[symbol] = restore_from[symbol.position()];
|
||
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) {
|
||
context.frame_change_collector->ResetTrackingValue(symbol.name());
|
||
}
|
||
}
|
||
};
|
||
|
||
if (left_op_frames_it_ == left_op_frames_.end()) {
|
||
// Advance right_op_cursor_.
|
||
if (!right_op_cursor_->Pull(frame, context)) return false;
|
||
|
||
right_op_frame_.assign(frame.elems().begin(), frame.elems().end());
|
||
left_op_frames_it_ = left_op_frames_.begin();
|
||
} else {
|
||
// Make sure right_op_cursor last pulled results are on frame.
|
||
restore_frame(self_.right_symbols_, right_op_frame_);
|
||
}
|
||
|
||
AbortCheck(context);
|
||
|
||
restore_frame(self_.left_symbols_, *left_op_frames_it_);
|
||
left_op_frames_it_++;
|
||
return true;
|
||
}
|
||
|
||
void Shutdown() override {
|
||
left_op_cursor_->Shutdown();
|
||
right_op_cursor_->Shutdown();
|
||
}
|
||
|
||
void Reset() override {
|
||
left_op_cursor_->Reset();
|
||
right_op_cursor_->Reset();
|
||
right_op_frame_.clear();
|
||
left_op_frames_.clear();
|
||
left_op_frames_it_ = left_op_frames_.end();
|
||
cartesian_pull_initialized_ = false;
|
||
}
|
||
|
||
private:
|
||
const Cartesian &self_;
|
||
utils::pmr::vector<utils::pmr::vector<TypedValue>> left_op_frames_;
|
||
utils::pmr::vector<TypedValue> right_op_frame_;
|
||
const UniqueCursorPtr left_op_cursor_;
|
||
const UniqueCursorPtr right_op_cursor_;
|
||
utils::pmr::vector<utils::pmr::vector<TypedValue>>::iterator left_op_frames_it_;
|
||
bool cartesian_pull_initialized_{false};
|
||
};
|
||
|
||
} // namespace
|
||
|
||
UniqueCursorPtr Cartesian::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::CartesianOperator);
|
||
|
||
return MakeUniqueCursorPtr<CartesianCursor>(mem, *this, mem);
|
||
}
|
||
|
||
OutputTable::OutputTable(std::vector<Symbol> output_symbols, std::vector<std::vector<TypedValue>> rows)
|
||
: output_symbols_(std::move(output_symbols)), callback_([rows](Frame *, ExecutionContext *) { return rows; }) {}
|
||
|
||
OutputTable::OutputTable(std::vector<Symbol> output_symbols,
|
||
std::function<std::vector<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback)
|
||
: output_symbols_(std::move(output_symbols)), callback_(std::move(callback)) {}
|
||
|
||
WITHOUT_SINGLE_INPUT(OutputTable);
|
||
|
||
class OutputTableCursor : public Cursor {
|
||
public:
|
||
explicit OutputTableCursor(const OutputTable &self) : self_(self) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
if (!pulled_) {
|
||
rows_ = self_.callback_(&frame, &context);
|
||
for (const auto &row : rows_) {
|
||
MG_ASSERT(row.size() == self_.output_symbols_.size(), "Wrong number of columns in row!");
|
||
}
|
||
pulled_ = true;
|
||
}
|
||
if (current_row_ < rows_.size()) {
|
||
for (size_t i = 0; i < self_.output_symbols_.size(); ++i) {
|
||
frame[self_.output_symbols_[i]] = rows_[current_row_][i];
|
||
if (context.frame_change_collector &&
|
||
context.frame_change_collector->IsKeyTracked(self_.output_symbols_[i].name())) {
|
||
context.frame_change_collector->ResetTrackingValue(self_.output_symbols_[i].name());
|
||
}
|
||
}
|
||
current_row_++;
|
||
return true;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
void Reset() override {
|
||
pulled_ = false;
|
||
current_row_ = 0;
|
||
rows_.clear();
|
||
}
|
||
|
||
void Shutdown() override {}
|
||
|
||
private:
|
||
const OutputTable &self_;
|
||
size_t current_row_{0};
|
||
std::vector<std::vector<TypedValue>> rows_;
|
||
bool pulled_{false};
|
||
};
|
||
|
||
UniqueCursorPtr OutputTable::MakeCursor(utils::MemoryResource *mem) const {
|
||
return MakeUniqueCursorPtr<OutputTableCursor>(mem, *this);
|
||
}
|
||
|
||
OutputTableStream::OutputTableStream(
|
||
std::vector<Symbol> output_symbols,
|
||
std::function<std::optional<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback)
|
||
: output_symbols_(std::move(output_symbols)), callback_(std::move(callback)) {}
|
||
|
||
WITHOUT_SINGLE_INPUT(OutputTableStream);
|
||
|
||
class OutputTableStreamCursor : public Cursor {
|
||
public:
|
||
explicit OutputTableStreamCursor(const OutputTableStream *self) : self_(self) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
const auto row = self_->callback_(&frame, &context);
|
||
if (row) {
|
||
MG_ASSERT(row->size() == self_->output_symbols_.size(), "Wrong number of columns in row!");
|
||
for (size_t i = 0; i < self_->output_symbols_.size(); ++i) {
|
||
frame[self_->output_symbols_[i]] = row->at(i);
|
||
if (context.frame_change_collector &&
|
||
context.frame_change_collector->IsKeyTracked(self_->output_symbols_[i].name())) {
|
||
context.frame_change_collector->ResetTrackingValue(self_->output_symbols_[i].name());
|
||
}
|
||
}
|
||
return true;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
// TODO(tsabolcec): Come up with better approach for handling `Reset()`.
|
||
// One possibility is to implement a custom closure utility class with
|
||
// `Reset()` method.
|
||
void Reset() override { throw utils::NotYetImplemented("OutputTableStreamCursor::Reset"); }
|
||
|
||
void Shutdown() override {}
|
||
|
||
private:
|
||
const OutputTableStream *self_;
|
||
};
|
||
|
||
UniqueCursorPtr OutputTableStream::MakeCursor(utils::MemoryResource *mem) const {
|
||
return MakeUniqueCursorPtr<OutputTableStreamCursor>(mem, this);
|
||
}
|
||
|
||
CallProcedure::CallProcedure(std::shared_ptr<LogicalOperator> input, std::string name, std::vector<Expression *> args,
|
||
std::vector<std::string> fields, std::vector<Symbol> symbols, Expression *memory_limit,
|
||
size_t memory_scale, bool is_write, int64_t procedure_id, bool void_procedure)
|
||
: input_(input ? input : std::make_shared<Once>()),
|
||
procedure_name_(std::move(name)),
|
||
arguments_(std::move(args)),
|
||
result_fields_(std::move(fields)),
|
||
result_symbols_(std::move(symbols)),
|
||
memory_limit_(memory_limit),
|
||
memory_scale_(memory_scale),
|
||
is_write_(is_write),
|
||
procedure_id_(procedure_id),
|
||
void_procedure_(void_procedure) {}
|
||
|
||
ACCEPT_WITH_INPUT(CallProcedure);
|
||
|
||
std::vector<Symbol> CallProcedure::OutputSymbols(const SymbolTable &) const { return result_symbols_; }
|
||
|
||
std::vector<Symbol> CallProcedure::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
symbols.insert(symbols.end(), result_symbols_.begin(), result_symbols_.end());
|
||
return symbols;
|
||
}
|
||
|
||
void CallProcedure::IncrementCounter(const std::string &procedure_name) {
|
||
procedure_counters_.WithLock([&](auto &counters) { ++counters[procedure_name]; });
|
||
}
|
||
|
||
std::unordered_map<std::string, int64_t> CallProcedure::GetAndResetCounters() {
|
||
auto counters = procedure_counters_.Lock();
|
||
auto ret = std::move(*counters);
|
||
counters->clear();
|
||
return ret;
|
||
}
|
||
|
||
namespace {
|
||
|
||
void CallCustomProcedure(const std::string_view fully_qualified_procedure_name, const mgp_proc &proc,
|
||
const std::vector<Expression *> &args, mgp_graph &graph, ExpressionEvaluator *evaluator,
|
||
utils::MemoryResource *memory, std::optional<size_t> memory_limit, mgp_result *result,
|
||
int64_t procedure_id, uint64_t transaction_id, const bool call_initializer = false) {
|
||
static_assert(std::uses_allocator_v<mgp_value, utils::Allocator<mgp_value>>,
|
||
"Expected mgp_value to use custom allocator and makes STL "
|
||
"containers aware of that");
|
||
// Build and type check procedure arguments.
|
||
mgp_list proc_args(memory);
|
||
std::vector<TypedValue> args_list;
|
||
args_list.reserve(args.size());
|
||
for (auto *expression : args) {
|
||
args_list.emplace_back(expression->Accept(*evaluator));
|
||
}
|
||
std::optional<query::Graph> subgraph;
|
||
std::optional<query::SubgraphDbAccessor> db_acc;
|
||
|
||
if (!args_list.empty() && args_list.front().type() == TypedValue::Type::Graph) {
|
||
auto subgraph_value = args_list.front().ValueGraph();
|
||
subgraph = query::Graph(std::move(subgraph_value), subgraph_value.GetMemoryResource());
|
||
args_list.erase(args_list.begin());
|
||
|
||
db_acc = query::SubgraphDbAccessor(*std::get<query::DbAccessor *>(graph.impl), &*subgraph);
|
||
graph.impl = &*db_acc;
|
||
}
|
||
|
||
procedure::ConstructArguments(args_list, proc, fully_qualified_procedure_name, proc_args, graph);
|
||
if (call_initializer) {
|
||
MG_ASSERT(proc.initializer);
|
||
mgp_memory initializer_memory{memory};
|
||
proc.initializer.value()(&proc_args, &graph, &initializer_memory);
|
||
}
|
||
if (memory_limit) {
|
||
SPDLOG_INFO("Running '{}' with memory limit of {}", fully_qualified_procedure_name,
|
||
utils::GetReadableSize(*memory_limit));
|
||
// Only allocations which can leak memory are
|
||
// our own mgp object allocations. Jemalloc can track
|
||
// memory correctly, but some memory may not be released
|
||
// immediately, so we want to give user info on leak still
|
||
// considering our allocations
|
||
utils::MemoryTrackingResource memory_tracking_resource{memory, *memory_limit};
|
||
// if we are already tracking, no harm no faul
|
||
// if we are not tracking, we need to start now, with unlimited memory
|
||
// for query, but limited for procedure
|
||
|
||
// check if transaction is tracked currently, so we
|
||
// can disable tracking on that arena if it is not
|
||
// once we are done with procedure tracking
|
||
|
||
bool is_transaction_tracked = memgraph::memory::IsTransactionTracked(transaction_id);
|
||
|
||
if (!is_transaction_tracked) {
|
||
// start tracking with unlimited limit on query
|
||
// which is same as not being tracked at all
|
||
memgraph::memory::TryStartTrackingOnTransaction(transaction_id, memgraph::memory::UNLIMITED_MEMORY);
|
||
}
|
||
memgraph::memory::StartTrackingCurrentThreadTransaction(transaction_id);
|
||
|
||
// due to mgp_batch_read_proc and mgp_batch_write_proc
|
||
// we can return to execution without exhausting whole
|
||
// memory. Here we need to update tracking
|
||
memgraph::memory::CreateOrContinueProcedureTracking(transaction_id, procedure_id, *memory_limit);
|
||
|
||
mgp_memory proc_memory{&memory_tracking_resource};
|
||
MG_ASSERT(result->signature == &proc.results);
|
||
|
||
utils::OnScopeExit on_scope_exit{[transaction_id = transaction_id]() {
|
||
memgraph::memory::StopTrackingCurrentThreadTransaction(transaction_id);
|
||
memgraph::memory::PauseProcedureTracking(transaction_id);
|
||
}};
|
||
|
||
// TODO: What about cross library boundary exceptions? OMG C++?!
|
||
proc.cb(&proc_args, &graph, result, &proc_memory);
|
||
|
||
auto leaked_bytes = memory_tracking_resource.GetAllocatedBytes();
|
||
if (leaked_bytes > 0U) {
|
||
spdlog::warn("Query procedure '{}' leaked {} *tracked* bytes", fully_qualified_procedure_name, leaked_bytes);
|
||
}
|
||
} else {
|
||
// TODO: Add a tracking MemoryResource without limits, so that we report
|
||
// memory leaks in procedure.
|
||
mgp_memory proc_memory{memory};
|
||
MG_ASSERT(result->signature == &proc.results);
|
||
// TODO: What about cross library boundary exceptions? OMG C++?!
|
||
proc.cb(&proc_args, &graph, result, &proc_memory);
|
||
}
|
||
}
|
||
|
||
} // namespace
|
||
|
||
class CallProcedureCursor : public Cursor {
|
||
const CallProcedure *self_;
|
||
UniqueCursorPtr input_cursor_;
|
||
mgp_result *result_;
|
||
decltype(result_->rows.end()) result_row_it_{result_->rows.end()};
|
||
size_t result_signature_size_{0};
|
||
bool stream_exhausted{true};
|
||
bool call_initializer{false};
|
||
std::optional<std::function<void()>> cleanup_{std::nullopt};
|
||
|
||
public:
|
||
CallProcedureCursor(const CallProcedure *self, utils::MemoryResource *mem)
|
||
: self_(self),
|
||
input_cursor_(self_->input_->MakeCursor(mem)),
|
||
// result_ needs to live throughout multiple Pull evaluations, until all
|
||
// rows are produced. We don't use the memory dedicated for QueryExecution (and Frame),
|
||
// but memory dedicated for procedure to wipe result_ and everything allocated in procedure all at once.
|
||
result_(utils::Allocator<mgp_result>(self_->memory_resource)
|
||
.new_object<mgp_result>(nullptr, self_->memory_resource)) {
|
||
MG_ASSERT(self_->result_fields_.size() == self_->result_symbols_.size(), "Incorrectly constructed CallProcedure");
|
||
}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(*self_);
|
||
|
||
AbortCheck(context);
|
||
|
||
auto skip_rows_with_deleted_values = [this]() {
|
||
while (result_row_it_ != result_->rows.end() && result_row_it_->has_deleted_values) {
|
||
++result_row_it_;
|
||
}
|
||
};
|
||
|
||
// We need to fetch new procedure results after pulling from input.
|
||
// TODO: Look into openCypher's distinction between procedures returning an
|
||
// empty result set vs procedures which return `void`. We currently don't
|
||
// have procedures registering what they return.
|
||
// This `while` loop will skip over empty results.
|
||
while (result_row_it_ == result_->rows.end()) {
|
||
// It might be a good idea to resolve the procedure name once, at the
|
||
// start. Unfortunately, this could deadlock if we tried to invoke a
|
||
// procedure from a module (read lock) and reload a module (write lock)
|
||
// inside the same execution thread. Also, our RWLock is set up so that
|
||
// it's not possible for a single thread to request multiple read locks.
|
||
// Builtin module registration in query/procedure/module.cpp depends on
|
||
// this locking scheme.
|
||
const auto &maybe_found = procedure::FindProcedure(procedure::gModuleRegistry, self_->procedure_name_,
|
||
context.evaluation_context.memory);
|
||
if (!maybe_found) {
|
||
throw QueryRuntimeException("There is no procedure named '{}'.", self_->procedure_name_);
|
||
}
|
||
const auto &[module, proc] = *maybe_found;
|
||
if (proc->info.is_write != self_->is_write_) {
|
||
auto get_proc_type_str = [](bool is_write) { return is_write ? "write" : "read"; };
|
||
throw QueryRuntimeException("The procedure named '{}' was a {} procedure, but changed to be a {} procedure.",
|
||
self_->procedure_name_, get_proc_type_str(self_->is_write_),
|
||
get_proc_type_str(proc->info.is_write));
|
||
}
|
||
if (!proc->info.is_batched) {
|
||
stream_exhausted = true;
|
||
}
|
||
|
||
if (stream_exhausted) {
|
||
if (!input_cursor_->Pull(frame, context)) {
|
||
if (proc->cleanup) {
|
||
proc->cleanup.value()();
|
||
}
|
||
return false;
|
||
}
|
||
stream_exhausted = false;
|
||
if (proc->initializer) {
|
||
call_initializer = true;
|
||
MG_ASSERT(proc->cleanup);
|
||
proc->cleanup.value()();
|
||
}
|
||
}
|
||
if (!cleanup_ && proc->cleanup) [[unlikely]] {
|
||
cleanup_.emplace(*proc->cleanup);
|
||
}
|
||
// Unpluging memory without calling destruct on each object since everything was allocated with this memory
|
||
// resource
|
||
self_->monotonic_memory.Release();
|
||
result_ =
|
||
utils::Allocator<mgp_result>(self_->memory_resource).new_object<mgp_result>(nullptr, self_->memory_resource);
|
||
|
||
const auto graph_view = proc->info.is_write ? storage::View::NEW : storage::View::OLD;
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
graph_view);
|
||
|
||
result_->signature = &proc->results;
|
||
result_->is_transactional = storage::IsTransactional(context.db_accessor->GetStorageMode());
|
||
|
||
// Use special memory as invoking procedure is complex
|
||
// TODO: This will probably need to be changed when we add support for
|
||
// generator like procedures which yield a new result on new query calls.
|
||
auto *memory = self_->memory_resource;
|
||
auto memory_limit = EvaluateMemoryLimit(evaluator, self_->memory_limit_, self_->memory_scale_);
|
||
auto graph = mgp_graph::WritableGraph(*context.db_accessor, graph_view, context);
|
||
const auto transaction_id = context.db_accessor->GetTransactionId();
|
||
MG_ASSERT(transaction_id.has_value());
|
||
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit,
|
||
result_, self_->procedure_id_, transaction_id.value(), call_initializer);
|
||
|
||
if (call_initializer) call_initializer = false;
|
||
|
||
// Reset result_.signature to nullptr, because outside of this scope we
|
||
// will no longer hold a lock on the `module`. If someone were to reload
|
||
// it, the pointer would be invalid.
|
||
result_signature_size_ = result_->signature->size();
|
||
result_->signature = nullptr;
|
||
if (result_->error_msg) {
|
||
memgraph::utils::MemoryTracker::OutOfMemoryExceptionBlocker blocker;
|
||
throw QueryRuntimeException("{}: {}", self_->procedure_name_, *result_->error_msg);
|
||
}
|
||
result_row_it_ = result_->rows.begin();
|
||
if (!result_->is_transactional) {
|
||
skip_rows_with_deleted_values();
|
||
}
|
||
|
||
stream_exhausted = result_row_it_ == result_->rows.end();
|
||
}
|
||
|
||
auto &values = result_row_it_->values;
|
||
// Check that the row has all fields as required by the result signature.
|
||
// C API guarantees that it's impossible to set fields which are not part of
|
||
// the result record, but it does not gurantee that some may be missing. See
|
||
// `mgp_result_record_insert`.
|
||
if (values.size() != result_signature_size_) {
|
||
throw QueryRuntimeException(
|
||
"Procedure '{}' did not yield all fields as required by its "
|
||
"signature.",
|
||
self_->procedure_name_);
|
||
}
|
||
for (size_t i = 0; i < self_->result_fields_.size(); ++i) {
|
||
std::string_view field_name(self_->result_fields_[i]);
|
||
auto result_it = values.find(field_name);
|
||
if (result_it == values.end()) {
|
||
throw QueryRuntimeException("Procedure '{}' did not yield a record with '{}' field.", self_->procedure_name_,
|
||
field_name);
|
||
}
|
||
frame[self_->result_symbols_[i]] = std::move(result_it->second);
|
||
if (context.frame_change_collector &&
|
||
context.frame_change_collector->IsKeyTracked(self_->result_symbols_[i].name())) {
|
||
context.frame_change_collector->ResetTrackingValue(self_->result_symbols_[i].name());
|
||
}
|
||
}
|
||
++result_row_it_;
|
||
if (!result_->is_transactional) {
|
||
skip_rows_with_deleted_values();
|
||
}
|
||
|
||
return true;
|
||
}
|
||
|
||
void Reset() override {
|
||
self_->monotonic_memory.Release();
|
||
result_ =
|
||
utils::Allocator<mgp_result>(self_->memory_resource).new_object<mgp_result>(nullptr, self_->memory_resource);
|
||
if (cleanup_) {
|
||
cleanup_.value()();
|
||
}
|
||
}
|
||
|
||
void Shutdown() override {
|
||
self_->monotonic_memory.Release();
|
||
if (cleanup_) {
|
||
cleanup_.value()();
|
||
}
|
||
}
|
||
};
|
||
|
||
class CallValidateProcedureCursor : public Cursor {
|
||
const CallProcedure *self_;
|
||
UniqueCursorPtr input_cursor_;
|
||
|
||
public:
|
||
CallValidateProcedureCursor(const CallProcedure *self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self_->input_->MakeCursor(mem)) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("CallValidateProcedureCursor");
|
||
|
||
AbortCheck(context);
|
||
if (!input_cursor_->Pull(frame, context)) {
|
||
return false;
|
||
}
|
||
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::NEW);
|
||
|
||
const auto args = self_->arguments_;
|
||
MG_ASSERT(args.size() == 3U);
|
||
|
||
const auto predicate = args[0]->Accept(evaluator);
|
||
const bool predicate_val = predicate.ValueBool();
|
||
|
||
if (predicate_val) [[unlikely]] {
|
||
const auto &message = args[1]->Accept(evaluator);
|
||
const auto &message_args = args[2]->Accept(evaluator);
|
||
|
||
using TString = std::remove_cvref_t<decltype(message.ValueString())>;
|
||
using TElement = std::remove_cvref_t<decltype(message_args.ValueList()[0])>;
|
||
|
||
utils::JStringFormatter<TString, TElement> formatter;
|
||
|
||
try {
|
||
const auto &msg = formatter.FormatString(message.ValueString(), message_args.ValueList());
|
||
throw QueryRuntimeException(msg);
|
||
} catch (const utils::JStringFormatException &e) {
|
||
throw QueryRuntimeException(e.what());
|
||
}
|
||
}
|
||
|
||
return true;
|
||
}
|
||
|
||
void Reset() override { input_cursor_->Reset(); }
|
||
|
||
void Shutdown() override {}
|
||
};
|
||
|
||
UniqueCursorPtr CallProcedure::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::CallProcedureOperator);
|
||
CallProcedure::IncrementCounter(procedure_name_);
|
||
|
||
if (void_procedure_) {
|
||
// Currently we do not support Call procedures that do not return
|
||
// anything. This cursor is way too specific, but it provides a workaround
|
||
// to ensure GraphQL compatibility until we start supporting truly void
|
||
// procedures.
|
||
return MakeUniqueCursorPtr<CallValidateProcedureCursor>(mem, this, mem);
|
||
}
|
||
|
||
return MakeUniqueCursorPtr<CallProcedureCursor>(mem, this, mem);
|
||
}
|
||
|
||
LoadCsv::LoadCsv(std::shared_ptr<LogicalOperator> input, Expression *file, bool with_header, bool ignore_bad,
|
||
Expression *delimiter, Expression *quote, Expression *nullif, Symbol row_var)
|
||
: input_(input ? input : (std::make_shared<Once>())),
|
||
file_(file),
|
||
with_header_(with_header),
|
||
ignore_bad_(ignore_bad),
|
||
delimiter_(delimiter),
|
||
quote_(quote),
|
||
nullif_(nullif),
|
||
row_var_(std::move(row_var)) {
|
||
MG_ASSERT(file_, "Something went wrong - '{}' member file_ shouldn't be a nullptr", __func__);
|
||
}
|
||
|
||
ACCEPT_WITH_INPUT(LoadCsv)
|
||
|
||
class LoadCsvCursor;
|
||
|
||
std::vector<Symbol> LoadCsv::OutputSymbols(const SymbolTable &sym_table) const { return {row_var_}; };
|
||
|
||
std::vector<Symbol> LoadCsv::ModifiedSymbols(const SymbolTable &sym_table) const {
|
||
auto symbols = input_->ModifiedSymbols(sym_table);
|
||
symbols.push_back(row_var_);
|
||
return symbols;
|
||
};
|
||
|
||
namespace {
|
||
// copy-pasted from interpreter.cpp
|
||
TypedValue EvaluateOptionalExpression(Expression *expression, ExpressionEvaluator *eval) {
|
||
return expression ? expression->Accept(*eval) : TypedValue();
|
||
}
|
||
|
||
auto ToOptionalString(ExpressionEvaluator *evaluator, Expression *expression) -> std::optional<utils::pmr::string> {
|
||
const auto evaluated_expr = EvaluateOptionalExpression(expression, evaluator);
|
||
if (evaluated_expr.IsString()) {
|
||
return utils::pmr::string(evaluated_expr.ValueString(), utils::NewDeleteResource());
|
||
}
|
||
return std::nullopt;
|
||
};
|
||
|
||
TypedValue CsvRowToTypedList(csv::Reader::Row &row, std::optional<utils::pmr::string> &nullif) {
|
||
auto *mem = row.get_allocator().GetMemoryResource();
|
||
auto typed_columns = utils::pmr::vector<TypedValue>(mem);
|
||
typed_columns.reserve(row.size());
|
||
for (auto &column : row) {
|
||
if (!nullif.has_value() || column != nullif.value()) {
|
||
typed_columns.emplace_back(std::move(column));
|
||
} else {
|
||
typed_columns.emplace_back();
|
||
}
|
||
}
|
||
return {std::move(typed_columns), mem};
|
||
}
|
||
|
||
TypedValue CsvRowToTypedMap(csv::Reader::Row &row, csv::Reader::Header header,
|
||
std::optional<utils::pmr::string> &nullif) {
|
||
// a valid row has the same number of elements as the header
|
||
auto *mem = row.get_allocator().GetMemoryResource();
|
||
utils::pmr::map<utils::pmr::string, TypedValue> m(mem);
|
||
for (auto i = 0; i < row.size(); ++i) {
|
||
if (!nullif.has_value() || row[i] != nullif.value()) {
|
||
m.emplace(std::move(header[i]), std::move(row[i]));
|
||
} else {
|
||
m.emplace(std::piecewise_construct, std::forward_as_tuple(std::move(header[i])), std::forward_as_tuple());
|
||
}
|
||
}
|
||
return {std::move(m), mem};
|
||
}
|
||
|
||
} // namespace
|
||
|
||
class LoadCsvCursor : public Cursor {
|
||
const LoadCsv *self_;
|
||
const UniqueCursorPtr input_cursor_;
|
||
bool did_pull_;
|
||
std::optional<csv::Reader> reader_{};
|
||
std::optional<utils::pmr::string> nullif_;
|
||
|
||
public:
|
||
LoadCsvCursor(const LoadCsv *self, utils::MemoryResource *mem)
|
||
: self_(self), input_cursor_(self_->input_->MakeCursor(mem)), did_pull_{false} {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP_BY_REF(*self_);
|
||
|
||
AbortCheck(context);
|
||
|
||
// ToDo(the-joksim):
|
||
// - this is an ungodly hack because the pipeline of creating a plan
|
||
// doesn't allow evaluating the expressions contained in self_->file_,
|
||
// self_->delimiter_, and self_->quote_ earlier (say, in the interpreter.cpp)
|
||
// without massacring the code even worse than I did here
|
||
if (UNLIKELY(!reader_)) {
|
||
reader_ = MakeReader(&context.evaluation_context);
|
||
nullif_ = ParseNullif(&context.evaluation_context);
|
||
}
|
||
|
||
if (input_cursor_->Pull(frame, context)) {
|
||
if (did_pull_) {
|
||
throw QueryRuntimeException(
|
||
"LOAD CSV can be executed only once, please check if the cardinality of the operator before LOAD CSV "
|
||
"is "
|
||
"1");
|
||
}
|
||
did_pull_ = true;
|
||
reader_->Reset();
|
||
}
|
||
|
||
auto row = reader_->GetNextRow(context.evaluation_context.memory);
|
||
if (!row) {
|
||
return false;
|
||
}
|
||
if (!reader_->HasHeader()) {
|
||
frame[self_->row_var_] = CsvRowToTypedList(*row, nullif_);
|
||
} else {
|
||
frame[self_->row_var_] =
|
||
CsvRowToTypedMap(*row, csv::Reader::Header(reader_->GetHeader(), context.evaluation_context.memory), nullif_);
|
||
}
|
||
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(self_->row_var_.name())) {
|
||
context.frame_change_collector->ResetTrackingValue(self_->row_var_.name());
|
||
}
|
||
return true;
|
||
}
|
||
|
||
void Reset() override { input_cursor_->Reset(); }
|
||
void Shutdown() override { input_cursor_->Shutdown(); }
|
||
|
||
private:
|
||
csv::Reader MakeReader(EvaluationContext *eval_context) {
|
||
Frame frame(0);
|
||
SymbolTable symbol_table;
|
||
DbAccessor *dba = nullptr;
|
||
auto evaluator = ExpressionEvaluator(&frame, symbol_table, *eval_context, dba, storage::View::OLD);
|
||
|
||
auto maybe_file = ToOptionalString(&evaluator, self_->file_);
|
||
auto maybe_delim = ToOptionalString(&evaluator, self_->delimiter_);
|
||
auto maybe_quote = ToOptionalString(&evaluator, self_->quote_);
|
||
|
||
// No need to check if maybe_file is std::nullopt, as the parser makes sure
|
||
// we can't get a nullptr for the 'file_' member in the LoadCsv clause.
|
||
// Note that the reader has to be given its own memory resource, as it
|
||
// persists between pulls, so it can't use the evalutation context memory
|
||
// resource.
|
||
return csv::Reader(
|
||
csv::CsvSource::Create(*maybe_file),
|
||
csv::Reader::Config(self_->with_header_, self_->ignore_bad_, std::move(maybe_delim), std::move(maybe_quote)),
|
||
utils::NewDeleteResource());
|
||
}
|
||
|
||
std::optional<utils::pmr::string> ParseNullif(EvaluationContext *eval_context) {
|
||
Frame frame(0);
|
||
SymbolTable symbol_table;
|
||
DbAccessor *dba = nullptr;
|
||
auto evaluator = ExpressionEvaluator(&frame, symbol_table, *eval_context, dba, storage::View::OLD);
|
||
|
||
return ToOptionalString(&evaluator, self_->nullif_);
|
||
}
|
||
};
|
||
|
||
UniqueCursorPtr LoadCsv::MakeCursor(utils::MemoryResource *mem) const {
|
||
return MakeUniqueCursorPtr<LoadCsvCursor>(mem, this, mem);
|
||
};
|
||
|
||
class ForeachCursor : public Cursor {
|
||
public:
|
||
explicit ForeachCursor(const Foreach &foreach, utils::MemoryResource *mem)
|
||
: loop_variable_symbol_(foreach.loop_variable_symbol_),
|
||
input_(foreach.input_->MakeCursor(mem)),
|
||
updates_(foreach.update_clauses_->MakeCursor(mem)),
|
||
expression(foreach.expression_) {}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP(op_name_);
|
||
|
||
if (!input_->Pull(frame, context)) {
|
||
return false;
|
||
}
|
||
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::NEW);
|
||
TypedValue expr_result = expression->Accept(evaluator);
|
||
|
||
if (expr_result.IsNull()) {
|
||
return true;
|
||
}
|
||
|
||
if (!expr_result.IsList()) {
|
||
throw QueryRuntimeException("FOREACH expression must resolve to a list, but got '{}'.", expr_result.type());
|
||
}
|
||
|
||
const auto &cache_ = expr_result.ValueList();
|
||
for (const auto &index : cache_) {
|
||
frame[loop_variable_symbol_] = index;
|
||
while (updates_->Pull(frame, context)) {
|
||
}
|
||
ResetUpdates();
|
||
}
|
||
|
||
return true;
|
||
}
|
||
|
||
void Shutdown() override { input_->Shutdown(); }
|
||
|
||
void ResetUpdates() { updates_->Reset(); }
|
||
|
||
void Reset() override {
|
||
input_->Reset();
|
||
ResetUpdates();
|
||
}
|
||
|
||
private:
|
||
const Symbol loop_variable_symbol_;
|
||
const UniqueCursorPtr input_;
|
||
const UniqueCursorPtr updates_;
|
||
Expression *expression;
|
||
const char *op_name_{"Foreach"};
|
||
};
|
||
|
||
Foreach::Foreach(std::shared_ptr<LogicalOperator> input, std::shared_ptr<LogicalOperator> updates, Expression *expr,
|
||
Symbol loop_variable_symbol)
|
||
: input_(input ? std::move(input) : std::make_shared<Once>()),
|
||
update_clauses_(std::move(updates)),
|
||
expression_(expr),
|
||
loop_variable_symbol_(std::move(loop_variable_symbol)) {}
|
||
|
||
UniqueCursorPtr Foreach::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ForeachOperator);
|
||
return MakeUniqueCursorPtr<ForeachCursor>(mem, *this, mem);
|
||
}
|
||
|
||
std::vector<Symbol> Foreach::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
symbols.emplace_back(loop_variable_symbol_);
|
||
return symbols;
|
||
}
|
||
|
||
bool Foreach::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||
if (visitor.PreVisit(*this)) {
|
||
input_->Accept(visitor);
|
||
update_clauses_->Accept(visitor);
|
||
}
|
||
return visitor.PostVisit(*this);
|
||
}
|
||
|
||
Apply::Apply(const std::shared_ptr<LogicalOperator> input, const std::shared_ptr<LogicalOperator> subquery,
|
||
bool subquery_has_return)
|
||
: input_(input ? input : std::make_shared<Once>()),
|
||
subquery_(subquery),
|
||
subquery_has_return_(subquery_has_return) {}
|
||
|
||
bool Apply::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||
if (visitor.PreVisit(*this)) {
|
||
input_->Accept(visitor) && subquery_->Accept(visitor);
|
||
}
|
||
return visitor.PostVisit(*this);
|
||
}
|
||
|
||
UniqueCursorPtr Apply::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::ApplyOperator);
|
||
|
||
return MakeUniqueCursorPtr<ApplyCursor>(mem, *this, mem);
|
||
}
|
||
|
||
Apply::ApplyCursor::ApplyCursor(const Apply &self, utils::MemoryResource *mem)
|
||
: self_(self),
|
||
input_(self.input_->MakeCursor(mem)),
|
||
subquery_(self.subquery_->MakeCursor(mem)),
|
||
subquery_has_return_(self.subquery_has_return_) {}
|
||
|
||
std::vector<Symbol> Apply::ModifiedSymbols(const SymbolTable &table) const {
|
||
// Since Apply is the Cartesian product, modified symbols are combined from
|
||
// both execution branches.
|
||
auto symbols = input_->ModifiedSymbols(table);
|
||
auto subquery_symbols = subquery_->ModifiedSymbols(table);
|
||
symbols.insert(symbols.end(), subquery_symbols.begin(), subquery_symbols.end());
|
||
return symbols;
|
||
}
|
||
|
||
bool Apply::ApplyCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
OOMExceptionEnabler oom_exception;
|
||
SCOPED_PROFILE_OP("Apply");
|
||
|
||
while (true) {
|
||
if (pull_input_ && !input_->Pull(frame, context)) {
|
||
return false;
|
||
};
|
||
|
||
if (subquery_->Pull(frame, context)) {
|
||
// if successful, next Pull from this should not pull_input_
|
||
pull_input_ = false;
|
||
return true;
|
||
}
|
||
// failed to pull from subquery cursor
|
||
// skip that row
|
||
pull_input_ = true;
|
||
subquery_->Reset();
|
||
|
||
// don't skip row if no rows are returned from subquery, return input_ rows
|
||
if (!subquery_has_return_) return true;
|
||
}
|
||
}
|
||
|
||
void Apply::ApplyCursor::Shutdown() {
|
||
input_->Shutdown();
|
||
subquery_->Shutdown();
|
||
}
|
||
|
||
void Apply::ApplyCursor::Reset() {
|
||
input_->Reset();
|
||
subquery_->Reset();
|
||
pull_input_ = true;
|
||
}
|
||
|
||
IndexedJoin::IndexedJoin(const std::shared_ptr<LogicalOperator> main_branch,
|
||
const std::shared_ptr<LogicalOperator> sub_branch)
|
||
: main_branch_(main_branch ? main_branch : std::make_shared<Once>()), sub_branch_(sub_branch) {}
|
||
|
||
WITHOUT_SINGLE_INPUT(IndexedJoin);
|
||
|
||
bool IndexedJoin::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||
if (visitor.PreVisit(*this)) {
|
||
main_branch_->Accept(visitor) && sub_branch_->Accept(visitor);
|
||
}
|
||
return visitor.PostVisit(*this);
|
||
}
|
||
|
||
UniqueCursorPtr IndexedJoin::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::IndexedJoinOperator);
|
||
|
||
return MakeUniqueCursorPtr<IndexedJoinCursor>(mem, *this, mem);
|
||
}
|
||
|
||
IndexedJoin::IndexedJoinCursor::IndexedJoinCursor(const IndexedJoin &self, utils::MemoryResource *mem)
|
||
: self_(self), main_branch_(self.main_branch_->MakeCursor(mem)), sub_branch_(self.sub_branch_->MakeCursor(mem)) {}
|
||
|
||
std::vector<Symbol> IndexedJoin::ModifiedSymbols(const SymbolTable &table) const {
|
||
// Since Apply is the Cartesian product, modified symbols are combined from
|
||
// both execution branches.
|
||
auto symbols = main_branch_->ModifiedSymbols(table);
|
||
auto sub_branch_symbols = sub_branch_->ModifiedSymbols(table);
|
||
symbols.insert(symbols.end(), sub_branch_symbols.begin(), sub_branch_symbols.end());
|
||
return symbols;
|
||
}
|
||
|
||
bool IndexedJoin::IndexedJoinCursor::Pull(Frame &frame, ExecutionContext &context) {
|
||
SCOPED_PROFILE_OP("IndexedJoin");
|
||
|
||
while (true) {
|
||
if (pull_input_ && !main_branch_->Pull(frame, context)) {
|
||
return false;
|
||
};
|
||
|
||
if (sub_branch_->Pull(frame, context)) {
|
||
// if successful, next Pull from this should not pull_input_
|
||
pull_input_ = false;
|
||
return true;
|
||
}
|
||
// failed to pull from subquery cursor
|
||
// skip that row
|
||
pull_input_ = true;
|
||
sub_branch_->Reset();
|
||
}
|
||
}
|
||
|
||
void IndexedJoin::IndexedJoinCursor::Shutdown() {
|
||
main_branch_->Shutdown();
|
||
sub_branch_->Shutdown();
|
||
}
|
||
|
||
void IndexedJoin::IndexedJoinCursor::Reset() {
|
||
main_branch_->Reset();
|
||
sub_branch_->Reset();
|
||
pull_input_ = true;
|
||
}
|
||
|
||
std::vector<Symbol> HashJoin::ModifiedSymbols(const SymbolTable &table) const {
|
||
auto symbols = left_op_->ModifiedSymbols(table);
|
||
auto right = right_op_->ModifiedSymbols(table);
|
||
symbols.insert(symbols.end(), right.begin(), right.end());
|
||
return symbols;
|
||
}
|
||
|
||
bool HashJoin::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||
if (visitor.PreVisit(*this)) {
|
||
left_op_->Accept(visitor) && right_op_->Accept(visitor);
|
||
}
|
||
return visitor.PostVisit(*this);
|
||
}
|
||
|
||
WITHOUT_SINGLE_INPUT(HashJoin);
|
||
|
||
namespace {
|
||
|
||
class HashJoinCursor : public Cursor {
|
||
public:
|
||
HashJoinCursor(const HashJoin &self, utils::MemoryResource *mem)
|
||
: self_(self),
|
||
left_op_cursor_(self.left_op_->MakeCursor(mem)),
|
||
right_op_cursor_(self_.right_op_->MakeCursor(mem)),
|
||
hashtable_(mem),
|
||
right_op_frame_(mem) {
|
||
MG_ASSERT(left_op_cursor_ != nullptr, "HashJoinCursor: Missing left operator cursor.");
|
||
MG_ASSERT(right_op_cursor_ != nullptr, "HashJoinCursor: Missing right operator cursor.");
|
||
}
|
||
|
||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||
SCOPED_PROFILE_OP("HashJoin");
|
||
|
||
if (!hash_join_initialized_) {
|
||
InitializeHashJoin(frame, context);
|
||
hash_join_initialized_ = true;
|
||
}
|
||
|
||
// If left_op yielded zero results, there is no cartesian product.
|
||
if (hashtable_.empty()) {
|
||
return false;
|
||
}
|
||
|
||
auto restore_frame = [&frame, &context](const auto &symbols, const auto &restore_from) {
|
||
for (const auto &symbol : symbols) {
|
||
frame[symbol] = restore_from[symbol.position()];
|
||
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) {
|
||
context.frame_change_collector->ResetTrackingValue(symbol.name());
|
||
}
|
||
}
|
||
};
|
||
|
||
if (!common_value_found_) {
|
||
// Pull from the right_op until there’s a mergeable frame
|
||
while (true) {
|
||
auto pulled = right_op_cursor_->Pull(frame, context);
|
||
if (!pulled) return false;
|
||
|
||
// Check if the join value from the pulled frame is shared with any left frames
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
auto right_value = self_.hash_join_condition_->expression2_->Accept(evaluator);
|
||
if (hashtable_.contains(right_value)) {
|
||
// If so, finish pulling for now and proceed to joining the pulled frame
|
||
right_op_frame_.assign(frame.elems().begin(), frame.elems().end());
|
||
common_value_found_ = true;
|
||
common_value = right_value;
|
||
left_op_frame_it_ = hashtable_[common_value].begin();
|
||
break;
|
||
}
|
||
}
|
||
} else {
|
||
// Restore the right frame ahead of restoring the left frame
|
||
restore_frame(self_.right_symbols_, right_op_frame_);
|
||
}
|
||
|
||
restore_frame(self_.left_symbols_, *left_op_frame_it_);
|
||
|
||
left_op_frame_it_++;
|
||
// When all left frames with the common value have been joined, move on to pulling and joining the next right
|
||
// frame
|
||
if (common_value_found_ && left_op_frame_it_ == hashtable_[common_value].end()) {
|
||
common_value_found_ = false;
|
||
}
|
||
|
||
return true;
|
||
}
|
||
|
||
void Shutdown() override {
|
||
left_op_cursor_->Shutdown();
|
||
right_op_cursor_->Shutdown();
|
||
}
|
||
|
||
void Reset() override {
|
||
left_op_cursor_->Reset();
|
||
right_op_cursor_->Reset();
|
||
hashtable_.clear();
|
||
right_op_frame_.clear();
|
||
left_op_frame_it_ = {};
|
||
hash_join_initialized_ = false;
|
||
common_value_found_ = false;
|
||
}
|
||
|
||
private:
|
||
void InitializeHashJoin(Frame &frame, ExecutionContext &context) {
|
||
// Pull all left_op_ frames
|
||
while (left_op_cursor_->Pull(frame, context)) {
|
||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||
storage::View::OLD);
|
||
auto left_value = self_.hash_join_condition_->expression1_->Accept(evaluator);
|
||
if (left_value.type() != TypedValue::Type::Null) {
|
||
hashtable_[left_value].emplace_back(frame.elems().begin(), frame.elems().end());
|
||
}
|
||
}
|
||
}
|
||
|
||
const HashJoin &self_;
|
||
const UniqueCursorPtr left_op_cursor_;
|
||
const UniqueCursorPtr right_op_cursor_;
|
||
utils::pmr::unordered_map<TypedValue, utils::pmr::vector<utils::pmr::vector<TypedValue>>, TypedValue::Hash,
|
||
TypedValue::BoolEqual>
|
||
hashtable_;
|
||
utils::pmr::vector<TypedValue> right_op_frame_;
|
||
utils::pmr::vector<utils::pmr::vector<TypedValue>>::iterator left_op_frame_it_;
|
||
bool hash_join_initialized_{false};
|
||
bool common_value_found_{false};
|
||
TypedValue common_value;
|
||
};
|
||
} // namespace
|
||
|
||
UniqueCursorPtr HashJoin::MakeCursor(utils::MemoryResource *mem) const {
|
||
memgraph::metrics::IncrementCounter(memgraph::metrics::HashJoinOperator);
|
||
return MakeUniqueCursorPtr<HashJoinCursor>(mem, *this, mem);
|
||
}
|
||
|
||
RollUpApply::RollUpApply(const std::shared_ptr<LogicalOperator> &input,
|
||
std::shared_ptr<LogicalOperator> &&second_branch)
|
||
: input_(input), list_collection_branch_(second_branch) {}
|
||
|
||
std::vector<Symbol> RollUpApply::OutputSymbols(const SymbolTable & /*symbol_table*/) const {
|
||
std::vector<Symbol> symbols;
|
||
return symbols;
|
||
}
|
||
|
||
std::vector<Symbol> RollUpApply::ModifiedSymbols(const SymbolTable &table) const { return OutputSymbols(table); }
|
||
|
||
bool RollUpApply::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||
if (visitor.PreVisit(*this)) {
|
||
if (!input_ || !list_collection_branch_) {
|
||
throw utils::NotYetImplemented("One of the branches in pattern comprehension is null! Please contact support.");
|
||
}
|
||
input_->Accept(visitor) && list_collection_branch_->Accept(visitor);
|
||
}
|
||
return visitor.PostVisit(*this);
|
||
}
|
||
|
||
} // namespace memgraph::query::plan
|