diff --git a/kv/storage/mem_storage.go b/kv/storage/mem_storage.go index c6e56b50..38c054be 100644 --- a/kv/storage/mem_storage.go +++ b/kv/storage/mem_storage.go @@ -35,7 +35,7 @@ func (s *MemStorage) Stop() error { } func (s *MemStorage) Reader(ctx *kvrpcpb.Context) (StorageReader, error) { - return &memReader{s}, nil + return &memReader{s, 0}, nil } func (s *MemStorage) Write(ctx *kvrpcpb.Context, batch []Modify) error { @@ -131,7 +131,8 @@ func (s *MemStorage) Len(cf string) int { // memReader is a StorageReader which reads from a MemStorage. type memReader struct { - inner *MemStorage + inner *MemStorage + iterCount int } func (mr *memReader) GetCF(cf string, key []byte) ([]byte, error) { @@ -168,18 +169,24 @@ func (mr *memReader) IterCF(cf string) engine_util.DBIterator { return nil } + mr.iterCount += 1 min := data.Min() if min == nil { - return &memIter{data, memItem{}} + return &memIter{data, memItem{}, mr} } - return &memIter{data, min.(memItem)} + return &memIter{data, min.(memItem), mr} } -func (r *memReader) Close() {} +func (r *memReader) Close() { + if r.iterCount > 0 { + panic("Unclosed iterator") + } +} type memIter struct { - data *llrb.LLRB - item memItem + data *llrb.LLRB + item memItem + reader *memReader } func (it *memIter) Item() engine_util.DBItem { @@ -212,7 +219,9 @@ func (it *memIter) Seek(key []byte) { }) } -func (it *memIter) Close() {} +func (it *memIter) Close() { + it.reader.iterCount -= 1 +} type memItem struct { key []byte