diff --git a/peerstore.go b/peerstore.go index 98cc460..9747984 100644 --- a/peerstore.go +++ b/peerstore.go @@ -46,6 +46,7 @@ type Peerstore interface { GetProtocols(peer.ID) ([]string, error) AddProtocols(peer.ID, ...string) error + SetProtocols(peer.ID, ...string) error SupportsProtocols(peer.ID, ...string) ([]string, error) } @@ -235,6 +236,18 @@ func (ps *peerstore) PeerInfo(p peer.ID) PeerInfo { } } +func (ps *peerstore) SetProtocols(p peer.ID, protos ...string) error { + ps.protolock.Lock() + defer ps.protolock.Unlock() + + protomap := make(map[string]struct{}) + for _, proto := range protos { + protomap[proto] = struct{}{} + } + + return ps.Put(p, "protocols", protomap) +} + func (ps *peerstore) AddProtocols(p peer.ID, protos ...string) error { ps.protolock.Lock() defer ps.protolock.Unlock() diff --git a/peerstore_test.go b/peerstore_test.go index 29cf6fa..6dbd8ab 100644 --- a/peerstore_test.go +++ b/peerstore_test.go @@ -220,6 +220,20 @@ func TestPeerstoreProtoStore(t *testing.T) { if supported[0] != "a" || supported[1] != "b" { t.Fatal("got wrong supported array: ", supported) } + + err = ps.SetProtocols(p1, "other") + if err != nil { + t.Fatal(err) + } + + supported, err = ps.SupportsProtocols(p1, "q", "w", "a", "y", "b") + if err != nil { + t.Fatal(err) + } + + if len(supported) != 0 { + t.Fatal("none of those protocols should have been supported") + } } func TestBasicPeerstore(t *testing.T) {