mirror of
https://github.com/libp2p/go-libp2p-core.git
synced 2025-03-13 11:00:10 +08:00
use a fallback basicEquals function everywhere
This also ensures we check that the types are equal, even if we're comparing directly with `k1.Equals(k2)` instead of `KeyEquals(k1, k2)`.
This commit is contained in:
parent
2df9672ee4
commit
9a4415d1a6
@ -119,7 +119,7 @@ func (ePriv *ECDSAPrivateKey) Raw() ([]byte, error) {
|
||||
func (ePriv *ECDSAPrivateKey) Equals(o Key) bool {
|
||||
oPriv, ok := o.(*ECDSAPrivateKey)
|
||||
if !ok {
|
||||
return false
|
||||
return basicEquals(ePriv, o)
|
||||
}
|
||||
|
||||
return ePriv.priv.D.Cmp(oPriv.priv.D) == 0
|
||||
@ -163,7 +163,7 @@ func (ePub ECDSAPublicKey) Raw() ([]byte, error) {
|
||||
func (ePub *ECDSAPublicKey) Equals(o Key) bool {
|
||||
oPub, ok := o.(*ECDSAPublicKey)
|
||||
if !ok {
|
||||
return false
|
||||
return basicEquals(ePub, o)
|
||||
}
|
||||
|
||||
return ePub.pub.X != nil && ePub.pub.Y != nil && oPub.pub.X != nil && oPub.pub.Y != nil &&
|
||||
|
@ -4,6 +4,7 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/elliptic"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
@ -363,9 +364,21 @@ func KeyEqual(k1, k2 Key) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
return k1.Equals(k2)
|
||||
}
|
||||
|
||||
func basicEquals(k1, k2 Key) bool {
|
||||
if k1.Type() != k2.Type() {
|
||||
return false
|
||||
}
|
||||
|
||||
return k1.Equals(k2)
|
||||
a, err := k1.Raw()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
b, err := k2.Raw()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return bytes.Equal(a, b)
|
||||
}
|
||||
|
@ -3,8 +3,6 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
pb "github.com/libp2p/go-libp2p-core/crypto/pb"
|
||||
|
||||
openssl "github.com/libp2p/go-openssl"
|
||||
@ -65,16 +63,7 @@ func (pk *opensslPublicKey) Raw() ([]byte, error) {
|
||||
func (pk *opensslPublicKey) Equals(k Key) bool {
|
||||
k0, ok := k.(*RsaPublicKey)
|
||||
if !ok {
|
||||
a, err := pk.Raw()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
b, err := k.Raw()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(a, b)
|
||||
return basicEquals(pk, k)
|
||||
}
|
||||
|
||||
return pk.key.Equal(k0.opensslPublicKey.key)
|
||||
@ -112,16 +101,7 @@ func (sk *opensslPrivateKey) Raw() ([]byte, error) {
|
||||
func (sk *opensslPrivateKey) Equals(k Key) bool {
|
||||
k0, ok := k.(*RsaPrivateKey)
|
||||
if !ok {
|
||||
a, err := sk.Raw()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
b, err := k.Raw()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(a, b)
|
||||
return basicEquals(sk, k)
|
||||
}
|
||||
|
||||
return sk.key.Equal(k0.opensslPrivateKey.key)
|
||||
|
@ -3,7 +3,6 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@ -67,15 +66,7 @@ func (pk *RsaPublicKey) Equals(k Key) bool {
|
||||
// make sure this is an rsa public key
|
||||
other, ok := (k).(*RsaPublicKey)
|
||||
if !ok {
|
||||
a, err := pk.Raw()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
b, err := k.Raw()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return bytes.Equal(a, b)
|
||||
return basicEquals(pk, k)
|
||||
}
|
||||
|
||||
return pk.k.N.Cmp(other.k.N) == 0 && pk.k.E == other.k.E
|
||||
@ -111,15 +102,7 @@ func (sk *RsaPrivateKey) Equals(k Key) bool {
|
||||
// make sure this is an rsa public key
|
||||
other, ok := (k).(*RsaPrivateKey)
|
||||
if !ok {
|
||||
a, err := sk.Raw()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
b, err := k.Raw()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return bytes.Equal(a, b)
|
||||
return basicEquals(sk, k)
|
||||
}
|
||||
|
||||
a := sk.sk
|
||||
|
@ -66,7 +66,7 @@ func (k *Secp256k1PrivateKey) Raw() ([]byte, error) {
|
||||
func (k *Secp256k1PrivateKey) Equals(o Key) bool {
|
||||
sk, ok := o.(*Secp256k1PrivateKey)
|
||||
if !ok {
|
||||
return false
|
||||
return basicEquals(k, o)
|
||||
}
|
||||
|
||||
return k.D.Cmp(sk.D) == 0
|
||||
@ -107,7 +107,7 @@ func (k *Secp256k1PublicKey) Raw() ([]byte, error) {
|
||||
func (k *Secp256k1PublicKey) Equals(o Key) bool {
|
||||
sk, ok := o.(*Secp256k1PublicKey)
|
||||
if !ok {
|
||||
return false
|
||||
return basicEquals(k, o)
|
||||
}
|
||||
|
||||
return (*btcec.PublicKey)(k).IsEqual((*btcec.PublicKey)(sk))
|
||||
|
Loading…
Reference in New Issue
Block a user