PR comments

This commit is contained in:
Marco Munizaga 2022-06-14 11:46:26 -07:00
parent 297cd00321
commit 84ca9d337a
3 changed files with 47 additions and 53 deletions

View File

@ -3,6 +3,7 @@ package rcmgr
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"net" "net"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
@ -23,16 +24,17 @@ type allowlist struct {
allowedPeerByNetwork map[peer.ID][]*net.IPNet allowedPeerByNetwork map[peer.ID][]*net.IPNet
} }
func newAllowList() allowlist { func newAllowlist() allowlist {
return allowlist{ return allowlist{
allowedPeerByNetwork: make(map[peer.ID][]*net.IPNet), allowedPeerByNetwork: make(map[peer.ID][]*net.IPNet),
} }
} }
func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, string, error) { func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, peer.ID, error) {
var ipString string var ipString string
var mask string var mask string
var allowedPeer string var allowedPeerStr string
var allowedPeer peer.ID
var isIPV4 bool var isIPV4 bool
multiaddr.ForEach(ma, func(c multiaddr.Component) bool { multiaddr.ForEach(ma, func(c multiaddr.Component) bool {
@ -44,15 +46,23 @@ func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, string, error) {
mask = c.Value() mask = c.Value()
} }
if c.Protocol().Code == multiaddr.P_P2P { if c.Protocol().Code == multiaddr.P_P2P {
allowedPeer = c.Value() allowedPeerStr = c.Value()
} }
return ipString == "" || mask == "" || allowedPeer == "" return ipString == "" || mask == "" || allowedPeerStr == ""
}) })
if ipString == "" { if ipString == "" {
return nil, allowedPeer, errors.New("missing ip address") return nil, allowedPeer, errors.New("missing ip address")
} }
if allowedPeerStr != "" {
var err error
allowedPeer, err = peer.Decode(allowedPeerStr)
if err != nil {
return nil, allowedPeer, fmt.Errorf("failed to decode allowed peer: %w", err)
}
}
if mask == "" { if mask == "" {
ip := net.ParseIP(ipString) ip := net.ParseIP(ipString)
if ip == nil { if ip == nil {
@ -74,59 +84,48 @@ func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, string, error) {
} }
// Add takes a multiaddr and adds it to the allowlist. The multiaddr should be
// an ip address of the peer with or without a `/p2p` protocol.
// e.g. /ip4/1.2.3.4/p2p/QmFoo and /ip4/1.2.3.4 are valid.
// /p2p/QmFoo is not valid.
func (al *allowlist) Add(ma multiaddr.Multiaddr) error { func (al *allowlist) Add(ma multiaddr.Multiaddr) error {
ipnet, allowedPeerStr, err := toIPNet(ma) ipnet, allowedPeer, err := toIPNet(ma)
if err != nil { if err != nil {
return err return err
} }
if allowedPeerStr != "" { if allowedPeer != peer.ID("") {
// We have a peerID constraint // We have a peerID constraint
allowedPeer, err := peer.Decode(allowedPeerStr) al.allowedPeerByNetwork[allowedPeer] = append(al.allowedPeerByNetwork[allowedPeer], ipnet)
if err != nil {
return err
}
if ipnet != nil {
al.allowedPeerByNetwork[allowedPeer] = append(al.allowedPeerByNetwork[allowedPeer], ipnet)
}
} else { } else {
if ipnet != nil { al.allowedNetworks = append(al.allowedNetworks, ipnet)
al.allowedNetworks = append(al.allowedNetworks, ipnet)
}
} }
return nil return nil
} }
func (al *allowlist) Remove(ma multiaddr.Multiaddr) error { func (al *allowlist) Remove(ma multiaddr.Multiaddr) error {
ipnet, allowedPeerStr, err := toIPNet(ma) ipnet, allowedPeer, err := toIPNet(ma)
if err != nil { if err != nil {
return err return err
} }
ipNetList := al.allowedNetworks ipNetList := al.allowedNetworks
var allowedPeer peer.ID if allowedPeer != peer.ID("") {
if allowedPeerStr != "" {
// We have a peerID constraint // We have a peerID constraint
allowedPeer, err = peer.Decode(allowedPeerStr)
if err != nil {
return err
}
ipNetList = al.allowedPeerByNetwork[allowedPeer] ipNetList = al.allowedPeerByNetwork[allowedPeer]
} }
if ipnet != nil { i := len(ipNetList)
i := len(ipNetList) for i > 0 {
for i > 0 { i--
i-- if ipNetList[i].IP.Equal(ipnet.IP) && bytes.Equal(ipNetList[i].Mask, ipnet.Mask) {
if ipNetList[i].IP.Equal(ipnet.IP) && bytes.Equal(ipNetList[i].Mask, ipnet.Mask) { if i == len(ipNetList)-1 {
if i == len(ipNetList)-1 { // Trim this element from the end
// Trim this element from the end ipNetList = ipNetList[:i]
ipNetList = ipNetList[:i] } else {
} else { // swap remove
// swap remove ipNetList[i] = ipNetList[len(ipNetList)-1]
ipNetList[i] = ipNetList[len(ipNetList)-1] ipNetList = ipNetList[:len(ipNetList)-1]
ipNetList = ipNetList[:len(ipNetList)-1]
}
} }
} }
} }
@ -146,9 +145,7 @@ func (al *allowlist) Allowed(ma multiaddr.Multiaddr) bool {
return false return false
} }
_ = ip
for _, network := range al.allowedNetworks { for _, network := range al.allowedNetworks {
_ = network
if network.Contains(ip) { if network.Contains(ip) {
return true return true
} }

View File

@ -12,8 +12,8 @@ import (
) )
func TestAllowedSimple(t *testing.T) { func TestAllowedSimple(t *testing.T) {
allowlist := newAllowList() allowlist := newAllowlist()
ma, _ := multiaddr.NewMultiaddr("/ip4/1.2.3.4/tcp/1234") ma := multiaddr.StringCast("/ip4/1.2.3.4/tcp/1234")
err := allowlist.Add(ma) err := allowlist.Add(ma)
if err != nil { if err != nil {
t.Fatalf("failed to add ip4: %s", err) t.Fatalf("failed to add ip4: %s", err)
@ -124,7 +124,7 @@ func TestAllowedWithPeer(t *testing.T) {
for _, tc := range testcases { for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
allowlist := newAllowList() allowlist := newAllowlist()
for _, maStr := range tc.allowlist { for _, maStr := range tc.allowlist {
ma, err := multiaddr.NewMultiaddr(maStr) ma, err := multiaddr.NewMultiaddr(maStr)
if err != nil { if err != nil {
@ -151,7 +151,7 @@ func TestRemoved(t *testing.T) {
allowedMA string allowedMA string
} }
peerA := test.RandPeerIDFatal(t) peerA := test.RandPeerIDFatal(t)
maA, _ := multiaddr.NewMultiaddr("/ip4/1.2.3.4") maA := multiaddr.StringCast("/ip4/1.2.3.4")
testCases := []testCase{ testCases := []testCase{
{name: "ip4", allowedMA: "/ip4/1.2.3.4"}, {name: "ip4", allowedMA: "/ip4/1.2.3.4"},
@ -162,13 +162,10 @@ func TestRemoved(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
allowlist := newAllowList() allowlist := newAllowlist()
ma, err := multiaddr.NewMultiaddr(tc.allowedMA) ma := multiaddr.StringCast(tc.allowedMA)
if err != nil {
t.Fatalf("failed to parse ma: %s", err)
}
err = allowlist.Add(ma) err := allowlist.Add(ma)
if err != nil { if err != nil {
t.Fatalf("failed to add ip4: %s", err) t.Fatalf("failed to add ip4: %s", err)
} }
@ -188,7 +185,7 @@ func TestRemoved(t *testing.T) {
// BenchmarkAllowlistCheck benchmarks the allowlist with plausible conditions. // BenchmarkAllowlistCheck benchmarks the allowlist with plausible conditions.
func BenchmarkAllowlistCheck(b *testing.B) { func BenchmarkAllowlistCheck(b *testing.B) {
allowlist := newAllowList() allowlist := newAllowlist()
// How often do we expect a peer to be specified? 1 in N // How often do we expect a peer to be specified? 1 in N
ratioOfSpecifiedPeers := 10 ratioOfSpecifiedPeers := 10
@ -227,9 +224,9 @@ func BenchmarkAllowlistCheck(b *testing.B) {
var ma multiaddr.Multiaddr var ma multiaddr.Multiaddr
if i%ratioOfSpecifiedPeers == 0 { if i%ratioOfSpecifiedPeers == 0 {
ma, err = multiaddr.NewMultiaddr(ipString + "/p2p/" + peer.Encode(test.RandPeerIDFatal(b))) ma = multiaddr.StringCast(ipString + "/p2p/" + peer.Encode(test.RandPeerIDFatal(b)))
} else { } else {
ma, err = multiaddr.NewMultiaddr(ipString) ma = multiaddr.StringCast(ipString)
} }
if err != nil { if err != nil {
b.Fatalf("Failed to generate multiaddr: %v", ipString) b.Fatalf("Failed to generate multiaddr: %v", ipString)

View File

@ -125,7 +125,7 @@ var _ network.StreamManagementScope = (*streamScope)(nil)
type Option func(*resourceManager) error type Option func(*resourceManager) error
func NewResourceManager(limits Limiter, opts ...Option) (network.ResourceManager, error) { func NewResourceManager(limits Limiter, opts ...Option) (network.ResourceManager, error) {
allowlist := newAllowList() allowlist := newAllowlist()
r := &resourceManager{ r := &resourceManager{
limits: limits, limits: limits,
allowlist: &allowlist, allowlist: &allowlist,