diff --git a/basic.go b/basic.go index 7ddf9e1..5d67edf 100644 --- a/basic.go +++ b/basic.go @@ -22,12 +22,7 @@ func NewBus() Bus { } } -func (b *bus) withNode(evtType interface{}, cb func(*node)) error { - typ := reflect.TypeOf(evtType) - if typ.Kind() != reflect.Ptr { - return errors.New("subscribe called with non-pointer type") - } - typ = typ.Elem() +func (b *bus) withNode(typ reflect.Type, cb func(*node)) error { path := typePath(typ) b.lk.Lock() @@ -45,8 +40,8 @@ func (b *bus) withNode(evtType interface{}, cb func(*node)) error { return nil } -func (b *bus) tryDropNode(evtType interface{}) { - path := typePath(reflect.TypeOf(evtType).Elem()) +func (b *bus) tryDropNode(typ reflect.Type) { + path := typePath(typ) b.lk.Lock() n, ok := b.nodes[path] @@ -67,18 +62,27 @@ func (b *bus) tryDropNode(evtType interface{}) { b.lk.Unlock() } -func (b *bus) Subscribe(evtType interface{}, _ ...SubOption) (s <-chan interface{}, c CancelFunc, err error) { - err = b.withNode(evtType, func(n *node) { - out, i := n.sub(0) - s = out +func (b *bus) Subscribe(typedChan interface{}, _ ...SubOption) (c CancelFunc, err error) { + refCh := reflect.ValueOf(typedChan) + typ := refCh.Type() + if typ.Kind() != reflect.Chan { + return nil, errors.New("expected a channel") + } + if typ.ChanDir() & reflect.SendDir == 0 { + return nil, errors.New("channel doesn't allow send") + } + + err = b.withNode(typ.Elem(), func(n *node) { + // when all subs are waiting on this channel, setting this to 1 doesn't + // really affect benchmarks + i := n.sub(refCh) c = func() { n.lk.Lock() delete(n.sinks, i) - close(out) tryDrop := len(n.sinks) == 0 && n.nEmitters == 0 n.lk.Unlock() if tryDrop { - b.tryDropNode(evtType) + b.tryDropNode(typ.Elem()) } } }) @@ -86,7 +90,13 @@ func (b *bus) Subscribe(evtType interface{}, _ ...SubOption) (s <-chan interface } func (b *bus) Emitter(evtType interface{}, _ ...EmitterOption) (e EmitFunc, c CancelFunc, err error) { - err = b.withNode(evtType, func(n *node) { + typ := reflect.TypeOf(evtType) + if typ.Kind() != reflect.Ptr { + return nil, nil, errors.New("emitter called with non-pointer type") + } + typ = typ.Elem() + + err = b.withNode(typ, func(n *node) { atomic.AddInt32(&n.nEmitters, 1) closed := false @@ -100,7 +110,7 @@ func (b *bus) Emitter(evtType interface{}, _ ...EmitterOption) (e EmitFunc, c Ca c = func() { closed = true if atomic.AddInt32(&n.nEmitters, -1) == 0 { - b.tryDropNode(evtType) + b.tryDropNode(typ) } } }) @@ -124,34 +134,33 @@ type node struct { // TODO: we could make emit a bit faster by making this into an array, but // it doesn't seem needed for now - sinks map[int]chan interface{} + sinks map[int]reflect.Value } func newNode(typ reflect.Type) *node { return &node{ typ: typ, - sinks: map[int]chan interface{}{}, + sinks: map[int]reflect.Value{}, } } -func (n *node) sub(buf int) (chan interface{}, int) { - out := make(chan interface{}, buf) +func (n *node) sub(outChan reflect.Value) int { i := n.sinkC n.sinkC++ - n.sinks[i] = out - return out, i + n.sinks[i] = outChan + return i } func (n *node) emit(event interface{}) { - etype := reflect.TypeOf(event) - if etype != n.typ { - panic(fmt.Sprintf("Emit called with wrong type. expected: %s, got: %s", n.typ, etype)) + eval := reflect.ValueOf(event) + if eval.Type() != n.typ { + panic(fmt.Sprintf("Emit called with wrong type. expected: %s, got: %s", n.typ, eval.Type())) } n.lk.RLock() for _, ch := range n.sinks { - ch <- event + ch.Send(eval) } n.lk.RUnlock() } diff --git a/basic_test.go b/basic_test.go index 5562540..f77e5d7 100644 --- a/basic_test.go +++ b/basic_test.go @@ -12,7 +12,8 @@ type EventB int func TestEmit(t *testing.T) { bus := NewBus() - events, cancel, err := bus.Subscribe(new(EventA)) + events := make(chan EventA) + cancel, err := bus.Subscribe(events) if err != nil { t.Fatal(err) } @@ -33,7 +34,8 @@ func TestEmit(t *testing.T) { func TestSub(t *testing.T) { bus := NewBus() - events, cancel, err := bus.Subscribe(new(EventB)) + events := make(chan EventB) + cancel, err := bus.Subscribe(events) if err != nil { t.Fatal(err) } @@ -45,7 +47,7 @@ func TestSub(t *testing.T) { go func() { defer cancel() - event = (<-events).(EventB) + event = <-events wait.Done() }() @@ -114,7 +116,7 @@ func TestClosingRaces(t *testing.T) { lk.RLock() defer lk.RUnlock() - _, cancel, _ := b.Subscribe(new(EventA)) + cancel, _ := b.Subscribe(make(chan EventA)) time.Sleep(10 * time.Millisecond) cancel() @@ -157,14 +159,15 @@ func TestSubMany(t *testing.T) { for i := 0; i < n; i++ { go func() { - events, cancel, err := bus.Subscribe(new(EventB)) + events := make(chan EventB) + cancel, err := bus.Subscribe(events) if err != nil { panic(err) } defer cancel() ready.Done() - atomic.AddInt32(&r, int32((<-events).(EventB))) + atomic.AddInt32(&r, int32(<-events)) wait.Done() }() } @@ -185,6 +188,42 @@ func TestSubMany(t *testing.T) { } } +/*func TestSendTo(t *testing.T) { + testSendTo(t, 1000) +} + +func testSendTo(t testing.TB, msgs int) { + bus := NewBus() + + go func() { + emit, cancel, err := bus.Emitter(new(EventB)) + if err != nil { + panic(err) + } + defer cancel() + + for i := 0; i < msgs; i++ { + emit(EventB(97)) + } + }() + + ch := make(chan EventB) + cancel, err := bus.SendTo(ch) + if err != nil { + return + } + defer cancel() + + r := 0 + for i := 0; i < msgs; i++ { + r += int(<-ch) + } + + if int(r) != 97 * msgs { + t.Fatal("got wrong result") + } +}*/ + func testMany(t testing.TB, subs, emits, msgs int) { bus := NewBus() @@ -197,7 +236,8 @@ func testMany(t testing.TB, subs, emits, msgs int) { for i := 0; i < subs; i++ { go func() { - events, cancel, err := bus.Subscribe(new(EventB)) + events := make(chan EventB) + cancel, err := bus.Subscribe(events) if err != nil { panic(err) } @@ -205,7 +245,7 @@ func testMany(t testing.TB, subs, emits, msgs int) { ready.Done() for i := 0; i < emits * msgs; i++ { - atomic.AddInt64(&r, int64((<-events).(EventB))) + atomic.AddInt64(&r, int64(<-events)) } wait.Done() }() @@ -277,6 +317,12 @@ func BenchmarkMs1e0m6(b *testing.B) { testMany(b, 10, 1, 1000000) } +func BenchmarkMs0e0m6(b *testing.B) { + b.N = 1000000 + b.ReportAllocs() + testMany(b, 1, 1, 1000000) +} + func BenchmarkMs0e6m0(b *testing.B) { b.N = 1000000 b.ReportAllocs() diff --git a/interface.go b/interface.go index 80e9585..d8b78fe 100644 --- a/interface.go +++ b/interface.go @@ -18,8 +18,7 @@ type Bus interface { // defer cancel() // // evt := (<-sub).(os.Signal) // guaranteed to be safe - Subscribe(eventType interface{}, opts ...SubOption) (<-chan interface{}, CancelFunc, error) - + Subscribe(typedChan interface{}, opts ...SubOption) (CancelFunc, error) Emitter(eventType interface{}, opts ...EmitterOption) (EmitFunc, CancelFunc, error) } @@ -31,4 +30,3 @@ type Bus interface { type EmitFunc func(event interface{}) type CancelFunc func() -