diff --git a/src/utils/memory.hpp b/src/utils/memory.hpp index 75ff7a69f..82af8b420 100644 --- a/src/utils/memory.hpp +++ b/src/utils/memory.hpp @@ -569,4 +569,37 @@ class SynchronizedPoolResource final : public MemoryResource { } }; +class LimitedMemoryResource final : public utils::MemoryResource { + public: + explicit LimitedMemoryResource(utils::MemoryResource *memory, + size_t max_allocated_bytes) + : memory_(memory), max_allocated_bytes_(max_allocated_bytes) {} + + size_t GetAllocatedBytes() const noexcept { + return max_allocated_bytes_ - available_bytes_; + } + + private: + utils::MemoryResource *memory_; + size_t max_allocated_bytes_; + size_t available_bytes_{max_allocated_bytes_}; + + void *DoAllocate(size_t bytes, size_t alignment) override { + if (bytes > available_bytes_) + throw utils::BadAlloc("Memory allocation limit exceeded!"); + available_bytes_ -= bytes; + return memory_->Allocate(bytes, alignment); + } + + void DoDeallocate(void *p, size_t bytes, size_t alignment) override { + CHECK(available_bytes_ + bytes > available_bytes_); + available_bytes_ += bytes; + return memory_->Deallocate(p, bytes, alignment); + } + + bool DoIsEqual(const MemoryResource &other) const noexcept override { + return this == &other; + } +}; + } // namespace utils