PR comments

This commit is contained in:
Marco Munizaga 2022-06-14 11:46:26 -07:00
parent 297cd00321
commit 84ca9d337a
3 changed files with 47 additions and 53 deletions

View File

@ -3,6 +3,7 @@ package rcmgr
import (
"bytes"
"errors"
"fmt"
"net"
"github.com/libp2p/go-libp2p-core/peer"
@ -23,16 +24,17 @@ type allowlist struct {
allowedPeerByNetwork map[peer.ID][]*net.IPNet
}
func newAllowList() allowlist {
func newAllowlist() allowlist {
return allowlist{
allowedPeerByNetwork: make(map[peer.ID][]*net.IPNet),
}
}
func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, string, error) {
func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, peer.ID, error) {
var ipString string
var mask string
var allowedPeer string
var allowedPeerStr string
var allowedPeer peer.ID
var isIPV4 bool
multiaddr.ForEach(ma, func(c multiaddr.Component) bool {
@ -44,15 +46,23 @@ func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, string, error) {
mask = c.Value()
}
if c.Protocol().Code == multiaddr.P_P2P {
allowedPeer = c.Value()
allowedPeerStr = c.Value()
}
return ipString == "" || mask == "" || allowedPeer == ""
return ipString == "" || mask == "" || allowedPeerStr == ""
})
if ipString == "" {
return nil, allowedPeer, errors.New("missing ip address")
}
if allowedPeerStr != "" {
var err error
allowedPeer, err = peer.Decode(allowedPeerStr)
if err != nil {
return nil, allowedPeer, fmt.Errorf("failed to decode allowed peer: %w", err)
}
}
if mask == "" {
ip := net.ParseIP(ipString)
if ip == nil {
@ -74,59 +84,48 @@ func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, string, error) {
}
// Add takes a multiaddr and adds it to the allowlist. The multiaddr should be
// an ip address of the peer with or without a `/p2p` protocol.
// e.g. /ip4/1.2.3.4/p2p/QmFoo and /ip4/1.2.3.4 are valid.
// /p2p/QmFoo is not valid.
func (al *allowlist) Add(ma multiaddr.Multiaddr) error {
ipnet, allowedPeerStr, err := toIPNet(ma)
ipnet, allowedPeer, err := toIPNet(ma)
if err != nil {
return err
}
if allowedPeerStr != "" {
if allowedPeer != peer.ID("") {
// We have a peerID constraint
allowedPeer, err := peer.Decode(allowedPeerStr)
if err != nil {
return err
}
if ipnet != nil {
al.allowedPeerByNetwork[allowedPeer] = append(al.allowedPeerByNetwork[allowedPeer], ipnet)
}
al.allowedPeerByNetwork[allowedPeer] = append(al.allowedPeerByNetwork[allowedPeer], ipnet)
} else {
if ipnet != nil {
al.allowedNetworks = append(al.allowedNetworks, ipnet)
}
al.allowedNetworks = append(al.allowedNetworks, ipnet)
}
return nil
}
func (al *allowlist) Remove(ma multiaddr.Multiaddr) error {
ipnet, allowedPeerStr, err := toIPNet(ma)
ipnet, allowedPeer, err := toIPNet(ma)
if err != nil {
return err
}
ipNetList := al.allowedNetworks
var allowedPeer peer.ID
if allowedPeerStr != "" {
if allowedPeer != peer.ID("") {
// We have a peerID constraint
allowedPeer, err = peer.Decode(allowedPeerStr)
if err != nil {
return err
}
ipNetList = al.allowedPeerByNetwork[allowedPeer]
}
if ipnet != nil {
i := len(ipNetList)
for i > 0 {
i--
if ipNetList[i].IP.Equal(ipnet.IP) && bytes.Equal(ipNetList[i].Mask, ipnet.Mask) {
if i == len(ipNetList)-1 {
// Trim this element from the end
ipNetList = ipNetList[:i]
} else {
// swap remove
ipNetList[i] = ipNetList[len(ipNetList)-1]
ipNetList = ipNetList[:len(ipNetList)-1]
}
i := len(ipNetList)
for i > 0 {
i--
if ipNetList[i].IP.Equal(ipnet.IP) && bytes.Equal(ipNetList[i].Mask, ipnet.Mask) {
if i == len(ipNetList)-1 {
// Trim this element from the end
ipNetList = ipNetList[:i]
} else {
// swap remove
ipNetList[i] = ipNetList[len(ipNetList)-1]
ipNetList = ipNetList[:len(ipNetList)-1]
}
}
}
@ -146,9 +145,7 @@ func (al *allowlist) Allowed(ma multiaddr.Multiaddr) bool {
return false
}
_ = ip
for _, network := range al.allowedNetworks {
_ = network
if network.Contains(ip) {
return true
}

View File

@ -12,8 +12,8 @@ import (
)
func TestAllowedSimple(t *testing.T) {
allowlist := newAllowList()
ma, _ := multiaddr.NewMultiaddr("/ip4/1.2.3.4/tcp/1234")
allowlist := newAllowlist()
ma := multiaddr.StringCast("/ip4/1.2.3.4/tcp/1234")
err := allowlist.Add(ma)
if err != nil {
t.Fatalf("failed to add ip4: %s", err)
@ -124,7 +124,7 @@ func TestAllowedWithPeer(t *testing.T) {
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
allowlist := newAllowList()
allowlist := newAllowlist()
for _, maStr := range tc.allowlist {
ma, err := multiaddr.NewMultiaddr(maStr)
if err != nil {
@ -151,7 +151,7 @@ func TestRemoved(t *testing.T) {
allowedMA string
}
peerA := test.RandPeerIDFatal(t)
maA, _ := multiaddr.NewMultiaddr("/ip4/1.2.3.4")
maA := multiaddr.StringCast("/ip4/1.2.3.4")
testCases := []testCase{
{name: "ip4", allowedMA: "/ip4/1.2.3.4"},
@ -162,13 +162,10 @@ func TestRemoved(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
allowlist := newAllowList()
ma, err := multiaddr.NewMultiaddr(tc.allowedMA)
if err != nil {
t.Fatalf("failed to parse ma: %s", err)
}
allowlist := newAllowlist()
ma := multiaddr.StringCast(tc.allowedMA)
err = allowlist.Add(ma)
err := allowlist.Add(ma)
if err != nil {
t.Fatalf("failed to add ip4: %s", err)
}
@ -188,7 +185,7 @@ func TestRemoved(t *testing.T) {
// BenchmarkAllowlistCheck benchmarks the allowlist with plausible conditions.
func BenchmarkAllowlistCheck(b *testing.B) {
allowlist := newAllowList()
allowlist := newAllowlist()
// How often do we expect a peer to be specified? 1 in N
ratioOfSpecifiedPeers := 10
@ -227,9 +224,9 @@ func BenchmarkAllowlistCheck(b *testing.B) {
var ma multiaddr.Multiaddr
if i%ratioOfSpecifiedPeers == 0 {
ma, err = multiaddr.NewMultiaddr(ipString + "/p2p/" + peer.Encode(test.RandPeerIDFatal(b)))
ma = multiaddr.StringCast(ipString + "/p2p/" + peer.Encode(test.RandPeerIDFatal(b)))
} else {
ma, err = multiaddr.NewMultiaddr(ipString)
ma = multiaddr.StringCast(ipString)
}
if err != nil {
b.Fatalf("Failed to generate multiaddr: %v", ipString)

View File

@ -125,7 +125,7 @@ var _ network.StreamManagementScope = (*streamScope)(nil)
type Option func(*resourceManager) error
func NewResourceManager(limits Limiter, opts ...Option) (network.ResourceManager, error) {
allowlist := newAllowList()
allowlist := newAllowlist()
r := &resourceManager{
limits: limits,
allowlist: &allowlist,