diff --git a/concurrent/Lock.go b/concurrent/Lock.go index b2df61b..ca896ee 100644 --- a/concurrent/Lock.go +++ b/concurrent/Lock.go @@ -6,9 +6,20 @@ type ( Unlock() } RWLock interface { - Lock() - Unlock() + Lock RLock() RUnlock() } ) + +func WithLock[R any](lock Lock, f func() R) R { + lock.Lock() + defer lock.Unlock() + return f() +} + +func WithRLock[R any](lock RWLock, f func() R) R { + lock.RLock() + defer lock.RUnlock() + return f() +} diff --git a/concurrent/ReentrantLock.go b/concurrent/ReentrantLock.go index 38d9fcf..06313ee 100644 --- a/concurrent/ReentrantLock.go +++ b/concurrent/ReentrantLock.go @@ -1,26 +1,54 @@ package concurrent -import "sync" +import ( + "fmt" + "github.com/tursom/GoCollections/exceptions" + "sync" +) type ReentrantLock struct { - sync.Mutex - cnt int32 + lock sync.Mutex + cond sync.Cond + recursion int32 + host int64 } -func (l *ReentrantLock) Lock() { - defer func() { - r := recover() - if r != nil { - l.cnt++ - } - }() - - l.Mutex.Lock() +func NewReentrantLock() *ReentrantLock { + res := &ReentrantLock{ + recursion: 0, + host: 0, + } + res.cond = *sync.NewCond(&res.lock) + return res } -func (l *ReentrantLock) Unlock() { - l.cnt-- - if l.cnt == 0 { - l.Mutex.Unlock() +func (rt *ReentrantLock) Lock() { + id := GetGoroutineID() + rt.lock.Lock() + defer rt.lock.Unlock() + + if rt.host == id { + rt.recursion++ + return + } + + for rt.recursion != 0 { + rt.cond.Wait() + } + rt.host = id + rt.recursion = 1 +} + +func (rt *ReentrantLock) Unlock() { + rt.lock.Lock() + defer rt.lock.Unlock() + + if rt.recursion == 0 || rt.host != GetGoroutineID() { + panic(exceptions.NewWrongCallHostException(fmt.Sprintf("the wrong call host: (%d); current_id: %d; recursion: %d", rt.host, GetGoroutineID(), rt.recursion))) + } + + rt.recursion-- + if rt.recursion == 0 { + rt.cond.Signal() } } diff --git a/concurrent/ReentrantRWLock.go b/concurrent/ReentrantRWLock.go new file mode 100644 index 0000000..b8e8143 --- /dev/null +++ b/concurrent/ReentrantRWLock.go @@ -0,0 +1,71 @@ +package concurrent + +import ( + "fmt" + "github.com/tursom/GoCollections/exceptions" + "sync" +) + +type ReentrantRWLock struct { + lock sync.Mutex + rlock sync.RWMutex + cond sync.Cond + recursion int32 + host int64 +} + +func NewReentrantRWLock() *ReentrantRWLock { + res := &ReentrantRWLock{ + recursion: 0, + host: 0, + } + res.cond = *sync.NewCond(&res.lock) + return res +} + +func (rt *ReentrantRWLock) Lock() { + id := GetGoroutineID() + rt.lock.Lock() + defer rt.lock.Unlock() + + if rt.host == id { + rt.recursion++ + return + } + + for rt.recursion != 0 { + rt.cond.Wait() + } + rt.host = id + rt.recursion = 1 + rt.rlock.Lock() +} + +func (rt *ReentrantRWLock) Unlock() { + rt.lock.Lock() + defer rt.lock.Unlock() + + if rt.recursion == 0 || rt.host != GetGoroutineID() { + panic(exceptions.NewWrongCallHostException(fmt.Sprintf("the wrong call host: (%d); current_id: %d; recursion: %d", rt.host, GetGoroutineID(), rt.recursion))) + } + + rt.recursion-- + if rt.recursion == 0 { + rt.rlock.Unlock() + rt.cond.Signal() + } +} + +func (rt *ReentrantRWLock) RLock() { + if rt.host == GetGoroutineID() { + return + } + rt.rlock.RLock() +} + +func (rt *ReentrantRWLock) RUnlock() { + if rt.host == GetGoroutineID() { + return + } + rt.rlock.RUnlock() +} diff --git a/concurrent/ReentrantRWLock_test.go b/concurrent/ReentrantRWLock_test.go new file mode 100644 index 0000000..82d96ac --- /dev/null +++ b/concurrent/ReentrantRWLock_test.go @@ -0,0 +1,27 @@ +package concurrent + +import ( + "fmt" + "testing" + "time" +) + +func TestReentrantRWLock_RLock(t *testing.T) { + lock := NewReentrantRWLock() + lock.Lock() + defer lock.Unlock() + + go func() { + lock.Lock() + defer lock.Unlock() + fmt.Println("get lock") + }() + time.Sleep(time.Second) + lock.Lock() + defer lock.Unlock() + lock.RLock() + defer lock.RUnlock() + lock.RLock() + defer lock.RUnlock() + fmt.Println("release lock") +} diff --git a/concurrent/Util.go b/concurrent/Util.go new file mode 100644 index 0000000..be332ef --- /dev/null +++ b/concurrent/Util.go @@ -0,0 +1,7 @@ +package concurrent + +import "github.com/petermattis/goid" + +func GetGoroutineID() int64 { + return goid.Get() +} diff --git a/exceptions/WrongCallHostException.go b/exceptions/WrongCallHostException.go new file mode 100644 index 0000000..3fc288a --- /dev/null +++ b/exceptions/WrongCallHostException.go @@ -0,0 +1,11 @@ +package exceptions + +type WrongCallHostException struct { + RuntimeException +} + +func NewWrongCallHostException(message string) WrongCallHostException { + return WrongCallHostException{ + NewRuntimeException(nil, message, nil), + } +} diff --git a/go.mod b/go.mod index 48b649a..cf02b81 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/tursom/GoCollections go 1.18 + +require github.com/petermattis/goid v0.0.0-20220302125637-5f11c28912df