space monkey internal commit export

[katamari commit: 9bd04d1d78e85304589695c66e328d23128f509c]
This commit is contained in:
Jeff Wendling 2014-02-25 13:45:14 -05:00 committed by JT Olds
parent 751143ef9c
commit fa8eb6a573
14 changed files with 1470 additions and 1470 deletions

306
bio.go
View File

@ -56,226 +56,226 @@ static void BIO_set_retry_read_not_a_macro(BIO *b) { BIO_set_retry_read(b); }
import "C" import "C"
import ( import (
"errors" "errors"
"io" "io"
"reflect" "reflect"
"sync" "sync"
"unsafe" "unsafe"
) )
const ( const (
SSLRecordSize = 16 * 1024 SSLRecordSize = 16 * 1024
) )
func nonCopyGoBytes(ptr uintptr, length int) []byte { func nonCopyGoBytes(ptr uintptr, length int) []byte {
var slice []byte var slice []byte
header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) header := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
header.Cap = length header.Cap = length
header.Len = length header.Len = length
header.Data = ptr header.Data = ptr
return slice return slice
} }
func nonCopyCString(data *C.char, size C.int) []byte { func nonCopyCString(data *C.char, size C.int) []byte {
return nonCopyGoBytes(uintptr(unsafe.Pointer(data)), int(size)) return nonCopyGoBytes(uintptr(unsafe.Pointer(data)), int(size))
} }
//export cbioNew //export cbioNew
func cbioNew(b *C.BIO) C.int { func cbioNew(b *C.BIO) C.int {
b.shutdown = 1 b.shutdown = 1
b.init = 1 b.init = 1
b.num = -1 b.num = -1
b.ptr = nil b.ptr = nil
b.flags = 0 b.flags = 0
return 1 return 1
} }
//export cbioFree //export cbioFree
func cbioFree(b *C.BIO) C.int { func cbioFree(b *C.BIO) C.int {
return 1 return 1
} }
type writeBio struct { type writeBio struct {
data_mtx sync.Mutex data_mtx sync.Mutex
op_mtx sync.Mutex op_mtx sync.Mutex
buf []byte buf []byte
} }
func loadWritePtr(b *C.BIO) *writeBio { func loadWritePtr(b *C.BIO) *writeBio {
return (*writeBio)(unsafe.Pointer(b.ptr)) return (*writeBio)(unsafe.Pointer(b.ptr))
} }
//export writeBioWrite //export writeBioWrite
func writeBioWrite(b *C.BIO, data *C.char, size C.int) C.int { func writeBioWrite(b *C.BIO, data *C.char, size C.int) C.int {
ptr := loadWritePtr(b) ptr := loadWritePtr(b)
if ptr == nil || data == nil || size < 0 { if ptr == nil || data == nil || size < 0 {
return -1 return -1
} }
ptr.data_mtx.Lock() ptr.data_mtx.Lock()
defer ptr.data_mtx.Unlock() defer ptr.data_mtx.Unlock()
C.BIO_clear_retry_flags_not_a_macro(b) C.BIO_clear_retry_flags_not_a_macro(b)
ptr.buf = append(ptr.buf, nonCopyCString(data, size)...) ptr.buf = append(ptr.buf, nonCopyCString(data, size)...)
return size return size
} }
//export writeBioCtrl //export writeBioCtrl
func writeBioCtrl(b *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) C.long { func writeBioCtrl(b *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) C.long {
switch cmd { switch cmd {
case C.BIO_CTRL_WPENDING: case C.BIO_CTRL_WPENDING:
return writeBioPending(b) return writeBioPending(b)
case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH: case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH:
return 1 return 1
default: default:
return 0 return 0
} }
} }
func writeBioPending(b *C.BIO) C.long { func writeBioPending(b *C.BIO) C.long {
ptr := loadWritePtr(b) ptr := loadWritePtr(b)
if ptr == nil { if ptr == nil {
return 0 return 0
} }
ptr.data_mtx.Lock() ptr.data_mtx.Lock()
defer ptr.data_mtx.Unlock() defer ptr.data_mtx.Unlock()
return C.long(len(ptr.buf)) return C.long(len(ptr.buf))
} }
func (b *writeBio) WriteTo(w io.Writer) (rv int64, err error) { func (b *writeBio) WriteTo(w io.Writer) (rv int64, err error) {
b.op_mtx.Lock() b.op_mtx.Lock()
defer b.op_mtx.Unlock() defer b.op_mtx.Unlock()
// write whatever data we currently have // write whatever data we currently have
b.data_mtx.Lock() b.data_mtx.Lock()
data := b.buf data := b.buf
b.data_mtx.Unlock() b.data_mtx.Unlock()
if len(data) == 0 { if len(data) == 0 {
return 0, nil return 0, nil
} }
n, err := w.Write(data) n, err := w.Write(data)
// subtract however much data we wrote from the buffer // subtract however much data we wrote from the buffer
b.data_mtx.Lock() b.data_mtx.Lock()
b.buf = b.buf[:copy(b.buf, b.buf[n:])] b.buf = b.buf[:copy(b.buf, b.buf[n:])]
b.data_mtx.Unlock() b.data_mtx.Unlock()
return int64(n), err return int64(n), err
} }
func (self *writeBio) Disconnect(b *C.BIO) { func (self *writeBio) Disconnect(b *C.BIO) {
if loadWritePtr(b) == self { if loadWritePtr(b) == self {
b.ptr = nil b.ptr = nil
} }
} }
func (b *writeBio) MakeCBIO() *C.BIO { func (b *writeBio) MakeCBIO() *C.BIO {
rv := C.BIO_new(C.BIO_s_writeBio()) rv := C.BIO_new(C.BIO_s_writeBio())
rv.ptr = unsafe.Pointer(b) rv.ptr = unsafe.Pointer(b)
return rv return rv
} }
type readBio struct { type readBio struct {
data_mtx sync.Mutex data_mtx sync.Mutex
op_mtx sync.Mutex op_mtx sync.Mutex
buf []byte buf []byte
eof bool eof bool
} }
func loadReadPtr(b *C.BIO) *readBio { func loadReadPtr(b *C.BIO) *readBio {
return (*readBio)(unsafe.Pointer(b.ptr)) return (*readBio)(unsafe.Pointer(b.ptr))
} }
//export readBioRead //export readBioRead
func readBioRead(b *C.BIO, data *C.char, size C.int) C.int { func readBioRead(b *C.BIO, data *C.char, size C.int) C.int {
ptr := loadReadPtr(b) ptr := loadReadPtr(b)
if ptr == nil || size < 0 { if ptr == nil || size < 0 {
return -1 return -1
} }
ptr.data_mtx.Lock() ptr.data_mtx.Lock()
defer ptr.data_mtx.Unlock() defer ptr.data_mtx.Unlock()
C.BIO_clear_retry_flags_not_a_macro(b) C.BIO_clear_retry_flags_not_a_macro(b)
if len(ptr.buf) == 0 { if len(ptr.buf) == 0 {
if ptr.eof { if ptr.eof {
return 0 return 0
} }
C.BIO_set_retry_read_not_a_macro(b) C.BIO_set_retry_read_not_a_macro(b)
return -1 return -1
} }
if size == 0 || data == nil { if size == 0 || data == nil {
return C.int(len(ptr.buf)) return C.int(len(ptr.buf))
} }
n := copy(nonCopyCString(data, size), ptr.buf) n := copy(nonCopyCString(data, size), ptr.buf)
ptr.buf = ptr.buf[:copy(ptr.buf, ptr.buf[n:])] ptr.buf = ptr.buf[:copy(ptr.buf, ptr.buf[n:])]
return C.int(n) return C.int(n)
} }
//export readBioCtrl //export readBioCtrl
func readBioCtrl(b *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) C.long { func readBioCtrl(b *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) C.long {
switch cmd { switch cmd {
case C.BIO_CTRL_PENDING: case C.BIO_CTRL_PENDING:
return readBioPending(b) return readBioPending(b)
case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH: case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH:
return 1 return 1
default: default:
return 0 return 0
} }
} }
func readBioPending(b *C.BIO) C.long { func readBioPending(b *C.BIO) C.long {
ptr := loadReadPtr(b) ptr := loadReadPtr(b)
if ptr == nil { if ptr == nil {
return 0 return 0
} }
ptr.data_mtx.Lock() ptr.data_mtx.Lock()
defer ptr.data_mtx.Unlock() defer ptr.data_mtx.Unlock()
return C.long(len(ptr.buf)) return C.long(len(ptr.buf))
} }
func (b *readBio) ReadFromOnce(r io.Reader) (n int, err error) { func (b *readBio) ReadFromOnce(r io.Reader) (n int, err error) {
b.op_mtx.Lock() b.op_mtx.Lock()
defer b.op_mtx.Unlock() defer b.op_mtx.Unlock()
// make sure we have a destination that fits at least one SSL record // make sure we have a destination that fits at least one SSL record
b.data_mtx.Lock() b.data_mtx.Lock()
if cap(b.buf) < len(b.buf)+SSLRecordSize { if cap(b.buf) < len(b.buf)+SSLRecordSize {
new_buf := make([]byte, len(b.buf), len(b.buf)+SSLRecordSize) new_buf := make([]byte, len(b.buf), len(b.buf)+SSLRecordSize)
copy(new_buf, b.buf) copy(new_buf, b.buf)
b.buf = new_buf b.buf = new_buf
} }
dst := b.buf[len(b.buf):cap(b.buf)] dst := b.buf[len(b.buf):cap(b.buf)]
dst_slice := b.buf dst_slice := b.buf
b.data_mtx.Unlock() b.data_mtx.Unlock()
n, err = r.Read(dst) n, err = r.Read(dst)
b.data_mtx.Lock() b.data_mtx.Lock()
defer b.data_mtx.Unlock() defer b.data_mtx.Unlock()
if n > 0 { if n > 0 {
if len(dst_slice) != len(b.buf) { if len(dst_slice) != len(b.buf) {
// someone shrunk the buffer, so we read in too far ahead and we // someone shrunk the buffer, so we read in too far ahead and we
// need to slide backwards // need to slide backwards
copy(b.buf[len(b.buf):len(b.buf)+n], dst) copy(b.buf[len(b.buf):len(b.buf)+n], dst)
} }
b.buf = b.buf[:len(b.buf)+n] b.buf = b.buf[:len(b.buf)+n]
} }
return n, err return n, err
} }
func (b *readBio) MakeCBIO() *C.BIO { func (b *readBio) MakeCBIO() *C.BIO {
rv := C.BIO_new(C.BIO_s_readBio()) rv := C.BIO_new(C.BIO_s_readBio())
rv.ptr = unsafe.Pointer(b) rv.ptr = unsafe.Pointer(b)
return rv return rv
} }
func (self *readBio) Disconnect(b *C.BIO) { func (self *readBio) Disconnect(b *C.BIO) {
if loadReadPtr(b) == self { if loadReadPtr(b) == self {
b.ptr = nil b.ptr = nil
} }
} }
func (b *readBio) MarkEOF() { func (b *readBio) MarkEOF() {
b.data_mtx.Lock() b.data_mtx.Lock()
defer b.data_mtx.Unlock() defer b.data_mtx.Unlock()
b.eof = true b.eof = true
} }
type anyBio C.BIO type anyBio C.BIO
@ -283,18 +283,18 @@ type anyBio C.BIO
func asAnyBio(b *C.BIO) *anyBio { return (*anyBio)(b) } func asAnyBio(b *C.BIO) *anyBio { return (*anyBio)(b) }
func (b *anyBio) Read(buf []byte) (n int, err error) { func (b *anyBio) Read(buf []byte) (n int, err error) {
n = int(C.BIO_read((*C.BIO)(b), unsafe.Pointer(&buf[0]), C.int(len(buf)))) n = int(C.BIO_read((*C.BIO)(b), unsafe.Pointer(&buf[0]), C.int(len(buf))))
if n <= 0 { if n <= 0 {
return 0, io.EOF return 0, io.EOF
} }
return n, nil return n, nil
} }
func (b *anyBio) Write(buf []byte) (written int, err error) { func (b *anyBio) Write(buf []byte) (written int, err error) {
n := int(C.BIO_write((*C.BIO)(b), unsafe.Pointer(&buf[0]), n := int(C.BIO_write((*C.BIO)(b), unsafe.Pointer(&buf[0]),
C.int(len(buf)))) C.int(len(buf))))
if n != len(buf) { if n != len(buf) {
return n, errors.New("BIO write failed") return n, errors.New("BIO write failed")
} }
return n, nil return n, nil
} }

570
conn.go
View File

