Fix slow IN LIST evaluation (#901)

This commit is contained in:
Antonio Filipovic 2023-05-29 17:52:20 +02:00 committed by GitHub
parent d842adbed3
commit d917c3f0fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 412 additions and 67 deletions

View File

@ -22,6 +22,8 @@
#include "query/trigger.hpp"
#include "utils/async_timer.hpp"
#include "query/frame_change.hpp"
namespace memgraph::query {
enum class TransactionStatus {
@ -82,6 +84,7 @@ struct ExecutionContext {
plan::ProfilingStats *stats_root{nullptr};
ExecutionStats execution_stats;
TriggerContextCollector *trigger_context_collector{nullptr};
FrameChangeCollector *frame_change_collector{nullptr};
utils::AsyncTimer timer;
#ifdef MG_ENTERPRISE
std::unique_ptr<FineGrainedAuthChecker> auth_checker{nullptr};

122
src/query/frame_change.hpp Normal file
View File

@ -0,0 +1,122 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/typed_value.hpp"
#include "utils/memory.hpp"
#include "utils/pmr/unordered_map.hpp"
#include "utils/pmr/vector.hpp"
namespace memgraph::query {
// Key is hash output, value is vector of unique elements
using CachedType = utils::pmr::unordered_map<size_t, std::vector<TypedValue>>;
struct CachedValue {
// Cached value, this can be probably templateized
CachedType cache_;
explicit CachedValue(utils::MemoryResource *mem) : cache_(mem) {}
CachedValue(CachedType &&cache, memgraph::utils::MemoryResource *memory) : cache_(std::move(cache), memory) {}
CachedValue(const CachedValue &other, memgraph::utils::MemoryResource *memory) : cache_(other.cache_, memory) {}
CachedValue(CachedValue &&other, memgraph::utils::MemoryResource *memory) : cache_(std::move(other.cache_), memory) {}
CachedValue(CachedValue &&other) noexcept = delete;
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
CachedValue(const CachedValue &) = delete;
CachedValue &operator=(const CachedValue &) = delete;
CachedValue &operator=(CachedValue &&) = delete;
~CachedValue() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept {
return cache_.get_allocator().GetMemoryResource();
}
// Func to check if cache_ contains value
bool CacheValue(const TypedValue &value) {
if (!value.IsList()) {
return false;
}
const auto &list = value.ValueList();
TypedValue::Hash hash{};
for (const TypedValue &element : list) {
const auto key = hash(element);
auto &vector_values = cache_[key];
if (!IsValueInVec(vector_values, element)) {
vector_values.push_back(element);
}
}
return true;
}
// Func to cache_value inside cache_
bool ContainsValue(const TypedValue &value) const {
TypedValue::Hash hash{};
const auto key = hash(value);
if (cache_.contains(key)) {
return IsValueInVec(cache_.at(key), value);
}
return false;
}
private:
bool IsValueInVec(const std::vector<TypedValue> &vec_values, const TypedValue &value) const {
return std::any_of(vec_values.begin(), vec_values.end(), [&value](auto &vec_value) {
const auto is_value_equal = vec_value == value;
if (is_value_equal.IsNull()) return false;
return is_value_equal.ValueBool();
});
}
};
// Class tracks keys for which user can cache values which help with faster search or faster retrieval
// in the future.
class FrameChangeCollector {
public:
explicit FrameChangeCollector(utils::MemoryResource *mem) : tracked_values_(mem){};
// Add tracking key to cache later value
CachedValue &AddTrackingKey(const std::string &key) {
const auto &[it, _] = tracked_values_.emplace(key, tracked_values_.get_allocator().GetMemoryResource());
return it->second;
}
// Is key tracked
bool IsKeyTracked(const std::string &key) const { return tracked_values_.contains(key); }
// Is value for given key cached
bool IsKeyValueCached(const std::string &key) const {
return tracked_values_.contains(key) && !tracked_values_.at(key).cache_.empty();
}
// Reset value for tracking key
bool ResetTrackingValue(const std::string &key) {
if (tracked_values_.contains(key)) {
tracked_values_.erase(key);
AddTrackingKey(key);
}
return true;
}
// Get value cached for tracking key, throws if key is not in tracked
CachedValue &GetCachedValue(const std::string &key) { return tracked_values_.at(key); }
// Checks for keys tracked
bool IsTrackingValues() const { return !tracked_values_.empty(); }
private:
// Key is output of utils::GetFrameChangeId, value is utils::pmr::unordered_map
memgraph::utils::pmr::unordered_map<std::string, CachedValue> tracked_values_;
};
} // namespace memgraph::query

View File

@ -13,6 +13,7 @@
#pragma once
#include <algorithm>
#include <cstddef>
#include <limits>
#include <map>
#include <optional>
@ -28,7 +29,11 @@
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/interpret/frame.hpp"
#include "query/typed_value.hpp"
#include "spdlog/spdlog.h"
#include "utils/exceptions.hpp"
#include "utils/frame_change_id.hpp"
#include "utils/logging.hpp"
#include "utils/pmr/unordered_map.hpp"
namespace memgraph::query {
@ -103,8 +108,13 @@ class ReferenceExpressionEvaluator : public ExpressionVisitor<TypedValue *> {
class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
public:
ExpressionEvaluator(Frame *frame, const SymbolTable &symbol_table, const EvaluationContext &ctx, DbAccessor *dba,
storage::View view)
: frame_(frame), symbol_table_(&symbol_table), ctx_(&ctx), dba_(dba), view_(view) {}
storage::View view, FrameChangeCollector *frame_change_collector = nullptr)
: frame_(frame),
symbol_table_(&symbol_table),
ctx_(&ctx),
dba_(dba),
view_(view),
frame_change_collector_(frame_change_collector) {}
using ExpressionVisitor<TypedValue>::Visit;
@ -193,25 +203,78 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
}
TypedValue Visit(InListOperator &in_list) override {
TypedValue *_list_ptr = nullptr;
TypedValue _list;
auto literal = in_list.expression1_->Accept(*this);
auto _list = in_list.expression2_->Accept(*this);
if (_list.IsNull()) {
return TypedValue(ctx_->memory);
auto get_list_literal = [this, &in_list, &_list, &_list_ptr]() -> void {
ReferenceExpressionEvaluator reference_expression_evaluator{frame_, symbol_table_, ctx_};
_list_ptr = in_list.expression2_->Accept(reference_expression_evaluator);
if (nullptr == _list_ptr) {
_list = in_list.expression2_->Accept(*this);
_list_ptr = &_list;
}
};
auto do_list_literal_checks = [this, &literal, &_list_ptr]() -> std::optional<TypedValue> {
MG_ASSERT(_list_ptr, "List literal should have been defined");
if (_list_ptr->IsNull()) {
return TypedValue(ctx_->memory);
}
// Exceptions have higher priority than returning nulls when list expression
// is not null.
if (_list_ptr->type() != TypedValue::Type::List) {
throw QueryRuntimeException("IN expected a list, got {}.", _list_ptr->type());
}
const auto &list = _list_ptr->ValueList();
// If literal is NULL there is no need to try to compare it with every
// element in the list since result of every comparison will be NULL. There
// is one special case that we must test explicitly: if list is empty then
// result is false since no comparison will be performed.
if (list.empty()) return TypedValue(false, ctx_->memory);
if (literal.IsNull()) return TypedValue(ctx_->memory);
return {};
};
const auto cached_id = memgraph::utils::GetFrameChangeId(in_list);
const auto do_cache{frame_change_collector_ != nullptr && cached_id &&
frame_change_collector_->IsKeyTracked(*cached_id)};
if (do_cache) {
if (!frame_change_collector_->IsKeyValueCached(*cached_id)) {
// Check only first time if everything is okay, later when we use
// cache there is no need to check again as we did check first time
get_list_literal();
auto preoperational_checks = do_list_literal_checks();
if (preoperational_checks) {
return std::move(*preoperational_checks);
}
auto &cached_value = frame_change_collector_->GetCachedValue(*cached_id);
cached_value.CacheValue(*_list_ptr);
spdlog::trace("Value cached {}", *cached_id);
}
const auto &cached_value = frame_change_collector_->GetCachedValue(*cached_id);
if (cached_value.ContainsValue(literal)) {
return TypedValue(true, ctx_->memory);
}
// has null
if (cached_value.ContainsValue(TypedValue(ctx_->memory))) {
return TypedValue(ctx_->memory);
}
return TypedValue(false, ctx_->memory);
}
// Exceptions have higher priority than returning nulls when list expression
// is not null.
if (_list.type() != TypedValue::Type::List) {
throw QueryRuntimeException("IN expected a list, got {}.", _list.type());
// When caching is not an option, we need to evaluate list literal every time
// and do the checks
get_list_literal();
auto preoperational_checks = do_list_literal_checks();
if (preoperational_checks) {
return std::move(*preoperational_checks);
}
const auto &list = _list.ValueList();
// If literal is NULL there is no need to try to compare it with every
// element in the list since result of every comparison will be NULL. There
// is one special case that we must test explicitly: if list is empty then
// result is false since no comparison will be performed.
if (list.empty()) return TypedValue(false, ctx_->memory);
if (literal.IsNull()) return TypedValue(ctx_->memory);
spdlog::trace("Not using cache on IN LIST operator");
auto has_null = false;
for (const auto &element : list) {
auto result = literal == element;
@ -973,6 +1036,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
DbAccessor *dba_;
// which switching approach should be used when evaluating
storage::View view_;
FrameChangeCollector *frame_change_collector_;
}; // namespace memgraph::query
/// A helper function for evaluating an expression that's an int.

View File

@ -45,13 +45,16 @@
#include "query/frontend/semantic/required_privileges.hpp"
#include "query/frontend/semantic/symbol_generator.hpp"
#include "query/interpret/eval.hpp"
#include "query/interpret/frame.hpp"
#include "query/metadata.hpp"
#include "query/plan/planner.hpp"
#include "query/plan/profile.hpp"
#include "query/plan/vertex_count_cache.hpp"
#include "query/stream.hpp"
#include "query/stream/common.hpp"
#include "query/trigger.hpp"
#include "query/typed_value.hpp"
#include "spdlog/spdlog.h"
#include "storage/v2/edge.hpp"
#include "storage/v2/id_types.hpp"
#include "storage/v2/isolation_level.hpp"
@ -1004,7 +1007,8 @@ struct PullPlan {
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
std::optional<std::string> username, std::atomic<TransactionStatus> *transaction_status,
TriggerContextCollector *trigger_context_collector = nullptr,
std::optional<size_t> memory_limit = {}, bool use_monotonic_memory = true);
std::optional<size_t> memory_limit = {}, bool use_monotonic_memory = true,
FrameChangeCollector *frame_change_collector_ = nullptr);
std::optional<plan::ProfilingStatsWithTotalTime> Pull(AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols,
@ -1042,7 +1046,7 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
std::optional<std::string> username, std::atomic<TransactionStatus> *transaction_status,
TriggerContextCollector *trigger_context_collector, const std::optional<size_t> memory_limit,
bool use_monotonic_memory)
bool use_monotonic_memory, FrameChangeCollector *frame_change_collector)
: plan_(plan),
cursor_(plan->plan().MakeCursor(execution_memory)),
frame_(plan->symbol_table().max_position(), execution_memory),
@ -1073,6 +1077,7 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
ctx_.transaction_status = transaction_status;
ctx_.is_profile_query = is_profile_query;
ctx_.trigger_context_collector = trigger_context_collector;
ctx_.frame_change_collector = frame_change_collector;
}
std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *stream, std::optional<int> n,
@ -1263,11 +1268,28 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper,
RWType::NONE};
}
inline static void TryCaching(const AstStorage &ast_storage, FrameChangeCollector *frame_change_collector) {
if (!frame_change_collector) return;
for (const auto &tree : ast_storage.storage_) {
if (tree->GetTypeInfo() != memgraph::query::InListOperator::kType) {
continue;
}
auto *in_list_operator = utils::Downcast<InListOperator>(tree.get());
const auto cached_id = memgraph::utils::GetFrameChangeId(*in_list_operator);
if (!cached_id || cached_id->empty()) {
continue;
}
frame_change_collector->AddTrackingKey(*cached_id);
spdlog::trace("Tracking {} operator, by id: {}", InListOperator::kType.name, *cached_id);
}
}
PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary,
InterpreterContext *interpreter_context, DbAccessor *dba,
utils::MemoryResource *execution_memory, std::vector<Notification> *notifications,
const std::string *username, std::atomic<TransactionStatus> *transaction_status,
TriggerContextCollector *trigger_context_collector = nullptr) {
TriggerContextCollector *trigger_context_collector = nullptr,
FrameChangeCollector *frame_change_collector = nullptr) {
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query);
Frame frame(0);
@ -1298,6 +1320,7 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
parsed_query.parameters,
parsed_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba);
TryCaching(plan->ast_storage(), frame_change_collector);
summary->insert_or_assign("cost_estimate", plan->cost());
auto rw_type_checker = plan::ReadWriteTypeChecker();
rw_type_checker.InferRWType(const_cast<plan::LogicalOperator &>(plan->plan()));
@ -1314,9 +1337,10 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
header.push_back(
utils::FindOr(parsed_query.stripped_query.named_expressions(), symbol.token_position(), symbol.name()).first);
}
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, StringPointerToOptional(username), transaction_status,
trigger_context_collector, memory_limit, use_monotonic_memory);
auto pull_plan = std::make_shared<PullPlan>(
plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory,
StringPointerToOptional(username), transaction_status, trigger_context_collector, memory_limit,
use_monotonic_memory, frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr);
return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
@ -1377,7 +1401,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string
PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username,
std::atomic<TransactionStatus> *transaction_status) {
std::atomic<TransactionStatus> *transaction_status,
FrameChangeCollector *frame_change_collector) {
const std::string kProfileQueryStart = "profile ";
MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kProfileQueryStart),
@ -1436,39 +1461,42 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
auto cypher_query_plan = CypherQueryToPlan(
parsed_inner_query.stripped_query.hash(), std::move(parsed_inner_query.ast_storage), cypher_query,
parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba);
TryCaching(cypher_query_plan->ast_storage(), frame_change_collector);
auto rw_type_checker = plan::ReadWriteTypeChecker();
auto optional_username = StringPointerToOptional(username);
rw_type_checker.InferRWType(const_cast<plan::LogicalOperator &>(cypher_query_plan->plan()));
return PreparedQuery{{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"},
std::move(parsed_query.required_privileges),
[plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters),
summary, dba, interpreter_context, execution_memory, memory_limit, optional_username,
// We want to execute the query we are profiling lazily, so we delay
// the construction of the corresponding context.
stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{},
pull_plan = std::shared_ptr<PullPlanVector>(nullptr), transaction_status, use_monotonic_memory](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
// No output symbols are given so that nothing is streamed.
if (!stats_and_total_time) {
stats_and_total_time = PullPlan(plan, parameters, true, dba, interpreter_context,
execution_memory, optional_username, transaction_status,
nullptr, memory_limit, use_monotonic_memory)
.Pull(stream, {}, {}, summary);
pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time));
}
return PreparedQuery{
{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"},
std::move(parsed_query.required_privileges),
[plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters), summary, dba,
interpreter_context, execution_memory, memory_limit, optional_username,
// We want to execute the query we are profiling lazily, so we delay
// the construction of the corresponding context.
stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{},
pull_plan = std::shared_ptr<PullPlanVector>(nullptr), transaction_status, use_monotonic_memory,
frame_change_collector](AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
// No output symbols are given so that nothing is streamed.
if (!stats_and_total_time) {
stats_and_total_time =
PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory, optional_username,
transaction_status, nullptr, memory_limit, use_monotonic_memory,
frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr)
.Pull(stream, {}, {}, summary);
pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time));
}
MG_ASSERT(stats_and_total_time, "Failed to execute the query!");
MG_ASSERT(stats_and_total_time, "Failed to execute the query!");
if (pull_plan->Pull(stream, n)) {
summary->insert_or_assign("profile", ProfilingStatsToJson(*stats_and_total_time).dump());
return QueryHandlerResult::ABORT;
}
if (pull_plan->Pull(stream, n)) {
summary->insert_or_assign("profile", ProfilingStatsToJson(*stats_and_total_time).dump());
return QueryHandlerResult::ABORT;
}
return std::nullopt;
},
rw_type_checker.type};
return std::nullopt;
},
rw_type_checker.type};
}
PreparedQuery PrepareDumpQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary, DbAccessor *dba,
@ -2871,18 +2899,21 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
utils::MemoryResource *memory_resource =
std::visit([](auto &execution_memory) -> utils::MemoryResource * { return &execution_memory; },
query_execution->execution_memory);
frame_change_collector_.reset();
frame_change_collector_.emplace(memory_resource);
if (utils::Downcast<CypherQuery>(parsed_query.query)) {
prepared_query =
PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, memory_resource, &query_execution->notifications, username,
&transaction_status_, trigger_context_collector_ ? &*trigger_context_collector_ : nullptr);
prepared_query = PrepareCypherQuery(
std::move(parsed_query), &query_execution->summary, interpreter_context_, &*execution_db_accessor_,
memory_resource, &query_execution->notifications, username, &transaction_status_,
trigger_context_collector_ ? &*trigger_context_collector_ : nullptr, &*frame_change_collector_);
} else if (utils::Downcast<ExplainQuery>(parsed_query.query)) {
prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory_with_exception);
} else if (utils::Downcast<ProfileQuery>(parsed_query.query)) {
prepared_query = PrepareProfileQuery(
std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory_with_exception, username, &transaction_status_);
prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception, username,
&transaction_status_, &*frame_change_collector_);
} else if (utils::Downcast<DumpQuery>(parsed_query.query)) {
prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_,
memory_resource);
@ -3000,6 +3031,7 @@ void Interpreter::Abort() {
execution_db_accessor_.reset();
db_accessor_.reset();
trigger_context_collector_.reset();
frame_change_collector_.reset();
}
namespace {
@ -3102,6 +3134,10 @@ void Interpreter::Commit() {
trigger_context_collector_.reset();
}
if (frame_change_collector_) {
frame_change_collector_.reset();
}
if (trigger_context) {
// Run the triggers
for (const auto &trigger : interpreter_context_->trigger_store.BeforeCommitTriggers().access()) {

View File

@ -403,6 +403,7 @@ class Interpreter final {
std::unique_ptr<storage::Storage::Accessor> db_accessor_;
std::optional<DbAccessor> execution_db_accessor_;
std::optional<TriggerContextCollector> trigger_context_collector_;
std::optional<FrameChangeCollector> frame_change_collector_;
std::optional<storage::IsolationLevel> interpreter_isolation_level;
std::optional<storage::IsolationLevel> next_transaction_isolation_level;

View File

@ -61,6 +61,7 @@
#include "utils/readable_size.hpp"
#include "utils/string.hpp"
#include "utils/temporal.hpp"
#include "utils/typeinfo.hpp"
// macro for the default implementation of LogicalOperator::Accept
// that accepts the visitor and visits it's input_ operator
@ -2332,12 +2333,11 @@ bool Filter::FilterCursor::Pull(Frame &frame, ExecutionContext &context) {
// Like all filters, newly set values should not affect filtering of old
// nodes and edges.
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
storage::View::OLD);
storage::View::OLD, context.frame_change_collector);
while (input_cursor_->Pull(frame, context)) {
for (const auto &pattern_filter_cursor : pattern_filter_cursors_) {
pattern_filter_cursor->Pull(frame, context);
}
if (EvaluateFilter(evaluator, self_.expression_)) return true;
}
return false;
@ -2410,9 +2410,13 @@ bool Produce::ProduceCursor::Pull(Frame &frame, ExecutionContext &context) {
if (input_cursor_->Pull(frame, context)) {
// Produce should always yield the latest results.
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
storage::View::NEW);
for (auto named_expr : self_.named_expressions_) named_expr->Accept(evaluator);
storage::View::NEW, context.frame_change_collector);
for (auto *named_expr : self_.named_expressions_) {
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(named_expr->name_)) {
context.frame_change_collector->ResetTrackingValue(named_expr->name_);
}
named_expr->Accept(evaluator);
}
return true;
}
return false;
@ -3233,7 +3237,12 @@ class AccumulateCursor : public Cursor {
if (MustAbort(context)) throw HintedAbortError();
if (cache_it_ == cache_.end()) return false;
auto row_it = (cache_it_++)->begin();
for (const Symbol &symbol : self_.symbols_) frame[symbol] = *row_it++;
for (const Symbol &symbol : self_.symbols_) {
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) {
context.frame_change_collector->ResetTrackingValue(symbol.name());
}
frame[symbol] = *row_it++;
}
return true;
}
@ -3315,10 +3324,20 @@ class AggregateCursor : public Cursor {
if (aggregation_.empty()) {
auto *pull_memory = context.evaluation_context.memory;
// place default aggregation values on the frame
for (const auto &elem : self_.aggregations_)
for (const auto &elem : self_.aggregations_) {
frame[elem.output_sym] = DefaultAggregationOpValue(elem, pull_memory);
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(elem.output_sym.name())) {
context.frame_change_collector->ResetTrackingValue(elem.output_sym.name());
}
}
// place null as remember values on the frame
for (const Symbol &remember_sym : self_.remember_) frame[remember_sym] = TypedValue(pull_memory);
for (const Symbol &remember_sym : self_.remember_) {
frame[remember_sym] = TypedValue(pull_memory);
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(remember_sym.name())) {
context.frame_change_collector->ResetTrackingValue(remember_sym.name());
}
}
return true;
}
}
@ -3798,8 +3817,12 @@ class OrderByCursor : public Cursor {
"Number of values does not match the number of output symbols "
"in OrderBy");
auto output_sym_it = self_.output_symbols_.begin();
for (const TypedValue &output : cache_it_->remember) frame[*output_sym_it++] = output;
for (const TypedValue &output : cache_it_->remember) {
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(output_sym_it->name())) {
context.frame_change_collector->ResetTrackingValue(output_sym_it->name());
}
frame[*output_sym_it++] = output;
}
cache_it_++;
return true;
}
@ -4032,6 +4055,9 @@ class UnwindCursor : public Cursor {
if (input_value_it_ == input_value_.end()) continue;
frame[self_.output_symbol_] = *input_value_it_++;
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(self_.output_symbol_.name_)) {
context.frame_change_collector->ResetTrackingValue(self_.output_symbol_.name_);
}
return true;
}
}
@ -4161,11 +4187,17 @@ bool Union::UnionCursor::Pull(Frame &frame, ExecutionContext &context) {
// collect values from the left child
for (const auto &output_symbol : self_.left_symbols_) {
results[output_symbol.name()] = frame[output_symbol];
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(output_symbol.name())) {
context.frame_change_collector->ResetTrackingValue(output_symbol.name());
}
}
} else if (right_cursor_->Pull(frame, context)) {
// collect values from the right child
for (const auto &output_symbol : self_.right_symbols_) {
results[output_symbol.name()] = frame[output_symbol];
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(output_symbol.name())) {
context.frame_change_collector->ResetTrackingValue(output_symbol.name());
}
}
} else {
return false;
@ -4174,6 +4206,9 @@ bool Union::UnionCursor::Pull(Frame &frame, ExecutionContext &context) {
// put collected values on frame under union symbols
for (const auto &symbol : self_.union_symbols_) {
frame[symbol] = results[symbol.name()];
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) {
context.frame_change_collector->ResetTrackingValue(symbol.name());
}
}
return true;
}
@ -4238,9 +4273,12 @@ class CartesianCursor : public Cursor {
return false;
}
auto restore_frame = [&frame](const auto &symbols, const auto &restore_from) {
auto restore_frame = [&frame, &context](const auto &symbols, const auto &restore_from) {
for (const auto &symbol : symbols) {
frame[symbol] = restore_from[symbol.position()];
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) {
context.frame_change_collector->ResetTrackingValue(symbol.name());
}
}
};
@ -4318,6 +4356,10 @@ class OutputTableCursor : public Cursor {
if (current_row_ < rows_.size()) {
for (size_t i = 0; i < self_.output_symbols_.size(); ++i) {
frame[self_.output_symbols_[i]] = rows_[current_row_][i];
if (context.frame_change_collector &&
context.frame_change_collector->IsKeyTracked(self_.output_symbols_[i].name())) {
context.frame_change_collector->ResetTrackingValue(self_.output_symbols_[i].name());
}
}
current_row_++;
return true;
@ -4361,6 +4403,10 @@ class OutputTableStreamCursor : public Cursor {
MG_ASSERT(row->size() == self_->output_symbols_.size(), "Wrong number of columns in row!");
for (size_t i = 0; i < self_->output_symbols_.size(); ++i) {
frame[self_->output_symbols_[i]] = row->at(i);
if (context.frame_change_collector &&
context.frame_change_collector->IsKeyTracked(self_->output_symbols_[i].name())) {
context.frame_change_collector->ResetTrackingValue(self_->output_symbols_[i].name());
}
}
return true;
}
@ -4564,6 +4610,10 @@ class CallProcedureCursor : public Cursor {
field_name);
}
frame[self_->result_symbols_[i]] = std::move(result_it->second);
if (context.frame_change_collector &&
context.frame_change_collector->IsKeyTracked(self_->result_symbols_[i].name())) {
context.frame_change_collector->ResetTrackingValue(self_->result_symbols_[i].name());
}
}
++result_row_it_;
@ -4690,6 +4740,9 @@ class LoadCsvCursor : public Cursor {
frame[self_->row_var_] =
CsvRowToTypedMap(*row, csv::Reader::Header(reader_->GetHeader(), context.evaluation_context.memory));
}
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(self_->row_var_.name())) {
context.frame_change_collector->ResetTrackingValue(self_->row_var_.name());
}
return true;
}

