diff --git a/pem.go b/pem.go index 7f51a17..893d7ea 100644 --- a/pem.go +++ b/pem.go @@ -7,7 +7,6 @@ package openssl import "C" import ( - "encoding/pem" "errors" "io/ioutil" "runtime" @@ -51,7 +50,12 @@ func (key *pKey) MarshalPKCS1PrivateKeyPEM() (pem_block []byte, return nil, errors.New("failed to allocate memory BIO") } defer C.BIO_free(bio) - if int(C.PEM_write_bio_PrivateKey(bio, key.key, nil, nil, C.int(0), nil, + rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) + if rsa == nil { + return nil, errors.New("failed getting rsa key") + } + defer C.RSA_free(rsa) + if int(C.PEM_write_bio_RSAPrivateKey(bio, rsa, nil, nil, C.int(0), nil, nil)) != 1 { return nil, errors.New("failed dumping private key") } @@ -60,19 +64,20 @@ func (key *pKey) MarshalPKCS1PrivateKeyPEM() (pem_block []byte, func (key *pKey) MarshalPKCS1PrivateKeyDER() (der_block []byte, err error) { - // TODO: i can't decipher how to get a generic PKCS1 Private Key in DER - // format out of the openssl docs, so until someone who knows better - // can chastise me for this, we'll do it this way. - pem_block, err := key.MarshalPKCS1PrivateKeyPEM() - if err != nil { - return nil, err + bio := C.BIO_new(C.BIO_s_mem()) + if bio == nil { + return nil, errors.New("failed to allocate memory BIO") } - var p *pem.Block - p, pem_block = pem.Decode(pem_block) - if len(pem_block) > 0 || p == nil { - return nil, errors.New("something went wrong with PEM generation") + defer C.BIO_free(bio) + rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) + if rsa == nil { + return nil, errors.New("failed getting rsa key") } - return p.Bytes, nil + defer C.RSA_free(rsa) + if int(C.i2d_RSAPrivateKey_bio(bio, rsa)) != 1 { + return nil, errors.New("failed dumping private key der") + } + return ioutil.ReadAll(asAnyBio(bio)) } func (key *pKey) MarshalPKIXPublicKeyPEM() (pem_block []byte, @@ -82,40 +87,60 @@ func (key *pKey) MarshalPKIXPublicKeyPEM() (pem_block []byte, return nil, errors.New("failed to allocate memory BIO") } defer C.BIO_free(bio) - if int(C.PEM_write_bio_PUBKEY(bio, key.key)) != 1 { - return nil, errors.New("failed dumping public key") + rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) + if rsa == nil { + return nil, errors.New("failed getting rsa key") + } + defer C.RSA_free(rsa) + if int(C.PEM_write_bio_RSA_PUBKEY(bio, rsa)) != 1 { + return nil, errors.New("failed dumping public key pem") } return ioutil.ReadAll(asAnyBio(bio)) } func (key *pKey) MarshalPKIXPublicKeyDER() (der_block []byte, err error) { - // TODO: i can't decipher how to get a generic PKIX Public Key in DER - // format out of the openssl docs, so until someone who knows better - // can chastise me for this, we'll do it this way. - pem_block, err := key.MarshalPKIXPublicKeyPEM() - if err != nil { - return nil, err + bio := C.BIO_new(C.BIO_s_mem()) + if bio == nil { + return nil, errors.New("failed to allocate memory BIO") } - var p *pem.Block - p, pem_block = pem.Decode(pem_block) - if len(pem_block) > 0 || p == nil { - return nil, errors.New("something went wrong with PEM generation") + defer C.BIO_free(bio) + rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) + if rsa == nil { + return nil, errors.New("failed getting rsa key") } - return p.Bytes, nil + defer C.RSA_free(rsa) + if int(C.i2d_RSA_PUBKEY_bio(bio, rsa)) != 1 { + return nil, errors.New("failed dumping public key der") + } + return ioutil.ReadAll(asAnyBio(bio)) } // LoadPrivateKey loads a private key from a PEM-encoded block. func LoadPrivateKey(pem_block []byte) (PrivateKey, error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]), C.int(len(pem_block))) - key := C.PEM_read_bio_PrivateKey(bio, nil, nil, nil) - C.BIO_free(bio) - if key == nil { - return nil, errorFromErrorQueue() + if bio == nil { + return nil, errors.New("failed creating bio") } + defer C.BIO_free(bio) + + rsakey := C.PEM_read_bio_RSAPrivateKey(bio, nil, nil, nil) + if rsakey == nil { + return nil, errors.New("failed reading rsa key") + } + defer C.RSA_free(rsakey) + + // convert to PKEY + key := C.EVP_PKEY_new() + if key == nil { + return nil, errors.New("failed converting to evp_pkey") + } + if C.EVP_PKEY_set1_RSA(key, (*C.struct_rsa_st)(rsakey)) != 1 { + C.EVP_PKEY_free(key) + return nil, errors.New("failed converting to evp_pkey") + } + p := &pKey{key: key} runtime.SetFinalizer(p, func(p *pKey) { C.EVP_PKEY_free(p.key) diff --git a/pem_test.go b/pem_test.go new file mode 100644 index 0000000..f29cbf1 --- /dev/null +++ b/pem_test.go @@ -0,0 +1,76 @@ +// Copyright (C) 2014 Space Monkey, Inc. + +package openssl + +import ( + "bytes" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "io/ioutil" + "testing" +) + +func TestMarshal(t *testing.T) { + key, err := LoadPrivateKey(keyBytes) + if err != nil { + t.Fatal(err) + } + cert, err := LoadCertificate(certBytes) + if err != nil { + t.Fatal(err) + } + + pem, err := cert.MarshalPEM() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(pem, certBytes) { + ioutil.WriteFile("generated", pem, 0644) + ioutil.WriteFile("hardcoded", certBytes, 0644) + t.Fatal("invalid cert pem bytes") + } + + pem, err = key.MarshalPKCS1PrivateKeyPEM() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(pem, keyBytes) { + ioutil.WriteFile("generated", pem, 0644) + ioutil.WriteFile("hardcoded", keyBytes, 0644) + t.Fatal("invalid private key pem bytes") + } + tls_cert, err := tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + t.Fatal(err) + } + tls_key, ok := tls_cert.PrivateKey.(*rsa.PrivateKey) + if !ok { + t.Fatal("FASDFASDF") + } + _ = tls_key + + der, err := key.MarshalPKCS1PrivateKeyDER() + if err != nil { + t.Fatal(err) + } + tls_der := x509.MarshalPKCS1PrivateKey(tls_key) + if !bytes.Equal(der, tls_der) { + t.Fatal("invalid private key der bytes: %s\n v.s. %s\n", hex.Dump(der), hex.Dump(tls_der)) + } + + der, err = key.MarshalPKIXPublicKeyDER() + if err != nil { + t.Fatal(err) + } + tls_der, err = x509.MarshalPKIXPublicKey(&tls_key.PublicKey) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(der, tls_der) { + ioutil.WriteFile("generated", []byte(hex.Dump(der)), 0644) + ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + t.Fatal("invalid public key der bytes") + } +} diff --git a/ssl_test.go b/ssl_test.go index 40b85f2..1d716d8 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -17,8 +17,7 @@ import ( ) var ( - certBytes = []byte(` ------BEGIN CERTIFICATE----- + certBytes = []byte(`-----BEGIN CERTIFICATE----- MIIDxDCCAqygAwIBAgIVAMcK/0VWQr2O3MNfJCydqR7oVELcMA0GCSqGSIb3DQEB BQUAMIGQMUkwRwYDVQQDE0A1NjdjZGRmYzRjOWZiNTYwZTk1M2ZlZjA1N2M0NGFm MDdiYjc4MDIzODIxYTA5NThiY2RmMGMwNzJhOTdiMThhMQswCQYDVQQGEwJVUzEN @@ -42,8 +41,7 @@ sRkg/uxcJf7wC5Y0BLlp1+aPwdmZD87T3a1uQ1Ij93jmHG+2T9U20MklHAePOl0q yTqdSPnSH1c= -----END CERTIFICATE----- `) - keyBytes = []byte(` ------BEGIN RSA PRIVATE KEY----- + keyBytes = []byte(`-----BEGIN RSA PRIVATE KEY----- MIIEpQIBAAKCAQEA3X94nDbxbK5a5zS4vEqHLHKpUmxavqRL5oXEqKoAy6nm56rv C3e9xySe+DBlxIEV/MWU+RYpzjC99QkerfRP493aleqfhn3ZRS3tyKrQtP2z1Zwg wYqwcoASOLgqzKvtVYQMT1nJaw6O5fUEWG7BMR/ZX5/kcr8XjTGYjgEmrL1WTZ3G @@ -70,11 +68,11 @@ ByGgkUECgYEAmop45kRi974g4MPvyLplcE4syb19ifrHj76gPRBi94Cp8jZosY1T ucCCa4lOGgPtXJ0Qf1c8yq5vh4yqkQjrgUTkr+CFDGR6y4CxmNDQxEMYIajaIiSY qmgvgyRayemfO2zR0CPgC6wSoGBth+xW6g+WA8y0z76ZSaWpFi8lVM4= -----END RSA PRIVATE KEY----- - `) +`) ) func NetPipe(t testing.TB) (net.Conn, net.Conn) { - l, err := net.Listen("tcp", ":0") + l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatal(err) } @@ -479,7 +477,7 @@ func TestStdlibOpenSSLFullDuplexRenegotiation(t *testing.T) { func LotsOfConns(t *testing.T, payload_size int64, loops, clients int, sleep time.Duration, newListener func(net.Listener) net.Listener, newClient func(net.Conn) (net.Conn, error)) { - tcp_listener, err := net.Listen("tcp", ":0") + tcp_listener, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatal(err) }