@ -9,77 +9,77 @@ package openssl
import "C" import "C"
import ( import (
"errors" "errors"
"io" "io"
"net" "net"
"runtime" "runtime"
"sync" "sync"
"time" "time"
"unsafe" "unsafe"
"code.spacemonkey.com/go/openssl/utils" "code.spacemonkey.com/go/openssl/utils"
) )
var ( var (
zeroReturn = errors.New("zero return") zeroReturn = errors.New("zero return")
wantRead = errors.New("want read") wantRead = errors.New("want read")
wantWrite = errors.New("want write") wantWrite = errors.New("want write")
tryAgain = errors.New("try again") tryAgain = errors.New("try again")
) )
type Conn struct { type Conn struct {
conn net.Conn conn net.Conn
ssl *C.SSL ssl *C.SSL
into_ssl *readBio into_ssl *readBio
from_ssl *writeBio from_ssl *writeBio
is_shutdown bool is_shutdown bool
mtx sync.Mutex mtx sync.Mutex
want_read_future *utils.Future want_read_future *utils.Future
} }
func newSSL(ctx *C.SSL_CTX) (*C.SSL, error) { func newSSL(ctx *C.SSL_CTX) (*C.SSL, error) {
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
ssl := C.SSL_new(ctx) ssl := C.SSL_new(ctx)
if ssl == nil { if ssl == nil {
return nil, errorFromErrorQueue() return nil, errorFromErrorQueue()
} }
return ssl, nil return ssl, nil
} }
func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) { func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) {
ssl, err := newSSL(ctx.ctx) ssl, err := newSSL(ctx.ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
into_ssl := &readBio{} into_ssl := &readBio{}
from_ssl := &writeBio{} from_ssl := &writeBio{}
into_ssl_cbio := into_ssl.MakeCBIO() into_ssl_cbio := into_ssl.MakeCBIO()
from_ssl_cbio := from_ssl.MakeCBIO() from_ssl_cbio := from_ssl.MakeCBIO()
if into_ssl_cbio == nil || from_ssl_cbio == nil { if into_ssl_cbio == nil || from_ssl_cbio == nil {
// these frees are null safe // these frees are null safe
C.BIO_free(into_ssl_cbio) C.BIO_free(into_ssl_cbio)
C.BIO_free(from_ssl_cbio) C.BIO_free(from_ssl_cbio)
C.SSL_free(ssl) C.SSL_free(ssl)
return nil, errors.New("failed to allocate memory BIO") return nil, errors.New("failed to allocate memory BIO")
} }
// the ssl object takes ownership of these objects now // the ssl object takes ownership of these objects now
C.SSL_set_bio(ssl, into_ssl_cbio, from_ssl_cbio) C.SSL_set_bio(ssl, into_ssl_cbio, from_ssl_cbio)
c := &Conn{ c := &Conn{
conn: conn, conn: conn,
ssl: ssl, ssl: ssl,
into_ssl: into_ssl, into_ssl: into_ssl,
from_ssl: from_ssl} from_ssl: from_ssl}
runtime.SetFinalizer(c, func(c *Conn) { runtime.SetFinalizer(c, func(c *Conn) {
c.into_ssl.Disconnect(into_ssl_cbio) c.into_ssl.Disconnect(into_ssl_cbio)
c.from_ssl.Disconnect(from_ssl_cbio) c.from_ssl.Disconnect(from_ssl_cbio)
C.SSL_free(c.ssl) C.SSL_free(c.ssl)
}) })
return c, nil return c, nil
} }
// Client wraps an existing stream connection and puts it in the connect state // Client wraps an existing stream connection and puts it in the connect state
@ -94,319 +94,319 @@ func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) {
// you're using. This library is not nice enough to use the system certificate // you're using. This library is not nice enough to use the system certificate
// store by default for you yet. // store by default for you yet.
func Client(conn net.Conn, ctx *Ctx) (*Conn, error) { func Client(conn net.Conn, ctx *Ctx) (*Conn, error) {
c, err := newConn(conn, ctx) c, err := newConn(conn, ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
C.SSL_set_connect_state(c.ssl) C.SSL_set_connect_state(c.ssl)
return c, nil return c, nil
} }
// Server wraps an existing stream connection and puts it in the accept state // Server wraps an existing stream connection and puts it in the accept state
// for any subsequent handshakes. // for any subsequent handshakes.
func Server(conn net.Conn, ctx *Ctx) (*Conn, error) { func Server(conn net.Conn, ctx *Ctx) (*Conn, error) {
c, err := newConn(conn, ctx) c, err := newConn(conn, ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
C.SSL_set_accept_state(c.ssl) C.SSL_set_accept_state(c.ssl)
return c, nil return c, nil
} }
func (c *Conn) fillInputBuffer() error { func (c *Conn) fillInputBuffer() error {
for { for {
n, err := c.into_ssl.ReadFromOnce(c.conn) n, err := c.into_ssl.ReadFromOnce(c.conn)
if n == 0 && err == nil { if n == 0 && err == nil {
continue continue
} }
if err == io.EOF { if err == io.EOF {
c.into_ssl.MarkEOF() c.into_ssl.MarkEOF()
return c.Close() return c.Close()
} }
return err return err
} }
} }
func (c *Conn) flushOutputBuffer() error { func (c *Conn) flushOutputBuffer() error {
_, err := c.from_ssl.WriteTo(c.conn) _, err := c.from_ssl.WriteTo(c.conn)
return err return err
} }
func (c *Conn) getErrorHandler(rv C.int, errno error) func() error { func (c *Conn) getErrorHandler(rv C.int, errno error) func() error {
errcode := C.SSL_get_error(c.ssl, rv) errcode := C.SSL_get_error(c.ssl, rv)
switch errcode { switch errcode {
case C.SSL_ERROR_ZERO_RETURN: case C.SSL_ERROR_ZERO_RETURN:
return func() error { return func() error {
c.Close() c.Close()
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
case C.SSL_ERROR_WANT_READ: case C.SSL_ERROR_WANT_READ:
go c.flushOutputBuffer() go c.flushOutputBuffer()
if c.want_read_future != nil { if c.want_read_future != nil {
want_read_future := c.want_read_future want_read_future := c.want_read_future
return func() error { return func() error {
_, err := want_read_future.Get() _, err := want_read_future.Get()
return err return err
} }
} }
c.want_read_future = utils.NewFuture() c.want_read_future = utils.NewFuture()
want_read_future := c.want_read_future want_read_future := c.want_read_future
return func() (err error) { return func() (err error) {
defer func() { defer func() {
c.mtx.Lock() c.mtx.Lock()
c.want_read_future = nil c.want_read_future = nil
c.mtx.Unlock() c.mtx.Unlock()
want_read_future.Set(nil, err) want_read_future.Set(nil, err)
}() }()
err = c.fillInputBuffer() err = c.fillInputBuffer()
if err != nil { if err != nil {
return err return err
} }
return tryAgain return tryAgain
} }
case C.SSL_ERROR_WANT_WRITE: case C.SSL_ERROR_WANT_WRITE:
return func() error { return func() error {
err := c.flushOutputBuffer() err := c.flushOutputBuffer()
if err != nil { if err != nil {
return err return err
} }
return tryAgain return tryAgain
} }
case C.SSL_ERROR_SYSCALL: case C.SSL_ERROR_SYSCALL:
var err error var err error
if C.ERR_peek_error() == 0 { if C.ERR_peek_error() == 0 {
switch rv { switch rv {
case 0: case 0:
err = errors.New("protocol-violating EOF") err = errors.New("protocol-violating EOF")
case -1: case -1:
err = errno err = errno
default: default:
err = errorFromErrorQueue() err = errorFromErrorQueue()
} }
} else { } else {
err = errorFromErrorQueue() err = errorFromErrorQueue()
} }
return func() error { return err } return func() error { return err }
default: default:
err := errorFromErrorQueue() err := errorFromErrorQueue()
return func() error { return err } return func() error { return err }
} }
} }
func (c *Conn) handleError(errcb func() error) error { func (c *Conn) handleError(errcb func() error) error {
if errcb != nil { if errcb != nil {
return errcb() return errcb()
} }
return nil return nil
} }
func (c *Conn) handshake() func() error { func (c *Conn) handshake() func() error {
c.mtx.Lock() c.mtx.Lock()
defer c.mtx.Unlock() defer c.mtx.Unlock()
if c.is_shutdown { if c.is_shutdown {
return func() error { return io.ErrUnexpectedEOF } return func() error { return io.ErrUnexpectedEOF }
} }
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
rv, errno := C.SSL_do_handshake(c.ssl) rv, errno := C.SSL_do_handshake(c.ssl)
if rv > 0 { if rv > 0 {
return nil return nil
} }
return c.getErrorHandler(rv, errno) return c.getErrorHandler(rv, errno)
} }
// Handshake performs an SSL handshake. If a handshake is not manually // Handshake performs an SSL handshake. If a handshake is not manually
// triggered, it will run before the first I/O on the encrypted stream. // triggered, it will run before the first I/O on the encrypted stream.
func (c *Conn) Handshake() error { func (c *Conn) Handshake() error {
err := tryAgain err := tryAgain
for err == tryAgain { for err == tryAgain {
err = c.handleError(c.handshake()) err = c.handleError(c.handshake())
} }
go c.flushOutputBuffer() go c.flushOutputBuffer()
return err return err
} }
// PeerCertificate returns the Certificate of the peer with which you're // PeerCertificate returns the Certificate of the peer with which you're
// communicating. Only valid after a handshake. // communicating. Only valid after a handshake.
func (c *Conn) PeerCertificate() (*Certificate, error) { func (c *Conn) PeerCertificate() (*Certificate, error) {
c.mtx.Lock() c.mtx.Lock()
if c.is_shutdown { if c.is_shutdown {
return nil, errors.New("connection closed") return nil, errors.New("connection closed")
} }
x := C.SSL_get_peer_certificate(c.ssl) x := C.SSL_get_peer_certificate(c.ssl)
c.mtx.Unlock() c.mtx.Unlock()
if x == nil { if x == nil {
return nil, errors.New("no peer certificate found") return nil, errors.New("no peer certificate found")
} }
cert := &Certificate{x: x} cert := &Certificate{x: x}
runtime.SetFinalizer(cert, func(cert *Certificate) { runtime.SetFinalizer(cert, func(cert *Certificate) {
C.X509_free(cert.x) C.X509_free(cert.x)
}) })
return cert, nil return cert, nil
} }
func (c *Conn) shutdown() func() error { func (c *Conn) shutdown() func() error {
c.mtx.Lock() c.mtx.Lock()
defer c.mtx.Unlock() defer c.mtx.Unlock()
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
rv, errno := C.SSL_shutdown(c.ssl) rv, errno := C.SSL_shutdown(c.ssl)
if rv > 0 { if rv > 0 {
return nil return nil
} }
if rv == 0 { if rv == 0 {
// The OpenSSL docs say that in this case, the shutdown is not // The OpenSSL docs say that in this case, the shutdown is not
// finished, and we should call SSL_shutdown() a second time, if a // finished, and we should call SSL_shutdown() a second time, if a
// bidirectional shutdown is going to be performed. Further, the // bidirectional shutdown is going to be performed. Further, the
// output of SSL_get_error may be misleading, as an erroneous // output of SSL_get_error may be misleading, as an erroneous
// SSL_ERROR_SYSCALL may be flagged even though no error occurred. // SSL_ERROR_SYSCALL may be flagged even though no error occurred.
// So, TODO: revisit bidrectional shutdown, possibly trying again. // So, TODO: revisit bidrectional shutdown, possibly trying again.
// Note: some broken clients won't engage in bidirectional shutdown // Note: some broken clients won't engage in bidirectional shutdown
// without tickling them to close by sending a TCP_FIN packet, or // without tickling them to close by sending a TCP_FIN packet, or
// shutting down the write-side of the connection. // shutting down the write-side of the connection.
return nil return nil
} else { } else {
return c.getErrorHandler(rv, errno) return c.getErrorHandler(rv, errno)
} }
} }
func (c *Conn) shutdownLoop() error { func (c *Conn) shutdownLoop() error {
err := tryAgain err := tryAgain
shutdown_tries := 0 shutdown_tries := 0
for err == tryAgain { for err == tryAgain {
shutdown_tries = shutdown_tries + 1 shutdown_tries = shutdown_tries + 1
err = c.handleError(c.shutdown()) err = c.handleError(c.shutdown())
if err == nil { if err == nil {
return c.flushOutputBuffer() return c.flushOutputBuffer()
} }
if err == tryAgain && shutdown_tries >= 2 { if err == tryAgain && shutdown_tries >= 2 {
return errors.New("shutdown requested a third time?") return errors.New("shutdown requested a third time?")
} }
} }
if err == io.ErrUnexpectedEOF { if err == io.ErrUnexpectedEOF {
err = nil err = nil
} }
return err return err
} }
// Close shuts down the SSL connection and closes the underlying wrapped // Close shuts down the SSL connection and closes the underlying wrapped
// connection. // connection.
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.mtx.Lock() c.mtx.Lock()
if c.is_shutdown { if c.is_shutdown {
c.mtx.Unlock() c.mtx.Unlock()
return nil return nil
} }
c.is_shutdown = true c.is_shutdown = true
c.mtx.Unlock() c.mtx.Unlock()
var errs utils.ErrorGroup var errs utils.ErrorGroup
errs.Add(c.shutdownLoop()) errs.Add(c.shutdownLoop())
errs.Add(c.conn.Close()) errs.Add(c.conn.Close())
return errs.Finalize() return errs.Finalize()
} }
func (c *Conn) read(b []byte) (int, func() error) { func (c *Conn) read(b []byte) (int, func() error) {
c.mtx.Lock() c.mtx.Lock()
defer c.mtx.Unlock() defer c.mtx.Unlock()
if c.is_shutdown { if c.is_shutdown {
return 0, func() error { return io.EOF } return 0, func() error { return io.EOF }
} }
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
rv, errno := C.SSL_read(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b))) rv, errno := C.SSL_read(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
if rv > 0 { if rv > 0 {
return int(rv), nil return int(rv), nil
} }
return 0, c.getErrorHandler(rv, errno) return 0, c.getErrorHandler(rv, errno)
} }
// Read reads up to len(b) bytes into b. It returns the number of bytes read // Read reads up to len(b) bytes into b. It returns the number of bytes read
// and an error if applicable. io.EOF is returned when the caller can expect // and an error if applicable. io.EOF is returned when the caller can expect
// to see no more data. // to see no more data.
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
if len(b) == 0 { if len(b) == 0 {
return 0, nil return 0, nil
} }
err = tryAgain err = tryAgain
for err == tryAgain { for err == tryAgain {
n, errcb := c.read(b) n, errcb := c.read(b)
err = c.handleError(errcb) err = c.handleError(errcb)
if err == nil { if err == nil {
go c.flushOutputBuffer() go c.flushOutputBuffer()
return n, nil return n, nil
} }
if err == io.ErrUnexpectedEOF { if err == io.ErrUnexpectedEOF {
err = io.EOF err = io.EOF
} }
} }
return 0, err return 0, err
} }
func (c *Conn) write(b []byte) (int, func() error) { func (c *Conn) write(b []byte) (int, func() error) {
c.mtx.Lock() c.mtx.Lock()
defer c.mtx.Unlock() defer c.mtx.Unlock()
if c.is_shutdown { if c.is_shutdown {
err := errors.New("connection closed") err := errors.New("connection closed")
return 0, func() error { return err } return 0, func() error { return err }
} }
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b))) rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
if rv > 0 { if rv > 0 {
return int(rv), nil return int(rv), nil
} }
return 0, c.getErrorHandler(rv, errno) return 0, c.getErrorHandler(rv, errno)
} }
// Write will encrypt the contents of b and write it to the underlying stream. // Write will encrypt the contents of b and write it to the underlying stream.
// Performance will be vastly improved if the size of b is a multiple of // Performance will be vastly improved if the size of b is a multiple of
// SSLRecordSize. // SSLRecordSize.
func (c *Conn) Write(b []byte) (written int, err error) { func (c *Conn) Write(b []byte) (written int, err error) {
if len(b) == 0 { if len(b) == 0 {
return 0, nil return 0, nil
} }
err = tryAgain err = tryAgain
for err == tryAgain { for err == tryAgain {
n, errcb := c.write(b) n, errcb := c.write(b)
err = c.handleError(errcb) err = c.handleError(errcb)
if err == nil { if err == nil {
return n, c.flushOutputBuffer() return n, c.flushOutputBuffer()
} }
} }
return 0, err return 0, err
} }
// VerifyHostname pulls the PeerCertificate and calls VerifyHostname on the // VerifyHostname pulls the PeerCertificate and calls VerifyHostname on the
// certificate. // certificate.
func (c *Conn) VerifyHostname(host string) error { func (c *Conn) VerifyHostname(host string) error {
cert, err := c.PeerCertificate() cert, err := c.PeerCertificate()
if err != nil { if err != nil {
return err return err
} }
return cert.VerifyHostname(host) return cert.VerifyHostname(host)
} }
// LocalAddr returns the underlying connection's local address // LocalAddr returns the underlying connection's local address
func (c *Conn) LocalAddr() net.Addr { func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr() return c.conn.LocalAddr()
} }
// RemoteAddr returns the underlying connection's remote address // RemoteAddr returns the underlying connection's remote address
func (c *Conn) RemoteAddr() net.Addr { func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr() return c.conn.RemoteAddr()
} }
// SetDeadline calls SetDeadline on the underlying connection. // SetDeadline calls SetDeadline on the underlying connection.
func (c *Conn) SetDeadline(t time.Time) error { func (c *Conn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t) return c.conn.SetDeadline(t)
} }
// SetReadDeadline calls SetReadDeadline on the underlying connection. // SetReadDeadline calls SetReadDeadline on the underlying connection.
func (c *Conn) SetReadDeadline(t time.Time) error { func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t) return c.conn.SetReadDeadline(t)
} }
// SetWriteDeadline calls SetWriteDeadline on the underlying connection. // SetWriteDeadline calls SetWriteDeadline on the underlying connection.
func (c *Conn) SetWriteDeadline(t time.Time) error { func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t) return c.conn.SetWriteDeadline(t)
} }

