diff --git a/src/memory/query_memory_control.cpp b/src/memory/query_memory_control.cpp index 91730c900..5e569bd13 100644 --- a/src/memory/query_memory_control.cpp +++ b/src/memory/query_memory_control.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -36,14 +36,22 @@ namespace memgraph::memory { void QueriesMemoryControl::UpdateThreadToTransactionId(const std::thread::id &thread_id, uint64_t transaction_id) { auto accessor = thread_id_to_transaction_id.access(); - accessor.insert({thread_id, transaction_id}); + auto elem = accessor.find(thread_id); + if (elem == accessor.end()) { + accessor.insert({thread_id, {transaction_id, 1}}); + } else { + elem->transaction_id.cnt++; + } } void QueriesMemoryControl::EraseThreadToTransactionId(const std::thread::id &thread_id, uint64_t transaction_id) { auto accessor = thread_id_to_transaction_id.access(); auto elem = accessor.find(thread_id); MG_ASSERT(elem != accessor.end() && elem->transaction_id == transaction_id); - accessor.remove(thread_id); + elem->transaction_id.cnt--; + if (elem->transaction_id.cnt == 0) { + accessor.remove(thread_id); + } } void QueriesMemoryControl::TrackAllocOnCurrentThread(size_t size) { diff --git a/src/memory/query_memory_control.hpp b/src/memory/query_memory_control.hpp index 901917757..3852027a5 100644 --- a/src/memory/query_memory_control.hpp +++ b/src/memory/query_memory_control.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -78,9 +78,20 @@ class QueriesMemoryControl { bool IsThreadTracked(); private: + struct TransactionId { + uint64_t id; + uint64_t cnt; + + bool operator<(const TransactionId &other) const { return id < other.id; } + bool operator==(const TransactionId &other) const { return id == other.id; } + + bool operator<(uint64_t other) const { return id < other; } + bool operator==(uint64_t other) const { return id == other; } + }; + struct ThreadIdToTransactionId { std::thread::id thread_id; - uint64_t transaction_id; + TransactionId transaction_id; bool operator<(const ThreadIdToTransactionId &other) const { return thread_id < other.thread_id; } bool operator==(const ThreadIdToTransactionId &other) const { return thread_id == other.thread_id; } @@ -98,6 +109,9 @@ class QueriesMemoryControl { bool operator<(uint64_t other) const { return transaction_id < other; } bool operator==(uint64_t other) const { return transaction_id == other; } + + bool operator<(TransactionId other) const { return transaction_id < other.id; } + bool operator==(TransactionId other) const { return transaction_id == other.id; } }; utils::SkipList thread_id_to_transaction_id;