diff --git a/bio.go b/bio.go index b2164a4..8be9c4f 100644 --- a/bio.go +++ b/bio.go @@ -56,226 +56,226 @@ static void BIO_set_retry_read_not_a_macro(BIO *b) { BIO_set_retry_read(b); } import "C" import ( - "errors" - "io" - "reflect" - "sync" - "unsafe" + "errors" + "io" + "reflect" + "sync" + "unsafe" ) const ( - SSLRecordSize = 16 * 1024 + SSLRecordSize = 16 * 1024 ) func nonCopyGoBytes(ptr uintptr, length int) []byte { - var slice []byte - header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) - header.Cap = length - header.Len = length - header.Data = ptr - return slice + var slice []byte + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + header.Cap = length + header.Len = length + header.Data = ptr + return slice } func nonCopyCString(data *C.char, size C.int) []byte { - return nonCopyGoBytes(uintptr(unsafe.Pointer(data)), int(size)) + return nonCopyGoBytes(uintptr(unsafe.Pointer(data)), int(size)) } //export cbioNew func cbioNew(b *C.BIO) C.int { - b.shutdown = 1 - b.init = 1 - b.num = -1 - b.ptr = nil - b.flags = 0 - return 1 + b.shutdown = 1 + b.init = 1 + b.num = -1 + b.ptr = nil + b.flags = 0 + return 1 } //export cbioFree func cbioFree(b *C.BIO) C.int { - return 1 + return 1 } type writeBio struct { - data_mtx sync.Mutex - op_mtx sync.Mutex - buf []byte + data_mtx sync.Mutex + op_mtx sync.Mutex + buf []byte } func loadWritePtr(b *C.BIO) *writeBio { - return (*writeBio)(unsafe.Pointer(b.ptr)) + return (*writeBio)(unsafe.Pointer(b.ptr)) } //export writeBioWrite func writeBioWrite(b *C.BIO, data *C.char, size C.int) C.int { - ptr := loadWritePtr(b) - if ptr == nil || data == nil || size < 0 { - return -1 - } - ptr.data_mtx.Lock() - defer ptr.data_mtx.Unlock() - C.BIO_clear_retry_flags_not_a_macro(b) - ptr.buf = append(ptr.buf, nonCopyCString(data, size)...) - return size + ptr := loadWritePtr(b) + if ptr == nil || data == nil || size < 0 { + return -1 + } + ptr.data_mtx.Lock() + defer ptr.data_mtx.Unlock() + C.BIO_clear_retry_flags_not_a_macro(b) + ptr.buf = append(ptr.buf, nonCopyCString(data, size)...) + return size } //export writeBioCtrl func writeBioCtrl(b *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) C.long { - switch cmd { - case C.BIO_CTRL_WPENDING: - return writeBioPending(b) - case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH: - return 1 - default: - return 0 - } + switch cmd { + case C.BIO_CTRL_WPENDING: + return writeBioPending(b) + case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH: + return 1 + default: + return 0 + } } func writeBioPending(b *C.BIO) C.long { - ptr := loadWritePtr(b) - if ptr == nil { - return 0 - } - ptr.data_mtx.Lock() - defer ptr.data_mtx.Unlock() - return C.long(len(ptr.buf)) + ptr := loadWritePtr(b) + if ptr == nil { + return 0 + } + ptr.data_mtx.Lock() + defer ptr.data_mtx.Unlock() + return C.long(len(ptr.buf)) } func (b *writeBio) WriteTo(w io.Writer) (rv int64, err error) { - b.op_mtx.Lock() - defer b.op_mtx.Unlock() + b.op_mtx.Lock() + defer b.op_mtx.Unlock() - // write whatever data we currently have - b.data_mtx.Lock() - data := b.buf - b.data_mtx.Unlock() + // write whatever data we currently have + b.data_mtx.Lock() + data := b.buf + b.data_mtx.Unlock() - if len(data) == 0 { - return 0, nil - } - n, err := w.Write(data) + if len(data) == 0 { + return 0, nil + } + n, err := w.Write(data) - // subtract however much data we wrote from the buffer - b.data_mtx.Lock() - b.buf = b.buf[:copy(b.buf, b.buf[n:])] - b.data_mtx.Unlock() + // subtract however much data we wrote from the buffer + b.data_mtx.Lock() + b.buf = b.buf[:copy(b.buf, b.buf[n:])] + b.data_mtx.Unlock() - return int64(n), err + return int64(n), err } func (self *writeBio) Disconnect(b *C.BIO) { - if loadWritePtr(b) == self { - b.ptr = nil - } + if loadWritePtr(b) == self { + b.ptr = nil + } } func (b *writeBio) MakeCBIO() *C.BIO { - rv := C.BIO_new(C.BIO_s_writeBio()) - rv.ptr = unsafe.Pointer(b) - return rv + rv := C.BIO_new(C.BIO_s_writeBio()) + rv.ptr = unsafe.Pointer(b) + return rv } type readBio struct { - data_mtx sync.Mutex - op_mtx sync.Mutex - buf []byte - eof bool + data_mtx sync.Mutex + op_mtx sync.Mutex + buf []byte + eof bool } func loadReadPtr(b *C.BIO) *readBio { - return (*readBio)(unsafe.Pointer(b.ptr)) + return (*readBio)(unsafe.Pointer(b.ptr)) } //export readBioRead func readBioRead(b *C.BIO, data *C.char, size C.int) C.int { - ptr := loadReadPtr(b) - if ptr == nil || size < 0 { - return -1 - } - ptr.data_mtx.Lock() - defer ptr.data_mtx.Unlock() - C.BIO_clear_retry_flags_not_a_macro(b) - if len(ptr.buf) == 0 { - if ptr.eof { - return 0 - } - C.BIO_set_retry_read_not_a_macro(b) - return -1 - } - if size == 0 || data == nil { - return C.int(len(ptr.buf)) - } - n := copy(nonCopyCString(data, size), ptr.buf) - ptr.buf = ptr.buf[:copy(ptr.buf, ptr.buf[n:])] - return C.int(n) + ptr := loadReadPtr(b) + if ptr == nil || size < 0 { + return -1 + } + ptr.data_mtx.Lock() + defer ptr.data_mtx.Unlock() + C.BIO_clear_retry_flags_not_a_macro(b) + if len(ptr.buf) == 0 { + if ptr.eof { + return 0 + } + C.BIO_set_retry_read_not_a_macro(b) + return -1 + } + if size == 0 || data == nil { + return C.int(len(ptr.buf)) + } + n := copy(nonCopyCString(data, size), ptr.buf) + ptr.buf = ptr.buf[:copy(ptr.buf, ptr.buf[n:])] + return C.int(n) } //export readBioCtrl func readBioCtrl(b *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) C.long { - switch cmd { - case C.BIO_CTRL_PENDING: - return readBioPending(b) - case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH: - return 1 - default: - return 0 - } + switch cmd { + case C.BIO_CTRL_PENDING: + return readBioPending(b) + case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH: + return 1 + default: + return 0 + } } func readBioPending(b *C.BIO) C.long { - ptr := loadReadPtr(b) - if ptr == nil { - return 0 - } - ptr.data_mtx.Lock() - defer ptr.data_mtx.Unlock() - return C.long(len(ptr.buf)) + ptr := loadReadPtr(b) + if ptr == nil { + return 0 + } + ptr.data_mtx.Lock() + defer ptr.data_mtx.Unlock() + return C.long(len(ptr.buf)) } func (b *readBio) ReadFromOnce(r io.Reader) (n int, err error) { - b.op_mtx.Lock() - defer b.op_mtx.Unlock() + b.op_mtx.Lock() + defer b.op_mtx.Unlock() - // make sure we have a destination that fits at least one SSL record - b.data_mtx.Lock() - if cap(b.buf) < len(b.buf)+SSLRecordSize { - new_buf := make([]byte, len(b.buf), len(b.buf)+SSLRecordSize) - copy(new_buf, b.buf) - b.buf = new_buf - } - dst := b.buf[len(b.buf):cap(b.buf)] - dst_slice := b.buf - b.data_mtx.Unlock() + // make sure we have a destination that fits at least one SSL record + b.data_mtx.Lock() + if cap(b.buf) < len(b.buf)+SSLRecordSize { + new_buf := make([]byte, len(b.buf), len(b.buf)+SSLRecordSize) + copy(new_buf, b.buf) + b.buf = new_buf + } + dst := b.buf[len(b.buf):cap(b.buf)] + dst_slice := b.buf + b.data_mtx.Unlock() - n, err = r.Read(dst) - b.data_mtx.Lock() - defer b.data_mtx.Unlock() - if n > 0 { - if len(dst_slice) != len(b.buf) { - // someone shrunk the buffer, so we read in too far ahead and we - // need to slide backwards - copy(b.buf[len(b.buf):len(b.buf)+n], dst) - } - b.buf = b.buf[:len(b.buf)+n] - } - return n, err + n, err = r.Read(dst) + b.data_mtx.Lock() + defer b.data_mtx.Unlock() + if n > 0 { + if len(dst_slice) != len(b.buf) { + // someone shrunk the buffer, so we read in too far ahead and we + // need to slide backwards + copy(b.buf[len(b.buf):len(b.buf)+n], dst) + } + b.buf = b.buf[:len(b.buf)+n] + } + return n, err } func (b *readBio) MakeCBIO() *C.BIO { - rv := C.BIO_new(C.BIO_s_readBio()) - rv.ptr = unsafe.Pointer(b) - return rv + rv := C.BIO_new(C.BIO_s_readBio()) + rv.ptr = unsafe.Pointer(b) + return rv } func (self *readBio) Disconnect(b *C.BIO) { - if loadReadPtr(b) == self { - b.ptr = nil - } + if loadReadPtr(b) == self { + b.ptr = nil + } } func (b *readBio) MarkEOF() { - b.data_mtx.Lock() - defer b.data_mtx.Unlock() - b.eof = true + b.data_mtx.Lock() + defer b.data_mtx.Unlock() + b.eof = true } type anyBio C.BIO @@ -283,18 +283,18 @@ type anyBio C.BIO func asAnyBio(b *C.BIO) *anyBio { return (*anyBio)(b) } func (b *anyBio) Read(buf []byte) (n int, err error) { - n = int(C.BIO_read((*C.BIO)(b), unsafe.Pointer(&buf[0]), C.int(len(buf)))) - if n <= 0 { - return 0, io.EOF - } - return n, nil + n = int(C.BIO_read((*C.BIO)(b), unsafe.Pointer(&buf[0]), C.int(len(buf)))) + if n <= 0 { + return 0, io.EOF + } + return n, nil } func (b *anyBio) Write(buf []byte) (written int, err error) { - n := int(C.BIO_write((*C.BIO)(b), unsafe.Pointer(&buf[0]), - C.int(len(buf)))) - if n != len(buf) { - return n, errors.New("BIO write failed") - } - return n, nil + n := int(C.BIO_write((*C.BIO)(b), unsafe.Pointer(&buf[0]), + C.int(len(buf)))) + if n != len(buf) { + return n, errors.New("BIO write failed") + } + return n, nil } diff --git a/conn.go b/conn.go index 6f19c20..ab3cac1 100644 --- a/conn.go +++ b/conn.go @@ -9,77 +9,77 @@ package openssl import "C" import ( - "errors" - "io" - "net" - "runtime" - "sync" - "time" - "unsafe" + "errors" + "io" + "net" + "runtime" + "sync" + "time" + "unsafe" - "code.spacemonkey.com/go/openssl/utils" + "code.spacemonkey.com/go/openssl/utils" ) var ( - zeroReturn = errors.New("zero return") - wantRead = errors.New("want read") - wantWrite = errors.New("want write") - tryAgain = errors.New("try again") + zeroReturn = errors.New("zero return") + wantRead = errors.New("want read") + wantWrite = errors.New("want write") + tryAgain = errors.New("try again") ) type Conn struct { - conn net.Conn - ssl *C.SSL - into_ssl *readBio - from_ssl *writeBio - is_shutdown bool - mtx sync.Mutex - want_read_future *utils.Future + conn net.Conn + ssl *C.SSL + into_ssl *readBio + from_ssl *writeBio + is_shutdown bool + mtx sync.Mutex + want_read_future *utils.Future } func newSSL(ctx *C.SSL_CTX) (*C.SSL, error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - ssl := C.SSL_new(ctx) - if ssl == nil { - return nil, errorFromErrorQueue() - } - return ssl, nil + runtime.LockOSThread() + defer runtime.UnlockOSThread() + ssl := C.SSL_new(ctx) + if ssl == nil { + return nil, errorFromErrorQueue() + } + return ssl, nil } func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) { - ssl, err := newSSL(ctx.ctx) - if err != nil { - return nil, err - } + ssl, err := newSSL(ctx.ctx) + if err != nil { + return nil, err + } - into_ssl := &readBio{} - from_ssl := &writeBio{} + into_ssl := &readBio{} + from_ssl := &writeBio{} - into_ssl_cbio := into_ssl.MakeCBIO() - from_ssl_cbio := from_ssl.MakeCBIO() - if into_ssl_cbio == nil || from_ssl_cbio == nil { - // these frees are null safe - C.BIO_free(into_ssl_cbio) - C.BIO_free(from_ssl_cbio) - C.SSL_free(ssl) - return nil, errors.New("failed to allocate memory BIO") - } + into_ssl_cbio := into_ssl.MakeCBIO() + from_ssl_cbio := from_ssl.MakeCBIO() + if into_ssl_cbio == nil || from_ssl_cbio == nil { + // these frees are null safe + C.BIO_free(into_ssl_cbio) + C.BIO_free(from_ssl_cbio) + C.SSL_free(ssl) + return nil, errors.New("failed to allocate memory BIO") + } - // the ssl object takes ownership of these objects now - C.SSL_set_bio(ssl, into_ssl_cbio, from_ssl_cbio) + // the ssl object takes ownership of these objects now + C.SSL_set_bio(ssl, into_ssl_cbio, from_ssl_cbio) - c := &Conn{ - conn: conn, - ssl: ssl, - into_ssl: into_ssl, - from_ssl: from_ssl} - runtime.SetFinalizer(c, func(c *Conn) { - c.into_ssl.Disconnect(into_ssl_cbio) - c.from_ssl.Disconnect(from_ssl_cbio) - C.SSL_free(c.ssl) - }) - return c, nil + c := &Conn{ + conn: conn, + ssl: ssl, + into_ssl: into_ssl, + from_ssl: from_ssl} + runtime.SetFinalizer(c, func(c *Conn) { + c.into_ssl.Disconnect(into_ssl_cbio) + c.from_ssl.Disconnect(from_ssl_cbio) + C.SSL_free(c.ssl) + }) + return c, nil } // Client wraps an existing stream connection and puts it in the connect state @@ -94,319 +94,319 @@ func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) { // you're using. This library is not nice enough to use the system certificate // store by default for you yet. func Client(conn net.Conn, ctx *Ctx) (*Conn, error) { - c, err := newConn(conn, ctx) - if err != nil { - return nil, err - } - C.SSL_set_connect_state(c.ssl) - return c, nil + c, err := newConn(conn, ctx) + if err != nil { + return nil, err + } + C.SSL_set_connect_state(c.ssl) + return c, nil } // Server wraps an existing stream connection and puts it in the accept state // for any subsequent handshakes. func Server(conn net.Conn, ctx *Ctx) (*Conn, error) { - c, err := newConn(conn, ctx) - if err != nil { - return nil, err - } - C.SSL_set_accept_state(c.ssl) - return c, nil + c, err := newConn(conn, ctx) + if err != nil { + return nil, err + } + C.SSL_set_accept_state(c.ssl) + return c, nil } func (c *Conn) fillInputBuffer() error { - for { - n, err := c.into_ssl.ReadFromOnce(c.conn) - if n == 0 && err == nil { - continue - } - if err == io.EOF { - c.into_ssl.MarkEOF() - return c.Close() - } - return err - } + for { + n, err := c.into_ssl.ReadFromOnce(c.conn) + if n == 0 && err == nil { + continue + } + if err == io.EOF { + c.into_ssl.MarkEOF() + return c.Close() + } + return err + } } func (c *Conn) flushOutputBuffer() error { - _, err := c.from_ssl.WriteTo(c.conn) - return err + _, err := c.from_ssl.WriteTo(c.conn) + return err } func (c *Conn) getErrorHandler(rv C.int, errno error) func() error { - errcode := C.SSL_get_error(c.ssl, rv) - switch errcode { - case C.SSL_ERROR_ZERO_RETURN: - return func() error { - c.Close() - return io.ErrUnexpectedEOF - } - case C.SSL_ERROR_WANT_READ: - go c.flushOutputBuffer() - if c.want_read_future != nil { - want_read_future := c.want_read_future - return func() error { - _, err := want_read_future.Get() - return err - } - } - c.want_read_future = utils.NewFuture() - want_read_future := c.want_read_future - return func() (err error) { - defer func() { - c.mtx.Lock() - c.want_read_future = nil - c.mtx.Unlock() - want_read_future.Set(nil, err) - }() - err = c.fillInputBuffer() - if err != nil { - return err - } - return tryAgain - } - case C.SSL_ERROR_WANT_WRITE: - return func() error { - err := c.flushOutputBuffer() - if err != nil { - return err - } - return tryAgain - } - case C.SSL_ERROR_SYSCALL: - var err error - if C.ERR_peek_error() == 0 { - switch rv { - case 0: - err = errors.New("protocol-violating EOF") - case -1: - err = errno - default: - err = errorFromErrorQueue() - } - } else { - err = errorFromErrorQueue() - } - return func() error { return err } - default: - err := errorFromErrorQueue() - return func() error { return err } - } + errcode := C.SSL_get_error(c.ssl, rv) + switch errcode { + case C.SSL_ERROR_ZERO_RETURN: + return func() error { + c.Close() + return io.ErrUnexpectedEOF + } + case C.SSL_ERROR_WANT_READ: + go c.flushOutputBuffer() + if c.want_read_future != nil { + want_read_future := c.want_read_future + return func() error { + _, err := want_read_future.Get() + return err + } + } + c.want_read_future = utils.NewFuture() + want_read_future := c.want_read_future + return func() (err error) { + defer func() { + c.mtx.Lock() + c.want_read_future = nil + c.mtx.Unlock() + want_read_future.Set(nil, err) + }() + err = c.fillInputBuffer() + if err != nil { + return err + } + return tryAgain + } + case C.SSL_ERROR_WANT_WRITE: + return func() error { + err := c.flushOutputBuffer() + if err != nil { + return err + } + return tryAgain + } + case C.SSL_ERROR_SYSCALL: + var err error + if C.ERR_peek_error() == 0 { + switch rv { + case 0: + err = errors.New("protocol-violating EOF") + case -1: + err = errno + default: + err = errorFromErrorQueue() + } + } else { + err = errorFromErrorQueue() + } + return func() error { return err } + default: + err := errorFromErrorQueue() + return func() error { return err } + } } func (c *Conn) handleError(errcb func() error) error { - if errcb != nil { - return errcb() - } - return nil + if errcb != nil { + return errcb() + } + return nil } func (c *Conn) handshake() func() error { - c.mtx.Lock() - defer c.mtx.Unlock() - if c.is_shutdown { - return func() error { return io.ErrUnexpectedEOF } - } - runtime.LockOSThread() - defer runtime.UnlockOSThread() - rv, errno := C.SSL_do_handshake(c.ssl) - if rv > 0 { - return nil - } - return c.getErrorHandler(rv, errno) + c.mtx.Lock() + defer c.mtx.Unlock() + if c.is_shutdown { + return func() error { return io.ErrUnexpectedEOF } + } + runtime.LockOSThread() + defer runtime.UnlockOSThread() + rv, errno := C.SSL_do_handshake(c.ssl) + if rv > 0 { + return nil + } + return c.getErrorHandler(rv, errno) } // Handshake performs an SSL handshake. If a handshake is not manually // triggered, it will run before the first I/O on the encrypted stream. func (c *Conn) Handshake() error { - err := tryAgain - for err == tryAgain { - err = c.handleError(c.handshake()) - } - go c.flushOutputBuffer() - return err + err := tryAgain + for err == tryAgain { + err = c.handleError(c.handshake()) + } + go c.flushOutputBuffer() + return err } // PeerCertificate returns the Certificate of the peer with which you're // communicating. Only valid after a handshake. func (c *Conn) PeerCertificate() (*Certificate, error) { - c.mtx.Lock() - if c.is_shutdown { - return nil, errors.New("connection closed") - } - x := C.SSL_get_peer_certificate(c.ssl) - c.mtx.Unlock() - if x == nil { - return nil, errors.New("no peer certificate found") - } - cert := &Certificate{x: x} - runtime.SetFinalizer(cert, func(cert *Certificate) { - C.X509_free(cert.x) - }) - return cert, nil + c.mtx.Lock() + if c.is_shutdown { + return nil, errors.New("connection closed") + } + x := C.SSL_get_peer_certificate(c.ssl) + c.mtx.Unlock() + if x == nil { + return nil, errors.New("no peer certificate found") + } + cert := &Certificate{x: x} + runtime.SetFinalizer(cert, func(cert *Certificate) { + C.X509_free(cert.x) + }) + return cert, nil } func (c *Conn) shutdown() func() error { - c.mtx.Lock() - defer c.mtx.Unlock() - runtime.LockOSThread() - defer runtime.UnlockOSThread() - rv, errno := C.SSL_shutdown(c.ssl) - if rv > 0 { - return nil - } - if rv == 0 { - // The OpenSSL docs say that in this case, the shutdown is not - // finished, and we should call SSL_shutdown() a second time, if a - // bidirectional shutdown is going to be performed. Further, the - // output of SSL_get_error may be misleading, as an erroneous - // SSL_ERROR_SYSCALL may be flagged even though no error occurred. - // So, TODO: revisit bidrectional shutdown, possibly trying again. - // Note: some broken clients won't engage in bidirectional shutdown - // without tickling them to close by sending a TCP_FIN packet, or - // shutting down the write-side of the connection. - return nil - } else { - return c.getErrorHandler(rv, errno) - } + c.mtx.Lock() + defer c.mtx.Unlock() + runtime.LockOSThread() + defer runtime.UnlockOSThread() + rv, errno := C.SSL_shutdown(c.ssl) + if rv > 0 { + return nil + } + if rv == 0 { + // The OpenSSL docs say that in this case, the shutdown is not + // finished, and we should call SSL_shutdown() a second time, if a + // bidirectional shutdown is going to be performed. Further, the + // output of SSL_get_error may be misleading, as an erroneous + // SSL_ERROR_SYSCALL may be flagged even though no error occurred. + // So, TODO: revisit bidrectional shutdown, possibly trying again. + // Note: some broken clients won't engage in bidirectional shutdown + // without tickling them to close by sending a TCP_FIN packet, or + // shutting down the write-side of the connection. + return nil + } else { + return c.getErrorHandler(rv, errno) + } } func (c *Conn) shutdownLoop() error { - err := tryAgain - shutdown_tries := 0 - for err == tryAgain { - shutdown_tries = shutdown_tries + 1 - err = c.handleError(c.shutdown()) - if err == nil { - return c.flushOutputBuffer() - } - if err == tryAgain && shutdown_tries >= 2 { - return errors.New("shutdown requested a third time?") - } - } - if err == io.ErrUnexpectedEOF { - err = nil - } - return err + err := tryAgain + shutdown_tries := 0 + for err == tryAgain { + shutdown_tries = shutdown_tries + 1 + err = c.handleError(c.shutdown()) + if err == nil { + return c.flushOutputBuffer() + } + if err == tryAgain && shutdown_tries >= 2 { + return errors.New("shutdown requested a third time?") + } + } + if err == io.ErrUnexpectedEOF { + err = nil + } + return err } // Close shuts down the SSL connection and closes the underlying wrapped // connection. func (c *Conn) Close() error { - c.mtx.Lock() - if c.is_shutdown { - c.mtx.Unlock() - return nil - } - c.is_shutdown = true - c.mtx.Unlock() - var errs utils.ErrorGroup - errs.Add(c.shutdownLoop()) - errs.Add(c.conn.Close()) - return errs.Finalize() + c.mtx.Lock() + if c.is_shutdown { + c.mtx.Unlock() + return nil + } + c.is_shutdown = true + c.mtx.Unlock() + var errs utils.ErrorGroup + errs.Add(c.shutdownLoop()) + errs.Add(c.conn.Close()) + return errs.Finalize() } func (c *Conn) read(b []byte) (int, func() error) { - c.mtx.Lock() - defer c.mtx.Unlock() - if c.is_shutdown { - return 0, func() error { return io.EOF } - } - runtime.LockOSThread() - defer runtime.UnlockOSThread() - rv, errno := C.SSL_read(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b))) - if rv > 0 { - return int(rv), nil - } - return 0, c.getErrorHandler(rv, errno) + c.mtx.Lock() + defer c.mtx.Unlock() + if c.is_shutdown { + return 0, func() error { return io.EOF } + } + runtime.LockOSThread() + defer runtime.UnlockOSThread() + rv, errno := C.SSL_read(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b))) + if rv > 0 { + return int(rv), nil + } + return 0, c.getErrorHandler(rv, errno) } // Read reads up to len(b) bytes into b. It returns the number of bytes read // and an error if applicable. io.EOF is returned when the caller can expect // to see no more data. func (c *Conn) Read(b []byte) (n int, err error) { - if len(b) == 0 { - return 0, nil - } - err = tryAgain - for err == tryAgain { - n, errcb := c.read(b) - err = c.handleError(errcb) - if err == nil { - go c.flushOutputBuffer() - return n, nil - } - if err == io.ErrUnexpectedEOF { - err = io.EOF - } - } - return 0, err + if len(b) == 0 { + return 0, nil + } + err = tryAgain + for err == tryAgain { + n, errcb := c.read(b) + err = c.handleError(errcb) + if err == nil { + go c.flushOutputBuffer() + return n, nil + } + if err == io.ErrUnexpectedEOF { + err = io.EOF + } + } + return 0, err } func (c *Conn) write(b []byte) (int, func() error) { - c.mtx.Lock() - defer c.mtx.Unlock() - if c.is_shutdown { - err := errors.New("connection closed") - return 0, func() error { return err } - } - runtime.LockOSThread() - defer runtime.UnlockOSThread() - rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b))) - if rv > 0 { - return int(rv), nil - } - return 0, c.getErrorHandler(rv, errno) + c.mtx.Lock() + defer c.mtx.Unlock() + if c.is_shutdown { + err := errors.New("connection closed") + return 0, func() error { return err } + } + runtime.LockOSThread() + defer runtime.UnlockOSThread() + rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b))) + if rv > 0 { + return int(rv), nil + } + return 0, c.getErrorHandler(rv, errno) } // Write will encrypt the contents of b and write it to the underlying stream. // Performance will be vastly improved if the size of b is a multiple of // SSLRecordSize. func (c *Conn) Write(b []byte) (written int, err error) { - if len(b) == 0 { - return 0, nil - } - err = tryAgain - for err == tryAgain { - n, errcb := c.write(b) - err = c.handleError(errcb) - if err == nil { - return n, c.flushOutputBuffer() - } - } - return 0, err + if len(b) == 0 { + return 0, nil + } + err = tryAgain + for err == tryAgain { + n, errcb := c.write(b) + err = c.handleError(errcb) + if err == nil { + return n, c.flushOutputBuffer() + } + } + return 0, err } // VerifyHostname pulls the PeerCertificate and calls VerifyHostname on the // certificate. func (c *Conn) VerifyHostname(host string) error { - cert, err := c.PeerCertificate() - if err != nil { - return err - } - return cert.VerifyHostname(host) + cert, err := c.PeerCertificate() + if err != nil { + return err + } + return cert.VerifyHostname(host) } // LocalAddr returns the underlying connection's local address func (c *Conn) LocalAddr() net.Addr { - return c.conn.LocalAddr() + return c.conn.LocalAddr() } // RemoteAddr returns the underlying connection's remote address func (c *Conn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() + return c.conn.RemoteAddr() } // SetDeadline calls SetDeadline on the underlying connection. func (c *Conn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) + return c.conn.SetDeadline(t) } // SetReadDeadline calls SetReadDeadline on the underlying connection. func (c *Conn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) + return c.conn.SetReadDeadline(t) } // SetWriteDeadline calls SetWriteDeadline on the underlying connection. func (c *Conn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return c.conn.SetWriteDeadline(t) } diff --git a/ctx.go b/ctx.go index 4528627..291a9ee 100644 --- a/ctx.go +++ b/ctx.go @@ -33,158 +33,158 @@ package openssl import "C" import ( - "errors" - "io/ioutil" - "runtime" - "unsafe" + "errors" + "io/ioutil" + "runtime" + "unsafe" ) type Ctx struct { - ctx *C.SSL_CTX + ctx *C.SSL_CTX } func newCtx(method *C.SSL_METHOD) (*Ctx, error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - ctx := C.SSL_CTX_new(method) - if ctx == nil { - return nil, errorFromErrorQueue() - } - c := &Ctx{ctx: ctx} - runtime.SetFinalizer(c, func(c *Ctx) { - C.SSL_CTX_free(c.ctx) - }) - return c, nil + runtime.LockOSThread() + defer runtime.UnlockOSThread() + ctx := C.SSL_CTX_new(method) + if ctx == nil { + return nil, errorFromErrorQueue() + } + c := &Ctx{ctx: ctx} + runtime.SetFinalizer(c, func(c *Ctx) { + C.SSL_CTX_free(c.ctx) + }) + return c, nil } type SSLVersion int const ( - SSLv3 SSLVersion = 0x02 - TLSv1 SSLVersion = 0x03 - TLSv1_1 SSLVersion = 0x04 - TLSv1_2 SSLVersion = 0x05 - AnyVersion SSLVersion = 0x06 + SSLv3 SSLVersion = 0x02 + TLSv1 SSLVersion = 0x03 + TLSv1_1 SSLVersion = 0x04 + TLSv1_2 SSLVersion = 0x05 + AnyVersion SSLVersion = 0x06 ) // NewCtxWithVersion creates an SSL context that is specific to the provided // SSL version. See http://www.openssl.org/docs/ssl/SSL_CTX_new.html for more. func NewCtxWithVersion(version SSLVersion) (*Ctx, error) { - var method *C.SSL_METHOD - switch version { - case SSLv3: - method = C.SSLv3_method() - case TLSv1: - method = C.TLSv1_method() - case TLSv1_1: - method = C.TLSv1_1_method() - case TLSv1_2: - method = C.TLSv1_2_method() - case AnyVersion: - method = C.SSLv23_method() - } - if method == nil { - return nil, errors.New("unknown ssl/tls version") - } - return newCtx(method) + var method *C.SSL_METHOD + switch version { + case SSLv3: + method = C.SSLv3_method() + case TLSv1: + method = C.TLSv1_method() + case TLSv1_1: + method = C.TLSv1_1_method() + case TLSv1_2: + method = C.TLSv1_2_method() + case AnyVersion: + method = C.SSLv23_method() + } + if method == nil { + return nil, errors.New("unknown ssl/tls version") + } + return newCtx(method) } // NewCtx creates a context that supports any TLS version 1.0 and newer. func NewCtx() (*Ctx, error) { - c, err := NewCtxWithVersion(AnyVersion) - if err == nil { - c.SetOptions(NoSSLv2 | NoSSLv3) - } - return c, err + c, err := NewCtxWithVersion(AnyVersion) + if err == nil { + c.SetOptions(NoSSLv2 | NoSSLv3) + } + return c, err } // NewCtxFromFiles calls NewCtx, loads the provided files, and configures the // context to use them. func NewCtxFromFiles(cert_file string, key_file string) (*Ctx, error) { - ctx, err := NewCtx() - if err != nil { - return nil, err - } + ctx, err := NewCtx() + if err != nil { + return nil, err + } - cert_bytes, err := ioutil.ReadFile(cert_file) - if err != nil { - return nil, err - } + cert_bytes, err := ioutil.ReadFile(cert_file) + if err != nil { + return nil, err + } - cert, err := LoadCertificate(cert_bytes) - if err != nil { - return nil, err - } + cert, err := LoadCertificate(cert_bytes) + if err != nil { + return nil, err + } - err = ctx.UseCertificate(cert) - if err != nil { - return nil, err - } + err = ctx.UseCertificate(cert) + if err != nil { + return nil, err + } - key_bytes, err := ioutil.ReadFile(key_file) - if err != nil { - return nil, err - } + key_bytes, err := ioutil.ReadFile(key_file) + if err != nil { + return nil, err + } - key, err := LoadPrivateKey(key_bytes) - if err != nil { - return nil, err - } + key, err := LoadPrivateKey(key_bytes) + if err != nil { + return nil, err + } - err = ctx.UsePrivateKey(key) - if err != nil { - return nil, err - } + err = ctx.UsePrivateKey(key) + if err != nil { + return nil, err + } - return ctx, nil + return ctx, nil } // UseCertificate configures the context to present the given certificate to // peers. func (c *Ctx) UseCertificate(cert *Certificate) error { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - if int(C.SSL_CTX_use_certificate(c.ctx, cert.x)) != 1 { - return errorFromErrorQueue() - } - return nil + runtime.LockOSThread() + defer runtime.UnlockOSThread() + if int(C.SSL_CTX_use_certificate(c.ctx, cert.x)) != 1 { + return errorFromErrorQueue() + } + return nil } // UsePrivateKey configures the context to use the given private key for SSL // handshakes. func (c *Ctx) UsePrivateKey(key PrivateKey) error { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - if int(C.SSL_CTX_use_PrivateKey(c.ctx, key.evpPKey())) != 1 { - return errorFromErrorQueue() - } - return nil + runtime.LockOSThread() + defer runtime.UnlockOSThread() + if int(C.SSL_CTX_use_PrivateKey(c.ctx, key.evpPKey())) != 1 { + return errorFromErrorQueue() + } + return nil } type CertificateStore struct { - store *C.X509_STORE - ctx *Ctx // for gc + store *C.X509_STORE + ctx *Ctx // for gc } // GetCertificateStore returns the context's certificate store that will be // used for peer validation. func (c *Ctx) GetCertificateStore() *CertificateStore { - // we don't need to dealloc the cert store pointer here, because it points - // to a ctx internal. so we do need to keep the ctx around - return &CertificateStore{ - store: C.SSL_CTX_get_cert_store(c.ctx), - ctx: c} + // we don't need to dealloc the cert store pointer here, because it points + // to a ctx internal. so we do need to keep the ctx around + return &CertificateStore{ + store: C.SSL_CTX_get_cert_store(c.ctx), + ctx: c} } // AddCertificate marks the provided Certificate as a trusted certificate in // the given CertificateStore. func (s *CertificateStore) AddCertificate(cert *Certificate) error { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - if int(C.X509_STORE_add_cert(s.store, cert.x)) != 1 { - return errorFromErrorQueue() - } - return nil + runtime.LockOSThread() + defer runtime.UnlockOSThread() + if int(C.X509_STORE_add_cert(s.store, cert.x)) != 1 { + return errorFromErrorQueue() + } + return nil } // LoadVerifyLocations tells the context to trust all certificate authorities @@ -192,120 +192,120 @@ func (s *CertificateStore) AddCertificate(cert *Certificate) error { // See http://www.openssl.org/docs/ssl/SSL_CTX_load_verify_locations.html for // more. func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - var c_ca_file, c_ca_path *C.char - if ca_file != "" { - c_ca_file = C.CString(ca_file) - defer C.free(unsafe.Pointer(c_ca_file)) - } - if ca_path != "" { - c_ca_path = C.CString(ca_path) - defer C.free(unsafe.Pointer(c_ca_path)) - } - if C.SSL_CTX_load_verify_locations(c.ctx, c_ca_file, c_ca_path) != 1 { - return errorFromErrorQueue() - } - return nil + runtime.LockOSThread() + defer runtime.UnlockOSThread() + var c_ca_file, c_ca_path *C.char + if ca_file != "" { + c_ca_file = C.CString(ca_file) + defer C.free(unsafe.Pointer(c_ca_file)) + } + if ca_path != "" { + c_ca_path = C.CString(ca_path) + defer C.free(unsafe.Pointer(c_ca_path)) + } + if C.SSL_CTX_load_verify_locations(c.ctx, c_ca_file, c_ca_path) != 1 { + return errorFromErrorQueue() + } + return nil } type Options int const ( - // NoCompression is only valid if you are using OpenSSL 1.0.1 or newer - NoCompression Options = C.SSL_OP_NO_COMPRESSION - NoSSLv2 Options = C.SSL_OP_NO_SSLv2 - NoSSLv3 Options = C.SSL_OP_NO_SSLv3 - NoTLSv1 Options = C.SSL_OP_NO_TLSv1 - CipherServerPreference Options = C.SSL_OP_CIPHER_SERVER_PREFERENCE - NoSessionResumptionOrRenegotiation Options = C.SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION - NoTicket Options = C.SSL_OP_NO_TICKET + // NoCompression is only valid if you are using OpenSSL 1.0.1 or newer + NoCompression Options = C.SSL_OP_NO_COMPRESSION + NoSSLv2 Options = C.SSL_OP_NO_SSLv2 + NoSSLv3 Options = C.SSL_OP_NO_SSLv3 + NoTLSv1 Options = C.SSL_OP_NO_TLSv1 + CipherServerPreference Options = C.SSL_OP_CIPHER_SERVER_PREFERENCE + NoSessionResumptionOrRenegotiation Options = C.SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION + NoTicket Options = C.SSL_OP_NO_TICKET ) // SetOptions sets context options. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_options.html func (c *Ctx) SetOptions(options Options) Options { - return Options(C.SSL_CTX_set_options_not_a_macro( - c.ctx, C.long(options))) + return Options(C.SSL_CTX_set_options_not_a_macro( + c.ctx, C.long(options))) } type Modes int const ( - // ReleaseBuffers is only valid if you are using OpenSSL 1.0.1 or newer - ReleaseBuffers Modes = C.SSL_MODE_RELEASE_BUFFERS + // ReleaseBuffers is only valid if you are using OpenSSL 1.0.1 or newer + ReleaseBuffers Modes = C.SSL_MODE_RELEASE_BUFFERS ) // SetMode sets context modes. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_mode.html func (c *Ctx) SetMode(modes Modes) Modes { - return Modes(C.SSL_CTX_set_mode_not_a_macro(c.ctx, C.long(modes))) + return Modes(C.SSL_CTX_set_mode_not_a_macro(c.ctx, C.long(modes))) } type VerifyOptions int const ( - VerifyNone VerifyOptions = C.SSL_VERIFY_NONE - VerifyPeer VerifyOptions = C.SSL_VERIFY_PEER - VerifyFailIfNoPeerCert VerifyOptions = C.SSL_VERIFY_FAIL_IF_NO_PEER_CERT - VerifyClientOnce VerifyOptions = C.SSL_VERIFY_CLIENT_ONCE + VerifyNone VerifyOptions = C.SSL_VERIFY_NONE + VerifyPeer VerifyOptions = C.SSL_VERIFY_PEER + VerifyFailIfNoPeerCert VerifyOptions = C.SSL_VERIFY_FAIL_IF_NO_PEER_CERT + VerifyClientOnce VerifyOptions = C.SSL_VERIFY_CLIENT_ONCE ) // SetVerify controls peer verification settings. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html func (c *Ctx) SetVerify(options VerifyOptions) { - // TODO: take a callback - C.SSL_CTX_set_verify(c.ctx, C.int(options), nil) + // TODO: take a callback + C.SSL_CTX_set_verify(c.ctx, C.int(options), nil) } // SetVerifyDepth controls how many certificates deep the certificate // verification logic is willing to follow a certificate chain. See // https://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html func (c *Ctx) SetVerifyDepth(depth int) { - C.SSL_CTX_set_verify_depth(c.ctx, C.int(depth)) + C.SSL_CTX_set_verify_depth(c.ctx, C.int(depth)) } func (c *Ctx) SetSessionId(session_id []byte) error { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - if int(C.SSL_CTX_set_session_id_context(c.ctx, - (*C.uchar)(unsafe.Pointer(&session_id[0])), - C.uint(len(session_id)))) == 0 { - return errorFromErrorQueue() - } - return nil + runtime.LockOSThread() + defer runtime.UnlockOSThread() + if int(C.SSL_CTX_set_session_id_context(c.ctx, + (*C.uchar)(unsafe.Pointer(&session_id[0])), + C.uint(len(session_id)))) == 0 { + return errorFromErrorQueue() + } + return nil } // SetCipherList sets the list of available ciphers. The format of the list is // described at http://www.openssl.org/docs/apps/ciphers.html, but see // http://www.openssl.org/docs/ssl/SSL_CTX_set_cipher_list.html for more. func (c *Ctx) SetCipherList(list string) error { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - clist := C.CString(list) - defer C.free(unsafe.Pointer(clist)) - if int(C.SSL_CTX_set_cipher_list(c.ctx, clist)) == 0 { - return errorFromErrorQueue() - } - return nil + runtime.LockOSThread() + defer runtime.UnlockOSThread() + clist := C.CString(list) + defer C.free(unsafe.Pointer(clist)) + if int(C.SSL_CTX_set_cipher_list(c.ctx, clist)) == 0 { + return errorFromErrorQueue() + } + return nil } type SessionCacheModes int const ( - SessionCacheOff SessionCacheModes = C.SSL_SESS_CACHE_OFF - SessionCacheClient SessionCacheModes = C.SSL_SESS_CACHE_CLIENT - SessionCacheServer SessionCacheModes = C.SSL_SESS_CACHE_SERVER - SessionCacheBoth SessionCacheModes = C.SSL_SESS_CACHE_BOTH - NoAutoClear SessionCacheModes = C.SSL_SESS_CACHE_NO_AUTO_CLEAR - NoInternalLookup SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_LOOKUP - NoInternalStore SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_STORE - NoInternal SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL + SessionCacheOff SessionCacheModes = C.SSL_SESS_CACHE_OFF + SessionCacheClient SessionCacheModes = C.SSL_SESS_CACHE_CLIENT + SessionCacheServer SessionCacheModes = C.SSL_SESS_CACHE_SERVER + SessionCacheBoth SessionCacheModes = C.SSL_SESS_CACHE_BOTH + NoAutoClear SessionCacheModes = C.SSL_SESS_CACHE_NO_AUTO_CLEAR + NoInternalLookup SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_LOOKUP + NoInternalStore SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_STORE + NoInternal SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL ) // SetSessionCacheMode enables or disables session caching. See // http://www.openssl.org/docs/ssl/SSL_CTX_set_session_cache_mode.html func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes { - return SessionCacheModes( - C.SSL_CTX_set_session_cache_mode_not_a_macro(c.ctx, C.long(modes))) + return SessionCacheModes( + C.SSL_CTX_set_session_cache_mode_not_a_macro(c.ctx, C.long(modes))) } diff --git a/hostname.go b/hostname.go index 14d1820..37197a3 100644 --- a/hostname.go +++ b/hostname.go @@ -23,20 +23,20 @@ extern int X509_check_ip(X509 *x, const unsigned char *chk, size_t chklen, import "C" import ( - "errors" - "net" - "unsafe" + "errors" + "net" + "unsafe" ) var ( - ValidationError = errors.New("Host validation error") + ValidationError = errors.New("Host validation error") ) type CheckFlags int const ( - AlwaysCheckSubject CheckFlags = C.X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT - NoWildcards CheckFlags = C.X509_CHECK_FLAG_NO_WILDCARDS + AlwaysCheckSubject CheckFlags = C.X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT + NoWildcards CheckFlags = C.X509_CHECK_FLAG_NO_WILDCARDS ) // CheckHost checks that the X509 certificate is signed for the provided @@ -45,17 +45,17 @@ const ( // Specifically returns ValidationError if the Certificate didn't match but // there was no internal error. func (c *Certificate) CheckHost(host string, flags CheckFlags) error { - chost := unsafe.Pointer(C.CString(host)) - defer C.free(chost) - rv := C.X509_check_host(c.x, (*C.uchar)(chost), C.size_t(len(host)), - C.uint(flags)) - if rv > 0 { - return nil - } - if rv == 0 { - return ValidationError - } - return errors.New("hostname validation had an internal failure") + chost := unsafe.Pointer(C.CString(host)) + defer C.free(chost) + rv := C.X509_check_host(c.x, (*C.uchar)(chost), C.size_t(len(host)), + C.uint(flags)) + if rv > 0 { + return nil + } + if rv == 0 { + return ValidationError + } + return errors.New("hostname validation had an internal failure") } // CheckEmail checks that the X509 certificate is signed for the provided @@ -64,17 +64,17 @@ func (c *Certificate) CheckHost(host string, flags CheckFlags) error { // Specifically returns ValidationError if the Certificate didn't match but // there was no internal error. func (c *Certificate) CheckEmail(email string, flags CheckFlags) error { - cemail := unsafe.Pointer(C.CString(email)) - defer C.free(cemail) - rv := C.X509_check_email(c.x, (*C.uchar)(cemail), C.size_t(len(email)), - C.uint(flags)) - if rv > 0 { - return nil - } - if rv == 0 { - return ValidationError - } - return errors.New("email validation had an internal failure") + cemail := unsafe.Pointer(C.CString(email)) + defer C.free(cemail) + rv := C.X509_check_email(c.x, (*C.uchar)(cemail), C.size_t(len(email)), + C.uint(flags)) + if rv > 0 { + return nil + } + if rv == 0 { + return ValidationError + } + return errors.New("email validation had an internal failure") } // CheckIP checks that the X509 certificate is signed for the provided @@ -83,16 +83,16 @@ func (c *Certificate) CheckEmail(email string, flags CheckFlags) error { // Specifically returns ValidationError if the Certificate didn't match but // there was no internal error. func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error { - cip := unsafe.Pointer(&ip[0]) - rv := C.X509_check_ip(c.x, (*C.uchar)(cip), C.size_t(len(ip)), - C.uint(flags)) - if rv > 0 { - return nil - } - if rv == 0 { - return ValidationError - } - return errors.New("ip validation had an internal failure") + cip := unsafe.Pointer(&ip[0]) + rv := C.X509_check_ip(c.x, (*C.uchar)(cip), C.size_t(len(ip)), + C.uint(flags)) + if rv > 0 { + return nil + } + if rv == 0 { + return ValidationError + } + return errors.New("ip validation had an internal failure") } // VerifyHostname is a combination of CheckHost and CheckIP. If the provided @@ -101,14 +101,14 @@ func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error { // Specifically returns ValidationError if the Certificate didn't match but // there was no internal error. func (c *Certificate) VerifyHostname(host string) error { - var ip net.IP - if len(host) >= 3 && host[0] == '[' && host[len(host)-1] == ']' { - ip = net.ParseIP(host[1 : len(host)-1]) - } else { - ip = net.ParseIP(host) - } - if ip != nil { - return c.CheckIP(ip, 0) - } - return c.CheckHost(host, 0) + var ip net.IP + if len(host) >= 3 && host[0] == '[' && host[len(host)-1] == ']' { + ip = net.ParseIP(host[1 : len(host)-1]) + } else { + ip = net.ParseIP(host) + } + if ip != nil { + return c.CheckIP(ip, 0) + } + return c.CheckHost(host, 0) } diff --git a/http.go b/http.go index 66361a2..5fd79d4 100644 --- a/http.go +++ b/http.go @@ -3,37 +3,37 @@ package openssl import ( - "net/http" + "net/http" ) // ListenAndServeTLS will take an http.Handler and serve it using OpenSSL over // the given tcp address, configured to use the provided cert and key files. func ListenAndServeTLS(addr string, cert_file string, key_file string, - handler http.Handler) error { - return ServerListenAndServeTLS( - &http.Server{Addr: addr, Handler: handler}, cert_file, key_file) + handler http.Handler) error { + return ServerListenAndServeTLS( + &http.Server{Addr: addr, Handler: handler}, cert_file, key_file) } // ServerListenAndServeTLS will take an http.Server and serve it using OpenSSL // configured to use the provided cert and key files. func ServerListenAndServeTLS(srv *http.Server, - cert_file, key_file string) error { - addr := srv.Addr - if addr == "" { - addr = ":https" - } + cert_file, key_file string) error { + addr := srv.Addr + if addr == "" { + addr = ":https" + } - ctx, err := NewCtxFromFiles(cert_file, key_file) - if err != nil { - return err - } + ctx, err := NewCtxFromFiles(cert_file, key_file) + if err != nil { + return err + } - l, err := Listen("tcp", addr, ctx) - if err != nil { - return err - } + l, err := Listen("tcp", addr, ctx) + if err != nil { + return err + } - return srv.Serve(l) + return srv.Serve(l) } // TODO: http client integration diff --git a/init.go b/init.go index 2f38652..b1a01ec 100644 --- a/init.go +++ b/init.go @@ -68,58 +68,58 @@ static void OpenSSL_add_all_algorithms_not_a_macro() { import "C" import ( - "errors" - "fmt" - "strings" - "sync" + "errors" + "fmt" + "strings" + "sync" - "code.spacemonkey.com/go/openssl/utils" + "code.spacemonkey.com/go/openssl/utils" ) var ( - sslMutexes []sync.Mutex + sslMutexes []sync.Mutex ) func init() { - C.OPENSSL_config(nil) - C.ENGINE_load_builtin_engines() - C.SSL_load_error_strings() - C.SSL_library_init() - C.OpenSSL_add_all_algorithms_not_a_macro() - sslMutexes = make([]sync.Mutex, int(C.CRYPTO_num_locks())) - C.CRYPTO_set_id_callback((*[0]byte)(C.sslThreadId)) - C.CRYPTO_set_locking_callback((*[0]byte)(C.sslMutexOp)) + C.OPENSSL_config(nil) + C.ENGINE_load_builtin_engines() + C.SSL_load_error_strings() + C.SSL_library_init() + C.OpenSSL_add_all_algorithms_not_a_macro() + sslMutexes = make([]sync.Mutex, int(C.CRYPTO_num_locks())) + C.CRYPTO_set_id_callback((*[0]byte)(C.sslThreadId)) + C.CRYPTO_set_locking_callback((*[0]byte)(C.sslMutexOp)) - // TODO: support dynlock callbacks + // TODO: support dynlock callbacks } // errorFromErrorQueue needs to run in the same OS thread as the operation // that caused the possible error func errorFromErrorQueue() error { - var errs []string - for { - err := C.ERR_get_error() - if err == 0 { - break - } - errs = append(errs, fmt.Sprintf("%s:%s:%s", - C.GoString(C.ERR_lib_error_string(err)), - C.GoString(C.ERR_func_error_string(err)), - C.GoString(C.ERR_reason_error_string(err)))) - } - return errors.New(fmt.Sprintf("SSL errors: %s", strings.Join(errs, "\n"))) + var errs []string + for { + err := C.ERR_get_error() + if err == 0 { + break + } + errs = append(errs, fmt.Sprintf("%s:%s:%s", + C.GoString(C.ERR_lib_error_string(err)), + C.GoString(C.ERR_func_error_string(err)), + C.GoString(C.ERR_reason_error_string(err)))) + } + return errors.New(fmt.Sprintf("SSL errors: %s", strings.Join(errs, "\n"))) } //export sslMutexOp func sslMutexOp(mode, n C.int, file *C.char, line C.int) { - if mode&C.CRYPTO_LOCK > 0 { - sslMutexes[n].Lock() - } else { - sslMutexes[n].Unlock() - } + if mode&C.CRYPTO_LOCK > 0 { + sslMutexes[n].Lock() + } else { + sslMutexes[n].Unlock() + } } //export sslThreadId func sslThreadId() C.ulong { - return C.ulong(uintptr(utils.ThreadId())) + return C.ulong(uintptr(utils.ThreadId())) } diff --git a/net.go b/net.go index 5cc6ef6..ce538b1 100644 --- a/net.go +++ b/net.go @@ -3,49 +3,49 @@ package openssl import ( - "errors" - "net" + "errors" + "net" ) type listener struct { - net.Listener - ctx *Ctx + net.Listener + ctx *Ctx } func (l *listener) Accept() (c net.Conn, err error) { - c, err = l.Listener.Accept() - if err != nil { - return nil, err - } - return Server(c, l.ctx) + c, err = l.Listener.Accept() + if err != nil { + return nil, err + } + return Server(c, l.ctx) } // NewListener wraps an existing net.Listener such that all accepted // connections are wrapped as OpenSSL server connections using the provided // context ctx. func NewListener(inner net.Listener, ctx *Ctx) net.Listener { - return &listener{ - Listener: inner, - ctx: ctx} + return &listener{ + Listener: inner, + ctx: ctx} } // Listen is a wrapper around net.Listen that wraps incoming connections with // an OpenSSL server connection using the provided context ctx. func Listen(network, laddr string, ctx *Ctx) (net.Listener, error) { - if ctx == nil { - return nil, errors.New("no ssl context provided") - } - l, err := net.Listen(network, laddr) - if err != nil { - return nil, err - } - return NewListener(l, ctx), nil + if ctx == nil { + return nil, errors.New("no ssl context provided") + } + l, err := net.Listen(network, laddr) + if err != nil { + return nil, err + } + return NewListener(l, ctx), nil } type DialFlags int const ( - InsecureSkipHostVerification DialFlags = 0x01 + InsecureSkipHostVerification DialFlags = 0x01 ) // Dial will connect to network/address and then wrap the corresponding @@ -59,39 +59,39 @@ const ( // This library is not nice enough to use the system certificate store by // default for you yet. func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) { - if ctx == nil { - var err error - ctx, err = NewCtx() - if err != nil { - return nil, err - } - // TODO: use operating system default certificate chain? - } - c, err := net.Dial(network, addr) - if err != nil { - return nil, err - } - conn, err := Client(c, ctx) - if err != nil { - c.Close() - return nil, err - } - err = conn.Handshake() - if err != nil { - c.Close() - return nil, err - } - if flags&InsecureSkipHostVerification == 0 { - host, _, err := net.SplitHostPort(addr) - if err != nil { - conn.Close() - return nil, err - } - err = conn.VerifyHostname(host) - if err != nil { - conn.Close() - return nil, err - } - } - return conn, nil + if ctx == nil { + var err error + ctx, err = NewCtx() + if err != nil { + return nil, err + } + // TODO: use operating system default certificate chain? + } + c, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + conn, err := Client(c, ctx) + if err != nil { + c.Close() + return nil, err + } + err = conn.Handshake() + if err != nil { + c.Close() + return nil, err + } + if flags&InsecureSkipHostVerification == 0 { + host, _, err := net.SplitHostPort(addr) + if err != nil { + conn.Close() + return nil, err + } + err = conn.VerifyHostname(host) + if err != nil { + conn.Close() + return nil, err + } + } + return conn, nil } diff --git a/oracle_stubs.go b/oracle_stubs.go index 126cc49..0285e89 100644 --- a/oracle_stubs.go +++ b/oracle_stubs.go @@ -4,13 +4,13 @@ package openssl import ( - "errors" - "net" - "time" + "errors" + "net" + "time" ) const ( - SSLRecordSize = 16 * 1024 + SSLRecordSize = 16 * 1024 ) type Conn struct{} @@ -37,11 +37,11 @@ type Ctx struct{} type SSLVersion int const ( - SSLv3 SSLVersion = 0x02 - TLSv1 SSLVersion = 0x03 - TLSv1_1 SSLVersion = 0x04 - TLSv1_2 SSLVersion = 0x05 - AnyVersion SSLVersion = 0x06 + SSLv3 SSLVersion = 0x02 + TLSv1 SSLVersion = 0x03 + TLSv1_1 SSLVersion = 0x04 + TLSv1_2 SSLVersion = 0x05 + AnyVersion SSLVersion = 0x06 ) func NewCtxWithVersion(version SSLVersion) (*Ctx, error) @@ -61,13 +61,13 @@ func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error type Options int const ( - NoCompression Options = 0 - NoSSLv2 Options = 0 - NoSSLv3 Options = 0 - NoTLSv1 Options = 0 - CipherServerPreference Options = 0 - NoSessionResumptionOrRenegotiation Options = 0 - NoTicket Options = 0 + NoCompression Options = 0 + NoSSLv2 Options = 0 + NoSSLv3 Options = 0 + NoTLSv1 Options = 0 + CipherServerPreference Options = 0 + NoSessionResumptionOrRenegotiation Options = 0 + NoTicket Options = 0 ) func (c *Ctx) SetOptions(options Options) Options @@ -75,7 +75,7 @@ func (c *Ctx) SetOptions(options Options) Options type Modes int const ( - ReleaseBuffers Modes = 0 + ReleaseBuffers Modes = 0 ) func (c *Ctx) SetMode(modes Modes) Modes @@ -83,10 +83,10 @@ func (c *Ctx) SetMode(modes Modes) Modes type VerifyOptions int const ( - VerifyNone VerifyOptions = 0 - VerifyPeer VerifyOptions = 0 - VerifyFailIfNoPeerCert VerifyOptions = 0 - VerifyClientOnce VerifyOptions = 0 + VerifyNone VerifyOptions = 0 + VerifyPeer VerifyOptions = 0 + VerifyFailIfNoPeerCert VerifyOptions = 0 + VerifyClientOnce VerifyOptions = 0 ) func (c *Ctx) SetVerify(options VerifyOptions) @@ -98,27 +98,27 @@ func (c *Ctx) SetCipherList(list string) error type SessionCacheModes int const ( - SessionCacheOff SessionCacheModes = 0 - SessionCacheClient SessionCacheModes = 0 - SessionCacheServer SessionCacheModes = 0 - SessionCacheBoth SessionCacheModes = 0 - NoAutoClear SessionCacheModes = 0 - NoInternalLookup SessionCacheModes = 0 - NoInternalStore SessionCacheModes = 0 - NoInternal SessionCacheModes = 0 + SessionCacheOff SessionCacheModes = 0 + SessionCacheClient SessionCacheModes = 0 + SessionCacheServer SessionCacheModes = 0 + SessionCacheBoth SessionCacheModes = 0 + NoAutoClear SessionCacheModes = 0 + NoInternalLookup SessionCacheModes = 0 + NoInternalStore SessionCacheModes = 0 + NoInternal SessionCacheModes = 0 ) func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes var ( - ValidationError = errors.New("Host validation error") + ValidationError = errors.New("Host validation error") ) type CheckFlags int const ( - AlwaysCheckSubject CheckFlags = 0 - NoWildcards CheckFlags = 0 + AlwaysCheckSubject CheckFlags = 0 + NoWildcards CheckFlags = 0 ) func (c *Certificate) CheckHost(host string, flags CheckFlags) error @@ -127,15 +127,15 @@ func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error func (c *Certificate) VerifyHostname(host string) error type PublicKey interface { - MarshalPKIXPublicKeyPEM() (pem_block []byte, err error) - MarshalPKIXPublicKeyDER() (der_block []byte, err error) - evpPKey() struct{} + MarshalPKIXPublicKeyPEM() (pem_block []byte, err error) + MarshalPKIXPublicKeyDER() (der_block []byte, err error) + evpPKey() struct{} } type PrivateKey interface { - PublicKey - MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error) - MarshalPKCS1PrivateKeyDER() (der_block []byte, err error) + PublicKey + MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error) + MarshalPKCS1PrivateKeyDER() (der_block []byte, err error) } func LoadPrivateKey(pem_block []byte) (PrivateKey, error) diff --git a/pem.go b/pem.go index 5db8644..3a2893a 100644 --- a/pem.go +++ b/pem.go @@ -8,191 +8,191 @@ package openssl import "C" import ( - "errors" - "io/ioutil" - "runtime" - "unsafe" + "errors" + "io/ioutil" + "runtime" + "unsafe" ) type PublicKey interface { - // MarshalPKIXPublicKeyPEM converts the public key to PEM-encoded PKIX - // format - MarshalPKIXPublicKeyPEM() (pem_block []byte, err error) + // MarshalPKIXPublicKeyPEM converts the public key to PEM-encoded PKIX + // format + MarshalPKIXPublicKeyPEM() (pem_block []byte, err error) - // MarshalPKIXPublicKeyDER converts the public key to DER-encoded PKIX - // format - MarshalPKIXPublicKeyDER() (der_block []byte, err error) + // MarshalPKIXPublicKeyDER converts the public key to DER-encoded PKIX + // format + MarshalPKIXPublicKeyDER() (der_block []byte, err error) - evpPKey() *C.EVP_PKEY + evpPKey() *C.EVP_PKEY } type PrivateKey interface { - PublicKey + PublicKey - // MarshalPKCS1PrivateKeyPEM converts the private key to PEM-encoded PKCS1 - // format - MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error) + // MarshalPKCS1PrivateKeyPEM converts the private key to PEM-encoded PKCS1 + // format + MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error) - // MarshalPKCS1PrivateKeyDER converts the private key to DER-encoded PKCS1 - // format - MarshalPKCS1PrivateKeyDER() (der_block []byte, err error) + // MarshalPKCS1PrivateKeyDER converts the private key to DER-encoded PKCS1 + // format + MarshalPKCS1PrivateKeyDER() (der_block []byte, err error) } type pKey struct { - key *C.EVP_PKEY + key *C.EVP_PKEY } func (key *pKey) evpPKey() *C.EVP_PKEY { return key.key } func (key *pKey) MarshalPKCS1PrivateKeyPEM() (pem_block []byte, - err error) { - bio := C.BIO_new(C.BIO_s_mem()) - if bio == nil { - return nil, errors.New("failed to allocate memory BIO") - } - defer C.BIO_free(bio) - rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) - if rsa == nil { - return nil, errors.New("failed getting rsa key") - } - defer C.RSA_free(rsa) - if int(C.PEM_write_bio_RSAPrivateKey(bio, rsa, nil, nil, C.int(0), nil, - nil)) != 1 { - return nil, errors.New("failed dumping private key") - } - return ioutil.ReadAll(asAnyBio(bio)) + err error) { + bio := C.BIO_new(C.BIO_s_mem()) + if bio == nil { + return nil, errors.New("failed to allocate memory BIO") + } + defer C.BIO_free(bio) + rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) + if rsa == nil { + return nil, errors.New("failed getting rsa key") + } + defer C.RSA_free(rsa) + if int(C.PEM_write_bio_RSAPrivateKey(bio, rsa, nil, nil, C.int(0), nil, + nil)) != 1 { + return nil, errors.New("failed dumping private key") + } + return ioutil.ReadAll(asAnyBio(bio)) } func (key *pKey) MarshalPKCS1PrivateKeyDER() (der_block []byte, - err error) { - bio := C.BIO_new(C.BIO_s_mem()) - if bio == nil { - return nil, errors.New("failed to allocate memory BIO") - } - defer C.BIO_free(bio) - rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) - if rsa == nil { - return nil, errors.New("failed getting rsa key") - } - defer C.RSA_free(rsa) - if int(C.i2d_RSAPrivateKey_bio(bio, rsa)) != 1 { - return nil, errors.New("failed dumping private key der") - } - return ioutil.ReadAll(asAnyBio(bio)) + err error) { + bio := C.BIO_new(C.BIO_s_mem()) + if bio == nil { + return nil, errors.New("failed to allocate memory BIO") + } + defer C.BIO_free(bio) + rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) + if rsa == nil { + return nil, errors.New("failed getting rsa key") + } + defer C.RSA_free(rsa) + if int(C.i2d_RSAPrivateKey_bio(bio, rsa)) != 1 { + return nil, errors.New("failed dumping private key der") + } + return ioutil.ReadAll(asAnyBio(bio)) } func (key *pKey) MarshalPKIXPublicKeyPEM() (pem_block []byte, - err error) { - bio := C.BIO_new(C.BIO_s_mem()) - if bio == nil { - return nil, errors.New("failed to allocate memory BIO") - } - defer C.BIO_free(bio) - rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) - if rsa == nil { - return nil, errors.New("failed getting rsa key") - } - defer C.RSA_free(rsa) - if int(C.PEM_write_bio_RSA_PUBKEY(bio, rsa)) != 1 { - return nil, errors.New("failed dumping public key pem") - } - return ioutil.ReadAll(asAnyBio(bio)) + err error) { + bio := C.BIO_new(C.BIO_s_mem()) + if bio == nil { + return nil, errors.New("failed to allocate memory BIO") + } + defer C.BIO_free(bio) + rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) + if rsa == nil { + return nil, errors.New("failed getting rsa key") + } + defer C.RSA_free(rsa) + if int(C.PEM_write_bio_RSA_PUBKEY(bio, rsa)) != 1 { + return nil, errors.New("failed dumping public key pem") + } + return ioutil.ReadAll(asAnyBio(bio)) } func (key *pKey) MarshalPKIXPublicKeyDER() (der_block []byte, - err error) { - bio := C.BIO_new(C.BIO_s_mem()) - if bio == nil { - return nil, errors.New("failed to allocate memory BIO") - } - defer C.BIO_free(bio) - rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) - if rsa == nil { - return nil, errors.New("failed getting rsa key") - } - defer C.RSA_free(rsa) - if int(C.i2d_RSA_PUBKEY_bio(bio, rsa)) != 1 { - return nil, errors.New("failed dumping public key der") - } - return ioutil.ReadAll(asAnyBio(bio)) + err error) { + bio := C.BIO_new(C.BIO_s_mem()) + if bio == nil { + return nil, errors.New("failed to allocate memory BIO") + } + defer C.BIO_free(bio) + rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) + if rsa == nil { + return nil, errors.New("failed getting rsa key") + } + defer C.RSA_free(rsa) + if int(C.i2d_RSA_PUBKEY_bio(bio, rsa)) != 1 { + return nil, errors.New("failed dumping public key der") + } + return ioutil.ReadAll(asAnyBio(bio)) } // LoadPrivateKey loads a private key from a PEM-encoded block. func LoadPrivateKey(pem_block []byte) (PrivateKey, error) { - bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]), - C.int(len(pem_block))) - if bio == nil { - return nil, errors.New("failed creating bio") - } - defer C.BIO_free(bio) + bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]), + C.int(len(pem_block))) + if bio == nil { + return nil, errors.New("failed creating bio") + } + defer C.BIO_free(bio) - rsakey := C.PEM_read_bio_RSAPrivateKey(bio, nil, nil, nil) - if rsakey == nil { - return nil, errors.New("failed reading rsa key") - } - defer C.RSA_free(rsakey) + rsakey := C.PEM_read_bio_RSAPrivateKey(bio, nil, nil, nil) + if rsakey == nil { + return nil, errors.New("failed reading rsa key") + } + defer C.RSA_free(rsakey) - // convert to PKEY - key := C.EVP_PKEY_new() - if key == nil { - return nil, errors.New("failed converting to evp_pkey") - } - if C.EVP_PKEY_set1_RSA(key, (*C.struct_rsa_st)(rsakey)) != 1 { - C.EVP_PKEY_free(key) - return nil, errors.New("failed converting to evp_pkey") - } + // convert to PKEY + key := C.EVP_PKEY_new() + if key == nil { + return nil, errors.New("failed converting to evp_pkey") + } + if C.EVP_PKEY_set1_RSA(key, (*C.struct_rsa_st)(rsakey)) != 1 { + C.EVP_PKEY_free(key) + return nil, errors.New("failed converting to evp_pkey") + } - p := &pKey{key: key} - runtime.SetFinalizer(p, func(p *pKey) { - C.EVP_PKEY_free(p.key) - }) - return p, nil + p := &pKey{key: key} + runtime.SetFinalizer(p, func(p *pKey) { + C.EVP_PKEY_free(p.key) + }) + return p, nil } type Certificate struct { - x *C.X509 + x *C.X509 } // LoadCertificate loads an X509 certificate from a PEM-encoded block. func LoadCertificate(pem_block []byte) (*Certificate, error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]), - C.int(len(pem_block))) - cert := C.PEM_read_bio_X509(bio, nil, nil, nil) - C.BIO_free(bio) - if cert == nil { - return nil, errorFromErrorQueue() - } - x := &Certificate{x: cert} - runtime.SetFinalizer(x, func(x *Certificate) { - C.X509_free(x.x) - }) - return x, nil + runtime.LockOSThread() + defer runtime.UnlockOSThread() + bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]), + C.int(len(pem_block))) + cert := C.PEM_read_bio_X509(bio, nil, nil, nil) + C.BIO_free(bio) + if cert == nil { + return nil, errorFromErrorQueue() + } + x := &Certificate{x: cert} + runtime.SetFinalizer(x, func(x *Certificate) { + C.X509_free(x.x) + }) + return x, nil } // MarshalPEM converts the X509 certificate to PEM-encoded format func (c *Certificate) MarshalPEM() (pem_block []byte, err error) { - bio := C.BIO_new(C.BIO_s_mem()) - if bio == nil { - return nil, errors.New("failed to allocate memory BIO") - } - defer C.BIO_free(bio) - if int(C.PEM_write_bio_X509(bio, c.x)) != 1 { - return nil, errors.New("failed dumping certificate") - } - return ioutil.ReadAll(asAnyBio(bio)) + bio := C.BIO_new(C.BIO_s_mem()) + if bio == nil { + return nil, errors.New("failed to allocate memory BIO") + } + defer C.BIO_free(bio) + if int(C.PEM_write_bio_X509(bio, c.x)) != 1 { + return nil, errors.New("failed dumping certificate") + } + return ioutil.ReadAll(asAnyBio(bio)) } // PublicKey returns the public key embedded in the X509 certificate. func (c *Certificate) PublicKey() (PublicKey, error) { - pkey := C.X509_get_pubkey(c.x) - if pkey == nil { - return nil, errors.New("no public key found") - } - key := &pKey{key: pkey} - runtime.SetFinalizer(key, func(key *pKey) { - C.EVP_PKEY_free(key.key) - }) - return key, nil + pkey := C.X509_get_pubkey(c.x) + if pkey == nil { + return nil, errors.New("no public key found") + } + key := &pKey{key: pkey} + runtime.SetFinalizer(key, func(key *pKey) { + C.EVP_PKEY_free(key.key) + }) + return key, nil } diff --git a/pem_test.go b/pem_test.go index f29cbf1..513a672 100644 --- a/pem_test.go +++ b/pem_test.go @@ -3,74 +3,74 @@ package openssl import ( - "bytes" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "encoding/hex" - "io/ioutil" - "testing" + "bytes" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "io/ioutil" + "testing" ) func TestMarshal(t *testing.T) { - key, err := LoadPrivateKey(keyBytes) - if err != nil { - t.Fatal(err) - } - cert, err := LoadCertificate(certBytes) - if err != nil { - t.Fatal(err) - } + key, err := LoadPrivateKey(keyBytes) + if err != nil { + t.Fatal(err) + } + cert, err := LoadCertificate(certBytes) + if err != nil { + t.Fatal(err) + } - pem, err := cert.MarshalPEM() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(pem, certBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", certBytes, 0644) - t.Fatal("invalid cert pem bytes") - } + pem, err := cert.MarshalPEM() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(pem, certBytes) { + ioutil.WriteFile("generated", pem, 0644) + ioutil.WriteFile("hardcoded", certBytes, 0644) + t.Fatal("invalid cert pem bytes") + } - pem, err = key.MarshalPKCS1PrivateKeyPEM() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(pem, keyBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", keyBytes, 0644) - t.Fatal("invalid private key pem bytes") - } - tls_cert, err := tls.X509KeyPair(certBytes, keyBytes) - if err != nil { - t.Fatal(err) - } - tls_key, ok := tls_cert.PrivateKey.(*rsa.PrivateKey) - if !ok { - t.Fatal("FASDFASDF") - } - _ = tls_key + pem, err = key.MarshalPKCS1PrivateKeyPEM() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(pem, keyBytes) { + ioutil.WriteFile("generated", pem, 0644) + ioutil.WriteFile("hardcoded", keyBytes, 0644) + t.Fatal("invalid private key pem bytes") + } + tls_cert, err := tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + t.Fatal(err) + } + tls_key, ok := tls_cert.PrivateKey.(*rsa.PrivateKey) + if !ok { + t.Fatal("FASDFASDF") + } + _ = tls_key - der, err := key.MarshalPKCS1PrivateKeyDER() - if err != nil { - t.Fatal(err) - } - tls_der := x509.MarshalPKCS1PrivateKey(tls_key) - if !bytes.Equal(der, tls_der) { - t.Fatal("invalid private key der bytes: %s\n v.s. %s\n", hex.Dump(der), hex.Dump(tls_der)) - } + der, err := key.MarshalPKCS1PrivateKeyDER() + if err != nil { + t.Fatal(err) + } + tls_der := x509.MarshalPKCS1PrivateKey(tls_key) + if !bytes.Equal(der, tls_der) { + t.Fatal("invalid private key der bytes: %s\n v.s. %s\n", hex.Dump(der), hex.Dump(tls_der)) + } - der, err = key.MarshalPKIXPublicKeyDER() - if err != nil { - t.Fatal(err) - } - tls_der, err = x509.MarshalPKIXPublicKey(&tls_key.PublicKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(der, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(der)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) - t.Fatal("invalid public key der bytes") - } + der, err = key.MarshalPKIXPublicKeyDER() + if err != nil { + t.Fatal(err) + } + tls_der, err = x509.MarshalPKIXPublicKey(&tls_key.PublicKey) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(der, tls_der) { + ioutil.WriteFile("generated", []byte(hex.Dump(der)), 0644) + ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + t.Fatal("invalid public key der bytes") + } } diff --git a/ssl_test.go b/ssl_test.go index 1d716d8..725ec26 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -3,21 +3,21 @@ package openssl import ( - "bytes" - "crypto/rand" - "crypto/tls" - "io" - "io/ioutil" - "net" - "sync" - "testing" - "time" + "bytes" + "crypto/rand" + "crypto/tls" + "io" + "io/ioutil" + "net" + "sync" + "testing" + "time" - "code.spacemonkey.com/go/openssl/utils" + "code.spacemonkey.com/go/openssl/utils" ) var ( - certBytes = []byte(`-----BEGIN CERTIFICATE----- + certBytes = []byte(`-----BEGIN CERTIFICATE----- MIIDxDCCAqygAwIBAgIVAMcK/0VWQr2O3MNfJCydqR7oVELcMA0GCSqGSIb3DQEB BQUAMIGQMUkwRwYDVQQDE0A1NjdjZGRmYzRjOWZiNTYwZTk1M2ZlZjA1N2M0NGFm MDdiYjc4MDIzODIxYTA5NThiY2RmMGMwNzJhOTdiMThhMQswCQYDVQQGEwJVUzEN @@ -41,7 +41,7 @@ sRkg/uxcJf7wC5Y0BLlp1+aPwdmZD87T3a1uQ1Ij93jmHG+2T9U20MklHAePOl0q yTqdSPnSH1c= -----END CERTIFICATE----- `) - keyBytes = []byte(`-----BEGIN RSA PRIVATE KEY----- + keyBytes = []byte(`-----BEGIN RSA PRIVATE KEY----- MIIEpQIBAAKCAQEA3X94nDbxbK5a5zS4vEqHLHKpUmxavqRL5oXEqKoAy6nm56rv C3e9xySe+DBlxIEV/MWU+RYpzjC99QkerfRP493aleqfhn3ZRS3tyKrQtP2z1Zwg wYqwcoASOLgqzKvtVYQMT1nJaw6O5fUEWG7BMR/ZX5/kcr8XjTGYjgEmrL1WTZ3G @@ -72,534 +72,534 @@ qmgvgyRayemfO2zR0CPgC6wSoGBth+xW6g+WA8y0z76ZSaWpFi8lVM4= ) func NetPipe(t testing.TB) (net.Conn, net.Conn) { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatal(err) - } - defer l.Close() - client_future := utils.NewFuture() - go func() { - client_future.Set(net.Dial(l.Addr().Network(), l.Addr().String())) - }() - var errs utils.ErrorGroup - server_conn, err := l.Accept() - errs.Add(err) - client_conn, err := client_future.Get() - errs.Add(err) - err = errs.Finalize() - if err != nil { - if server_conn != nil { - server_conn.Close() - } - if client_conn != nil { - client_conn.(net.Conn).Close() - } - t.Fatal(err) - } - return server_conn, client_conn.(net.Conn) + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + client_future := utils.NewFuture() + go func() { + client_future.Set(net.Dial(l.Addr().Network(), l.Addr().String())) + }() + var errs utils.ErrorGroup + server_conn, err := l.Accept() + errs.Add(err) + client_conn, err := client_future.Get() + errs.Add(err) + err = errs.Finalize() + if err != nil { + if server_conn != nil { + server_conn.Close() + } + if client_conn != nil { + client_conn.(net.Conn).Close() + } + t.Fatal(err) + } + return server_conn, client_conn.(net.Conn) } type HandshakingConn interface { - net.Conn - Handshake() error + net.Conn + Handshake() error } func SimpleConnTest(t testing.TB, constructor func( - t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { - server_conn, client_conn := NetPipe(t) - defer server_conn.Close() - defer client_conn.Close() + t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { + server_conn, client_conn := NetPipe(t) + defer server_conn.Close() + defer client_conn.Close() - data := "first test string\n" + data := "first test string\n" - server, client := constructor(t, server_conn, client_conn) - defer close_both(server, client) + server, client := constructor(t, server_conn, client_conn) + defer close_both(server, client) - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() - err := client.Handshake() - if err != nil { - t.Fatal(err) - } + err := client.Handshake() + if err != nil { + t.Fatal(err) + } - _, err = io.Copy(client, bytes.NewReader([]byte(data))) - if err != nil { - t.Fatal(err) - } + _, err = io.Copy(client, bytes.NewReader([]byte(data))) + if err != nil { + t.Fatal(err) + } - err = client.Close() - if err != nil { - t.Fatal(err) - } - }() - go func() { - defer wg.Done() + err = client.Close() + if err != nil { + t.Fatal(err) + } + }() + go func() { + defer wg.Done() - err := server.Handshake() - if err != nil { - t.Fatal(err) - } + err := server.Handshake() + if err != nil { + t.Fatal(err) + } - buf := bytes.NewBuffer(make([]byte, 0, len(data))) - _, err = io.CopyN(buf, server, int64(len(data))) - if err != nil { - t.Fatal(err) - } - if string(buf.Bytes()) != data { - t.Fatal("mismatched data") - } + buf := bytes.NewBuffer(make([]byte, 0, len(data))) + _, err = io.CopyN(buf, server, int64(len(data))) + if err != nil { + t.Fatal(err) + } + if string(buf.Bytes()) != data { + t.Fatal("mismatched data") + } - err = server.Close() - if err != nil { - t.Fatal(err) - } - }() - wg.Wait() + err = server.Close() + if err != nil { + t.Fatal(err) + } + }() + wg.Wait() } func close_both(closer1, closer2 io.Closer) { - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - closer1.Close() - }() - go func() { - defer wg.Done() - closer2.Close() - }() - wg.Wait() + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + closer1.Close() + }() + go func() { + defer wg.Done() + closer2.Close() + }() + wg.Wait() } func ClosingTest(t testing.TB, constructor func( - t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { + t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { - run_test := func(close_tcp bool, server_writes bool) { - server_conn, client_conn := NetPipe(t) - defer server_conn.Close() - defer client_conn.Close() - server, client := constructor(t, server_conn, client_conn) - defer close_both(server, client) + run_test := func(close_tcp bool, server_writes bool) { + server_conn, client_conn := NetPipe(t) + defer server_conn.Close() + defer client_conn.Close() + server, client := constructor(t, server_conn, client_conn) + defer close_both(server, client) - var sslconn1, sslconn2 HandshakingConn - var conn1 net.Conn - if server_writes { - sslconn1 = server - conn1 = server_conn - sslconn2 = client - } else { - sslconn1 = client - conn1 = client_conn - sslconn2 = server - } + var sslconn1, sslconn2 HandshakingConn + var conn1 net.Conn + if server_writes { + sslconn1 = server + conn1 = server_conn + sslconn2 = client + } else { + sslconn1 = client + conn1 = client_conn + sslconn2 = server + } - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - _, err := sslconn1.Write([]byte("hello")) - if err != nil { - t.Fatal(err) - } - if close_tcp { - err = conn1.Close() - } else { - err = sslconn1.Close() - } - if err != nil { - t.Fatal(err) - } - }() + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _, err := sslconn1.Write([]byte("hello")) + if err != nil { + t.Fatal(err) + } + if close_tcp { + err = conn1.Close() + } else { + err = sslconn1.Close() + } + if err != nil { + t.Fatal(err) + } + }() - go func() { - defer wg.Done() - data, err := ioutil.ReadAll(sslconn2) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(data, []byte("hello")) { - t.Fatal("bytes don't match") - } - }() + go func() { + defer wg.Done() + data, err := ioutil.ReadAll(sslconn2) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(data, []byte("hello")) { + t.Fatal("bytes don't match") + } + }() - wg.Wait() - } + wg.Wait() + } - run_test(true, false) - run_test(false, false) - run_test(true, true) - run_test(false, true) + run_test(true, false) + run_test(false, false) + run_test(true, true) + run_test(false, true) } func ThroughputBenchmark(b *testing.B, constructor func( - t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { - server_conn, client_conn := NetPipe(b) - defer server_conn.Close() - defer client_conn.Close() + t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { + server_conn, client_conn := NetPipe(b) + defer server_conn.Close() + defer client_conn.Close() - server, client := constructor(b, server_conn, client_conn) - defer close_both(server, client) + server, client := constructor(b, server_conn, client_conn) + defer close_both(server, client) - b.SetBytes(1024) - data := make([]byte, b.N*1024) - _, err := io.ReadFull(rand.Reader, data[:]) - if err != nil { - b.Fatal(err) - } + b.SetBytes(1024) + data := make([]byte, b.N*1024) + _, err := io.ReadFull(rand.Reader, data[:]) + if err != nil { + b.Fatal(err) + } - b.ResetTimer() - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - _, err = io.Copy(client, bytes.NewReader([]byte(data))) - if err != nil { - b.Fatal(err) - } - }() - go func() { - defer wg.Done() + b.ResetTimer() + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _, err = io.Copy(client, bytes.NewReader([]byte(data))) + if err != nil { + b.Fatal(err) + } + }() + go func() { + defer wg.Done() - buf := &bytes.Buffer{} - _, err = io.CopyN(buf, server, int64(len(data))) - if err != nil { - b.Fatal(err) - } - if !bytes.Equal(buf.Bytes(), data) { - b.Fatal("mismatched data") - } - }() - wg.Wait() - b.StopTimer() + buf := &bytes.Buffer{} + _, err = io.CopyN(buf, server, int64(len(data))) + if err != nil { + b.Fatal(err) + } + if !bytes.Equal(buf.Bytes(), data) { + b.Fatal("mismatched data") + } + }() + wg.Wait() + b.StopTimer() } func StdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) ( - server, client HandshakingConn) { - cert, err := tls.X509KeyPair(certBytes, keyBytes) - if err != nil { - t.Fatal(err) - } - config := &tls.Config{ - Certificates: []tls.Certificate{cert}, - InsecureSkipVerify: true, - CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}} - server = tls.Server(server_conn, config) - client = tls.Client(client_conn, config) - return server, client + server, client HandshakingConn) { + cert, err := tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + t.Fatal(err) + } + config := &tls.Config{ + Certificates: []tls.Certificate{cert}, + InsecureSkipVerify: true, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}} + server = tls.Server(server_conn, config) + client = tls.Client(client_conn, config) + return server, client } func OpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) ( - server, client HandshakingConn) { - ctx, err := NewCtx() - if err != nil { - t.Fatal(err) - } - key, err := LoadPrivateKey(keyBytes) - if err != nil { - t.Fatal(err) - } - err = ctx.UsePrivateKey(key) - if err != nil { - t.Fatal(err) - } - cert, err := LoadCertificate(certBytes) - if err != nil { - t.Fatal(err) - } - err = ctx.UseCertificate(cert) - if err != nil { - t.Fatal(err) - } - err = ctx.SetCipherList("AES128-SHA") - if err != nil { - t.Fatal(err) - } - server, err = Server(server_conn, ctx) - if err != nil { - t.Fatal(err) - } - client, err = Client(client_conn, ctx) - if err != nil { - t.Fatal(err) - } - return server, client + server, client HandshakingConn) { + ctx, err := NewCtx() + if err != nil { + t.Fatal(err) + } + key, err := LoadPrivateKey(keyBytes) + if err != nil { + t.Fatal(err) + } + err = ctx.UsePrivateKey(key) + if err != nil { + t.Fatal(err) + } + cert, err := LoadCertificate(certBytes) + if err != nil { + t.Fatal(err) + } + err = ctx.UseCertificate(cert) + if err != nil { + t.Fatal(err) + } + err = ctx.SetCipherList("AES128-SHA") + if err != nil { + t.Fatal(err) + } + server, err = Server(server_conn, ctx) + if err != nil { + t.Fatal(err) + } + client, err = Client(client_conn, ctx) + if err != nil { + t.Fatal(err) + } + return server, client } func StdlibOpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) ( - server, client HandshakingConn) { - server_std, _ := StdlibConstructor(t, server_conn, client_conn) - _, client_ssl := OpenSSLConstructor(t, server_conn, client_conn) - return server_std, client_ssl + server, client HandshakingConn) { + server_std, _ := StdlibConstructor(t, server_conn, client_conn) + _, client_ssl := OpenSSLConstructor(t, server_conn, client_conn) + return server_std, client_ssl } func OpenSSLStdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) ( - server, client HandshakingConn) { - _, client_std := StdlibConstructor(t, server_conn, client_conn) - server_ssl, _ := OpenSSLConstructor(t, server_conn, client_conn) - return server_ssl, client_std + server, client HandshakingConn) { + _, client_std := StdlibConstructor(t, server_conn, client_conn) + server_ssl, _ := OpenSSLConstructor(t, server_conn, client_conn) + return server_ssl, client_std } func TestStdlibSimple(t *testing.T) { - SimpleConnTest(t, StdlibConstructor) + SimpleConnTest(t, StdlibConstructor) } func TestOpenSSLSimple(t *testing.T) { - SimpleConnTest(t, OpenSSLConstructor) + SimpleConnTest(t, OpenSSLConstructor) } func TestStdlibClosing(t *testing.T) { - ClosingTest(t, StdlibConstructor) + ClosingTest(t, StdlibConstructor) } func TestOpenSSLClosing(t *testing.T) { - ClosingTest(t, OpenSSLConstructor) + ClosingTest(t, OpenSSLConstructor) } func BenchmarkStdlibThroughput(b *testing.B) { - ThroughputBenchmark(b, StdlibConstructor) + ThroughputBenchmark(b, StdlibConstructor) } func BenchmarkOpenSSLThroughput(b *testing.B) { - ThroughputBenchmark(b, OpenSSLConstructor) + ThroughputBenchmark(b, OpenSSLConstructor) } func TestStdlibOpenSSLSimple(t *testing.T) { - SimpleConnTest(t, StdlibOpenSSLConstructor) + SimpleConnTest(t, StdlibOpenSSLConstructor) } func TestOpenSSLStdlibSimple(t *testing.T) { - SimpleConnTest(t, OpenSSLStdlibConstructor) + SimpleConnTest(t, OpenSSLStdlibConstructor) } func TestStdlibOpenSSLClosing(t *testing.T) { - ClosingTest(t, StdlibOpenSSLConstructor) + ClosingTest(t, StdlibOpenSSLConstructor) } func TestOpenSSLStdlibClosing(t *testing.T) { - ClosingTest(t, OpenSSLStdlibConstructor) + ClosingTest(t, OpenSSLStdlibConstructor) } func BenchmarkStdlibOpenSSLThroughput(b *testing.B) { - ThroughputBenchmark(b, StdlibOpenSSLConstructor) + ThroughputBenchmark(b, StdlibOpenSSLConstructor) } func BenchmarkOpenSSLStdlibThroughput(b *testing.B) { - ThroughputBenchmark(b, OpenSSLStdlibConstructor) + ThroughputBenchmark(b, OpenSSLStdlibConstructor) } func FullDuplexRenegotiationTest(t testing.TB, constructor func( - t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { + t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { - server_conn, client_conn := NetPipe(t) - defer server_conn.Close() - defer client_conn.Close() + server_conn, client_conn := NetPipe(t) + defer server_conn.Close() + defer client_conn.Close() - times := 256 - data_len := 4 * SSLRecordSize - data1 := make([]byte, data_len) - _, err := io.ReadFull(rand.Reader, data1[:]) - if err != nil { - t.Fatal(err) - } - data2 := make([]byte, data_len) - _, err = io.ReadFull(rand.Reader, data1[:]) - if err != nil { - t.Fatal(err) - } + times := 256 + data_len := 4 * SSLRecordSize + data1 := make([]byte, data_len) + _, err := io.ReadFull(rand.Reader, data1[:]) + if err != nil { + t.Fatal(err) + } + data2 := make([]byte, data_len) + _, err = io.ReadFull(rand.Reader, data1[:]) + if err != nil { + t.Fatal(err) + } - server, client := constructor(t, server_conn, client_conn) - defer close_both(server, client) + server, client := constructor(t, server_conn, client_conn) + defer close_both(server, client) - var wg sync.WaitGroup + var wg sync.WaitGroup - send_func := func(sender HandshakingConn, data []byte) { - defer wg.Done() - for i := 0; i < times; i++ { - if i == times/2 { - wg.Add(1) - go func() { - defer wg.Done() - err := sender.Handshake() - if err != nil { - t.Fatal(err) - } - }() - } - _, err := sender.Write(data) - if err != nil { - t.Fatal(err) - } - } - } + send_func := func(sender HandshakingConn, data []byte) { + defer wg.Done() + for i := 0; i < times; i++ { + if i == times/2 { + wg.Add(1) + go func() { + defer wg.Done() + err := sender.Handshake() + if err != nil { + t.Fatal(err) + } + }() + } + _, err := sender.Write(data) + if err != nil { + t.Fatal(err) + } + } + } - recv_func := func(receiver net.Conn, data []byte) { - defer wg.Done() + recv_func := func(receiver net.Conn, data []byte) { + defer wg.Done() - buf := make([]byte, len(data)) - for i := 0; i < times; i++ { - n, err := io.ReadFull(receiver, buf[:]) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(buf[:n], data) { - t.Fatal(err) - } - } - } + buf := make([]byte, len(data)) + for i := 0; i < times; i++ { + n, err := io.ReadFull(receiver, buf[:]) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf[:n], data) { + t.Fatal(err) + } + } + } - wg.Add(4) - go recv_func(server, data1) - go send_func(client, data1) - go send_func(server, data2) - go recv_func(client, data2) - wg.Wait() + wg.Add(4) + go recv_func(server, data1) + go send_func(client, data1) + go send_func(server, data2) + go recv_func(client, data2) + wg.Wait() } func TestStdlibFullDuplexRenegotiation(t *testing.T) { - FullDuplexRenegotiationTest(t, StdlibConstructor) + FullDuplexRenegotiationTest(t, StdlibConstructor) } func TestOpenSSLFullDuplexRenegotiation(t *testing.T) { - FullDuplexRenegotiationTest(t, OpenSSLConstructor) + FullDuplexRenegotiationTest(t, OpenSSLConstructor) } func TestOpenSSLStdlibFullDuplexRenegotiation(t *testing.T) { - FullDuplexRenegotiationTest(t, OpenSSLStdlibConstructor) + FullDuplexRenegotiationTest(t, OpenSSLStdlibConstructor) } func TestStdlibOpenSSLFullDuplexRenegotiation(t *testing.T) { - FullDuplexRenegotiationTest(t, StdlibOpenSSLConstructor) + FullDuplexRenegotiationTest(t, StdlibOpenSSLConstructor) } func LotsOfConns(t *testing.T, payload_size int64, loops, clients int, - sleep time.Duration, newListener func(net.Listener) net.Listener, - newClient func(net.Conn) (net.Conn, error)) { - tcp_listener, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatal(err) - } - ssl_listener := newListener(tcp_listener) - go func() { - for { - conn, err := ssl_listener.Accept() - if err != nil { - t.Fatalf("failed accept: %s", err) - continue - } - go func() { - defer func() { - err = conn.Close() - if err != nil { - t.Fatalf("failed closing: %s", err) - } - }() - for i := 0; i < loops; i++ { - _, err := io.Copy(ioutil.Discard, - io.LimitReader(conn, payload_size)) - if err != nil { - t.Fatalf("failed reading: %s", err) - return - } - _, err = io.Copy(conn, io.LimitReader(rand.Reader, - payload_size)) - if err != nil { - t.Fatalf("failed writing: %s", err) - return - } - } - time.Sleep(sleep) - }() - } - }() - var wg sync.WaitGroup - for i := 0; i < clients; i++ { - tcp_client, err := net.Dial(tcp_listener.Addr().Network(), - tcp_listener.Addr().String()) - if err != nil { - t.Fatal(err) - } - ssl_client, err := newClient(tcp_client) - if err != nil { - t.Fatal(err) - } - wg.Add(1) - go func(i int) { - defer func() { - err = ssl_client.Close() - if err != nil { - t.Fatalf("failed closing: %s", err) - } - wg.Done() - }() - for i := 0; i < loops; i++ { - _, err := io.Copy(ssl_client, io.LimitReader(rand.Reader, - payload_size)) - if err != nil { - t.Fatalf("failed writing: %s", err) - return - } - _, err = io.Copy(ioutil.Discard, - io.LimitReader(ssl_client, payload_size)) - if err != nil { - t.Fatalf("failed reading: %s", err) - return - } - } - time.Sleep(sleep) - }(i) - } - wg.Wait() + sleep time.Duration, newListener func(net.Listener) net.Listener, + newClient func(net.Conn) (net.Conn, error)) { + tcp_listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + ssl_listener := newListener(tcp_listener) + go func() { + for { + conn, err := ssl_listener.Accept() + if err != nil { + t.Fatalf("failed accept: %s", err) + continue + } + go func() { + defer func() { + err = conn.Close() + if err != nil { + t.Fatalf("failed closing: %s", err) + } + }() + for i := 0; i < loops; i++ { + _, err := io.Copy(ioutil.Discard, + io.LimitReader(conn, payload_size)) + if err != nil { + t.Fatalf("failed reading: %s", err) + return + } + _, err = io.Copy(conn, io.LimitReader(rand.Reader, + payload_size)) + if err != nil { + t.Fatalf("failed writing: %s", err) + return + } + } + time.Sleep(sleep) + }() + } + }() + var wg sync.WaitGroup + for i := 0; i < clients; i++ { + tcp_client, err := net.Dial(tcp_listener.Addr().Network(), + tcp_listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + ssl_client, err := newClient(tcp_client) + if err != nil { + t.Fatal(err) + } + wg.Add(1) + go func(i int) { + defer func() { + err = ssl_client.Close() + if err != nil { + t.Fatalf("failed closing: %s", err) + } + wg.Done() + }() + for i := 0; i < loops; i++ { + _, err := io.Copy(ssl_client, io.LimitReader(rand.Reader, + payload_size)) + if err != nil { + t.Fatalf("failed writing: %s", err) + return + } + _, err = io.Copy(ioutil.Discard, + io.LimitReader(ssl_client, payload_size)) + if err != nil { + t.Fatalf("failed reading: %s", err) + return + } + } + time.Sleep(sleep) + }(i) + } + wg.Wait() } func TestStdlibLotsOfConns(t *testing.T) { - tls_cert, err := tls.X509KeyPair(certBytes, keyBytes) - if err != nil { - t.Fatal(err) - } - tls_config := &tls.Config{ - Certificates: []tls.Certificate{tls_cert}, - InsecureSkipVerify: true, - CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}} - LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, - func(l net.Listener) net.Listener { - return tls.NewListener(l, tls_config) - }, func(c net.Conn) (net.Conn, error) { - return tls.Client(c, tls_config), nil - }) + tls_cert, err := tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + t.Fatal(err) + } + tls_config := &tls.Config{ + Certificates: []tls.Certificate{tls_cert}, + InsecureSkipVerify: true, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}} + LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, + func(l net.Listener) net.Listener { + return tls.NewListener(l, tls_config) + }, func(c net.Conn) (net.Conn, error) { + return tls.Client(c, tls_config), nil + }) } func TestOpenSSLLotsOfConns(t *testing.T) { - ctx, err := NewCtx() - if err != nil { - t.Fatal(err) - } - key, err := LoadPrivateKey(keyBytes) - if err != nil { - t.Fatal(err) - } - err = ctx.UsePrivateKey(key) - if err != nil { - t.Fatal(err) - } - cert, err := LoadCertificate(certBytes) - if err != nil { - t.Fatal(err) - } - err = ctx.UseCertificate(cert) - if err != nil { - t.Fatal(err) - } - err = ctx.SetCipherList("AES128-SHA") - if err != nil { - t.Fatal(err) - } - LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, - func(l net.Listener) net.Listener { - return NewListener(l, ctx) - }, func(c net.Conn) (net.Conn, error) { - return Client(c, ctx) - }) + ctx, err := NewCtx() + if err != nil { + t.Fatal(err) + } + key, err := LoadPrivateKey(keyBytes) + if err != nil { + t.Fatal(err) + } + err = ctx.UsePrivateKey(key) + if err != nil { + t.Fatal(err) + } + cert, err := LoadCertificate(certBytes) + if err != nil { + t.Fatal(err) + } + err = ctx.UseCertificate(cert) + if err != nil { + t.Fatal(err) + } + err = ctx.SetCipherList("AES128-SHA") + if err != nil { + t.Fatal(err) + } + LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, + func(l net.Listener) net.Listener { + return NewListener(l, ctx) + }, func(c net.Conn) (net.Conn, error) { + return Client(c, ctx) + }) } diff --git a/utils/errors.go b/utils/errors.go index eb7a473..a01f2ea 100644 --- a/utils/errors.go +++ b/utils/errors.go @@ -3,20 +3,20 @@ package utils import ( - "errors" - "strings" + "errors" + "strings" ) // ErrorGroup collates errors type ErrorGroup struct { - Errors []error + Errors []error } // Add adds an error to an existing error group func (e *ErrorGroup) Add(err error) { - if err != nil { - e.Errors = append(e.Errors, err) - } + if err != nil { + e.Errors = append(e.Errors, err) + } } // Finalize returns an error corresponding to the ErrorGroup state. If there's @@ -24,15 +24,15 @@ func (e *ErrorGroup) Add(err error) { // Finalize returns that error. Otherwise, Finalize will make a new error // consisting of the messages from the constituent errors. func (e *ErrorGroup) Finalize() error { - if len(e.Errors) == 0 { - return nil - } - if len(e.Errors) == 1 { - return e.Errors[0] - } - msgs := make([]string, 0, len(e.Errors)) - for _, err := range e.Errors { - msgs = append(msgs, err.Error()) - } - return errors.New(strings.Join(msgs, "\n")) + if len(e.Errors) == 0 { + return nil + } + if len(e.Errors) == 1 { + return e.Errors[0] + } + msgs := make([]string, 0, len(e.Errors)) + for _, err := range e.Errors { + msgs = append(msgs, err.Error()) + } + return errors.New(strings.Join(msgs, "\n")) } diff --git a/utils/future.go b/utils/future.go index 739bcb3..95168b1 100644 --- a/utils/future.go +++ b/utils/future.go @@ -3,7 +3,7 @@ package utils import ( - "sync" + "sync" ) // Future is a type that is essentially the inverse of a channel. With a @@ -13,55 +13,55 @@ import ( // results, we also capture and return error values as well. Use NewFuture // to initialize. type Future struct { - mutex *sync.Mutex - cond *sync.Cond - received bool - val interface{} - err error + mutex *sync.Mutex + cond *sync.Cond + received bool + val interface{} + err error } // NewFuture returns an initialized and ready Future. func NewFuture() *Future { - mutex := &sync.Mutex{} - return &Future{ - mutex: mutex, - cond: sync.NewCond(mutex), - received: false, - val: nil, - err: nil, - } + mutex := &sync.Mutex{} + return &Future{ + mutex: mutex, + cond: sync.NewCond(mutex), + received: false, + val: nil, + err: nil, + } } // Get blocks until the Future has a value set. func (self *Future) Get() (interface{}, error) { - self.mutex.Lock() - defer self.mutex.Unlock() - for { - if self.received { - return self.val, self.err - } - self.cond.Wait() - } + self.mutex.Lock() + defer self.mutex.Unlock() + for { + if self.received { + return self.val, self.err + } + self.cond.Wait() + } } // Fired returns whether or not a value has been set. If Fired is true, Get // won't block. func (self *Future) Fired() bool { - self.mutex.Lock() - defer self.mutex.Unlock() - return self.received + self.mutex.Lock() + defer self.mutex.Unlock() + return self.received } // Set provides the value to present and future Get calls. If Set has already // been called, this is a no-op. func (self *Future) Set(val interface{}, err error) { - self.mutex.Lock() - defer self.mutex.Unlock() - if self.received { - return - } - self.received = true - self.val = val - self.err = err - self.cond.Broadcast() + self.mutex.Lock() + defer self.mutex.Unlock() + if self.received { + return + } + self.received = true + self.val = val + self.err = err + self.cond.Broadcast() } diff --git a/utils/thread_id.go b/utils/thread_id.go index c0300e8..4e0873b 100644 --- a/utils/thread_id.go +++ b/utils/thread_id.go @@ -5,7 +5,7 @@ package utils import ( - "unsafe" + "unsafe" ) // ThreadId returns the current runtime's thread id. Thanks to Gustavo Niemeyer