316
ctx.go
View File

@ -33,158 +33,158 @@ package openssl
import "C" import "C"
import ( import (
"errors" "errors"
"io/ioutil" "io/ioutil"
"runtime" "runtime"
"unsafe" "unsafe"
) )
type Ctx struct { type Ctx struct {
ctx *C.SSL_CTX ctx *C.SSL_CTX
} }
func newCtx(method *C.SSL_METHOD) (*Ctx, error) { func newCtx(method *C.SSL_METHOD) (*Ctx, error) {
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
ctx := C.SSL_CTX_new(method) ctx := C.SSL_CTX_new(method)
if ctx == nil { if ctx == nil {
return nil, errorFromErrorQueue() return nil, errorFromErrorQueue()
} }
c := &Ctx{ctx: ctx} c := &Ctx{ctx: ctx}
runtime.SetFinalizer(c, func(c *Ctx) { runtime.SetFinalizer(c, func(c *Ctx) {
C.SSL_CTX_free(c.ctx) C.SSL_CTX_free(c.ctx)
}) })
return c, nil return c, nil
} }
type SSLVersion int type SSLVersion int
const ( const (
SSLv3 SSLVersion = 0x02 SSLv3 SSLVersion = 0x02
TLSv1 SSLVersion = 0x03 TLSv1 SSLVersion = 0x03
TLSv1_1 SSLVersion = 0x04 TLSv1_1 SSLVersion = 0x04
TLSv1_2 SSLVersion = 0x05 TLSv1_2 SSLVersion = 0x05
AnyVersion SSLVersion = 0x06 AnyVersion SSLVersion = 0x06
) )
// NewCtxWithVersion creates an SSL context that is specific to the provided // NewCtxWithVersion creates an SSL context that is specific to the provided
// SSL version. See http://www.openssl.org/docs/ssl/SSL_CTX_new.html for more. // SSL version. See http://www.openssl.org/docs/ssl/SSL_CTX_new.html for more.
func NewCtxWithVersion(version SSLVersion) (*Ctx, error) { func NewCtxWithVersion(version SSLVersion) (*Ctx, error) {
var method *C.SSL_METHOD var method *C.SSL_METHOD
switch version { switch version {
case SSLv3: case SSLv3:
method = C.SSLv3_method() method = C.SSLv3_method()
case TLSv1: case TLSv1:
method = C.TLSv1_method() method = C.TLSv1_method()
case TLSv1_1: case TLSv1_1:
method = C.TLSv1_1_method() method = C.TLSv1_1_method()
case TLSv1_2: case TLSv1_2:
method = C.TLSv1_2_method() method = C.TLSv1_2_method()
case AnyVersion: case AnyVersion:
method = C.SSLv23_method() method = C.SSLv23_method()
} }
if method == nil { if method == nil {
return nil, errors.New("unknown ssl/tls version") return nil, errors.New("unknown ssl/tls version")
} }
return newCtx(method) return newCtx(method)
} }
// NewCtx creates a context that supports any TLS version 1.0 and newer. // NewCtx creates a context that supports any TLS version 1.0 and newer.
func NewCtx() (*Ctx, error) { func NewCtx() (*Ctx, error) {
c, err := NewCtxWithVersion(AnyVersion) c, err := NewCtxWithVersion(AnyVersion)
if err == nil { if err == nil {
c.SetOptions(NoSSLv2 | NoSSLv3) c.SetOptions(NoSSLv2 | NoSSLv3)
} }
return c, err return c, err
} }
// NewCtxFromFiles calls NewCtx, loads the provided files, and configures the // NewCtxFromFiles calls NewCtx, loads the provided files, and configures the
// context to use them. // context to use them.
func NewCtxFromFiles(cert_file string, key_file string) (*Ctx, error) { func NewCtxFromFiles(cert_file string, key_file string) (*Ctx, error) {
ctx, err := NewCtx() ctx, err := NewCtx()
if err != nil { if err != nil {
return nil, err return nil, err
} }
cert_bytes, err := ioutil.ReadFile(cert_file) cert_bytes, err := ioutil.ReadFile(cert_file)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cert, err := LoadCertificate(cert_bytes) cert, err := LoadCertificate(cert_bytes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.UseCertificate(cert) err = ctx.UseCertificate(cert)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key_bytes, err := ioutil.ReadFile(key_file) key_bytes, err := ioutil.ReadFile(key_file)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key, err := LoadPrivateKey(key_bytes) key, err := LoadPrivateKey(key_bytes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.UsePrivateKey(key) err = ctx.UsePrivateKey(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ctx, nil return ctx, nil
} }
// UseCertificate configures the context to present the given certificate to // UseCertificate configures the context to present the given certificate to
// peers. // peers.
func (c *Ctx) UseCertificate(cert *Certificate) error { func (c *Ctx) UseCertificate(cert *Certificate) error {
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
if int(C.SSL_CTX_use_certificate(c.ctx, cert.x)) != 1 { if int(C.SSL_CTX_use_certificate(c.ctx, cert.x)) != 1 {
return errorFromErrorQueue() return errorFromErrorQueue()
} }
return nil return nil
} }
// UsePrivateKey configures the context to use the given private key for SSL // UsePrivateKey configures the context to use the given private key for SSL
// handshakes. // handshakes.
func (c *Ctx) UsePrivateKey(key PrivateKey) error { func (c *Ctx) UsePrivateKey(key PrivateKey) error {
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
if int(C.SSL_CTX_use_PrivateKey(c.ctx, key.evpPKey())) != 1 { if int(C.SSL_CTX_use_PrivateKey(c.ctx, key.evpPKey())) != 1 {
return errorFromErrorQueue() return errorFromErrorQueue()
} }
return nil return nil
} }
type CertificateStore struct { type CertificateStore struct {
store *C.X509_STORE store *C.X509_STORE
ctx *Ctx // for gc ctx *Ctx // for gc
} }
// GetCertificateStore returns the context's certificate store that will be // GetCertificateStore returns the context's certificate store that will be
// used for peer validation. // used for peer validation.
func (c *Ctx) GetCertificateStore() *CertificateStore { func (c *Ctx) GetCertificateStore() *CertificateStore {
// we don't need to dealloc the cert store pointer here, because it points // we don't need to dealloc the cert store pointer here, because it points
// to a ctx internal. so we do need to keep the ctx around // to a ctx internal. so we do need to keep the ctx around
return &CertificateStore{ return &CertificateStore{
store: C.SSL_CTX_get_cert_store(c.ctx), store: C.SSL_CTX_get_cert_store(c.ctx),
ctx: c} ctx: c}
} }
// AddCertificate marks the provided Certificate as a trusted certificate in // AddCertificate marks the provided Certificate as a trusted certificate in
// the given CertificateStore. // the given CertificateStore.
func (s *CertificateStore) AddCertificate(cert *Certificate) error { func (s *CertificateStore) AddCertificate(cert *Certificate) error {
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
if int(C.X509_STORE_add_cert(s.store, cert.x)) != 1 { if int(C.X509_STORE_add_cert(s.store, cert.x)) != 1 {
return errorFromErrorQueue() return errorFromErrorQueue()
} }
return nil return nil
} }
// LoadVerifyLocations tells the context to trust all certificate authorities // LoadVerifyLocations tells the context to trust all certificate authorities
@ -192,120 +192,120 @@ func (s *CertificateStore) AddCertificate(cert *Certificate) error {
// See http://www.openssl.org/docs/ssl/SSL_CTX_load_verify_locations.html for // See http://www.openssl.org/docs/ssl/SSL_CTX_load_verify_locations.html for
// more. // more.
func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error { func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error {
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
var c_ca_file, c_ca_path *C.char var c_ca_file, c_ca_path *C.char
if ca_file != "" { if ca_file != "" {
c_ca_file = C.CString(ca_file) c_ca_file = C.CString(ca_file)
defer C.free(unsafe.Pointer(c_ca_file)) defer C.free(unsafe.Pointer(c_ca_file))
} }
if ca_path != "" { if ca_path != "" {
c_ca_path = C.CString(ca_path) c_ca_path = C.CString(ca_path)
defer C.free(unsafe.Pointer(c_ca_path)) defer C.free(unsafe.Pointer(c_ca_path))
} }
if C.SSL_CTX_load_verify_locations(c.ctx, c_ca_file, c_ca_path) != 1 { if C.SSL_CTX_load_verify_locations(c.ctx, c_ca_file, c_ca_path) != 1 {
return errorFromErrorQueue() return errorFromErrorQueue()
} }
return nil return nil
} }
type Options int type Options int
const ( const (
// NoCompression is only valid if you are using OpenSSL 1.0.1 or newer // NoCompression is only valid if you are using OpenSSL 1.0.1 or newer
NoCompression Options = C.SSL_OP_NO_COMPRESSION NoCompression Options = C.SSL_OP_NO_COMPRESSION
NoSSLv2 Options = C.SSL_OP_NO_SSLv2 NoSSLv2 Options = C.SSL_OP_NO_SSLv2
NoSSLv3 Options = C.SSL_OP_NO_SSLv3 NoSSLv3 Options = C.SSL_OP_NO_SSLv3
NoTLSv1 Options = C.SSL_OP_NO_TLSv1 NoTLSv1 Options = C.SSL_OP_NO_TLSv1
CipherServerPreference Options = C.SSL_OP_CIPHER_SERVER_PREFERENCE CipherServerPreference Options = C.SSL_OP_CIPHER_SERVER_PREFERENCE
NoSessionResumptionOrRenegotiation Options = C.SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION NoSessionResumptionOrRenegotiation Options = C.SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION
NoTicket Options = C.SSL_OP_NO_TICKET NoTicket Options = C.SSL_OP_NO_TICKET
) )
// SetOptions sets context options. See // SetOptions sets context options. See
// http://www.openssl.org/docs/ssl/SSL_CTX_set_options.html // http://www.openssl.org/docs/ssl/SSL_CTX_set_options.html
func (c *Ctx) SetOptions(options Options) Options { func (c *Ctx) SetOptions(options Options) Options {
return Options(C.SSL_CTX_set_options_not_a_macro( return Options(C.SSL_CTX_set_options_not_a_macro(
c.ctx, C.long(options))) c.ctx, C.long(options)))
} }
type Modes int type Modes int
const ( const (
// ReleaseBuffers is only valid if you are using OpenSSL 1.0.1 or newer // ReleaseBuffers is only valid if you are using OpenSSL 1.0.1 or newer
ReleaseBuffers Modes = C.SSL_MODE_RELEASE_BUFFERS ReleaseBuffers Modes = C.SSL_MODE_RELEASE_BUFFERS
) )
// SetMode sets context modes. See // SetMode sets context modes. See
// http://www.openssl.org/docs/ssl/SSL_CTX_set_mode.html // http://www.openssl.org/docs/ssl/SSL_CTX_set_mode.html
func (c *Ctx) SetMode(modes Modes) Modes { func (c *Ctx) SetMode(modes Modes) Modes {
return Modes(C.SSL_CTX_set_mode_not_a_macro(c.ctx, C.long(modes))) return Modes(C.SSL_CTX_set_mode_not_a_macro(c.ctx, C.long(modes)))
} }
type VerifyOptions int type VerifyOptions int
const ( const (
VerifyNone VerifyOptions = C.SSL_VERIFY_NONE VerifyNone VerifyOptions = C.SSL_VERIFY_NONE
VerifyPeer VerifyOptions = C.SSL_VERIFY_PEER VerifyPeer VerifyOptions = C.SSL_VERIFY_PEER
VerifyFailIfNoPeerCert VerifyOptions = C.SSL_VERIFY_FAIL_IF_NO_PEER_CERT VerifyFailIfNoPeerCert VerifyOptions = C.SSL_VERIFY_FAIL_IF_NO_PEER_CERT
VerifyClientOnce VerifyOptions = C.SSL_VERIFY_CLIENT_ONCE VerifyClientOnce VerifyOptions = C.SSL_VERIFY_CLIENT_ONCE
) )
// SetVerify controls peer verification settings. See // SetVerify controls peer verification settings. See
// http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html // http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html
func (c *Ctx) SetVerify(options VerifyOptions) { func (c *Ctx) SetVerify(options VerifyOptions) {
// TODO: take a callback // TODO: take a callback
C.SSL_CTX_set_verify(c.ctx, C.int(options), nil) C.SSL_CTX_set_verify(c.ctx, C.int(options), nil)
} }
// SetVerifyDepth controls how many certificates deep the certificate // SetVerifyDepth controls how many certificates deep the certificate
// verification logic is willing to follow a certificate chain. See // verification logic is willing to follow a certificate chain. See
// https://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html // https://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html
func (c *Ctx) SetVerifyDepth(depth int) { func (c *Ctx) SetVerifyDepth(depth int) {
C.SSL_CTX_set_verify_depth(c.ctx, C.int(depth)) C.SSL_CTX_set_verify_depth(c.ctx, C.int(depth))
} }
func (c *Ctx) SetSessionId(session_id []byte) error { func (c *Ctx) SetSessionId(session_id []byte) error {
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
if int(C.SSL_CTX_set_session_id_context(c.ctx, if int(C.SSL_CTX_set_session_id_context(c.ctx,
(*C.uchar)(unsafe.Pointer(&session_id[0])), (*C.uchar)(unsafe.Pointer(&session_id[0])),
C.uint(len(session_id)))) == 0 { C.uint(len(session_id)))) == 0 {
return errorFromErrorQueue() return errorFromErrorQueue()
} }
return nil return nil
} }
// SetCipherList sets the list of available ciphers. The format of the list is // SetCipherList sets the list of available ciphers. The format of the list is
// described at http://www.openssl.org/docs/apps/ciphers.html, but see // described at http://www.openssl.org/docs/apps/ciphers.html, but see
// http://www.openssl.org/docs/ssl/SSL_CTX_set_cipher_list.html for more. // http://www.openssl.org/docs/ssl/SSL_CTX_set_cipher_list.html for more.
func (c *Ctx) SetCipherList(list string) error { func (c *Ctx) SetCipherList(list string) error {
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
clist := C.CString(list) clist := C.CString(list)
defer C.free(unsafe.Pointer(clist)) defer C.free(unsafe.Pointer(clist))
if int(C.SSL_CTX_set_cipher_list(c.ctx, clist)) == 0 { if int(C.SSL_CTX_set_cipher_list(c.ctx, clist)) == 0 {
return errorFromErrorQueue() return errorFromErrorQueue()
} }
return nil return nil
} }
type SessionCacheModes int type SessionCacheModes int
const ( const (
SessionCacheOff SessionCacheModes = C.SSL_SESS_CACHE_OFF SessionCacheOff SessionCacheModes = C.SSL_SESS_CACHE_OFF
SessionCacheClient SessionCacheModes = C.SSL_SESS_CACHE_CLIENT SessionCacheClient SessionCacheModes = C.SSL_SESS_CACHE_CLIENT
SessionCacheServer SessionCacheModes = C.SSL_SESS_CACHE_SERVER SessionCacheServer SessionCacheModes = C.SSL_SESS_CACHE_SERVER
SessionCacheBoth SessionCacheModes = C.SSL_SESS_CACHE_BOTH SessionCacheBoth SessionCacheModes = C.SSL_SESS_CACHE_BOTH
NoAutoClear SessionCacheModes = C.SSL_SESS_CACHE_NO_AUTO_CLEAR NoAutoClear SessionCacheModes = C.SSL_SESS_CACHE_NO_AUTO_CLEAR
NoInternalLookup SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_LOOKUP NoInternalLookup SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_LOOKUP
NoInternalStore SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_STORE NoInternalStore SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_STORE
NoInternal SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL NoInternal SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL
) )
// SetSessionCacheMode enables or disables session caching. See // SetSessionCacheMode enables or disables session caching. See
// http://www.openssl.org/docs/ssl/SSL_CTX_set_session_cache_mode.html // http://www.openssl.org/docs/ssl/SSL_CTX_set_session_cache_mode.html
func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes { func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes {
return SessionCacheModes( return SessionCacheModes(
C.SSL_CTX_set_session_cache_mode_not_a_macro(c.ctx, C.long(modes))) C.SSL_CTX_set_session_cache_mode_not_a_macro(c.ctx, C.long(modes)))
} }

View File

@ -23,20 +23,20 @@ extern int X509_check_ip(X509 *x, const unsigned char *chk, size_t chklen,
import "C" import "C"
import ( import (
"errors" "errors"
"net" "net"
"unsafe" "unsafe"
) )
var ( var (
ValidationError = errors.New("Host validation error") ValidationError = errors.New("Host validation error")
) )
type CheckFlags int type CheckFlags int
const ( const (
AlwaysCheckSubject CheckFlags = C.X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT AlwaysCheckSubject CheckFlags = C.X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT
NoWildcards CheckFlags = C.X509_CHECK_FLAG_NO_WILDCARDS NoWildcards CheckFlags = C.X509_CHECK_FLAG_NO_WILDCARDS
) )
// CheckHost checks that the X509 certificate is signed for the provided // CheckHost checks that the X509 certificate is signed for the provided
@ -45,17 +45,17 @@ const (
// Specifically returns ValidationError if the Certificate didn't match but // Specifically returns ValidationError if the Certificate didn't match but
// there was no internal error. // there was no internal error.
func (c *Certificate) CheckHost(host string, flags CheckFlags) error { func (c *Certificate) CheckHost(host string, flags CheckFlags) error {
chost := unsafe.Pointer(C.CString(host)) chost := unsafe.Pointer(C.CString(host))
defer C.free(chost) defer C.free(chost)
rv := C.X509_check_host(c.x, (*C.uchar)(chost), C.size_t(len(host)), rv := C.X509_check_host(c.x, (*C.uchar)(chost), C.size_t(len(host)),
C.uint(flags)) C.uint(flags))
if rv > 0 { if rv > 0 {
return nil return nil
} }
if rv == 0 { if rv == 0 {
return ValidationError return ValidationError
} }
return errors.New("hostname validation had an internal failure") return errors.New("hostname validation had an internal failure")
} }
// CheckEmail checks that the X509 certificate is signed for the provided // CheckEmail checks that the X509 certificate is signed for the provided
@ -64,17 +64,17 @@ func (c *Certificate) CheckHost(host string, flags CheckFlags) error {
// Specifically returns ValidationError if the Certificate didn't match but // Specifically returns ValidationError if the Certificate didn't match but
// there was no internal error. // there was no internal error.
func (c *Certificate) CheckEmail(email string, flags CheckFlags) error { func (c *Certificate) CheckEmail(email string, flags CheckFlags) error {
cemail := unsafe.Pointer(C.CString(email)) cemail := unsafe.Pointer(C.CString(email))
defer C.free(cemail) defer C.free(cemail)
rv := C.X509_check_email(c.x, (*C.uchar)(cemail), C.size_t(len(email)), rv := C.X509_check_email(c.x, (*C.uchar)(cemail), C.size_t(len(email)),
C.uint(flags)) C.uint(flags))
if rv > 0 { if rv > 0 {
return nil return nil
} }
if rv == 0 { if rv == 0 {
return ValidationError return ValidationError
} }
return errors.New("email validation had an internal failure") return errors.New("email validation had an internal failure")
} }
// CheckIP checks that the X509 certificate is signed for the provided // CheckIP checks that the X509 certificate is signed for the provided
@ -83,16 +83,16 @@ func (c *Certificate) CheckEmail(email string, flags CheckFlags) error {
// Specifically returns ValidationError if the Certificate didn't match but // Specifically returns ValidationError if the Certificate didn't match but
// there was no internal error. // there was no internal error.
func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error { func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error {
cip := unsafe.Pointer(&ip[0]) cip := unsafe.Pointer(&ip[0])
rv := C.X509_check_ip(c.x, (*C.uchar)(cip), C.size_t(len(ip)), rv := C.X509_check_ip(c.x, (*C.uchar)(cip), C.size_t(len(ip)),
C.uint(flags)) C.uint(flags))
if rv > 0 { if rv > 0 {
return nil return nil
} }
if rv == 0 { if rv == 0 {
return ValidationError return ValidationError
} }
return errors.New("ip validation had an internal failure") return errors.New("ip validation had an internal failure")
} }
// VerifyHostname is a combination of CheckHost and CheckIP. If the provided // VerifyHostname is a combination of CheckHost and CheckIP. If the provided
@ -101,14 +101,14 @@ func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error {
// Specifically returns ValidationError if the Certificate didn't match but // Specifically returns ValidationError if the Certificate didn't match but
// there was no internal error. // there was no internal error.
func (c *Certificate) VerifyHostname(host string) error { func (c *Certificate) VerifyHostname(host string) error {
var ip net.IP var ip net.IP
if len(host) >= 3 && host[0] == '[' && host[len(host)-1] == ']' { if len(host) >= 3 && host[0] == '[' && host[len(host)-1] == ']' {
ip = net.ParseIP(host[1 : len(host)-1]) ip = net.ParseIP(host[1 : len(host)-1])
} else { } else {
ip = net.ParseIP(host) ip = net.ParseIP(host)
} }
if ip != nil { if ip != nil {
return c.CheckIP(ip, 0) return c.CheckIP(ip, 0)
} }
return c.CheckHost(host, 0) return c.CheckHost(host, 0)
} }

36
http.go
View File

@ -3,37 +3,37 @@
package openssl package openssl
import ( import (
"net/http" "net/http"
) )
// ListenAndServeTLS will take an http.Handler and serve it using OpenSSL over // ListenAndServeTLS will take an http.Handler and serve it using OpenSSL over
// the given tcp address, configured to use the provided cert and key files. // the given tcp address, configured to use the provided cert and key files.
func ListenAndServeTLS(addr string, cert_file string, key_file string, func ListenAndServeTLS(addr string, cert_file string, key_file string,
handler http.Handler) error { handler http.Handler) error {
return ServerListenAndServeTLS( return ServerListenAndServeTLS(
&http.Server{Addr: addr, Handler: handler}, cert_file, key_file) &http.Server{Addr: addr, Handler: handler}, cert_file, key_file)
} }
// ServerListenAndServeTLS will take an http.Server and serve it using OpenSSL // ServerListenAndServeTLS will take an http.Server and serve it using OpenSSL
// configured to use the provided cert and key files. // configured to use the provided cert and key files.
func ServerListenAndServeTLS(srv *http.Server, func ServerListenAndServeTLS(srv *http.Server,
cert_file, key_file string) error { cert_file, key_file string) error {
addr := srv.Addr addr := srv.Addr
if addr == "" { if addr == "" {
addr = ":https" addr = ":https"
} }
ctx, err := NewCtxFromFiles(cert_file, key_file) ctx, err := NewCtxFromFiles(cert_file, key_file)
if err != nil { if err != nil {
return err return err
} }
l, err := Listen("tcp", addr, ctx) l, err := Listen("tcp", addr, ctx)
if err != nil { if err != nil {
return err return err
} }
return srv.Serve(l) return srv.Serve(l)
} }
// TODO: http client integration // TODO: http client integration

66
init.go
View File

@ -68,58 +68,58 @@ static void OpenSSL_add_all_algorithms_not_a_macro() {
import "C" import "C"
import ( import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
"code.spacemonkey.com/go/openssl/utils" "code.spacemonkey.com/go/openssl/utils"
) )
var ( var (
sslMutexes []sync.Mutex sslMutexes []sync.Mutex
) )
func init() { func init() {
C.OPENSSL_config(nil) C.OPENSSL_config(nil)
C.ENGINE_load_builtin_engines() C.ENGINE_load_builtin_engines()
C.SSL_load_error_strings() C.SSL_load_error_strings()
C.SSL_library_init() C.SSL_library_init()
C.OpenSSL_add_all_algorithms_not_a_macro() C.OpenSSL_add_all_algorithms_not_a_macro()
sslMutexes = make([]sync.Mutex, int(C.CRYPTO_num_locks())) sslMutexes = make([]sync.Mutex, int(C.CRYPTO_num_locks()))
C.CRYPTO_set_id_callback((*[0]byte)(C.sslThreadId)) C.CRYPTO_set_id_callback((*[0]byte)(C.sslThreadId))
C.CRYPTO_set_locking_callback((*[0]byte)(C.sslMutexOp)) C.CRYPTO_set_locking_callback((*[0]byte)(C.sslMutexOp))
// TODO: support dynlock callbacks // TODO: support dynlock callbacks
} }
// errorFromErrorQueue needs to run in the same OS thread as the operation // errorFromErrorQueue needs to run in the same OS thread as the operation
// that caused the possible error // that caused the possible error
func errorFromErrorQueue() error { func errorFromErrorQueue() error {
var errs []string var errs []string
for { for {
err := C.ERR_get_error() err := C.ERR_get_error()
if err == 0 { if err == 0 {
break break
} }
errs = append(errs, fmt.Sprintf("%s:%s:%s", errs = append(errs, fmt.Sprintf("%s:%s:%s",
C.GoString(C.ERR_lib_error_string(err)), C.GoString(C.ERR_lib_error_string(err)),
C.GoString(C.ERR_func_error_string(err)), C.GoString(C.ERR_func_error_string(err)),
C.GoString(C.ERR_reason_error_string(err)))) C.GoString(C.ERR_reason_error_string(err))))
} }
return errors.New(fmt.Sprintf("SSL errors: %s", strings.Join(errs, "\n"))) return errors.New(fmt.Sprintf("SSL errors: %s", strings.Join(errs, "\n")))
} }
//export sslMutexOp //export sslMutexOp
func sslMutexOp(mode, n C.int, file *C.char, line C.int) { func sslMutexOp(mode, n C.int, file *C.char, line C.int) {
if mode&C.CRYPTO_LOCK > 0 { if mode&C.CRYPTO_LOCK > 0 {
sslMutexes[n].Lock() sslMutexes[n].Lock()
} else { } else {
sslMutexes[n].Unlock() sslMutexes[n].Unlock()
} }
} }
//export sslThreadId //export sslThreadId
func sslThreadId() C.ulong { func sslThreadId() C.ulong {
return C.ulong(uintptr(utils.ThreadId())) return C.ulong(uintptr(utils.ThreadId()))
} }

112
net.go
View File

@ -3,49 +3,49 @@
package openssl package openssl
import ( import (
"errors" "errors"
"net" "net"
) )
type listener struct { type listener struct {
net.Listener net.Listener
ctx *Ctx ctx *Ctx
} }
func (l *listener) Accept() (c net.Conn, err error) { func (l *listener) Accept() (c net.Conn, err error) {
c, err = l.Listener.Accept() c, err = l.Listener.Accept()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Server(c, l.ctx) return Server(c, l.ctx)
} }
// NewListener wraps an existing net.Listener such that all accepted // NewListener wraps an existing net.Listener such that all accepted
// connections are wrapped as OpenSSL server connections using the provided // connections are wrapped as OpenSSL server connections using the provided
// context ctx. // context ctx.
func NewListener(inner net.Listener, ctx *Ctx) net.Listener { func NewListener(inner net.Listener, ctx *Ctx) net.Listener {
return &listener{ return &listener{
Listener: inner, Listener: inner,
ctx: ctx} ctx: ctx}
} }
// Listen is a wrapper around net.Listen that wraps incoming connections with // Listen is a wrapper around net.Listen that wraps incoming connections with
// an OpenSSL server connection using the provided context ctx. // an OpenSSL server connection using the provided context ctx.
func Listen(network, laddr string, ctx *Ctx) (net.Listener, error) { func Listen(network, laddr string, ctx *Ctx) (net.Listener, error) {
if ctx == nil { if ctx == nil {
return nil, errors.New("no ssl context provided") return nil, errors.New("no ssl context provided")
} }
l, err := net.Listen(network, laddr) l, err := net.Listen(network, laddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewListener(l, ctx), nil return NewListener(l, ctx), nil
} }
type DialFlags int type DialFlags int
const ( const (
InsecureSkipHostVerification DialFlags = 0x01 InsecureSkipHostVerification DialFlags = 0x01
) )
// Dial will connect to network/address and then wrap the corresponding // Dial will connect to network/address and then wrap the corresponding
@ -59,39 +59,39 @@ const (
// This library is not nice enough to use the system certificate store by // This library is not nice enough to use the system certificate store by
// default for you yet. // default for you yet.
func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) { func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
if ctx == nil { if ctx == nil {
var err error var err error
ctx, err = NewCtx() ctx, err = NewCtx()
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: use operating system default certificate chain? // TODO: use operating system default certificate chain?
} }
c, err := net.Dial(network, addr) c, err := net.Dial(network, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn, err := Client(c, ctx) conn, err := Client(c, ctx)
if err != nil { if err != nil {
c.Close() c.Close()
return nil, err return nil, err
} }
err = conn.Handshake() err = conn.Handshake()
if err != nil { if err != nil {
c.Close() c.Close()
return nil, err return nil, err
} }
if flags&InsecureSkipHostVerification == 0 { if flags&InsecureSkipHostVerification == 0 {
host, _, err := net.SplitHostPort(addr) host, _, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }
err = conn.VerifyHostname(host) err = conn.VerifyHostname(host)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }
} }
return conn, nil return conn, nil
} }

View File

@ -4,13 +4,13 @@
package openssl package openssl
import ( import (
"errors" "errors"
"net" "net"
"time" "time"
) )
const ( const (
SSLRecordSize = 16 * 1024 SSLRecordSize = 16 * 1024
) )
type Conn struct{} type Conn struct{}
@ -37,11 +37,11 @@ type Ctx struct{}
type SSLVersion int type SSLVersion int
const ( const (
SSLv3 SSLVersion = 0x02 SSLv3 SSLVersion = 0x02
TLSv1 SSLVersion = 0x03 TLSv1 SSLVersion = 0x03
TLSv1_1 SSLVersion = 0x04 TLSv1_1 SSLVersion = 0x04
TLSv1_2 SSLVersion = 0x05 TLSv1_2 SSLVersion = 0x05
AnyVersion SSLVersion = 0x06 AnyVersion SSLVersion = 0x06
) )
func NewCtxWithVersion(version SSLVersion) (*Ctx, error) func NewCtxWithVersion(version SSLVersion) (*Ctx, error)
@ -61,13 +61,13 @@ func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error
type Options int type Options int
const ( const (
NoCompression Options = 0 NoCompression Options = 0
NoSSLv2 Options = 0 NoSSLv2 Options = 0
NoSSLv3 Options = 0 NoSSLv3 Options = 0
NoTLSv1 Options = 0 NoTLSv1 Options = 0
CipherServerPreference Options = 0 CipherServerPreference Options = 0
NoSessionResumptionOrRenegotiation Options = 0 NoSessionResumptionOrRenegotiation Options = 0
NoTicket Options = 0 NoTicket Options = 0
) )
func (c *Ctx) SetOptions(options Options) Options func (c *Ctx) SetOptions(options Options) Options
@ -75,7 +75,7 @@ func (c *Ctx) SetOptions(options Options) Options
type Modes int type Modes int
const ( const (
ReleaseBuffers Modes = 0 ReleaseBuffers Modes = 0
) )
func (c *Ctx) SetMode(modes Modes) Modes func (c *Ctx) SetMode(modes Modes) Modes
@ -83,10 +83,10 @@ func (c *Ctx) SetMode(modes Modes) Modes
type VerifyOptions int type VerifyOptions int
const ( const (
VerifyNone VerifyOptions = 0 VerifyNone VerifyOptions = 0
VerifyPeer VerifyOptions = 0 VerifyPeer VerifyOptions = 0
VerifyFailIfNoPeerCert VerifyOptions = 0 VerifyFailIfNoPeerCert VerifyOptions = 0
VerifyClientOnce VerifyOptions = 0 VerifyClientOnce VerifyOptions = 0
) )
func (c *Ctx) SetVerify(options VerifyOptions) func (c *Ctx) SetVerify(options VerifyOptions)
@ -98,27 +98,27 @@ func (c *Ctx) SetCipherList(list string) error
type SessionCacheModes int type SessionCacheModes int
const ( const (
SessionCacheOff SessionCacheModes = 0 SessionCacheOff SessionCacheModes = 0
SessionCacheClient SessionCacheModes = 0 SessionCacheClient SessionCacheModes = 0
SessionCacheServer SessionCacheModes = 0 SessionCacheServer SessionCacheModes = 0
SessionCacheBoth SessionCacheModes = 0 SessionCacheBoth SessionCacheModes = 0
NoAutoClear SessionCacheModes = 0 NoAutoClear SessionCacheModes = 0
NoInternalLookup SessionCacheModes = 0 NoInternalLookup SessionCacheModes = 0
NoInternalStore SessionCacheModes = 0 NoInternalStore SessionCacheModes = 0
NoInternal SessionCacheModes = 0 NoInternal SessionCacheModes = 0
) )
func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes
var ( var (
ValidationError = errors.New("Host validation error") ValidationError = errors.New("Host validation error")
) )
type CheckFlags int type CheckFlags int
const ( const (
AlwaysCheckSubject CheckFlags = 0 AlwaysCheckSubject CheckFlags = 0
NoWildcards CheckFlags = 0 NoWildcards CheckFlags = 0
) )
func (c *Certificate) CheckHost(host string, flags CheckFlags) error func (c *Certificate) CheckHost(host string, flags CheckFlags) error
@ -127,15 +127,15 @@ func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error
func (c *Certificate) VerifyHostname(host string) error func (c *Certificate) VerifyHostname(host string) error
type PublicKey interface { type PublicKey interface {
MarshalPKIXPublicKeyPEM() (pem_block []byte, err error) MarshalPKIXPublicKeyPEM() (pem_block []byte, err error)
MarshalPKIXPublicKeyDER() (der_block []byte, err error) MarshalPKIXPublicKeyDER() (der_block []byte, err error)
evpPKey() struct{} evpPKey() struct{}
} }
type PrivateKey interface { type PrivateKey interface {
PublicKey PublicKey
MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error) MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error)
MarshalPKCS1PrivateKeyDER() (der_block []byte, err error) MarshalPKCS1PrivateKeyDER() (der_block []byte, err error)
} }
func LoadPrivateKey(pem_block []byte) (PrivateKey, error) func LoadPrivateKey(pem_block []byte) (PrivateKey, error)

276
pem.go
View File

@ -8,191 +8,191 @@ package openssl
import "C" import "C"
import ( import (
"errors" "errors"
"io/ioutil" "io/ioutil"
"runtime" "runtime"
"unsafe" "unsafe"
) )
type PublicKey interface { type PublicKey interface {
// MarshalPKIXPublicKeyPEM converts the public key to PEM-encoded PKIX // MarshalPKIXPublicKeyPEM converts the public key to PEM-encoded PKIX
// format // format
MarshalPKIXPublicKeyPEM() (pem_block []byte, err error) MarshalPKIXPublicKeyPEM() (pem_block []byte, err error)
// MarshalPKIXPublicKeyDER converts the public key to DER-encoded PKIX // MarshalPKIXPublicKeyDER converts the public key to DER-encoded PKIX
// format // format
MarshalPKIXPublicKeyDER() (der_block []byte, err error) MarshalPKIXPublicKeyDER() (der_block []byte, err error)
evpPKey() *C.EVP_PKEY evpPKey() *C.EVP_PKEY
} }
type PrivateKey interface { type PrivateKey interface {
PublicKey PublicKey
// MarshalPKCS1PrivateKeyPEM converts the private key to PEM-encoded PKCS1 // MarshalPKCS1PrivateKeyPEM converts the private key to PEM-encoded PKCS1
// format // format
MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error) MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error)
// MarshalPKCS1PrivateKeyDER converts the private key to DER-encoded PKCS1 // MarshalPKCS1PrivateKeyDER converts the private key to DER-encoded PKCS1
// format // format
MarshalPKCS1PrivateKeyDER() (der_block []byte, err error) MarshalPKCS1PrivateKeyDER() (der_block []byte, err error)
} }
type pKey struct { type pKey struct {
key *C.EVP_PKEY key *C.EVP_PKEY
} }
func (key *pKey) evpPKey() *C.EVP_PKEY { return key.key } func (key *pKey) evpPKey() *C.EVP_PKEY { return key.key }
func (key *pKey) MarshalPKCS1PrivateKeyPEM() (pem_block []byte, func (key *pKey) MarshalPKCS1PrivateKeyPEM() (pem_block []byte,
err error) { err error) {
bio := C.BIO_new(C.BIO_s_mem()) bio := C.BIO_new(C.BIO_s_mem())
if bio == nil { if bio == nil {
return nil, errors.New("failed to allocate memory BIO") return nil, errors.New("failed to allocate memory BIO")
} }
defer C.BIO_free(bio) defer C.BIO_free(bio)
rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key))
if rsa == nil { if rsa == nil {
return nil, errors.New("failed getting rsa key") return nil, errors.New("failed getting rsa key")
} }
defer C.RSA_free(rsa) defer C.RSA_free(rsa)
if int(C.PEM_write_bio_RSAPrivateKey(bio, rsa, nil, nil, C.int(0), nil, if int(C.PEM_write_bio_RSAPrivateKey(bio, rsa, nil, nil, C.int(0), nil,
nil)) != 1 { nil)) != 1 {
return nil, errors.New("failed dumping private key") return nil, errors.New("failed dumping private key")
} }
return ioutil.ReadAll(asAnyBio(bio)) return ioutil.ReadAll(asAnyBio(bio))
} }
func (key *pKey) MarshalPKCS1PrivateKeyDER() (der_block []byte, func (key *pKey) MarshalPKCS1PrivateKeyDER() (der_block []byte,
err error) { err error) {
bio := C.BIO_new(C.BIO_s_mem()) bio := C.BIO_new(C.BIO_s_mem())
if bio == nil { if bio == nil {
return nil, errors.New("failed to allocate memory BIO") return nil, errors.New("failed to allocate memory BIO")
} }
defer C.BIO_free(bio) defer C.BIO_free(bio)
rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key))
if rsa == nil { if rsa == nil {
return nil, errors.New("failed getting rsa key") return nil, errors.New("failed getting rsa key")
} }
defer C.RSA_free(rsa) defer C.RSA_free(rsa)
if int(C.i2d_RSAPrivateKey_bio(bio, rsa)) != 1 { if int(C.i2d_RSAPrivateKey_bio(bio, rsa)) != 1 {
return nil, errors.New("failed dumping private key der") return nil, errors.New("failed dumping private key der")
} }
return ioutil.ReadAll(asAnyBio(bio)) return ioutil.ReadAll(asAnyBio(bio))
} }
func (key *pKey) MarshalPKIXPublicKeyPEM() (pem_block []byte, func (key *pKey) MarshalPKIXPublicKeyPEM() (pem_block []byte,
err error) { err error) {
bio := C.BIO_new(C.BIO_s_mem()) bio := C.BIO_new(C.BIO_s_mem())
if bio == nil { if bio == nil {
return nil, errors.New("failed to allocate memory BIO") return nil, errors.New("failed to allocate memory BIO")
} }
defer C.BIO_free(bio) defer C.BIO_free(bio)
rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key))
if rsa == nil { if rsa == nil {
return nil, errors.New("failed getting rsa key") return nil, errors.New("failed getting rsa key")
} }
defer C.RSA_free(rsa) defer C.RSA_free(rsa)
if int(C.PEM_write_bio_RSA_PUBKEY(bio, rsa)) != 1 { if int(C.PEM_write_bio_RSA_PUBKEY(bio, rsa)) != 1 {
return nil, errors.New("failed dumping public key pem") return nil, errors.New("failed dumping public key pem")
} }
return ioutil.ReadAll(asAnyBio(bio)) return ioutil.ReadAll(asAnyBio(bio))
} }
func (key *pKey) MarshalPKIXPublicKeyDER() (der_block []byte, func (key *pKey) MarshalPKIXPublicKeyDER() (der_block []byte,
err error) { err error) {
bio := C.BIO_new(C.BIO_s_mem()) bio := C.BIO_new(C.BIO_s_mem())
if bio == nil { if bio == nil {
return nil, errors.New("failed to allocate memory BIO") return nil, errors.New("failed to allocate memory BIO")
} }
defer C.BIO_free(bio) defer C.BIO_free(bio)
rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key)) rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key))
if rsa == nil { if rsa == nil {
return nil, errors.New("failed getting rsa key") return nil, errors.New("failed getting rsa key")
} }
defer C.RSA_free(rsa) defer C.RSA_free(rsa)
if int(C.i2d_RSA_PUBKEY_bio(bio, rsa)) != 1 { if int(C.i2d_RSA_PUBKEY_bio(bio, rsa)) != 1 {
return nil, errors.New("failed dumping public key der") return nil, errors.New("failed dumping public key der")
} }
return ioutil.ReadAll(asAnyBio(bio)) return ioutil.ReadAll(asAnyBio(bio))
} }
// LoadPrivateKey loads a private key from a PEM-encoded block. // LoadPrivateKey loads a private key from a PEM-encoded block.
func LoadPrivateKey(pem_block []byte) (PrivateKey, error) { func LoadPrivateKey(pem_block []byte) (PrivateKey, error) {
bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]), bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]),
C.int(len(pem_block))) C.int(len(pem_block)))
if bio == nil { if bio == nil {
return nil, errors.New("failed creating bio") return nil, errors.New("failed creating bio")
} }
defer C.BIO_free(bio) defer C.BIO_free(bio)
rsakey := C.PEM_read_bio_RSAPrivateKey(bio, nil, nil, nil) rsakey := C.PEM_read_bio_RSAPrivateKey(bio, nil, nil, nil)
if rsakey == nil { if rsakey == nil {
return nil, errors.New("failed reading rsa key") return nil, errors.New("failed reading rsa key")
} }
defer C.RSA_free(rsakey) defer C.RSA_free(rsakey)
// convert to PKEY // convert to PKEY
key := C.EVP_PKEY_new() key := C.EVP_PKEY_new()
if key == nil { if key == nil {
return nil, errors.New("failed converting to evp_pkey") return nil, errors.New("failed converting to evp_pkey")
} }
if C.EVP_PKEY_set1_RSA(key, (*C.struct_rsa_st)(rsakey)) != 1 { if C.EVP_PKEY_set1_RSA(key, (*C.struct_rsa_st)(rsakey)) != 1 {
C.EVP_PKEY_free(key) C.EVP_PKEY_free(key)
return nil, errors.New("failed converting to evp_pkey") return nil, errors.New("failed converting to evp_pkey")
} }
p := &pKey{key: key} p := &pKey{key: key}
runtime.SetFinalizer(p, func(p *pKey) { runtime.SetFinalizer(p, func(p *pKey) {
C.EVP_PKEY_free(p.key) C.EVP_PKEY_free(p.key)
}) })
return p, nil return p, nil
} }
type Certificate struct { type Certificate struct {
x *C.X509 x *C.X509
} }
// LoadCertificate loads an X509 certificate from a PEM-encoded block. // LoadCertificate loads an X509 certificate from a PEM-encoded block.
func LoadCertificate(pem_block []byte) (*Certificate, error) { func LoadCertificate(pem_block []byte) (*Certificate, error) {
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]), bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]),
C.int(len(pem_block))) C.int(len(pem_block)))
cert := C.PEM_read_bio_X509(bio, nil, nil, nil) cert := C.PEM_read_bio_X509(bio, nil, nil, nil)
C.BIO_free(bio) C.BIO_free(bio)
if cert == nil { if cert == nil {
return nil, errorFromErrorQueue() return nil, errorFromErrorQueue()
} }
x := &Certificate{x: cert} x := &Certificate{x: cert}
runtime.SetFinalizer(x, func(x *Certificate) { runtime.SetFinalizer(x, func(x *Certificate) {
C.X509_free(x.x) C.X509_free(x.x)
}) })
return x, nil return x, nil
} }
// MarshalPEM converts the X509 certificate to PEM-encoded format // MarshalPEM converts the X509 certificate to PEM-encoded format
func (c *Certificate) MarshalPEM() (pem_block []byte, err error) { func (c *Certificate) MarshalPEM() (pem_block []byte, err error) {
bio := C.BIO_new(C.BIO_s_mem()) bio := C.BIO_new(C.BIO_s_mem())
if bio == nil { if bio == nil {
return nil, errors.New("failed to allocate memory BIO") return nil, errors.New("failed to allocate memory BIO")
} }
defer C.BIO_free(bio) defer C.BIO_free(bio)
if int(C.PEM_write_bio_X509(bio, c.x)) != 1 { if int(C.PEM_write_bio_X509(bio, c.x)) != 1 {
return nil, errors.New("failed dumping certificate") return nil, errors.New("failed dumping certificate")
} }
return ioutil.ReadAll(asAnyBio(bio)) return ioutil.ReadAll(asAnyBio(bio))
} }
// PublicKey returns the public key embedded in the X509 certificate. // PublicKey returns the public key embedded in the X509 certificate.
func (c *Certificate) PublicKey() (PublicKey, error) { func (c *Certificate) PublicKey() (PublicKey, error) {
pkey := C.X509_get_pubkey(c.x) pkey := C.X509_get_pubkey(c.x)
if pkey == nil { if pkey == nil {
return nil, errors.New("no public key found") return nil, errors.New("no public key found")
} }
key := &pKey{key: pkey} key := &pKey{key: pkey}
runtime.SetFinalizer(key, func(key *pKey) { runtime.SetFinalizer(key, func(key *pKey) {
C.EVP_PKEY_free(key.key) C.EVP_PKEY_free(key.key)
}) })
return key, nil return key, nil
} }

