From 525a0e67fe8b9e786ba4a4f3a3615a50eca7799c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=C5=81ukasz=20Magiera?= <magik6k@gmail.com>
Date: Sat, 22 Jun 2019 12:05:03 +0200
Subject: [PATCH] fix close deadlock and Sub type error

---
 basic.go      | 28 +++++++++++++++++++++++-----
 basic_test.go | 21 +++++++++++++++++++++
 2 files changed, 44 insertions(+), 5 deletions(-)

diff --git a/basic.go b/basic.go
index ca102b5..2fe3798 100644
--- a/basic.go
+++ b/basic.go
@@ -104,9 +104,21 @@ func (s *sub) Out() <-chan interface{} {
 }
 
 func (s *sub) Close() error {
-	close(s.ch)
+	stop := make(chan struct{})
+	go func() {
+		for {
+			select {
+			case <-s.ch:
+			case <-stop:
+				close(s.ch)
+				return
+			}
+		}
+	}()
+
 	for _, n := range s.nodes {
 		n.lk.Lock()
+
 		for i := 0; i < len(n.sinks); i++ {
 			if n.sinks[i] == s.ch {
 				n.sinks[i], n.sinks[len(n.sinks)-1] = n.sinks[len(n.sinks)-1], nil
@@ -114,12 +126,16 @@ func (s *sub) Close() error {
 				break
 			}
 		}
+
 		tryDrop := len(n.sinks) == 0 && atomic.LoadInt32(&n.nEmitters) == 0
+
 		n.lk.Unlock()
+
 		if tryDrop {
 			s.dropper(n.typ)
 		}
 	}
+	close(stop)
 	return nil
 }
 
@@ -148,12 +164,14 @@ func (b *basicBus) Subscribe(evtTypes interface{}, opts ...event.SubscriptionOpt
 		dropper: b.tryDropNode,
 	}
 
-	for i, etyp := range types {
-		typ := reflect.TypeOf(etyp)
-
-		if typ.Kind() != reflect.Ptr {
+	for _, etyp := range types {
+		if reflect.TypeOf(etyp).Kind() != reflect.Ptr {
 			return nil, errors.New("subscribe called with non-pointer type")
 		}
+	}
+
+	for i, etyp := range types {
+		typ := reflect.TypeOf(etyp)
 
 		err = b.withNode(typ.Elem(), func(n *node) {
 			n.sinks = append(n.sinks, out.ch)
diff --git a/basic_test.go b/basic_test.go
index fc23e61..62b84eb 100644
--- a/basic_test.go
+++ b/basic_test.go
@@ -297,6 +297,27 @@ func TestStateful(t *testing.T) {
 	}
 }
 
+func TestCloseBlocking(t *testing.T) {
+	bus := NewBus()
+	em, err := bus.Emitter(new(EventB))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	sub, err := bus.Subscribe(new(EventB))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	go func() {
+		em.Emit(EventB(159))
+	}()
+
+	time.Sleep(10 * time.Millisecond) // make sure that emit is blocked
+
+	sub.Close() // cancel sub
+}
+
 func testMany(t testing.TB, subs, emits, msgs int, stateful bool) {
 	if race.WithRace() && subs+emits > 5000 {
 		t.SkipNow()