go-libp2p-resource-manager/allowlist.go
2022-06-30 09:55:46 -07:00

216 lines
4.8 KiB
Go

package rcmgr
import (
"bytes"
"errors"
"fmt"
"net"
"sync"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
type Allowlist struct {
mu sync.RWMutex
// a simple structure of lists of networks. There is probably a faster way
// to check if an IP address is in this network than iterating over this
// list, but this is good enough for small numbers of networks (<1_000).
// Analyze the benchmark before trying to optimize this.
// Any peer with these IPs are allowed
allowedNetworks []*net.IPNet
// Only the specified peers can use these IPs
allowedPeerByNetwork map[peer.ID][]*net.IPNet
}
// WithAllowlistedMultiaddrs sets the multiaddrs to be in the allowlist
func WithAllowlistedMultiaddrs(mas []multiaddr.Multiaddr) Option {
return func(rm *resourceManager) error {
for _, ma := range mas {
err := rm.allowlist.Add(ma)
if err != nil {
return err
}
}
return nil
}
}
func newAllowlist() Allowlist {
return Allowlist{
allowedPeerByNetwork: make(map[peer.ID][]*net.IPNet),
}
}
func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, peer.ID, error) {
var ipString string
var mask string
var allowedPeerStr string
var allowedPeer peer.ID
var isIPV4 bool
multiaddr.ForEach(ma, func(c multiaddr.Component) bool {
if c.Protocol().Code == multiaddr.P_IP4 || c.Protocol().Code == multiaddr.P_IP6 {
isIPV4 = c.Protocol().Code == multiaddr.P_IP4
ipString = c.Value()
}
if c.Protocol().Code == multiaddr.P_IPCIDR {
mask = c.Value()
}
if c.Protocol().Code == multiaddr.P_P2P {
allowedPeerStr = c.Value()
}
return ipString == "" || mask == "" || allowedPeerStr == ""
})
if ipString == "" {
return nil, allowedPeer, errors.New("missing ip address")
}
if allowedPeerStr != "" {
var err error
allowedPeer, err = peer.Decode(allowedPeerStr)
if err != nil {
return nil, allowedPeer, fmt.Errorf("failed to decode allowed peer: %w", err)
}
}
if mask == "" {
ip := net.ParseIP(ipString)
if ip == nil {
return nil, allowedPeer, errors.New("invalid ip address")
}
var mask net.IPMask
if isIPV4 {
mask = net.CIDRMask(32, 32)
} else {
mask = net.CIDRMask(128, 128)
}
net := &net.IPNet{IP: ip, Mask: mask}
return net, allowedPeer, nil
}
_, ipnet, err := net.ParseCIDR(ipString + "/" + mask)
return ipnet, allowedPeer, err
}
// Add takes a multiaddr and adds it to the allowlist. The multiaddr should be
// an ip address of the peer with or without a `/p2p` protocol.
// e.g. /ip4/1.2.3.4/p2p/QmFoo, /ip4/1.2.3.4, and /ip4/1.2.3.0/ipcidr/24 are valid.
// /p2p/QmFoo is not valid.
func (al *Allowlist) Add(ma multiaddr.Multiaddr) error {
ipnet, allowedPeer, err := toIPNet(ma)
if err != nil {
return err
}
al.mu.Lock()
defer al.mu.Unlock()
if allowedPeer != peer.ID("") {
// We have a peerID constraint
if al.allowedPeerByNetwork == nil {
al.allowedPeerByNetwork = make(map[peer.ID][]*net.IPNet)
}
al.allowedPeerByNetwork[allowedPeer] = append(al.allowedPeerByNetwork[allowedPeer], ipnet)
} else {
al.allowedNetworks = append(al.allowedNetworks, ipnet)
}
return nil
}
func (al *Allowlist) Remove(ma multiaddr.Multiaddr) error {
ipnet, allowedPeer, err := toIPNet(ma)
if err != nil {
return err
}
al.mu.Lock()
defer al.mu.Unlock()
ipNetList := al.allowedNetworks
if allowedPeer != "" {
// We have a peerID constraint
ipNetList = al.allowedPeerByNetwork[allowedPeer]
}
if ipNetList == nil {
return nil
}
i := len(ipNetList)
for i > 0 {
i--
if ipNetList[i].IP.Equal(ipnet.IP) && bytes.Equal(ipNetList[i].Mask, ipnet.Mask) {
// swap remove
ipNetList[i] = ipNetList[len(ipNetList)-1]
ipNetList = ipNetList[:len(ipNetList)-1]
// We only remove one thing
break
}
}
if allowedPeer != "" {
al.allowedPeerByNetwork[allowedPeer] = ipNetList
} else {
al.allowedNetworks = ipNetList
}
return nil
}
func (al *Allowlist) Allowed(ma multiaddr.Multiaddr) bool {
ip, err := manet.ToIP(ma)
if err != nil {
return false
}
al.mu.RLock()
defer al.mu.RUnlock()
for _, network := range al.allowedNetworks {
if network.Contains(ip) {
return true
}
}
for _, allowedNetworks := range al.allowedPeerByNetwork {
for _, network := range allowedNetworks {
if network.Contains(ip) {
return true
}
}
}
return false
}
func (al *Allowlist) AllowedPeerAndMultiaddr(peerID peer.ID, ma multiaddr.Multiaddr) bool {
ip, err := manet.ToIP(ma)
if err != nil {
return false
}
al.mu.RLock()
defer al.mu.RUnlock()
for _, network := range al.allowedNetworks {
if network.Contains(ip) {
// We found a match that isn't constrained by a peerID
return true
}
}
if expectedNetworks, ok := al.allowedPeerByNetwork[peerID]; ok {
for _, expectedNetwork := range expectedNetworks {
if expectedNetwork.Contains(ip) {
return true
}
}
}
return false
}