go-libp2p-peerstore/pstoreds/protobook.go

220 lines
3.9 KiB
Go
Raw Normal View History

2019-05-17 18:13:18 +08:00
package pstoreds
import (
"errors"
2019-05-17 18:13:18 +08:00
"fmt"
"sync"
"github.com/libp2p/go-libp2p-core/peer"
2019-05-17 18:13:18 +08:00
pstore "github.com/libp2p/go-libp2p-core/peerstore"
2019-05-17 18:13:18 +08:00
)
type protoSegment struct {
sync.RWMutex
2019-05-17 18:13:18 +08:00
}
type protoSegments [256]*protoSegment
2019-05-17 18:13:18 +08:00
func (s *protoSegments) get(p peer.ID) *protoSegment {
return s[byte(p[len(p)-1])]
2019-05-17 18:13:18 +08:00
}
var errTooManyProtocols = errors.New("too many protocols")
type ProtoBookOption func(*dsProtoBook) error
func WithMaxProtocols(num int) ProtoBookOption {
return func(pb *dsProtoBook) error {
pb.maxProtos = num
return nil
}
}
type dsProtoBook struct {
segments protoSegments
meta pstore.PeerMetadata
maxProtos int
2019-05-17 18:13:18 +08:00
}
var _ pstore.ProtoBook = (*dsProtoBook)(nil)
2019-05-17 18:13:18 +08:00
func NewProtoBook(meta pstore.PeerMetadata, opts ...ProtoBookOption) (*dsProtoBook, error) {
pb := &dsProtoBook{
meta: meta,
segments: func() (ret protoSegments) {
for i := range ret {
ret[i] = &protoSegment{}
}
return ret
}(),
maxProtos: 1024,
}
for _, opt := range opts {
if err := opt(pb); err != nil {
return nil, err
}
}
return pb, nil
2019-05-17 18:13:18 +08:00
}
func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error {
if err := p.Validate(); err != nil {
return err
}
if len(protos) > pb.maxProtos {
return errTooManyProtocols
}
2019-05-17 18:13:18 +08:00
protomap := make(map[string]struct{}, len(protos))
for _, proto := range protos {
protomap[proto] = struct{}{}
}
s := pb.segments.get(p)
s.Lock()
defer s.Unlock()
2019-05-17 18:13:18 +08:00
return pb.meta.Put(p, "protocols", protomap)
}
func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error {
if err := p.Validate(); err != nil {
return err
}
s := pb.segments.get(p)
s.Lock()
defer s.Unlock()
2019-05-17 18:13:18 +08:00
pmap, err := pb.getProtocolMap(p)
if err != nil {
return err
}
if len(pmap)+len(protos) > pb.maxProtos {
return errTooManyProtocols
}
2019-05-17 18:13:18 +08:00
for _, proto := range protos {
pmap[proto] = struct{}{}
}
return pb.meta.Put(p, "protocols", pmap)
}
func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) {
if err := p.Validate(); err != nil {
return nil, err
}
s := pb.segments.get(p)
s.RLock()
defer s.RUnlock()
2019-05-17 18:13:18 +08:00
pmap, err := pb.getProtocolMap(p)
if err != nil {
return nil, err
}
res := make([]string, 0, len(pmap))
for proto := range pmap {
res = append(res, proto)
}
return res, nil
}
func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) {
if err := p.Validate(); err != nil {
return nil, err
}
s := pb.segments.get(p)
s.RLock()
defer s.RUnlock()
2019-05-17 18:13:18 +08:00
pmap, err := pb.getProtocolMap(p)
if err != nil {
return nil, err
}
res := make([]string, 0, len(protos))
for _, proto := range protos {
if _, ok := pmap[proto]; ok {
res = append(res, proto)
}
}
return res, nil
}
2020-05-11 15:37:08 +08:00
func (pb *dsProtoBook) FirstSupportedProtocol(p peer.ID, protos ...string) (string, error) {
2020-05-08 14:40:14 +08:00
if err := p.Validate(); err != nil {
2020-05-11 15:37:08 +08:00
return "", err
2020-05-08 14:40:14 +08:00
}
s := pb.segments.get(p)
s.RLock()
defer s.RUnlock()
pmap, err := pb.getProtocolMap(p)
if err != nil {
2020-05-11 15:37:08 +08:00
return "", err
2020-05-08 14:40:14 +08:00
}
for _, proto := range protos {
if _, ok := pmap[proto]; ok {
2020-05-11 15:37:08 +08:00
return proto, nil
2020-05-08 14:40:14 +08:00
}
}
2020-05-11 15:37:08 +08:00
return "", nil
2020-05-08 14:40:14 +08:00
}
func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...string) error {
if err := p.Validate(); err != nil {
return err
}
s := pb.segments.get(p)
s.Lock()
defer s.Unlock()
pmap, err := pb.getProtocolMap(p)
if err != nil {
return err
}
if len(pmap) == 0 {
// nothing to do.
return nil
}
for _, proto := range protos {
delete(pmap, proto)
}
return pb.meta.Put(p, "protocols", pmap)
}
2019-05-17 18:13:18 +08:00
func (pb *dsProtoBook) getProtocolMap(p peer.ID) (map[string]struct{}, error) {
iprotomap, err := pb.meta.Get(p, "protocols")
switch err {
default:
return nil, err
case pstore.ErrNotFound:
return make(map[string]struct{}), nil
case nil:
cast, ok := iprotomap.(map[string]struct{})
if !ok {
return nil, fmt.Errorf("stored protocol set was not a map")
}
return cast, nil
}
}
2021-10-23 20:31:58 +08:00
func (pb *dsProtoBook) RemovePeer(p peer.ID) {
pb.meta.RemovePeer(p)
}