mirror of
https://github.com/libp2p/go-libp2p-core.git
synced 2025-02-10 06:40:09 +08:00
pass the peer ID to SecureInbound in the SecureTransport and SecureMuxer (#211)
The peer ID may be empty. This will be the common case. In that case, connections from any peer are accepted.
This commit is contained in:
parent
094b0d3f8b
commit
60a3d1748e
@ -60,7 +60,7 @@ func (t *Transport) LocalPrivateKey() ci.PrivKey {
|
|||||||
//
|
//
|
||||||
// SecureInbound may fail if the remote peer sends an ID and public key that are inconsistent
|
// SecureInbound may fail if the remote peer sends an ID and public key that are inconsistent
|
||||||
// with each other, or if a network error occurs during the ID exchange.
|
// with each other, or if a network error occurs during the ID exchange.
|
||||||
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) {
|
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
|
||||||
conn := &Conn{
|
conn := &Conn{
|
||||||
Conn: insecure,
|
Conn: insecure,
|
||||||
local: t.id,
|
local: t.id,
|
||||||
@ -72,6 +72,10 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.S
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if t.key != nil && p != "" && p != conn.remote {
|
||||||
|
return nil, fmt.Errorf("remote peer sent unexpected peer ID. expected=%s received=%s", p, conn.remote)
|
||||||
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,157 +1,128 @@
|
|||||||
package insecure
|
package insecure
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"github.com/libp2p/go-libp2p-core/peer"
|
|
||||||
"github.com/libp2p/go-libp2p-core/sec"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
ci "github.com/libp2p/go-libp2p-core/crypto"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/libp2p/go-libp2p-core/crypto"
|
||||||
|
"github.com/libp2p/go-libp2p-core/peer"
|
||||||
|
"github.com/libp2p/go-libp2p-core/sec"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Run a set of sessions through the session setup and verification.
|
// Run a set of sessions through the session setup and verification.
|
||||||
func TestConnections(t *testing.T) {
|
func TestConnections(t *testing.T) {
|
||||||
clientTpt := newTestTransport(t, ci.RSA, 2048)
|
clientTpt := newTestTransport(t, crypto.RSA, 2048)
|
||||||
serverTpt := newTestTransport(t, ci.Ed25519, 1024)
|
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)
|
||||||
|
|
||||||
testConnection(t, clientTpt, serverTpt)
|
clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "")
|
||||||
|
require.NoError(t, clientErr)
|
||||||
|
require.NoError(t, serverErr)
|
||||||
|
testIDs(t, clientTpt, serverTpt, clientConn, serverConn)
|
||||||
|
testKeys(t, clientTpt, serverTpt, clientConn, serverConn)
|
||||||
|
testReadWrite(t, clientConn, serverConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerIdMatchInbound(t *testing.T) {
|
||||||
|
clientTpt := newTestTransport(t, crypto.RSA, 2048)
|
||||||
|
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)
|
||||||
|
|
||||||
|
clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), clientTpt.LocalPeer())
|
||||||
|
require.NoError(t, clientErr)
|
||||||
|
require.NoError(t, serverErr)
|
||||||
|
testIDs(t, clientTpt, serverTpt, clientConn, serverConn)
|
||||||
|
testKeys(t, clientTpt, serverTpt, clientConn, serverConn)
|
||||||
|
testReadWrite(t, clientConn, serverConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerIDMismatchInbound(t *testing.T) {
|
||||||
|
clientTpt := newTestTransport(t, crypto.RSA, 2048)
|
||||||
|
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)
|
||||||
|
|
||||||
|
_, _, _, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "a-random-peer")
|
||||||
|
require.Error(t, serverErr)
|
||||||
|
require.Contains(t, serverErr.Error(), "remote peer sent unexpected peer ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerIDMismatchOutbound(t *testing.T) {
|
||||||
|
clientTpt := newTestTransport(t, crypto.RSA, 2048)
|
||||||
|
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)
|
||||||
|
|
||||||
|
_, _, clientErr, _ := connect(t, clientTpt, serverTpt, "a random peer", "")
|
||||||
|
require.Error(t, clientErr)
|
||||||
|
require.Contains(t, clientErr.Error(), "remote peer sent unexpected peer ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestTransport(t *testing.T, typ, bits int) *Transport {
|
func newTestTransport(t *testing.T, typ, bits int) *Transport {
|
||||||
priv, pub, err := ci.GenerateKeyPair(typ, bits)
|
priv, pub, err := crypto.GenerateKeyPair(typ, bits)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
id, err := peer.IDFromPublicKey(pub)
|
id, err := peer.IDFromPublicKey(pub)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewWithIdentity(id, priv)
|
return NewWithIdentity(id, priv)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new pair of connected TCP sockets.
|
// Create a new pair of connected TCP sockets.
|
||||||
func newConnPair(t *testing.T) (net.Conn, net.Conn) {
|
func newConnPair(t *testing.T) (net.Conn, net.Conn) {
|
||||||
lstnr, err := net.Listen("tcp", "localhost:0")
|
lstnr, err := net.Listen("tcp", "localhost:0")
|
||||||
if err != nil {
|
require.NoError(t, err, "failed to listen")
|
||||||
t.Fatalf("Failed to listen: %v", err)
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var clientErr error
|
var clientErr error
|
||||||
var client net.Conn
|
var client net.Conn
|
||||||
addr := lstnr.Addr()
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(done)
|
defer close(done)
|
||||||
|
addr := lstnr.Addr()
|
||||||
client, clientErr = net.Dial(addr.Network(), addr.String())
|
client, clientErr = net.Dial(addr.Network(), addr.String())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
server, err := lstnr.Accept()
|
server, err := lstnr.Accept()
|
||||||
|
require.NoError(t, err, "failed to accept")
|
||||||
|
|
||||||
<-done
|
<-done
|
||||||
|
|
||||||
lstnr.Close()
|
lstnr.Close()
|
||||||
|
require.NoError(t, clientErr, "failed to connect")
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to accept: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if clientErr != nil {
|
|
||||||
t.Fatalf("Failed to connect: %v", clientErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return client, server
|
return client, server
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new pair of connected sessions based off of the provided
|
func connect(t *testing.T, clientTpt, serverTpt *Transport, clientExpectsID, serverExpectsID peer.ID) (clientConn sec.SecureConn, serverConn sec.SecureConn, clientErr, serverErr error) {
|
||||||
// session generators.
|
|
||||||
func connect(t *testing.T, clientTpt, serverTpt *Transport) (sec.SecureConn, sec.SecureConn) {
|
|
||||||
client, server := newConnPair(t)
|
client, server := newConnPair(t)
|
||||||
|
|
||||||
// Connect the client and server sessions
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
||||||
var clientConn sec.SecureConn
|
|
||||||
var clientErr error
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(done)
|
defer close(done)
|
||||||
clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, serverTpt.LocalPeer())
|
clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, clientExpectsID)
|
||||||
}()
|
}()
|
||||||
|
serverConn, serverErr = serverTpt.SecureInbound(context.TODO(), server, serverExpectsID)
|
||||||
serverConn, serverErr := serverTpt.SecureInbound(context.TODO(), server)
|
|
||||||
<-done
|
<-done
|
||||||
|
return
|
||||||
if serverErr != nil {
|
|
||||||
t.Fatal(serverErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if clientErr != nil {
|
|
||||||
t.Fatal(clientErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return clientConn, serverConn
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the peer IDs
|
// Check the peer IDs
|
||||||
func testIDs(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) {
|
func testIDs(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) {
|
||||||
if clientConn.LocalPeer() != clientTpt.LocalPeer() {
|
require.Equal(t, clientConn.LocalPeer(), clientTpt.LocalPeer(), "Client Local Peer ID mismatch.")
|
||||||
t.Fatal("Client Local Peer ID mismatch.")
|
require.Equal(t, clientConn.RemotePeer(), serverTpt.LocalPeer(), "Client Remote Peer ID mismatch.")
|
||||||
}
|
require.Equal(t, clientConn.LocalPeer(), serverConn.RemotePeer(), "Server Local Peer ID mismatch.")
|
||||||
|
|
||||||
if clientConn.RemotePeer() != serverTpt.LocalPeer() {
|
|
||||||
t.Fatal("Client Remote Peer ID mismatch.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if clientConn.LocalPeer() != serverConn.RemotePeer() {
|
|
||||||
t.Fatal("Server Local Peer ID mismatch.")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the keys
|
// Check the keys
|
||||||
func testKeys(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) {
|
func testKeys(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) {
|
||||||
sk := serverConn.LocalPrivateKey()
|
sk := serverConn.LocalPrivateKey()
|
||||||
pk := sk.GetPublic()
|
require.True(t, sk.Equals(serverTpt.LocalPrivateKey()), "private key mismatch")
|
||||||
|
require.True(t, sk.GetPublic().Equals(clientConn.RemotePublicKey()), "public key mismatch")
|
||||||
if !sk.Equals(serverTpt.LocalPrivateKey()) {
|
|
||||||
t.Error("Private key Mismatch.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !pk.Equals(clientConn.RemotePublicKey()) {
|
|
||||||
t.Error("Public key mismatch.")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check sending and receiving messages
|
// Check sending and receiving messages
|
||||||
func testReadWrite(t *testing.T, clientConn, serverConn sec.SecureConn) {
|
func testReadWrite(t *testing.T, clientConn, serverConn sec.SecureConn) {
|
||||||
before := []byte("hello world")
|
before := []byte("hello world")
|
||||||
_, err := clientConn.Write(before)
|
_, err := clientConn.Write(before)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
after := make([]byte, len(before))
|
after := make([]byte, len(before))
|
||||||
_, err = io.ReadFull(serverConn, after)
|
_, err = io.ReadFull(serverConn, after)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
require.Equal(t, before, after, "message mismatch")
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(before, after) {
|
|
||||||
t.Errorf("Message mismatch. %v != %v", before, after)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup a new session with a pair of locally connected sockets
|
|
||||||
func testConnection(t *testing.T, clientTpt, serverTpt *Transport) {
|
|
||||||
clientConn, serverConn := connect(t, clientTpt, serverTpt)
|
|
||||||
|
|
||||||
testIDs(t, clientTpt, serverTpt, clientConn, serverConn)
|
|
||||||
testKeys(t, clientTpt, serverTpt, clientConn, serverConn)
|
|
||||||
testReadWrite(t, clientConn, serverConn)
|
|
||||||
|
|
||||||
clientConn.Close()
|
|
||||||
serverConn.Close()
|
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,8 @@ type SecureConn interface {
|
|||||||
// plain-text, native connections into authenticated, encrypted connections.
|
// plain-text, native connections into authenticated, encrypted connections.
|
||||||
type SecureTransport interface {
|
type SecureTransport interface {
|
||||||
// SecureInbound secures an inbound connection.
|
// SecureInbound secures an inbound connection.
|
||||||
SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, error)
|
// If p is empty, connections from any peer are accepted.
|
||||||
|
SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error)
|
||||||
|
|
||||||
// SecureOutbound secures an outbound connection.
|
// SecureOutbound secures an outbound connection.
|
||||||
SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error)
|
SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error)
|
||||||
@ -29,9 +30,10 @@ type SecureTransport interface {
|
|||||||
// and open outbound connections with simultaneous open.
|
// and open outbound connections with simultaneous open.
|
||||||
type SecureMuxer interface {
|
type SecureMuxer interface {
|
||||||
// SecureInbound secures an inbound connection.
|
// SecureInbound secures an inbound connection.
|
||||||
// The returned boolean indicates whether the connection should be trated as a server
|
// The returned boolean indicates whether the connection should be treated as a server
|
||||||
// connection; in the case of SecureInbound it should always be true.
|
// connection; in the case of SecureInbound it should always be true.
|
||||||
SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, bool, error)
|
// If p is empty, connections from any peer are accepted.
|
||||||
|
SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, bool, error)
|
||||||
|
|
||||||
// SecureOutbound secures an outbound connection.
|
// SecureOutbound secures an outbound connection.
|
||||||
// The returned boolean indicates whether the connection should be treated as a server
|
// The returned boolean indicates whether the connection should be treated as a server
|
||||||
|
Loading…
Reference in New Issue
Block a user