View File

@ -3,74 +3,74 @@
package openssl package openssl
import ( import (
"bytes" "bytes"
"crypto/rsa" "crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"io/ioutil" "io/ioutil"
"testing" "testing"
) )
func TestMarshal(t *testing.T) { func TestMarshal(t *testing.T) {
key, err := LoadPrivateKey(keyBytes) key, err := LoadPrivateKey(keyBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cert, err := LoadCertificate(certBytes) cert, err := LoadCertificate(certBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
pem, err := cert.MarshalPEM() pem, err := cert.MarshalPEM()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(pem, certBytes) { if !bytes.Equal(pem, certBytes) {
ioutil.WriteFile("generated", pem, 0644) ioutil.WriteFile("generated", pem, 0644)
ioutil.WriteFile("hardcoded", certBytes, 0644) ioutil.WriteFile("hardcoded", certBytes, 0644)
t.Fatal("invalid cert pem bytes") t.Fatal("invalid cert pem bytes")
} }
pem, err = key.MarshalPKCS1PrivateKeyPEM() pem, err = key.MarshalPKCS1PrivateKeyPEM()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(pem, keyBytes) { if !bytes.Equal(pem, keyBytes) {
ioutil.WriteFile("generated", pem, 0644) ioutil.WriteFile("generated", pem, 0644)
ioutil.WriteFile("hardcoded", keyBytes, 0644) ioutil.WriteFile("hardcoded", keyBytes, 0644)
t.Fatal("invalid private key pem bytes") t.Fatal("invalid private key pem bytes")
} }
tls_cert, err := tls.X509KeyPair(certBytes, keyBytes) tls_cert, err := tls.X509KeyPair(certBytes, keyBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tls_key, ok := tls_cert.PrivateKey.(*rsa.PrivateKey) tls_key, ok := tls_cert.PrivateKey.(*rsa.PrivateKey)
if !ok { if !ok {
t.Fatal("FASDFASDF") t.Fatal("FASDFASDF")
} }
_ = tls_key _ = tls_key
der, err := key.MarshalPKCS1PrivateKeyDER() der, err := key.MarshalPKCS1PrivateKeyDER()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tls_der := x509.MarshalPKCS1PrivateKey(tls_key) tls_der := x509.MarshalPKCS1PrivateKey(tls_key)
if !bytes.Equal(der, tls_der) { if !bytes.Equal(der, tls_der) {
t.Fatal("invalid private key der bytes: %s\n v.s. %s\n", hex.Dump(der), hex.Dump(tls_der)) t.Fatal("invalid private key der bytes: %s\n v.s. %s\n", hex.Dump(der), hex.Dump(tls_der))
} }
der, err = key.MarshalPKIXPublicKeyDER() der, err = key.MarshalPKIXPublicKeyDER()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tls_der, err = x509.MarshalPKIXPublicKey(&tls_key.PublicKey) tls_der, err = x509.MarshalPKIXPublicKey(&tls_key.PublicKey)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(der, tls_der) { if !bytes.Equal(der, tls_der) {
ioutil.WriteFile("generated", []byte(hex.Dump(der)), 0644) ioutil.WriteFile("generated", []byte(hex.Dump(der)), 0644)
ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644)
t.Fatal("invalid public key der bytes") t.Fatal("invalid public key der bytes")
} }
} }

View File

@ -3,21 +3,21 @@
package openssl package openssl
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"sync" "sync"
"testing" "testing"
"time" "time"
"code.spacemonkey.com/go/openssl/utils" "code.spacemonkey.com/go/openssl/utils"
) )
var ( var (
certBytes = []byte(`-----BEGIN CERTIFICATE----- certBytes = []byte(`-----BEGIN CERTIFICATE-----
MIIDxDCCAqygAwIBAgIVAMcK/0VWQr2O3MNfJCydqR7oVELcMA0GCSqGSIb3DQEB MIIDxDCCAqygAwIBAgIVAMcK/0VWQr2O3MNfJCydqR7oVELcMA0GCSqGSIb3DQEB
BQUAMIGQMUkwRwYDVQQDE0A1NjdjZGRmYzRjOWZiNTYwZTk1M2ZlZjA1N2M0NGFm BQUAMIGQMUkwRwYDVQQDE0A1NjdjZGRmYzRjOWZiNTYwZTk1M2ZlZjA1N2M0NGFm
MDdiYjc4MDIzODIxYTA5NThiY2RmMGMwNzJhOTdiMThhMQswCQYDVQQGEwJVUzEN MDdiYjc4MDIzODIxYTA5NThiY2RmMGMwNzJhOTdiMThhMQswCQYDVQQGEwJVUzEN
@ -41,7 +41,7 @@ sRkg/uxcJf7wC5Y0BLlp1+aPwdmZD87T3a1uQ1Ij93jmHG+2T9U20MklHAePOl0q
yTqdSPnSH1c= yTqdSPnSH1c=
-----END CERTIFICATE----- -----END CERTIFICATE-----
`) `)
keyBytes = []byte(`-----BEGIN RSA PRIVATE KEY----- keyBytes = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpQIBAAKCAQEA3X94nDbxbK5a5zS4vEqHLHKpUmxavqRL5oXEqKoAy6nm56rv MIIEpQIBAAKCAQEA3X94nDbxbK5a5zS4vEqHLHKpUmxavqRL5oXEqKoAy6nm56rv
C3e9xySe+DBlxIEV/MWU+RYpzjC99QkerfRP493aleqfhn3ZRS3tyKrQtP2z1Zwg C3e9xySe+DBlxIEV/MWU+RYpzjC99QkerfRP493aleqfhn3ZRS3tyKrQtP2z1Zwg
wYqwcoASOLgqzKvtVYQMT1nJaw6O5fUEWG7BMR/ZX5/kcr8XjTGYjgEmrL1WTZ3G wYqwcoASOLgqzKvtVYQMT1nJaw6O5fUEWG7BMR/ZX5/kcr8XjTGYjgEmrL1WTZ3G
@ -72,534 +72,534 @@ qmgvgyRayemfO2zR0CPgC6wSoGBth+xW6g+WA8y0z76ZSaWpFi8lVM4=
) )
func NetPipe(t testing.TB) (net.Conn, net.Conn) { func NetPipe(t testing.TB) (net.Conn, net.Conn) {
l, err := net.Listen("tcp", "localhost:0") l, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer l.Close() defer l.Close()
client_future := utils.NewFuture() client_future := utils.NewFuture()
go func() { go func() {
client_future.Set(net.Dial(l.Addr().Network(), l.Addr().String())) client_future.Set(net.Dial(l.Addr().Network(), l.Addr().String()))
}() }()
var errs utils.ErrorGroup var errs utils.ErrorGroup
server_conn, err := l.Accept() server_conn, err := l.Accept()
errs.Add(err) errs.Add(err)
client_conn, err := client_future.Get() client_conn, err := client_future.Get()
errs.Add(err) errs.Add(err)
err = errs.Finalize() err = errs.Finalize()
if err != nil { if err != nil {
if server_conn != nil { if server_conn != nil {
server_conn.Close() server_conn.Close()
} }
if client_conn != nil { if client_conn != nil {
client_conn.(net.Conn).Close() client_conn.(net.Conn).Close()
} }
t.Fatal(err) t.Fatal(err)
} }
return server_conn, client_conn.(net.Conn) return server_conn, client_conn.(net.Conn)
} }
type HandshakingConn interface { type HandshakingConn interface {
net.Conn net.Conn
Handshake() error Handshake() error
} }
func SimpleConnTest(t testing.TB, constructor func( func SimpleConnTest(t testing.TB, constructor func(
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
server_conn, client_conn := NetPipe(t) server_conn, client_conn := NetPipe(t)
defer server_conn.Close() defer server_conn.Close()
defer client_conn.Close() defer client_conn.Close()
data := "first test string\n" data := "first test string\n"
server, client := constructor(t, server_conn, client_conn) server, client := constructor(t, server_conn, client_conn)
defer close_both(server, client) defer close_both(server, client)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
go func() { go func() {
defer wg.Done() defer wg.Done()
err := client.Handshake() err := client.Handshake()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = io.Copy(client, bytes.NewReader([]byte(data))) _, err = io.Copy(client, bytes.NewReader([]byte(data)))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = client.Close() err = client.Close()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
err := server.Handshake() err := server.Handshake()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
buf := bytes.NewBuffer(make([]byte, 0, len(data))) buf := bytes.NewBuffer(make([]byte, 0, len(data)))
_, err = io.CopyN(buf, server, int64(len(data))) _, err = io.CopyN(buf, server, int64(len(data)))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if string(buf.Bytes()) != data { if string(buf.Bytes()) != data {
t.Fatal("mismatched data") t.Fatal("mismatched data")
} }
err = server.Close() err = server.Close()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
}() }()
wg.Wait() wg.Wait()
} }
func close_both(closer1, closer2 io.Closer) { func close_both(closer1, closer2 io.Closer) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
go func() { go func() {
defer wg.Done() defer wg.Done()
closer1.Close() closer1.Close()
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
closer2.Close() closer2.Close()
}() }()
wg.Wait() wg.Wait()
} }
func ClosingTest(t testing.TB, constructor func( func ClosingTest(t testing.TB, constructor func(
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
run_test := func(close_tcp bool, server_writes bool) { run_test := func(close_tcp bool, server_writes bool) {
server_conn, client_conn := NetPipe(t) server_conn, client_conn := NetPipe(t)
defer server_conn.Close() defer server_conn.Close()
defer client_conn.Close() defer client_conn.Close()
server, client := constructor(t, server_conn, client_conn) server, client := constructor(t, server_conn, client_conn)
defer close_both(server, client) defer close_both(server, client)
var sslconn1, sslconn2 HandshakingConn var sslconn1, sslconn2 HandshakingConn
var conn1 net.Conn var conn1 net.Conn
if server_writes { if server_writes {
sslconn1 = server sslconn1 = server
conn1 = server_conn conn1 = server_conn
sslconn2 = client sslconn2 = client
} else { } else {
sslconn1 = client sslconn1 = client
conn1 = client_conn conn1 = client_conn
sslconn2 = server sslconn2 = server
} }
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
go func() { go func() {
defer wg.Done() defer wg.Done()
_, err := sslconn1.Write([]byte("hello")) _, err := sslconn1.Write([]byte("hello"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if close_tcp { if close_tcp {
err = conn1.Close() err = conn1.Close()
} else { } else {
err = sslconn1.Close() err = sslconn1.Close()
} }
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
data, err := ioutil.ReadAll(sslconn2) data, err := ioutil.ReadAll(sslconn2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(data, []byte("hello")) { if !bytes.Equal(data, []byte("hello")) {
t.Fatal("bytes don't match") t.Fatal("bytes don't match")
} }
}() }()
wg.Wait() wg.Wait()
} }
run_test(true, false) run_test(true, false)
run_test(false, false) run_test(false, false)
run_test(true, true) run_test(true, true)
run_test(false, true) run_test(false, true)
} }
func ThroughputBenchmark(b *testing.B, constructor func( func ThroughputBenchmark(b *testing.B, constructor func(
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
server_conn, client_conn := NetPipe(b) server_conn, client_conn := NetPipe(b)
defer server_conn.Close() defer server_conn.Close()
defer client_conn.Close() defer client_conn.Close()
server, client := constructor(b, server_conn, client_conn) server, client := constructor(b, server_conn, client_conn)
defer close_both(server, client) defer close_both(server, client)
b.SetBytes(1024) b.SetBytes(1024)
data := make([]byte, b.N*1024) data := make([]byte, b.N*1024)
_, err := io.ReadFull(rand.Reader, data[:]) _, err := io.ReadFull(rand.Reader, data[:])
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
b.ResetTimer() b.ResetTimer()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
go func() { go func() {
defer wg.Done() defer wg.Done()
_, err = io.Copy(client, bytes.NewReader([]byte(data))) _, err = io.Copy(client, bytes.NewReader([]byte(data)))
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
_, err = io.CopyN(buf, server, int64(len(data))) _, err = io.CopyN(buf, server, int64(len(data)))
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
if !bytes.Equal(buf.Bytes(), data) { if !bytes.Equal(buf.Bytes(), data) {
b.Fatal("mismatched data") b.Fatal("mismatched data")
} }
}() }()
wg.Wait() wg.Wait()
b.StopTimer() b.StopTimer()
} }
func StdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) ( func StdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) (
server, client HandshakingConn) { server, client HandshakingConn) {
cert, err := tls.X509KeyPair(certBytes, keyBytes) cert, err := tls.X509KeyPair(certBytes, keyBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
config := &tls.Config{ config := &tls.Config{
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true, InsecureSkipVerify: true,
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}} CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}}
server = tls.Server(server_conn, config) server = tls.Server(server_conn, config)
client = tls.Client(client_conn, config) client = tls.Client(client_conn, config)
return server, client return server, client
} }
func OpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) ( func OpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) (
server, client HandshakingConn) { server, client HandshakingConn) {
ctx, err := NewCtx() ctx, err := NewCtx()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
key, err := LoadPrivateKey(keyBytes) key, err := LoadPrivateKey(keyBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = ctx.UsePrivateKey(key) err = ctx.UsePrivateKey(key)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cert, err := LoadCertificate(certBytes) cert, err := LoadCertificate(certBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = ctx.UseCertificate(cert) err = ctx.UseCertificate(cert)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = ctx.SetCipherList("AES128-SHA") err = ctx.SetCipherList("AES128-SHA")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
server, err = Server(server_conn, ctx) server, err = Server(server_conn, ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
client, err = Client(client_conn, ctx) client, err = Client(client_conn, ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return server, client return server, client
} }
func StdlibOpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) ( func StdlibOpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) (
server, client HandshakingConn) { server, client HandshakingConn) {
server_std, _ := StdlibConstructor(t, server_conn, client_conn) server_std, _ := StdlibConstructor(t, server_conn, client_conn)
_, client_ssl := OpenSSLConstructor(t, server_conn, client_conn) _, client_ssl := OpenSSLConstructor(t, server_conn, client_conn)
return server_std, client_ssl return server_std, client_ssl
} }
func OpenSSLStdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) ( func OpenSSLStdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) (
server, client HandshakingConn) { server, client HandshakingConn) {
_, client_std := StdlibConstructor(t, server_conn, client_conn) _, client_std := StdlibConstructor(t, server_conn, client_conn)
server_ssl, _ := OpenSSLConstructor(t, server_conn, client_conn) server_ssl, _ := OpenSSLConstructor(t, server_conn, client_conn)
return server_ssl, client_std return server_ssl, client_std
} }
func TestStdlibSimple(t *testing.T) { func TestStdlibSimple(t *testing.T) {
SimpleConnTest(t, StdlibConstructor) SimpleConnTest(t, StdlibConstructor)
} }
func TestOpenSSLSimple(t *testing.T) { func TestOpenSSLSimple(t *testing.T) {
SimpleConnTest(t, OpenSSLConstructor) SimpleConnTest(t, OpenSSLConstructor)
} }
func TestStdlibClosing(t *testing.T) { func TestStdlibClosing(t *testing.T) {
ClosingTest(t, StdlibConstructor) ClosingTest(t, StdlibConstructor)
} }
func TestOpenSSLClosing(t *testing.T) { func TestOpenSSLClosing(t *testing.T) {
ClosingTest(t, OpenSSLConstructor) ClosingTest(t, OpenSSLConstructor)
} }
func BenchmarkStdlibThroughput(b *testing.B) { func BenchmarkStdlibThroughput(b *testing.B) {
ThroughputBenchmark(b, StdlibConstructor) ThroughputBenchmark(b, StdlibConstructor)
} }
func BenchmarkOpenSSLThroughput(b *testing.B) { func BenchmarkOpenSSLThroughput(b *testing.B) {
ThroughputBenchmark(b, OpenSSLConstructor) ThroughputBenchmark(b, OpenSSLConstructor)
} }
func TestStdlibOpenSSLSimple(t *testing.T) { func TestStdlibOpenSSLSimple(t *testing.T) {
SimpleConnTest(t, StdlibOpenSSLConstructor) SimpleConnTest(t, StdlibOpenSSLConstructor)
} }
func TestOpenSSLStdlibSimple(t *testing.T) { func TestOpenSSLStdlibSimple(t *testing.T) {
SimpleConnTest(t, OpenSSLStdlibConstructor) SimpleConnTest(t, OpenSSLStdlibConstructor)
} }
func TestStdlibOpenSSLClosing(t *testing.T) { func TestStdlibOpenSSLClosing(t *testing.T) {
ClosingTest(t, StdlibOpenSSLConstructor) ClosingTest(t, StdlibOpenSSLConstructor)
} }
func TestOpenSSLStdlibClosing(t *testing.T) { func TestOpenSSLStdlibClosing(t *testing.T) {
ClosingTest(t, OpenSSLStdlibConstructor) ClosingTest(t, OpenSSLStdlibConstructor)
} }
func BenchmarkStdlibOpenSSLThroughput(b *testing.B) { func BenchmarkStdlibOpenSSLThroughput(b *testing.B) {
ThroughputBenchmark(b, StdlibOpenSSLConstructor) ThroughputBenchmark(b, StdlibOpenSSLConstructor)
} }
func BenchmarkOpenSSLStdlibThroughput(b *testing.B) { func BenchmarkOpenSSLStdlibThroughput(b *testing.B) {
ThroughputBenchmark(b, OpenSSLStdlibConstructor) ThroughputBenchmark(b, OpenSSLStdlibConstructor)
} }
func FullDuplexRenegotiationTest(t testing.TB, constructor func( func FullDuplexRenegotiationTest(t testing.TB, constructor func(
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
server_conn, client_conn := NetPipe(t) server_conn, client_conn := NetPipe(t)
defer server_conn.Close() defer server_conn.Close()
defer client_conn.Close() defer client_conn.Close()
times := 256 times := 256
data_len := 4 * SSLRecordSize data_len := 4 * SSLRecordSize
data1 := make([]byte, data_len) data1 := make([]byte, data_len)
_, err := io.ReadFull(rand.Reader, data1[:]) _, err := io.ReadFull(rand.Reader, data1[:])
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
data2 := make([]byte, data_len) data2 := make([]byte, data_len)
_, err = io.ReadFull(rand.Reader, data1[:]) _, err = io.ReadFull(rand.Reader, data1[:])
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
server, client := constructor(t, server_conn, client_conn) server, client := constructor(t, server_conn, client_conn)
defer close_both(server, client) defer close_both(server, client)
var wg sync.WaitGroup var wg sync.WaitGroup
send_func := func(sender HandshakingConn, data []byte) { send_func := func(sender HandshakingConn, data []byte) {
defer wg.Done() defer wg.Done()
for i := 0; i < times; i++ { for i := 0; i < times; i++ {
if i == times/2 { if i == times/2 {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
err := sender.Handshake() err := sender.Handshake()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
}() }()
} }
_, err := sender.Write(data) _, err := sender.Write(data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
} }
recv_func := func(receiver net.Conn, data []byte) { recv_func := func(receiver net.Conn, data []byte) {
defer wg.Done() defer wg.Done()
buf := make([]byte, len(data)) buf := make([]byte, len(data))
for i := 0; i < times; i++ { for i := 0; i < times; i++ {
n, err := io.ReadFull(receiver, buf[:]) n, err := io.ReadFull(receiver, buf[:])
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(buf[:n], data) { if !bytes.Equal(buf[:n], data) {
t.Fatal(err) t.Fatal(err)
} }
} }
} }
wg.Add(4) wg.Add(4)
go recv_func(server, data1) go recv_func(server, data1)
go send_func(client, data1) go send_func(client, data1)
go send_func(server, data2) go send_func(server, data2)
go recv_func(client, data2) go recv_func(client, data2)
wg.Wait() wg.Wait()
} }
func TestStdlibFullDuplexRenegotiation(t *testing.T) { func TestStdlibFullDuplexRenegotiation(t *testing.T) {
FullDuplexRenegotiationTest(t, StdlibConstructor) FullDuplexRenegotiationTest(t, StdlibConstructor)
} }
func TestOpenSSLFullDuplexRenegotiation(t *testing.T) { func TestOpenSSLFullDuplexRenegotiation(t *testing.T) {
FullDuplexRenegotiationTest(t, OpenSSLConstructor) FullDuplexRenegotiationTest(t, OpenSSLConstructor)
} }
func TestOpenSSLStdlibFullDuplexRenegotiation(t *testing.T) { func TestOpenSSLStdlibFullDuplexRenegotiation(t *testing.T) {
FullDuplexRenegotiationTest(t, OpenSSLStdlibConstructor) FullDuplexRenegotiationTest(t, OpenSSLStdlibConstructor)
} }
func TestStdlibOpenSSLFullDuplexRenegotiation(t *testing.T) { func TestStdlibOpenSSLFullDuplexRenegotiation(t *testing.T) {
FullDuplexRenegotiationTest(t, StdlibOpenSSLConstructor) FullDuplexRenegotiationTest(t, StdlibOpenSSLConstructor)
} }
func LotsOfConns(t *testing.T, payload_size int64, loops, clients int, func LotsOfConns(t *testing.T, payload_size int64, loops, clients int,
sleep time.Duration, newListener func(net.Listener) net.Listener, sleep time.Duration, newListener func(net.Listener) net.Listener,
newClient func(net.Conn) (net.Conn, error)) { newClient func(net.Conn) (net.Conn, error)) {
tcp_listener, err := net.Listen("tcp", "localhost:0") tcp_listener, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ssl_listener := newListener(tcp_listener) ssl_listener := newListener(tcp_listener)
go func() { go func() {
for { for {
conn, err := ssl_listener.Accept() conn, err := ssl_listener.Accept()
if err != nil { if err != nil {
t.Fatalf("failed accept: %s", err) t.Fatalf("failed accept: %s", err)
continue continue
} }
go func() { go func() {
defer func() { defer func() {
err = conn.Close() err = conn.Close()
if err != nil { if err != nil {
t.Fatalf("failed closing: %s", err) t.Fatalf("failed closing: %s", err)
} }
}() }()
for i := 0; i < loops; i++ { for i := 0; i < loops; i++ {
_, err := io.Copy(ioutil.Discard, _, err := io.Copy(ioutil.Discard,
io.LimitReader(conn, payload_size)) io.LimitReader(conn, payload_size))
if err != nil { if err != nil {
t.Fatalf("failed reading: %s", err) t.Fatalf("failed reading: %s", err)
return return
} }
_, err = io.Copy(conn, io.LimitReader(rand.Reader, _, err = io.Copy(conn, io.LimitReader(rand.Reader,
payload_size)) payload_size))
if err != nil { if err != nil {
t.Fatalf("failed writing: %s", err) t.Fatalf("failed writing: %s", err)
return return
} }
} }
time.Sleep(sleep) time.Sleep(sleep)
}() }()
} }
}() }()
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < clients; i++ { for i := 0; i < clients; i++ {
tcp_client, err := net.Dial(tcp_listener.Addr().Network(), tcp_client, err := net.Dial(tcp_listener.Addr().Network(),
tcp_listener.Addr().String()) tcp_listener.Addr().String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ssl_client, err := newClient(tcp_client) ssl_client, err := newClient(tcp_client)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wg.Add(1) wg.Add(1)
go func(i int) { go func(i int) {
defer func() { defer func() {
err = ssl_client.Close() err = ssl_client.Close()
if err != nil { if err != nil {
t.Fatalf("failed closing: %s", err) t.Fatalf("failed closing: %s", err)
} }
wg.Done() wg.Done()
}() }()
for i := 0; i < loops; i++ { for i := 0; i < loops; i++ {
_, err := io.Copy(ssl_client, io.LimitReader(rand.Reader, _, err := io.Copy(ssl_client, io.LimitReader(rand.Reader,
payload_size)) payload_size))
if err != nil { if err != nil {
t.Fatalf("failed writing: %s", err) t.Fatalf("failed writing: %s", err)
return return
} }
_, err = io.Copy(ioutil.Discard, _, err = io.Copy(ioutil.Discard,
io.LimitReader(ssl_client, payload_size)) io.LimitReader(ssl_client, payload_size))
if err != nil { if err != nil {
t.Fatalf("failed reading: %s", err) t.Fatalf("failed reading: %s", err)
return return
} }
} }
time.Sleep(sleep) time.Sleep(sleep)
}(i) }(i)
} }
wg.Wait() wg.Wait()
} }
func TestStdlibLotsOfConns(t *testing.T) { func TestStdlibLotsOfConns(t *testing.T) {
tls_cert, err := tls.X509KeyPair(certBytes, keyBytes) tls_cert, err := tls.X509KeyPair(certBytes, keyBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tls_config := &tls.Config{ tls_config := &tls.Config{
Certificates: []tls.Certificate{tls_cert}, Certificates: []tls.Certificate{tls_cert},
InsecureSkipVerify: true, InsecureSkipVerify: true,
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}} CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}}
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
func(l net.Listener) net.Listener { func(l net.Listener) net.Listener {
return tls.NewListener(l, tls_config) return tls.NewListener(l, tls_config)
}, func(c net.Conn) (net.Conn, error) { }, func(c net.Conn) (net.Conn, error) {
return tls.Client(c, tls_config), nil return tls.Client(c, tls_config), nil
}) })
} }
func TestOpenSSLLotsOfConns(t *testing.T) { func TestOpenSSLLotsOfConns(t *testing.T) {
ctx, err := NewCtx() ctx, err := NewCtx()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
key, err := LoadPrivateKey(keyBytes) key, err := LoadPrivateKey(keyBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = ctx.UsePrivateKey(key) err = ctx.UsePrivateKey(key)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cert, err := LoadCertificate(certBytes) cert, err := LoadCertificate(certBytes)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = ctx.UseCertificate(cert) err = ctx.UseCertificate(cert)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = ctx.SetCipherList("AES128-SHA") err = ctx.SetCipherList("AES128-SHA")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
func(l net.Listener) net.Listener { func(l net.Listener) net.Listener {
return NewListener(l, ctx) return NewListener(l, ctx)
}, func(c net.Conn) (net.Conn, error) { }, func(c net.Conn) (net.Conn, error) {
return Client(c, ctx) return Client(c, ctx)
}) })
} }

View File

@ -3,20 +3,20 @@
package utils package utils
import ( import (
"errors" "errors"
"strings" "strings"
) )
// ErrorGroup collates errors // ErrorGroup collates errors
type ErrorGroup struct { type ErrorGroup struct {
Errors []error Errors []error
} }
// Add adds an error to an existing error group // Add adds an error to an existing error group
func (e *ErrorGroup) Add(err error) { func (e *ErrorGroup) Add(err error) {
if err != nil { if err != nil {
e.Errors = append(e.Errors, err) e.Errors = append(e.Errors, err)
} }
} }
// Finalize returns an error corresponding to the ErrorGroup state. If there's // Finalize returns an error corresponding to the ErrorGroup state. If there's
@ -24,15 +24,15 @@ func (e *ErrorGroup) Add(err error) {
// Finalize returns that error. Otherwise, Finalize will make a new error // Finalize returns that error. Otherwise, Finalize will make a new error
// consisting of the messages from the constituent errors. // consisting of the messages from the constituent errors.
func (e *ErrorGroup) Finalize() error { func (e *ErrorGroup) Finalize() error {
if len(e.Errors) == 0 { if len(e.Errors) == 0 {
return nil return nil
} }
if len(e.Errors) == 1 { if len(e.Errors) == 1 {
return e.Errors[0] return e.Errors[0]
} }
msgs := make([]string, 0, len(e.Errors)) msgs := make([]string, 0, len(e.Errors))
for _, err := range e.Errors { for _, err := range e.Errors {
msgs = append(msgs, err.Error()) msgs = append(msgs, err.Error())
} }
return errors.New(strings.Join(msgs, "\n")) return errors.New(strings.Join(msgs, "\n"))
} }

View File

@ -3,7 +3,7 @@
package utils package utils
import ( import (
"sync" "sync"
) )
// Future is a type that is essentially the inverse of a channel. With a // Future is a type that is essentially the inverse of a channel. With a
@ -13,55 +13,55 @@ import (
// results, we also capture and return error values as well. Use NewFuture // results, we also capture and return error values as well. Use NewFuture
// to initialize. // to initialize.
type Future struct { type Future struct {
mutex *sync.Mutex mutex *sync.Mutex
cond *sync.Cond cond *sync.Cond
received bool received bool
val interface{} val interface{}
err error err error
} }
// NewFuture returns an initialized and ready Future. // NewFuture returns an initialized and ready Future.
func NewFuture() *Future { func NewFuture() *Future {
mutex := &sync.Mutex{} mutex := &sync.Mutex{}
return &Future{ return &Future{
mutex: mutex, mutex: mutex,
cond: sync.NewCond(mutex), cond: sync.NewCond(mutex),
received: false, received: false,
val: nil, val: nil,
err: nil, err: nil,
} }
} }
// Get blocks until the Future has a value set. // Get blocks until the Future has a value set.
func (self *Future) Get() (interface{}, error) { func (self *Future) Get() (interface{}, error) {
self.mutex.Lock() self.mutex.Lock()
defer self.mutex.Unlock() defer self.mutex.Unlock()
for { for {
if self.received { if self.received {
return self.val, self.err return self.val, self.err
} }
self.cond.Wait() self.cond.Wait()
} }
} }
// Fired returns whether or not a value has been set. If Fired is true, Get // Fired returns whether or not a value has been set. If Fired is true, Get
// won't block. // won't block.
func (self *Future) Fired() bool { func (self *Future) Fired() bool {
self.mutex.Lock() self.mutex.Lock()
defer self.mutex.Unlock() defer self.mutex.Unlock()
return self.received return self.received
} }
// Set provides the value to present and future Get calls. If Set has already // Set provides the value to present and future Get calls. If Set has already
// been called, this is a no-op. // been called, this is a no-op.
func (self *Future) Set(val interface{}, err error) { func (self *Future) Set(val interface{}, err error) {
self.mutex.Lock() self.mutex.Lock()
defer self.mutex.Unlock() defer self.mutex.Unlock()
if self.received { if self.received {
return return
} }
self.received = true self.received = true
self.val = val self.val = val
self.err = err self.err = err
self.cond.Broadcast() self.cond.Broadcast()
} }

View File

@ -5,7 +5,7 @@
package utils package utils
import ( import (
"unsafe" "unsafe"
) )
// ThreadId returns the current runtime's thread id. Thanks to Gustavo Niemeyer // ThreadId returns the current runtime's thread id. Thanks to Gustavo Niemeyer