diff --git a/src/query/frame_change.hpp b/src/query/frame_change.hpp index 32fe1f36e..7baf1fe41 100644 --- a/src/query/frame_change.hpp +++ b/src/query/frame_change.hpp @@ -8,41 +8,42 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. +#include +#include #include "query/typed_value.hpp" +#include "utils/fnv.hpp" #include "utils/memory.hpp" #include "utils/pmr/unordered_map.hpp" #include "utils/pmr/vector.hpp" namespace memgraph::query { // Key is hash output, value is vector of unique elements -using CachedType = utils::pmr::unordered_map>; +using CachedType = utils::pmr::unordered_map>; struct CachedValue { + using allocator_type = utils::Allocator; + // Cached value, this can be probably templateized CachedType cache_; - explicit CachedValue(utils::MemoryResource *mem) : cache_(mem) {} + explicit CachedValue(utils::MemoryResource *mem) : cache_{mem} {}; + CachedValue(const CachedValue &other, utils::MemoryResource *mem) : cache_(other.cache_, mem) {} + CachedValue(CachedValue &&other, utils::MemoryResource *mem) : cache_(std::move(other.cache_), mem){}; - CachedValue(CachedType &&cache, memgraph::utils::MemoryResource *memory) : cache_(std::move(cache), memory) {} + CachedValue(CachedValue &&other) noexcept : CachedValue(std::move(other), other.GetMemoryResource()) {} - CachedValue(const CachedValue &other, memgraph::utils::MemoryResource *memory) : cache_(other.cache_, memory) {} + CachedValue(const CachedValue &other) + : CachedValue(other, std::allocator_traits::select_on_container_copy_construction( + other.GetMemoryResource()) + .GetMemoryResource()) {} - CachedValue(CachedValue &&other, memgraph::utils::MemoryResource *memory) : cache_(std::move(other.cache_), memory) {} - - CachedValue(CachedValue &&other) noexcept = delete; - - /// Copy construction without memgraph::utils::MemoryResource is not allowed. - CachedValue(const CachedValue &) = delete; + utils::MemoryResource *GetMemoryResource() const { return cache_.get_allocator().GetMemoryResource(); } CachedValue &operator=(const CachedValue &) = delete; CachedValue &operator=(CachedValue &&) = delete; ~CachedValue() = default; - memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { - return cache_.get_allocator().GetMemoryResource(); - } - bool CacheValue(const TypedValue &maybe_list) { if (!maybe_list.IsList()) { return false; @@ -70,7 +71,7 @@ struct CachedValue { } private: - static bool IsValueInVec(const std::vector &vec_values, const TypedValue &value) { + static bool IsValueInVec(const utils::pmr::vector &vec_values, const TypedValue &value) { return std::any_of(vec_values.begin(), vec_values.end(), [&value](auto &vec_value) { const auto is_value_equal = vec_value == value; if (is_value_equal.IsNull()) return false; @@ -82,35 +83,70 @@ struct CachedValue { // Class tracks keys for which user can cache values which help with faster search or faster retrieval // in the future. Used for IN LIST operator. class FrameChangeCollector { + /** Allocator type so that STL containers are aware that we need one */ + using allocator_type = utils::Allocator; + public: - explicit FrameChangeCollector() : tracked_values_(&memory_resource_){}; + explicit FrameChangeCollector(utils::MemoryResource *mem = utils::NewDeleteResource()) : tracked_values_{mem} {} + + FrameChangeCollector(FrameChangeCollector &&other, utils::MemoryResource *mem) + : tracked_values_(std::move(other.tracked_values_), mem) {} + FrameChangeCollector(const FrameChangeCollector &other, utils::MemoryResource *mem) + : tracked_values_(other.tracked_values_, mem) {} + + FrameChangeCollector(const FrameChangeCollector &other) + : FrameChangeCollector(other, std::allocator_traits::select_on_container_copy_construction( + other.GetMemoryResource()) + .GetMemoryResource()){}; + + FrameChangeCollector(FrameChangeCollector &&other) noexcept + : FrameChangeCollector(std::move(other), other.GetMemoryResource()) {} + + /** Copy assign other, utils::MemoryResource of `this` is used */ + FrameChangeCollector &operator=(const FrameChangeCollector &) = default; + + /** Move assign other, utils::MemoryResource of `this` is used. */ + FrameChangeCollector &operator=(FrameChangeCollector &&) noexcept = default; + + utils::MemoryResource *GetMemoryResource() const { return tracked_values_.get_allocator().GetMemoryResource(); } CachedValue &AddTrackingKey(const std::string &key) { - const auto &[it, _] = tracked_values_.emplace(key, tracked_values_.get_allocator().GetMemoryResource()); + const auto &[it, _] = tracked_values_.emplace( + std::piecewise_construct, std::forward_as_tuple(utils::pmr::string(key, utils::NewDeleteResource())), + std::forward_as_tuple()); return it->second; } - bool IsKeyTracked(const std::string &key) const { return tracked_values_.contains(key); } + bool IsKeyTracked(const std::string &key) const { + return tracked_values_.contains(utils::pmr::string(key, utils::NewDeleteResource())); + } bool IsKeyValueCached(const std::string &key) const { - return IsKeyTracked(key) && !tracked_values_.at(key).cache_.empty(); + return IsKeyTracked(key) && !tracked_values_.at(utils::pmr::string(key, utils::NewDeleteResource())).cache_.empty(); } bool ResetTrackingValue(const std::string &key) { - if (!tracked_values_.contains(key)) { + if (!tracked_values_.contains(utils::pmr::string(key, utils::NewDeleteResource()))) { return false; } - tracked_values_.erase(key); + tracked_values_.erase(utils::pmr::string(key, utils::NewDeleteResource())); AddTrackingKey(key); return true; } - CachedValue &GetCachedValue(const std::string &key) { return tracked_values_.at(key); } + CachedValue &GetCachedValue(const std::string &key) { + return tracked_values_.at(utils::pmr::string(key, utils::NewDeleteResource())); + } bool IsTrackingValues() const { return !tracked_values_.empty(); } + ~FrameChangeCollector() = default; + private: - utils::MonotonicBufferResource memory_resource_{0}; - memgraph::utils::pmr::unordered_map tracked_values_; + struct PmrStringHash { + size_t operator()(const utils::pmr::string &key) const { return utils::Fnv(key); } + }; + + utils::pmr::unordered_map tracked_values_; }; } // namespace memgraph::query