go-openssl/conn.go
Andrew Harding 420209b5b2 space monkey internal commit export
[katamari commit: ec10a16bdc8a55c53af32481120946fa49570b30]
2014-03-11 18:09:47 +00:00

468 lines
11 KiB
Go

// Copyright (C) 2014 Space Monkey, Inc.
// +build cgo
package openssl
// #include <openssl/ssl.h>
// #include <openssl/conf.h>
// #include <openssl/err.h>
//
// int sk_X509_num_not_a_macro(STACK_OF(X509) *sk) { return sk_X509_num(sk); }
// X509 *sk_X509_value_not_a_macro(STACK_OF(X509)* sk, int i) {
// return sk_X509_value(sk, i);
// }
import "C"
import (
"errors"
"io"
"net"
"runtime"
"sync"
"time"
"unsafe"
"code.spacemonkey.com/go/openssl/utils"
)
var (
zeroReturn = errors.New("zero return")
wantRead = errors.New("want read")
wantWrite = errors.New("want write")
tryAgain = errors.New("try again")
)
type Conn struct {
conn net.Conn
ssl *C.SSL
ctx *Ctx // for gc
into_ssl *readBio
from_ssl *writeBio
is_shutdown bool
mtx sync.Mutex
want_read_future *utils.Future
}
func newSSL(ctx *C.SSL_CTX) (*C.SSL, error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
ssl := C.SSL_new(ctx)
if ssl == nil {
return nil, errorFromErrorQueue()
}
return ssl, nil
}
func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) {
ssl, err := newSSL(ctx.ctx)
if err != nil {
return nil, err
}
into_ssl := &readBio{}
from_ssl := &writeBio{}
into_ssl_cbio := into_ssl.MakeCBIO()
from_ssl_cbio := from_ssl.MakeCBIO()
if into_ssl_cbio == nil || from_ssl_cbio == nil {
// these frees are null safe
C.BIO_free(into_ssl_cbio)
C.BIO_free(from_ssl_cbio)
C.SSL_free(ssl)
return nil, errors.New("failed to allocate memory BIO")
}
// the ssl object takes ownership of these objects now
C.SSL_set_bio(ssl, into_ssl_cbio, from_ssl_cbio)
c := &Conn{
conn: conn,
ssl: ssl,
ctx: ctx,
into_ssl: into_ssl,
from_ssl: from_ssl}
runtime.SetFinalizer(c, func(c *Conn) {
c.into_ssl.Disconnect(into_ssl_cbio)
c.from_ssl.Disconnect(from_ssl_cbio)
C.SSL_free(c.ssl)
})
return c, nil
}
// Client wraps an existing stream connection and puts it in the connect state
// for any subsequent handshakes.
//
// IMPORTANT NOTE: if you use this method instead of Dial to construct an SSL
// connection, you are responsible for verifying the peer's hostname.
// Otherwise, you are vulnerable to MITM attacks.
//
// Client connections probably won't work for you unless you set a verify
// location or add some certs to the certificate store of the client context
// you're using. This library is not nice enough to use the system certificate
// store by default for you yet.
func Client(conn net.Conn, ctx *Ctx) (*Conn, error) {
c, err := newConn(conn, ctx)
if err != nil {
return nil, err
}
C.SSL_set_connect_state(c.ssl)
return c, nil
}
// Server wraps an existing stream connection and puts it in the accept state
// for any subsequent handshakes.
func Server(conn net.Conn, ctx *Ctx) (*Conn, error) {
c, err := newConn(conn, ctx)
if err != nil {
return nil, err
}
C.SSL_set_accept_state(c.ssl)
return c, nil
}
func (c *Conn) fillInputBuffer() error {
for {
n, err := c.into_ssl.ReadFromOnce(c.conn)
if n == 0 && err == nil {
continue
}
if err == io.EOF {
c.into_ssl.MarkEOF()
return c.Close()
}
return err
}
}
func (c *Conn) flushOutputBuffer() error {
_, err := c.from_ssl.WriteTo(c.conn)
return err
}
func (c *Conn) getErrorHandler(rv C.int, errno error) func() error {
errcode := C.SSL_get_error(c.ssl, rv)
switch errcode {
case C.SSL_ERROR_ZERO_RETURN:
return func() error {
c.Close()
return io.ErrUnexpectedEOF
}
case C.SSL_ERROR_WANT_READ:
go c.flushOutputBuffer()
if c.want_read_future != nil {
want_read_future := c.want_read_future
return func() error {
_, err := want_read_future.Get()
return err
}
}
c.want_read_future = utils.NewFuture()
want_read_future := c.want_read_future
return func() (err error) {
defer func() {
c.mtx.Lock()
c.want_read_future = nil
c.mtx.Unlock()
want_read_future.Set(nil, err)
}()
err = c.fillInputBuffer()
if err != nil {
return err
}
return tryAgain
}
case C.SSL_ERROR_WANT_WRITE:
return func() error {
err := c.flushOutputBuffer()
if err != nil {
return err
}
return tryAgain
}
case C.SSL_ERROR_SYSCALL:
var err error
if C.ERR_peek_error() == 0 {
switch rv {
case 0:
err = errors.New("protocol-violating EOF")
case -1:
err = errno
default:
err = errorFromErrorQueue()
}
} else {
err = errorFromErrorQueue()
}
return func() error { return err }
default:
err := errorFromErrorQueue()
return func() error { return err }
}
}
func (c *Conn) handleError(errcb func() error) error {
if errcb != nil {
return errcb()
}
return nil
}
func (c *Conn) handshake() func() error {
c.mtx.Lock()
defer c.mtx.Unlock()
if c.is_shutdown {
return func() error { return io.ErrUnexpectedEOF }
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
rv, errno := C.SSL_do_handshake(c.ssl)
if rv > 0 {
return nil
}
return c.getErrorHandler(rv, errno)
}
// Handshake performs an SSL handshake. If a handshake is not manually
// triggered, it will run before the first I/O on the encrypted stream.
func (c *Conn) Handshake() error {
err := tryAgain
for err == tryAgain {
err = c.handleError(c.handshake())
}
go c.flushOutputBuffer()
return err
}
// PeerCertificate returns the Certificate of the peer with which you're
// communicating. Only valid after a handshake.
func (c *Conn) PeerCertificate() (*Certificate, error) {
c.mtx.Lock()
defer c.mtx.Unlock()
if c.is_shutdown {
return nil, errors.New("connection closed")
}
x := C.SSL_get_peer_certificate(c.ssl)
if x == nil {
return nil, errors.New("no peer certificate found")
}
cert := &Certificate{x: x}
runtime.SetFinalizer(cert, func(cert *Certificate) {
C.X509_free(cert.x)
})
return cert, nil
}
// PeerCertificateChain returns the certificate chain of the peer. If called on
// the client side, the stack also contains the peer's certificate; if called
// on the server side, the peer's certificate must be obtained separately using
// PeerCertificate.
func (c *Conn) PeerCertificateChain() (rv []*Certificate, err error) {
c.mtx.Lock()
defer c.mtx.Unlock()
if c.is_shutdown {
return nil, errors.New("connection closed")
}
sk := C.SSL_get_peer_cert_chain(c.ssl)
if sk == nil {
return nil, errors.New("no peer certificates found")
}
sk_num := int(C.sk_X509_num_not_a_macro(sk))
rv = make([]*Certificate, 0, sk_num)
for i := 0; i < sk_num; i++ {
x := C.sk_X509_value_not_a_macro(sk, C.int(i))
// ref holds on to the underlying connection memory so we don't need to
// worry about incrementing refcounts manually or freeing the X509
rv = append(rv, &Certificate{x: x, ref: c})
}
return rv, nil
}
type ConnectionState struct {
Certificate *Certificate
CertificateError error
CertificateChain []*Certificate
CertificateChainError error
}
func (c *Conn) ConnectionState() (rv ConnectionState) {
rv.Certificate, rv.CertificateError = c.PeerCertificate()
rv.CertificateChain, rv.CertificateChainError = c.PeerCertificateChain()
return
}
func (c *Conn) shutdown() func() error {
c.mtx.Lock()
defer c.mtx.Unlock()
runtime.LockOSThread()
defer runtime.UnlockOSThread()
rv, errno := C.SSL_shutdown(c.ssl)
if rv > 0 {
return nil
}
if rv == 0 {
// The OpenSSL docs say that in this case, the shutdown is not
// finished, and we should call SSL_shutdown() a second time, if a
// bidirectional shutdown is going to be performed. Further, the
// output of SSL_get_error may be misleading, as an erroneous
// SSL_ERROR_SYSCALL may be flagged even though no error occurred.
// So, TODO: revisit bidrectional shutdown, possibly trying again.
// Note: some broken clients won't engage in bidirectional shutdown
// without tickling them to close by sending a TCP_FIN packet, or
// shutting down the write-side of the connection.
return nil
} else {
return c.getErrorHandler(rv, errno)
}
}
func (c *Conn) shutdownLoop() error {
err := tryAgain
shutdown_tries := 0
for err == tryAgain {
shutdown_tries = shutdown_tries + 1
err = c.handleError(c.shutdown())
if err == nil {
return c.flushOutputBuffer()
}
if err == tryAgain && shutdown_tries >= 2 {
return errors.New("shutdown requested a third time?")
}
}
if err == io.ErrUnexpectedEOF {
err = nil
}
return err
}
// Close shuts down the SSL connection and closes the underlying wrapped
// connection.
func (c *Conn) Close() error {
c.mtx.Lock()
if c.is_shutdown {
c.mtx.Unlock()
return nil
}
c.is_shutdown = true
c.mtx.Unlock()
var errs utils.ErrorGroup
errs.Add(c.shutdownLoop())
errs.Add(c.conn.Close())
return errs.Finalize()
}
func (c *Conn) read(b []byte) (int, func() error) {
if len(b) == 0 {
return 0, nil
}
c.mtx.Lock()
defer c.mtx.Unlock()
if c.is_shutdown {
return 0, func() error { return io.EOF }
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
rv, errno := C.SSL_read(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
if rv > 0 {
return int(rv), nil
}
return 0, c.getErrorHandler(rv, errno)
}
// Read reads up to len(b) bytes into b. It returns the number of bytes read
// and an error if applicable. io.EOF is returned when the caller can expect
// to see no more data.
func (c *Conn) Read(b []byte) (n int, err error) {
if len(b) == 0 {
return 0, nil
}
err = tryAgain
for err == tryAgain {
n, errcb := c.read(b)
err = c.handleError(errcb)
if err == nil {
go c.flushOutputBuffer()
return n, nil
}
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
}
return 0, err
}
func (c *Conn) write(b []byte) (int, func() error) {
if len(b) == 0 {
return 0, nil
}
c.mtx.Lock()
defer c.mtx.Unlock()
if c.is_shutdown {
err := errors.New("connection closed")
return 0, func() error { return err }
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
if rv > 0 {
return int(rv), nil
}
return 0, c.getErrorHandler(rv, errno)
}
// Write will encrypt the contents of b and write it to the underlying stream.
// Performance will be vastly improved if the size of b is a multiple of
// SSLRecordSize.
func (c *Conn) Write(b []byte) (written int, err error) {
if len(b) == 0 {
return 0, nil
}
err = tryAgain
for err == tryAgain {
n, errcb := c.write(b)
err = c.handleError(errcb)
if err == nil {
return n, c.flushOutputBuffer()
}
}
return 0, err
}
// VerifyHostname pulls the PeerCertificate and calls VerifyHostname on the
// certificate.
func (c *Conn) VerifyHostname(host string) error {
cert, err := c.PeerCertificate()
if err != nil {
return err
}
return cert.VerifyHostname(host)
}
// LocalAddr returns the underlying connection's local address
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the underlying connection's remote address
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline calls SetDeadline on the underlying connection.
func (c *Conn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
// SetReadDeadline calls SetReadDeadline on the underlying connection.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline calls SetWriteDeadline on the underlying connection.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *Conn) UnderlyingConn() net.Conn {
return c.conn
}