View File

@ -0,0 +1,22 @@
#include <string>
#include "query/frontend/ast/ast.hpp"
#include "spdlog/spdlog.h"
namespace memgraph::utils {
// Get ID by which FrameChangeCollector struct can cache in_list.expression2_
inline std::optional<std::string> GetFrameChangeId(memgraph::query::InListOperator &in_list) {
if (in_list.expression2_->GetTypeInfo() == memgraph::query::ListLiteral::kType) {
std::stringstream ss;
ss << static_cast<const void *>(in_list.expression2_);
return ss.str();
}
if (in_list.expression2_->GetTypeInfo() == memgraph::query::Identifier::kType) {
auto *identifier = utils::Downcast<memgraph::query::Identifier>(in_list.expression2_);
return identifier->name_;
}
return {};
};
} // namespace memgraph::utils

View File

@ -219,3 +219,47 @@ Feature: List operators
| [[1], 2] |
| [3] |
| 4 |
Scenario: Unwind + InList test1
When executing query:
"""
UNWIND [[1,2], [3,4]] as l
RETURN 2 in l as x
"""
Then the result should be:
| x |
| true |
| false |
Scenario: Unwind + InList test2
When executing query:
"""
WITH [[1,2], [3,4]] as list
UNWIND list as l
RETURN 2 in l as x
"""
Then the result should be:
| x |
| true |
| false |
Scenario: Unwind + InList test3
Given an empty graph
And having executed
"""
CREATE ({id: 1}), ({id: 2}), ({id: 3}), ({id: 4})
"""
When executing query:
"""
WITH [1, 2, 3] as list
MATCH (n) WHERE n.id in list
WITH n
WITH n, [1, 2] as list
WHERE n.id in list
RETURN n.id as id
ORDER BY id;
"""
Then the result should be:
| id |
| 1 |
| 2 |