1
0
mirror of https://github.com/libp2p/go-libp2p-core.git synced 2025-04-28 17:10:14 +08:00

pass the peer ID to SecureInbound in the SecureTransport and SecureMuxer ()

The peer ID may be empty. This will be the common case. In that case,
connections from any peer are accepted.
This commit is contained in:
Marten Seemann 2021-09-08 11:34:32 +01:00 committed by GitHub
parent 094b0d3f8b
commit 60a3d1748e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 96 deletions

View File

@ -60,7 +60,7 @@ func (t *Transport) LocalPrivateKey() ci.PrivKey {
//
// SecureInbound may fail if the remote peer sends an ID and public key that are inconsistent
// with each other, or if a network error occurs during the ID exchange.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) {
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
conn := &Conn{
Conn: insecure,
local: t.id,
@ -72,6 +72,10 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.S
return nil, err
}
if t.key != nil && p != "" && p != conn.remote {
return nil, fmt.Errorf("remote peer sent unexpected peer ID. expected=%s received=%s", p, conn.remote)
}
return conn, nil
}

View File

@ -1,157 +1,128 @@
package insecure
import (
"bytes"
"context"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"
"io"
"net"
"testing"
ci "github.com/libp2p/go-libp2p-core/crypto"
"github.com/stretchr/testify/require"
"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"
)
// Run a set of sessions through the session setup and verification.
func TestConnections(t *testing.T) {
clientTpt := newTestTransport(t, ci.RSA, 2048)
serverTpt := newTestTransport(t, ci.Ed25519, 1024)
clientTpt := newTestTransport(t, crypto.RSA, 2048)
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)
testConnection(t, clientTpt, serverTpt)
clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "")
require.NoError(t, clientErr)
require.NoError(t, serverErr)
testIDs(t, clientTpt, serverTpt, clientConn, serverConn)
testKeys(t, clientTpt, serverTpt, clientConn, serverConn)
testReadWrite(t, clientConn, serverConn)
}
func TestPeerIdMatchInbound(t *testing.T) {
clientTpt := newTestTransport(t, crypto.RSA, 2048)
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)
clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), clientTpt.LocalPeer())
require.NoError(t, clientErr)
require.NoError(t, serverErr)
testIDs(t, clientTpt, serverTpt, clientConn, serverConn)
testKeys(t, clientTpt, serverTpt, clientConn, serverConn)
testReadWrite(t, clientConn, serverConn)
}
func TestPeerIDMismatchInbound(t *testing.T) {
clientTpt := newTestTransport(t, crypto.RSA, 2048)
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)
_, _, _, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "a-random-peer")
require.Error(t, serverErr)
require.Contains(t, serverErr.Error(), "remote peer sent unexpected peer ID")
}
func TestPeerIDMismatchOutbound(t *testing.T) {
clientTpt := newTestTransport(t, crypto.RSA, 2048)
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)
_, _, clientErr, _ := connect(t, clientTpt, serverTpt, "a random peer", "")
require.Error(t, clientErr)
require.Contains(t, clientErr.Error(), "remote peer sent unexpected peer ID")
}
func newTestTransport(t *testing.T, typ, bits int) *Transport {
priv, pub, err := ci.GenerateKeyPair(typ, bits)
if err != nil {
t.Fatal(err)
}
priv, pub, err := crypto.GenerateKeyPair(typ, bits)
require.NoError(t, err)
id, err := peer.IDFromPublicKey(pub)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
return NewWithIdentity(id, priv)
}
// Create a new pair of connected TCP sockets.
func newConnPair(t *testing.T) (net.Conn, net.Conn) {
lstnr, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
return nil, nil
}
require.NoError(t, err, "failed to listen")
var clientErr error
var client net.Conn
addr := lstnr.Addr()
done := make(chan struct{})
go func() {
defer close(done)
addr := lstnr.Addr()
client, clientErr = net.Dial(addr.Network(), addr.String())
}()
server, err := lstnr.Accept()
require.NoError(t, err, "failed to accept")
<-done
lstnr.Close()
if err != nil {
t.Fatalf("Failed to accept: %v", err)
}
if clientErr != nil {
t.Fatalf("Failed to connect: %v", clientErr)
}
require.NoError(t, clientErr, "failed to connect")
return client, server
}
// Create a new pair of connected sessions based off of the provided
// session generators.
func connect(t *testing.T, clientTpt, serverTpt *Transport) (sec.SecureConn, sec.SecureConn) {
func connect(t *testing.T, clientTpt, serverTpt *Transport, clientExpectsID, serverExpectsID peer.ID) (clientConn sec.SecureConn, serverConn sec.SecureConn, clientErr, serverErr error) {
client, server := newConnPair(t)
// Connect the client and server sessions
done := make(chan struct{})
var clientConn sec.SecureConn
var clientErr error
go func() {
defer close(done)
clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, serverTpt.LocalPeer())
clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, clientExpectsID)
}()
serverConn, serverErr := serverTpt.SecureInbound(context.TODO(), server)
serverConn, serverErr = serverTpt.SecureInbound(context.TODO(), server, serverExpectsID)
<-done
if serverErr != nil {
t.Fatal(serverErr)
}
if clientErr != nil {
t.Fatal(clientErr)
}
return clientConn, serverConn
return
}
// Check the peer IDs
func testIDs(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) {
if clientConn.LocalPeer() != clientTpt.LocalPeer() {
t.Fatal("Client Local Peer ID mismatch.")
}
if clientConn.RemotePeer() != serverTpt.LocalPeer() {
t.Fatal("Client Remote Peer ID mismatch.")
}
if clientConn.LocalPeer() != serverConn.RemotePeer() {
t.Fatal("Server Local Peer ID mismatch.")
}
require.Equal(t, clientConn.LocalPeer(), clientTpt.LocalPeer(), "Client Local Peer ID mismatch.")
require.Equal(t, clientConn.RemotePeer(), serverTpt.LocalPeer(), "Client Remote Peer ID mismatch.")
require.Equal(t, clientConn.LocalPeer(), serverConn.RemotePeer(), "Server Local Peer ID mismatch.")
}
// Check the keys
func testKeys(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) {
sk := serverConn.LocalPrivateKey()
pk := sk.GetPublic()
if !sk.Equals(serverTpt.LocalPrivateKey()) {
t.Error("Private key Mismatch.")
}
if !pk.Equals(clientConn.RemotePublicKey()) {
t.Error("Public key mismatch.")
}
require.True(t, sk.Equals(serverTpt.LocalPrivateKey()), "private key mismatch")
require.True(t, sk.GetPublic().Equals(clientConn.RemotePublicKey()), "public key mismatch")
}
// Check sending and receiving messages
func testReadWrite(t *testing.T, clientConn, serverConn sec.SecureConn) {
before := []byte("hello world")
_, err := clientConn.Write(before)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
after := make([]byte, len(before))
_, err = io.ReadFull(serverConn, after)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(before, after) {
t.Errorf("Message mismatch. %v != %v", before, after)
}
}
// Setup a new session with a pair of locally connected sockets
func testConnection(t *testing.T, clientTpt, serverTpt *Transport) {
clientConn, serverConn := connect(t, clientTpt, serverTpt)
testIDs(t, clientTpt, serverTpt, clientConn, serverConn)
testKeys(t, clientTpt, serverTpt, clientConn, serverConn)
testReadWrite(t, clientConn, serverConn)
clientConn.Close()
serverConn.Close()
require.NoError(t, err)
require.Equal(t, before, after, "message mismatch")
}

View File

@ -19,7 +19,8 @@ type SecureConn interface {
// plain-text, native connections into authenticated, encrypted connections.
type SecureTransport interface {
// SecureInbound secures an inbound connection.
SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, error)
// If p is empty, connections from any peer are accepted.
SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error)
// SecureOutbound secures an outbound connection.
SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error)
@ -29,9 +30,10 @@ type SecureTransport interface {
// and open outbound connections with simultaneous open.
type SecureMuxer interface {
// SecureInbound secures an inbound connection.
// The returned boolean indicates whether the connection should be trated as a server
// The returned boolean indicates whether the connection should be treated as a server
// connection; in the case of SecureInbound it should always be true.
SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, bool, error)
// If p is empty, connections from any peer are accepted.
SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, bool, error)
// SecureOutbound secures an outbound connection.
// The returned boolean indicates whether the connection should be treated as a server