From 84ca9d337a7da7b33d36d1236f6a80215d15e989 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Tue, 14 Jun 2022 11:46:26 -0700 Subject: [PATCH] PR comments --- allowlist.go | 75 +++++++++++++++++++++++------------------------ allowlist_test.go | 23 +++++++-------- rcmgr.go | 2 +- 3 files changed, 47 insertions(+), 53 deletions(-) diff --git a/allowlist.go b/allowlist.go index 325b7fe..0bb73fe 100644 --- a/allowlist.go +++ b/allowlist.go @@ -3,6 +3,7 @@ package rcmgr import ( "bytes" "errors" + "fmt" "net" "github.com/libp2p/go-libp2p-core/peer" @@ -23,16 +24,17 @@ type allowlist struct { allowedPeerByNetwork map[peer.ID][]*net.IPNet } -func newAllowList() allowlist { +func newAllowlist() allowlist { return allowlist{ allowedPeerByNetwork: make(map[peer.ID][]*net.IPNet), } } -func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, string, error) { +func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, peer.ID, error) { var ipString string var mask string - var allowedPeer string + var allowedPeerStr string + var allowedPeer peer.ID var isIPV4 bool multiaddr.ForEach(ma, func(c multiaddr.Component) bool { @@ -44,15 +46,23 @@ func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, string, error) { mask = c.Value() } if c.Protocol().Code == multiaddr.P_P2P { - allowedPeer = c.Value() + allowedPeerStr = c.Value() } - return ipString == "" || mask == "" || allowedPeer == "" + return ipString == "" || mask == "" || allowedPeerStr == "" }) if ipString == "" { return nil, allowedPeer, errors.New("missing ip address") } + if allowedPeerStr != "" { + var err error + allowedPeer, err = peer.Decode(allowedPeerStr) + if err != nil { + return nil, allowedPeer, fmt.Errorf("failed to decode allowed peer: %w", err) + } + } + if mask == "" { ip := net.ParseIP(ipString) if ip == nil { @@ -74,59 +84,48 @@ func toIPNet(ma multiaddr.Multiaddr) (*net.IPNet, string, error) { } +// Add takes a multiaddr and adds it to the allowlist. The multiaddr should be +// an ip address of the peer with or without a `/p2p` protocol. +// e.g. /ip4/1.2.3.4/p2p/QmFoo and /ip4/1.2.3.4 are valid. +// /p2p/QmFoo is not valid. func (al *allowlist) Add(ma multiaddr.Multiaddr) error { - ipnet, allowedPeerStr, err := toIPNet(ma) + ipnet, allowedPeer, err := toIPNet(ma) if err != nil { return err } - if allowedPeerStr != "" { + if allowedPeer != peer.ID("") { // We have a peerID constraint - allowedPeer, err := peer.Decode(allowedPeerStr) - if err != nil { - return err - } - if ipnet != nil { - al.allowedPeerByNetwork[allowedPeer] = append(al.allowedPeerByNetwork[allowedPeer], ipnet) - } + al.allowedPeerByNetwork[allowedPeer] = append(al.allowedPeerByNetwork[allowedPeer], ipnet) } else { - if ipnet != nil { - al.allowedNetworks = append(al.allowedNetworks, ipnet) - } + al.allowedNetworks = append(al.allowedNetworks, ipnet) } return nil } func (al *allowlist) Remove(ma multiaddr.Multiaddr) error { - ipnet, allowedPeerStr, err := toIPNet(ma) + ipnet, allowedPeer, err := toIPNet(ma) if err != nil { return err } ipNetList := al.allowedNetworks - var allowedPeer peer.ID - if allowedPeerStr != "" { + if allowedPeer != peer.ID("") { // We have a peerID constraint - allowedPeer, err = peer.Decode(allowedPeerStr) - if err != nil { - return err - } ipNetList = al.allowedPeerByNetwork[allowedPeer] } - if ipnet != nil { - i := len(ipNetList) - for i > 0 { - i-- - if ipNetList[i].IP.Equal(ipnet.IP) && bytes.Equal(ipNetList[i].Mask, ipnet.Mask) { - if i == len(ipNetList)-1 { - // Trim this element from the end - ipNetList = ipNetList[:i] - } else { - // swap remove - ipNetList[i] = ipNetList[len(ipNetList)-1] - ipNetList = ipNetList[:len(ipNetList)-1] - } + i := len(ipNetList) + for i > 0 { + i-- + if ipNetList[i].IP.Equal(ipnet.IP) && bytes.Equal(ipNetList[i].Mask, ipnet.Mask) { + if i == len(ipNetList)-1 { + // Trim this element from the end + ipNetList = ipNetList[:i] + } else { + // swap remove + ipNetList[i] = ipNetList[len(ipNetList)-1] + ipNetList = ipNetList[:len(ipNetList)-1] } } } @@ -146,9 +145,7 @@ func (al *allowlist) Allowed(ma multiaddr.Multiaddr) bool { return false } - _ = ip for _, network := range al.allowedNetworks { - _ = network if network.Contains(ip) { return true } diff --git a/allowlist_test.go b/allowlist_test.go index 79f5744..3a8e9f9 100644 --- a/allowlist_test.go +++ b/allowlist_test.go @@ -12,8 +12,8 @@ import ( ) func TestAllowedSimple(t *testing.T) { - allowlist := newAllowList() - ma, _ := multiaddr.NewMultiaddr("/ip4/1.2.3.4/tcp/1234") + allowlist := newAllowlist() + ma := multiaddr.StringCast("/ip4/1.2.3.4/tcp/1234") err := allowlist.Add(ma) if err != nil { t.Fatalf("failed to add ip4: %s", err) @@ -124,7 +124,7 @@ func TestAllowedWithPeer(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - allowlist := newAllowList() + allowlist := newAllowlist() for _, maStr := range tc.allowlist { ma, err := multiaddr.NewMultiaddr(maStr) if err != nil { @@ -151,7 +151,7 @@ func TestRemoved(t *testing.T) { allowedMA string } peerA := test.RandPeerIDFatal(t) - maA, _ := multiaddr.NewMultiaddr("/ip4/1.2.3.4") + maA := multiaddr.StringCast("/ip4/1.2.3.4") testCases := []testCase{ {name: "ip4", allowedMA: "/ip4/1.2.3.4"}, @@ -162,13 +162,10 @@ func TestRemoved(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - allowlist := newAllowList() - ma, err := multiaddr.NewMultiaddr(tc.allowedMA) - if err != nil { - t.Fatalf("failed to parse ma: %s", err) - } + allowlist := newAllowlist() + ma := multiaddr.StringCast(tc.allowedMA) - err = allowlist.Add(ma) + err := allowlist.Add(ma) if err != nil { t.Fatalf("failed to add ip4: %s", err) } @@ -188,7 +185,7 @@ func TestRemoved(t *testing.T) { // BenchmarkAllowlistCheck benchmarks the allowlist with plausible conditions. func BenchmarkAllowlistCheck(b *testing.B) { - allowlist := newAllowList() + allowlist := newAllowlist() // How often do we expect a peer to be specified? 1 in N ratioOfSpecifiedPeers := 10 @@ -227,9 +224,9 @@ func BenchmarkAllowlistCheck(b *testing.B) { var ma multiaddr.Multiaddr if i%ratioOfSpecifiedPeers == 0 { - ma, err = multiaddr.NewMultiaddr(ipString + "/p2p/" + peer.Encode(test.RandPeerIDFatal(b))) + ma = multiaddr.StringCast(ipString + "/p2p/" + peer.Encode(test.RandPeerIDFatal(b))) } else { - ma, err = multiaddr.NewMultiaddr(ipString) + ma = multiaddr.StringCast(ipString) } if err != nil { b.Fatalf("Failed to generate multiaddr: %v", ipString) diff --git a/rcmgr.go b/rcmgr.go index 2315d8b..8a91700 100644 --- a/rcmgr.go +++ b/rcmgr.go @@ -125,7 +125,7 @@ var _ network.StreamManagementScope = (*streamScope)(nil) type Option func(*resourceManager) error func NewResourceManager(limits Limiter, opts ...Option) (network.ResourceManager, error) { - allowlist := newAllowList() + allowlist := newAllowlist() r := &resourceManager{ limits: limits, allowlist: &allowlist,