diff --git a/conn.go b/conn.go index d5fd538..f5aef7d 100644 --- a/conn.go +++ b/conn.go @@ -6,6 +6,11 @@ package openssl // #include // #include // #include +// +// int sk_X509_num_not_a_macro(STACK_OF(X509) *sk) { return sk_X509_num(sk); } +// X509 *sk_X509_value_not_a_macro(STACK_OF(X509)* sk, int i) { +// return sk_X509_value(sk, i); +// } import "C" import ( @@ -230,11 +235,11 @@ func (c *Conn) Handshake() error { // communicating. Only valid after a handshake. func (c *Conn) PeerCertificate() (*Certificate, error) { c.mtx.Lock() + defer c.mtx.Unlock() if c.is_shutdown { return nil, errors.New("connection closed") } x := C.SSL_get_peer_certificate(c.ssl) - c.mtx.Unlock() if x == nil { return nil, errors.New("no peer certificate found") } @@ -245,13 +250,41 @@ func (c *Conn) PeerCertificate() (*Certificate, error) { return cert, nil } +// PeerCertificateChain returns the certificate chain of the peer. If called on +// the client side, the stack also contains the peer's certificate; if called +// on the server side, the peer's certificate must be obtained separately using +// PeerCertificate. +func (c *Conn) PeerCertificateChain() (rv []*Certificate, err error) { + c.mtx.Lock() + defer c.mtx.Unlock() + if c.is_shutdown { + return nil, errors.New("connection closed") + } + sk := C.SSL_get_peer_cert_chain(c.ssl) + if sk == nil { + return nil, errors.New("no peer certificates found") + } + sk_num := int(C.sk_X509_num_not_a_macro(sk)) + rv = make([]*Certificate, 0, sk_num) + for i := 0; i < sk_num; i++ { + x := C.sk_X509_value_not_a_macro(sk, C.int(i)) + // ref holds on to the underlying connection memory so we don't need to + // worry about incrementing refcounts manually or freeing the X509 + rv = append(rv, &Certificate{x: x, ref: c}) + } + return rv, nil +} + type ConnectionState struct { - Certificate *Certificate - CertificateError error + Certificate *Certificate + CertificateError error + CertificateChain []*Certificate + CertificateChainError error } func (c *Conn) ConnectionState() (rv ConnectionState) { rv.Certificate, rv.CertificateError = c.PeerCertificate() + rv.CertificateChain, rv.CertificateChainError = c.PeerCertificateChain() return } diff --git a/pem.go b/pem.go index b04fbf0..22488e8 100644 --- a/pem.go +++ b/pem.go @@ -155,7 +155,8 @@ func LoadPrivateKey(pem_block []byte) (PrivateKey, error) { } type Certificate struct { - x *C.X509 + x *C.X509 + ref interface{} } // LoadCertificate loads an X509 certificate from a PEM-encoded block.