mirror of
https://github.com/libp2p/go-openssl.git
synced 2025-01-30 05:20:08 +08:00
space monkey internal commit export
[katamari commit: 9bd04d1d78e85304589695c66e328d23128f509c]
This commit is contained in:
parent
751143ef9c
commit
fa8eb6a573
306
bio.go
306
bio.go
@ -56,226 +56,226 @@ static void BIO_set_retry_read_not_a_macro(BIO *b) { BIO_set_retry_read(b); }
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SSLRecordSize = 16 * 1024
|
SSLRecordSize = 16 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
func nonCopyGoBytes(ptr uintptr, length int) []byte {
|
func nonCopyGoBytes(ptr uintptr, length int) []byte {
|
||||||
var slice []byte
|
var slice []byte
|
||||||
header := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
|
header := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
|
||||||
header.Cap = length
|
header.Cap = length
|
||||||
header.Len = length
|
header.Len = length
|
||||||
header.Data = ptr
|
header.Data = ptr
|
||||||
return slice
|
return slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func nonCopyCString(data *C.char, size C.int) []byte {
|
func nonCopyCString(data *C.char, size C.int) []byte {
|
||||||
return nonCopyGoBytes(uintptr(unsafe.Pointer(data)), int(size))
|
return nonCopyGoBytes(uintptr(unsafe.Pointer(data)), int(size))
|
||||||
}
|
}
|
||||||
|
|
||||||
//export cbioNew
|
//export cbioNew
|
||||||
func cbioNew(b *C.BIO) C.int {
|
func cbioNew(b *C.BIO) C.int {
|
||||||
b.shutdown = 1
|
b.shutdown = 1
|
||||||
b.init = 1
|
b.init = 1
|
||||||
b.num = -1
|
b.num = -1
|
||||||
b.ptr = nil
|
b.ptr = nil
|
||||||
b.flags = 0
|
b.flags = 0
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
//export cbioFree
|
//export cbioFree
|
||||||
func cbioFree(b *C.BIO) C.int {
|
func cbioFree(b *C.BIO) C.int {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
type writeBio struct {
|
type writeBio struct {
|
||||||
data_mtx sync.Mutex
|
data_mtx sync.Mutex
|
||||||
op_mtx sync.Mutex
|
op_mtx sync.Mutex
|
||||||
buf []byte
|
buf []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadWritePtr(b *C.BIO) *writeBio {
|
func loadWritePtr(b *C.BIO) *writeBio {
|
||||||
return (*writeBio)(unsafe.Pointer(b.ptr))
|
return (*writeBio)(unsafe.Pointer(b.ptr))
|
||||||
}
|
}
|
||||||
|
|
||||||
//export writeBioWrite
|
//export writeBioWrite
|
||||||
func writeBioWrite(b *C.BIO, data *C.char, size C.int) C.int {
|
func writeBioWrite(b *C.BIO, data *C.char, size C.int) C.int {
|
||||||
ptr := loadWritePtr(b)
|
ptr := loadWritePtr(b)
|
||||||
if ptr == nil || data == nil || size < 0 {
|
if ptr == nil || data == nil || size < 0 {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
ptr.data_mtx.Lock()
|
ptr.data_mtx.Lock()
|
||||||
defer ptr.data_mtx.Unlock()
|
defer ptr.data_mtx.Unlock()
|
||||||
C.BIO_clear_retry_flags_not_a_macro(b)
|
C.BIO_clear_retry_flags_not_a_macro(b)
|
||||||
ptr.buf = append(ptr.buf, nonCopyCString(data, size)...)
|
ptr.buf = append(ptr.buf, nonCopyCString(data, size)...)
|
||||||
return size
|
return size
|
||||||
}
|
}
|
||||||
|
|
||||||
//export writeBioCtrl
|
//export writeBioCtrl
|
||||||
func writeBioCtrl(b *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) C.long {
|
func writeBioCtrl(b *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) C.long {
|
||||||
switch cmd {
|
switch cmd {
|
||||||
case C.BIO_CTRL_WPENDING:
|
case C.BIO_CTRL_WPENDING:
|
||||||
return writeBioPending(b)
|
return writeBioPending(b)
|
||||||
case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH:
|
case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH:
|
||||||
return 1
|
return 1
|
||||||
default:
|
default:
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeBioPending(b *C.BIO) C.long {
|
func writeBioPending(b *C.BIO) C.long {
|
||||||
ptr := loadWritePtr(b)
|
ptr := loadWritePtr(b)
|
||||||
if ptr == nil {
|
if ptr == nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
ptr.data_mtx.Lock()
|
ptr.data_mtx.Lock()
|
||||||
defer ptr.data_mtx.Unlock()
|
defer ptr.data_mtx.Unlock()
|
||||||
return C.long(len(ptr.buf))
|
return C.long(len(ptr.buf))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *writeBio) WriteTo(w io.Writer) (rv int64, err error) {
|
func (b *writeBio) WriteTo(w io.Writer) (rv int64, err error) {
|
||||||
b.op_mtx.Lock()
|
b.op_mtx.Lock()
|
||||||
defer b.op_mtx.Unlock()
|
defer b.op_mtx.Unlock()
|
||||||
|
|
||||||
// write whatever data we currently have
|
// write whatever data we currently have
|
||||||
b.data_mtx.Lock()
|
b.data_mtx.Lock()
|
||||||
data := b.buf
|
data := b.buf
|
||||||
b.data_mtx.Unlock()
|
b.data_mtx.Unlock()
|
||||||
|
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
n, err := w.Write(data)
|
n, err := w.Write(data)
|
||||||
|
|
||||||
// subtract however much data we wrote from the buffer
|
// subtract however much data we wrote from the buffer
|
||||||
b.data_mtx.Lock()
|
b.data_mtx.Lock()
|
||||||
b.buf = b.buf[:copy(b.buf, b.buf[n:])]
|
b.buf = b.buf[:copy(b.buf, b.buf[n:])]
|
||||||
b.data_mtx.Unlock()
|
b.data_mtx.Unlock()
|
||||||
|
|
||||||
return int64(n), err
|
return int64(n), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *writeBio) Disconnect(b *C.BIO) {
|
func (self *writeBio) Disconnect(b *C.BIO) {
|
||||||
if loadWritePtr(b) == self {
|
if loadWritePtr(b) == self {
|
||||||
b.ptr = nil
|
b.ptr = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *writeBio) MakeCBIO() *C.BIO {
|
func (b *writeBio) MakeCBIO() *C.BIO {
|
||||||
rv := C.BIO_new(C.BIO_s_writeBio())
|
rv := C.BIO_new(C.BIO_s_writeBio())
|
||||||
rv.ptr = unsafe.Pointer(b)
|
rv.ptr = unsafe.Pointer(b)
|
||||||
return rv
|
return rv
|
||||||
}
|
}
|
||||||
|
|
||||||
type readBio struct {
|
type readBio struct {
|
||||||
data_mtx sync.Mutex
|
data_mtx sync.Mutex
|
||||||
op_mtx sync.Mutex
|
op_mtx sync.Mutex
|
||||||
buf []byte
|
buf []byte
|
||||||
eof bool
|
eof bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadReadPtr(b *C.BIO) *readBio {
|
func loadReadPtr(b *C.BIO) *readBio {
|
||||||
return (*readBio)(unsafe.Pointer(b.ptr))
|
return (*readBio)(unsafe.Pointer(b.ptr))
|
||||||
}
|
}
|
||||||
|
|
||||||
//export readBioRead
|
//export readBioRead
|
||||||
func readBioRead(b *C.BIO, data *C.char, size C.int) C.int {
|
func readBioRead(b *C.BIO, data *C.char, size C.int) C.int {
|
||||||
ptr := loadReadPtr(b)
|
ptr := loadReadPtr(b)
|
||||||
if ptr == nil || size < 0 {
|
if ptr == nil || size < 0 {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
ptr.data_mtx.Lock()
|
ptr.data_mtx.Lock()
|
||||||
defer ptr.data_mtx.Unlock()
|
defer ptr.data_mtx.Unlock()
|
||||||
C.BIO_clear_retry_flags_not_a_macro(b)
|
C.BIO_clear_retry_flags_not_a_macro(b)
|
||||||
if len(ptr.buf) == 0 {
|
if len(ptr.buf) == 0 {
|
||||||
if ptr.eof {
|
if ptr.eof {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
C.BIO_set_retry_read_not_a_macro(b)
|
C.BIO_set_retry_read_not_a_macro(b)
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
if size == 0 || data == nil {
|
if size == 0 || data == nil {
|
||||||
return C.int(len(ptr.buf))
|
return C.int(len(ptr.buf))
|
||||||
}
|
}
|
||||||
n := copy(nonCopyCString(data, size), ptr.buf)
|
n := copy(nonCopyCString(data, size), ptr.buf)
|
||||||
ptr.buf = ptr.buf[:copy(ptr.buf, ptr.buf[n:])]
|
ptr.buf = ptr.buf[:copy(ptr.buf, ptr.buf[n:])]
|
||||||
return C.int(n)
|
return C.int(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
//export readBioCtrl
|
//export readBioCtrl
|
||||||
func readBioCtrl(b *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) C.long {
|
func readBioCtrl(b *C.BIO, cmd C.int, arg1 C.long, arg2 unsafe.Pointer) C.long {
|
||||||
switch cmd {
|
switch cmd {
|
||||||
case C.BIO_CTRL_PENDING:
|
case C.BIO_CTRL_PENDING:
|
||||||
return readBioPending(b)
|
return readBioPending(b)
|
||||||
case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH:
|
case C.BIO_CTRL_DUP, C.BIO_CTRL_FLUSH:
|
||||||
return 1
|
return 1
|
||||||
default:
|
default:
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func readBioPending(b *C.BIO) C.long {
|
func readBioPending(b *C.BIO) C.long {
|
||||||
ptr := loadReadPtr(b)
|
ptr := loadReadPtr(b)
|
||||||
if ptr == nil {
|
if ptr == nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
ptr.data_mtx.Lock()
|
ptr.data_mtx.Lock()
|
||||||
defer ptr.data_mtx.Unlock()
|
defer ptr.data_mtx.Unlock()
|
||||||
return C.long(len(ptr.buf))
|
return C.long(len(ptr.buf))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *readBio) ReadFromOnce(r io.Reader) (n int, err error) {
|
func (b *readBio) ReadFromOnce(r io.Reader) (n int, err error) {
|
||||||
b.op_mtx.Lock()
|
b.op_mtx.Lock()
|
||||||
defer b.op_mtx.Unlock()
|
defer b.op_mtx.Unlock()
|
||||||
|
|
||||||
// make sure we have a destination that fits at least one SSL record
|
// make sure we have a destination that fits at least one SSL record
|
||||||
b.data_mtx.Lock()
|
b.data_mtx.Lock()
|
||||||
if cap(b.buf) < len(b.buf)+SSLRecordSize {
|
if cap(b.buf) < len(b.buf)+SSLRecordSize {
|
||||||
new_buf := make([]byte, len(b.buf), len(b.buf)+SSLRecordSize)
|
new_buf := make([]byte, len(b.buf), len(b.buf)+SSLRecordSize)
|
||||||
copy(new_buf, b.buf)
|
copy(new_buf, b.buf)
|
||||||
b.buf = new_buf
|
b.buf = new_buf
|
||||||
}
|
}
|
||||||
dst := b.buf[len(b.buf):cap(b.buf)]
|
dst := b.buf[len(b.buf):cap(b.buf)]
|
||||||
dst_slice := b.buf
|
dst_slice := b.buf
|
||||||
b.data_mtx.Unlock()
|
b.data_mtx.Unlock()
|
||||||
|
|
||||||
n, err = r.Read(dst)
|
n, err = r.Read(dst)
|
||||||
b.data_mtx.Lock()
|
b.data_mtx.Lock()
|
||||||
defer b.data_mtx.Unlock()
|
defer b.data_mtx.Unlock()
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
if len(dst_slice) != len(b.buf) {
|
if len(dst_slice) != len(b.buf) {
|
||||||
// someone shrunk the buffer, so we read in too far ahead and we
|
// someone shrunk the buffer, so we read in too far ahead and we
|
||||||
// need to slide backwards
|
// need to slide backwards
|
||||||
copy(b.buf[len(b.buf):len(b.buf)+n], dst)
|
copy(b.buf[len(b.buf):len(b.buf)+n], dst)
|
||||||
}
|
}
|
||||||
b.buf = b.buf[:len(b.buf)+n]
|
b.buf = b.buf[:len(b.buf)+n]
|
||||||
}
|
}
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *readBio) MakeCBIO() *C.BIO {
|
func (b *readBio) MakeCBIO() *C.BIO {
|
||||||
rv := C.BIO_new(C.BIO_s_readBio())
|
rv := C.BIO_new(C.BIO_s_readBio())
|
||||||
rv.ptr = unsafe.Pointer(b)
|
rv.ptr = unsafe.Pointer(b)
|
||||||
return rv
|
return rv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *readBio) Disconnect(b *C.BIO) {
|
func (self *readBio) Disconnect(b *C.BIO) {
|
||||||
if loadReadPtr(b) == self {
|
if loadReadPtr(b) == self {
|
||||||
b.ptr = nil
|
b.ptr = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *readBio) MarkEOF() {
|
func (b *readBio) MarkEOF() {
|
||||||
b.data_mtx.Lock()
|
b.data_mtx.Lock()
|
||||||
defer b.data_mtx.Unlock()
|
defer b.data_mtx.Unlock()
|
||||||
b.eof = true
|
b.eof = true
|
||||||
}
|
}
|
||||||
|
|
||||||
type anyBio C.BIO
|
type anyBio C.BIO
|
||||||
@ -283,18 +283,18 @@ type anyBio C.BIO
|
|||||||
func asAnyBio(b *C.BIO) *anyBio { return (*anyBio)(b) }
|
func asAnyBio(b *C.BIO) *anyBio { return (*anyBio)(b) }
|
||||||
|
|
||||||
func (b *anyBio) Read(buf []byte) (n int, err error) {
|
func (b *anyBio) Read(buf []byte) (n int, err error) {
|
||||||
n = int(C.BIO_read((*C.BIO)(b), unsafe.Pointer(&buf[0]), C.int(len(buf))))
|
n = int(C.BIO_read((*C.BIO)(b), unsafe.Pointer(&buf[0]), C.int(len(buf))))
|
||||||
if n <= 0 {
|
if n <= 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *anyBio) Write(buf []byte) (written int, err error) {
|
func (b *anyBio) Write(buf []byte) (written int, err error) {
|
||||||
n := int(C.BIO_write((*C.BIO)(b), unsafe.Pointer(&buf[0]),
|
n := int(C.BIO_write((*C.BIO)(b), unsafe.Pointer(&buf[0]),
|
||||||
C.int(len(buf))))
|
C.int(len(buf))))
|
||||||
if n != len(buf) {
|
if n != len(buf) {
|
||||||
return n, errors.New("BIO write failed")
|
return n, errors.New("BIO write failed")
|
||||||
}
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
570
conn.go
570
conn.go
@ -9,77 +9,77 @@ package openssl
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"code.spacemonkey.com/go/openssl/utils"
|
"code.spacemonkey.com/go/openssl/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
zeroReturn = errors.New("zero return")
|
zeroReturn = errors.New("zero return")
|
||||||
wantRead = errors.New("want read")
|
wantRead = errors.New("want read")
|
||||||
wantWrite = errors.New("want write")
|
wantWrite = errors.New("want write")
|
||||||
tryAgain = errors.New("try again")
|
tryAgain = errors.New("try again")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
ssl *C.SSL
|
ssl *C.SSL
|
||||||
into_ssl *readBio
|
into_ssl *readBio
|
||||||
from_ssl *writeBio
|
from_ssl *writeBio
|
||||||
is_shutdown bool
|
is_shutdown bool
|
||||||
mtx sync.Mutex
|
mtx sync.Mutex
|
||||||
want_read_future *utils.Future
|
want_read_future *utils.Future
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSSL(ctx *C.SSL_CTX) (*C.SSL, error) {
|
func newSSL(ctx *C.SSL_CTX) (*C.SSL, error) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
ssl := C.SSL_new(ctx)
|
ssl := C.SSL_new(ctx)
|
||||||
if ssl == nil {
|
if ssl == nil {
|
||||||
return nil, errorFromErrorQueue()
|
return nil, errorFromErrorQueue()
|
||||||
}
|
}
|
||||||
return ssl, nil
|
return ssl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
||||||
ssl, err := newSSL(ctx.ctx)
|
ssl, err := newSSL(ctx.ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
into_ssl := &readBio{}
|
into_ssl := &readBio{}
|
||||||
from_ssl := &writeBio{}
|
from_ssl := &writeBio{}
|
||||||
|
|
||||||
into_ssl_cbio := into_ssl.MakeCBIO()
|
into_ssl_cbio := into_ssl.MakeCBIO()
|
||||||
from_ssl_cbio := from_ssl.MakeCBIO()
|
from_ssl_cbio := from_ssl.MakeCBIO()
|
||||||
if into_ssl_cbio == nil || from_ssl_cbio == nil {
|
if into_ssl_cbio == nil || from_ssl_cbio == nil {
|
||||||
// these frees are null safe
|
// these frees are null safe
|
||||||
C.BIO_free(into_ssl_cbio)
|
C.BIO_free(into_ssl_cbio)
|
||||||
C.BIO_free(from_ssl_cbio)
|
C.BIO_free(from_ssl_cbio)
|
||||||
C.SSL_free(ssl)
|
C.SSL_free(ssl)
|
||||||
return nil, errors.New("failed to allocate memory BIO")
|
return nil, errors.New("failed to allocate memory BIO")
|
||||||
}
|
}
|
||||||
|
|
||||||
// the ssl object takes ownership of these objects now
|
// the ssl object takes ownership of these objects now
|
||||||
C.SSL_set_bio(ssl, into_ssl_cbio, from_ssl_cbio)
|
C.SSL_set_bio(ssl, into_ssl_cbio, from_ssl_cbio)
|
||||||
|
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
ssl: ssl,
|
ssl: ssl,
|
||||||
into_ssl: into_ssl,
|
into_ssl: into_ssl,
|
||||||
from_ssl: from_ssl}
|
from_ssl: from_ssl}
|
||||||
runtime.SetFinalizer(c, func(c *Conn) {
|
runtime.SetFinalizer(c, func(c *Conn) {
|
||||||
c.into_ssl.Disconnect(into_ssl_cbio)
|
c.into_ssl.Disconnect(into_ssl_cbio)
|
||||||
c.from_ssl.Disconnect(from_ssl_cbio)
|
c.from_ssl.Disconnect(from_ssl_cbio)
|
||||||
C.SSL_free(c.ssl)
|
C.SSL_free(c.ssl)
|
||||||
})
|
})
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client wraps an existing stream connection and puts it in the connect state
|
// Client wraps an existing stream connection and puts it in the connect state
|
||||||
@ -94,319 +94,319 @@ func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
|||||||
// you're using. This library is not nice enough to use the system certificate
|
// you're using. This library is not nice enough to use the system certificate
|
||||||
// store by default for you yet.
|
// store by default for you yet.
|
||||||
func Client(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
func Client(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
||||||
c, err := newConn(conn, ctx)
|
c, err := newConn(conn, ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
C.SSL_set_connect_state(c.ssl)
|
C.SSL_set_connect_state(c.ssl)
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Server wraps an existing stream connection and puts it in the accept state
|
// Server wraps an existing stream connection and puts it in the accept state
|
||||||
// for any subsequent handshakes.
|
// for any subsequent handshakes.
|
||||||
func Server(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
func Server(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
||||||
c, err := newConn(conn, ctx)
|
c, err := newConn(conn, ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
C.SSL_set_accept_state(c.ssl)
|
C.SSL_set_accept_state(c.ssl)
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) fillInputBuffer() error {
|
func (c *Conn) fillInputBuffer() error {
|
||||||
for {
|
for {
|
||||||
n, err := c.into_ssl.ReadFromOnce(c.conn)
|
n, err := c.into_ssl.ReadFromOnce(c.conn)
|
||||||
if n == 0 && err == nil {
|
if n == 0 && err == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
c.into_ssl.MarkEOF()
|
c.into_ssl.MarkEOF()
|
||||||
return c.Close()
|
return c.Close()
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) flushOutputBuffer() error {
|
func (c *Conn) flushOutputBuffer() error {
|
||||||
_, err := c.from_ssl.WriteTo(c.conn)
|
_, err := c.from_ssl.WriteTo(c.conn)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) getErrorHandler(rv C.int, errno error) func() error {
|
func (c *Conn) getErrorHandler(rv C.int, errno error) func() error {
|
||||||
errcode := C.SSL_get_error(c.ssl, rv)
|
errcode := C.SSL_get_error(c.ssl, rv)
|
||||||
switch errcode {
|
switch errcode {
|
||||||
case C.SSL_ERROR_ZERO_RETURN:
|
case C.SSL_ERROR_ZERO_RETURN:
|
||||||
return func() error {
|
return func() error {
|
||||||
c.Close()
|
c.Close()
|
||||||
return io.ErrUnexpectedEOF
|
return io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
case C.SSL_ERROR_WANT_READ:
|
case C.SSL_ERROR_WANT_READ:
|
||||||
go c.flushOutputBuffer()
|
go c.flushOutputBuffer()
|
||||||
if c.want_read_future != nil {
|
if c.want_read_future != nil {
|
||||||
want_read_future := c.want_read_future
|
want_read_future := c.want_read_future
|
||||||
return func() error {
|
return func() error {
|
||||||
_, err := want_read_future.Get()
|
_, err := want_read_future.Get()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.want_read_future = utils.NewFuture()
|
c.want_read_future = utils.NewFuture()
|
||||||
want_read_future := c.want_read_future
|
want_read_future := c.want_read_future
|
||||||
return func() (err error) {
|
return func() (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
c.mtx.Lock()
|
c.mtx.Lock()
|
||||||
c.want_read_future = nil
|
c.want_read_future = nil
|
||||||
c.mtx.Unlock()
|
c.mtx.Unlock()
|
||||||
want_read_future.Set(nil, err)
|
want_read_future.Set(nil, err)
|
||||||
}()
|
}()
|
||||||
err = c.fillInputBuffer()
|
err = c.fillInputBuffer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tryAgain
|
return tryAgain
|
||||||
}
|
}
|
||||||
case C.SSL_ERROR_WANT_WRITE:
|
case C.SSL_ERROR_WANT_WRITE:
|
||||||
return func() error {
|
return func() error {
|
||||||
err := c.flushOutputBuffer()
|
err := c.flushOutputBuffer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tryAgain
|
return tryAgain
|
||||||
}
|
}
|
||||||
case C.SSL_ERROR_SYSCALL:
|
case C.SSL_ERROR_SYSCALL:
|
||||||
var err error
|
var err error
|
||||||
if C.ERR_peek_error() == 0 {
|
if C.ERR_peek_error() == 0 {
|
||||||
switch rv {
|
switch rv {
|
||||||
case 0:
|
case 0:
|
||||||
err = errors.New("protocol-violating EOF")
|
err = errors.New("protocol-violating EOF")
|
||||||
case -1:
|
case -1:
|
||||||
err = errno
|
err = errno
|
||||||
default:
|
default:
|
||||||
err = errorFromErrorQueue()
|
err = errorFromErrorQueue()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = errorFromErrorQueue()
|
err = errorFromErrorQueue()
|
||||||
}
|
}
|
||||||
return func() error { return err }
|
return func() error { return err }
|
||||||
default:
|
default:
|
||||||
err := errorFromErrorQueue()
|
err := errorFromErrorQueue()
|
||||||
return func() error { return err }
|
return func() error { return err }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) handleError(errcb func() error) error {
|
func (c *Conn) handleError(errcb func() error) error {
|
||||||
if errcb != nil {
|
if errcb != nil {
|
||||||
return errcb()
|
return errcb()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) handshake() func() error {
|
func (c *Conn) handshake() func() error {
|
||||||
c.mtx.Lock()
|
c.mtx.Lock()
|
||||||
defer c.mtx.Unlock()
|
defer c.mtx.Unlock()
|
||||||
if c.is_shutdown {
|
if c.is_shutdown {
|
||||||
return func() error { return io.ErrUnexpectedEOF }
|
return func() error { return io.ErrUnexpectedEOF }
|
||||||
}
|
}
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
rv, errno := C.SSL_do_handshake(c.ssl)
|
rv, errno := C.SSL_do_handshake(c.ssl)
|
||||||
if rv > 0 {
|
if rv > 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return c.getErrorHandler(rv, errno)
|
return c.getErrorHandler(rv, errno)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handshake performs an SSL handshake. If a handshake is not manually
|
// Handshake performs an SSL handshake. If a handshake is not manually
|
||||||
// triggered, it will run before the first I/O on the encrypted stream.
|
// triggered, it will run before the first I/O on the encrypted stream.
|
||||||
func (c *Conn) Handshake() error {
|
func (c *Conn) Handshake() error {
|
||||||
err := tryAgain
|
err := tryAgain
|
||||||
for err == tryAgain {
|
for err == tryAgain {
|
||||||
err = c.handleError(c.handshake())
|
err = c.handleError(c.handshake())
|
||||||
}
|
}
|
||||||
go c.flushOutputBuffer()
|
go c.flushOutputBuffer()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerCertificate returns the Certificate of the peer with which you're
|
// PeerCertificate returns the Certificate of the peer with which you're
|
||||||
// communicating. Only valid after a handshake.
|
// communicating. Only valid after a handshake.
|
||||||
func (c *Conn) PeerCertificate() (*Certificate, error) {
|
func (c *Conn) PeerCertificate() (*Certificate, error) {
|
||||||
c.mtx.Lock()
|
c.mtx.Lock()
|
||||||
if c.is_shutdown {
|
if c.is_shutdown {
|
||||||
return nil, errors.New("connection closed")
|
return nil, errors.New("connection closed")
|
||||||
}
|
}
|
||||||
x := C.SSL_get_peer_certificate(c.ssl)
|
x := C.SSL_get_peer_certificate(c.ssl)
|
||||||
c.mtx.Unlock()
|
c.mtx.Unlock()
|
||||||
if x == nil {
|
if x == nil {
|
||||||
return nil, errors.New("no peer certificate found")
|
return nil, errors.New("no peer certificate found")
|
||||||
}
|
}
|
||||||
cert := &Certificate{x: x}
|
cert := &Certificate{x: x}
|
||||||
runtime.SetFinalizer(cert, func(cert *Certificate) {
|
runtime.SetFinalizer(cert, func(cert *Certificate) {
|
||||||
C.X509_free(cert.x)
|
C.X509_free(cert.x)
|
||||||
})
|
})
|
||||||
return cert, nil
|
return cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) shutdown() func() error {
|
func (c *Conn) shutdown() func() error {
|
||||||
c.mtx.Lock()
|
c.mtx.Lock()
|
||||||
defer c.mtx.Unlock()
|
defer c.mtx.Unlock()
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
rv, errno := C.SSL_shutdown(c.ssl)
|
rv, errno := C.SSL_shutdown(c.ssl)
|
||||||
if rv > 0 {
|
if rv > 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if rv == 0 {
|
if rv == 0 {
|
||||||
// The OpenSSL docs say that in this case, the shutdown is not
|
// The OpenSSL docs say that in this case, the shutdown is not
|
||||||
// finished, and we should call SSL_shutdown() a second time, if a
|
// finished, and we should call SSL_shutdown() a second time, if a
|
||||||
// bidirectional shutdown is going to be performed. Further, the
|
// bidirectional shutdown is going to be performed. Further, the
|
||||||
// output of SSL_get_error may be misleading, as an erroneous
|
// output of SSL_get_error may be misleading, as an erroneous
|
||||||
// SSL_ERROR_SYSCALL may be flagged even though no error occurred.
|
// SSL_ERROR_SYSCALL may be flagged even though no error occurred.
|
||||||
// So, TODO: revisit bidrectional shutdown, possibly trying again.
|
// So, TODO: revisit bidrectional shutdown, possibly trying again.
|
||||||
// Note: some broken clients won't engage in bidirectional shutdown
|
// Note: some broken clients won't engage in bidirectional shutdown
|
||||||
// without tickling them to close by sending a TCP_FIN packet, or
|
// without tickling them to close by sending a TCP_FIN packet, or
|
||||||
// shutting down the write-side of the connection.
|
// shutting down the write-side of the connection.
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
return c.getErrorHandler(rv, errno)
|
return c.getErrorHandler(rv, errno)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) shutdownLoop() error {
|
func (c *Conn) shutdownLoop() error {
|
||||||
err := tryAgain
|
err := tryAgain
|
||||||
shutdown_tries := 0
|
shutdown_tries := 0
|
||||||
for err == tryAgain {
|
for err == tryAgain {
|
||||||
shutdown_tries = shutdown_tries + 1
|
shutdown_tries = shutdown_tries + 1
|
||||||
err = c.handleError(c.shutdown())
|
err = c.handleError(c.shutdown())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return c.flushOutputBuffer()
|
return c.flushOutputBuffer()
|
||||||
}
|
}
|
||||||
if err == tryAgain && shutdown_tries >= 2 {
|
if err == tryAgain && shutdown_tries >= 2 {
|
||||||
return errors.New("shutdown requested a third time?")
|
return errors.New("shutdown requested a third time?")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err == io.ErrUnexpectedEOF {
|
if err == io.ErrUnexpectedEOF {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close shuts down the SSL connection and closes the underlying wrapped
|
// Close shuts down the SSL connection and closes the underlying wrapped
|
||||||
// connection.
|
// connection.
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
c.mtx.Lock()
|
c.mtx.Lock()
|
||||||
if c.is_shutdown {
|
if c.is_shutdown {
|
||||||
c.mtx.Unlock()
|
c.mtx.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c.is_shutdown = true
|
c.is_shutdown = true
|
||||||
c.mtx.Unlock()
|
c.mtx.Unlock()
|
||||||
var errs utils.ErrorGroup
|
var errs utils.ErrorGroup
|
||||||
errs.Add(c.shutdownLoop())
|
errs.Add(c.shutdownLoop())
|
||||||
errs.Add(c.conn.Close())
|
errs.Add(c.conn.Close())
|
||||||
return errs.Finalize()
|
return errs.Finalize()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) read(b []byte) (int, func() error) {
|
func (c *Conn) read(b []byte) (int, func() error) {
|
||||||
c.mtx.Lock()
|
c.mtx.Lock()
|
||||||
defer c.mtx.Unlock()
|
defer c.mtx.Unlock()
|
||||||
if c.is_shutdown {
|
if c.is_shutdown {
|
||||||
return 0, func() error { return io.EOF }
|
return 0, func() error { return io.EOF }
|
||||||
}
|
}
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
rv, errno := C.SSL_read(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
|
rv, errno := C.SSL_read(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
|
||||||
if rv > 0 {
|
if rv > 0 {
|
||||||
return int(rv), nil
|
return int(rv), nil
|
||||||
}
|
}
|
||||||
return 0, c.getErrorHandler(rv, errno)
|
return 0, c.getErrorHandler(rv, errno)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read reads up to len(b) bytes into b. It returns the number of bytes read
|
// Read reads up to len(b) bytes into b. It returns the number of bytes read
|
||||||
// and an error if applicable. io.EOF is returned when the caller can expect
|
// and an error if applicable. io.EOF is returned when the caller can expect
|
||||||
// to see no more data.
|
// to see no more data.
|
||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
err = tryAgain
|
err = tryAgain
|
||||||
for err == tryAgain {
|
for err == tryAgain {
|
||||||
n, errcb := c.read(b)
|
n, errcb := c.read(b)
|
||||||
err = c.handleError(errcb)
|
err = c.handleError(errcb)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
go c.flushOutputBuffer()
|
go c.flushOutputBuffer()
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
if err == io.ErrUnexpectedEOF {
|
if err == io.ErrUnexpectedEOF {
|
||||||
err = io.EOF
|
err = io.EOF
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) write(b []byte) (int, func() error) {
|
func (c *Conn) write(b []byte) (int, func() error) {
|
||||||
c.mtx.Lock()
|
c.mtx.Lock()
|
||||||
defer c.mtx.Unlock()
|
defer c.mtx.Unlock()
|
||||||
if c.is_shutdown {
|
if c.is_shutdown {
|
||||||
err := errors.New("connection closed")
|
err := errors.New("connection closed")
|
||||||
return 0, func() error { return err }
|
return 0, func() error { return err }
|
||||||
}
|
}
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
|
rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
|
||||||
if rv > 0 {
|
if rv > 0 {
|
||||||
return int(rv), nil
|
return int(rv), nil
|
||||||
}
|
}
|
||||||
return 0, c.getErrorHandler(rv, errno)
|
return 0, c.getErrorHandler(rv, errno)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write will encrypt the contents of b and write it to the underlying stream.
|
// Write will encrypt the contents of b and write it to the underlying stream.
|
||||||
// Performance will be vastly improved if the size of b is a multiple of
|
// Performance will be vastly improved if the size of b is a multiple of
|
||||||
// SSLRecordSize.
|
// SSLRecordSize.
|
||||||
func (c *Conn) Write(b []byte) (written int, err error) {
|
func (c *Conn) Write(b []byte) (written int, err error) {
|
||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
err = tryAgain
|
err = tryAgain
|
||||||
for err == tryAgain {
|
for err == tryAgain {
|
||||||
n, errcb := c.write(b)
|
n, errcb := c.write(b)
|
||||||
err = c.handleError(errcb)
|
err = c.handleError(errcb)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return n, c.flushOutputBuffer()
|
return n, c.flushOutputBuffer()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyHostname pulls the PeerCertificate and calls VerifyHostname on the
|
// VerifyHostname pulls the PeerCertificate and calls VerifyHostname on the
|
||||||
// certificate.
|
// certificate.
|
||||||
func (c *Conn) VerifyHostname(host string) error {
|
func (c *Conn) VerifyHostname(host string) error {
|
||||||
cert, err := c.PeerCertificate()
|
cert, err := c.PeerCertificate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return cert.VerifyHostname(host)
|
return cert.VerifyHostname(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalAddr returns the underlying connection's local address
|
// LocalAddr returns the underlying connection's local address
|
||||||
func (c *Conn) LocalAddr() net.Addr {
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
return c.conn.LocalAddr()
|
return c.conn.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoteAddr returns the underlying connection's remote address
|
// RemoteAddr returns the underlying connection's remote address
|
||||||
func (c *Conn) RemoteAddr() net.Addr {
|
func (c *Conn) RemoteAddr() net.Addr {
|
||||||
return c.conn.RemoteAddr()
|
return c.conn.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetDeadline calls SetDeadline on the underlying connection.
|
// SetDeadline calls SetDeadline on the underlying connection.
|
||||||
func (c *Conn) SetDeadline(t time.Time) error {
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
return c.conn.SetDeadline(t)
|
return c.conn.SetDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetReadDeadline calls SetReadDeadline on the underlying connection.
|
// SetReadDeadline calls SetReadDeadline on the underlying connection.
|
||||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
return c.conn.SetReadDeadline(t)
|
return c.conn.SetReadDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetWriteDeadline calls SetWriteDeadline on the underlying connection.
|
// SetWriteDeadline calls SetWriteDeadline on the underlying connection.
|
||||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
return c.conn.SetWriteDeadline(t)
|
return c.conn.SetWriteDeadline(t)
|
||||||
}
|
}
|
||||||
|
316
ctx.go
316
ctx.go
@ -33,158 +33,158 @@ package openssl
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"runtime"
|
"runtime"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Ctx struct {
|
type Ctx struct {
|
||||||
ctx *C.SSL_CTX
|
ctx *C.SSL_CTX
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCtx(method *C.SSL_METHOD) (*Ctx, error) {
|
func newCtx(method *C.SSL_METHOD) (*Ctx, error) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
ctx := C.SSL_CTX_new(method)
|
ctx := C.SSL_CTX_new(method)
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return nil, errorFromErrorQueue()
|
return nil, errorFromErrorQueue()
|
||||||
}
|
}
|
||||||
c := &Ctx{ctx: ctx}
|
c := &Ctx{ctx: ctx}
|
||||||
runtime.SetFinalizer(c, func(c *Ctx) {
|
runtime.SetFinalizer(c, func(c *Ctx) {
|
||||||
C.SSL_CTX_free(c.ctx)
|
C.SSL_CTX_free(c.ctx)
|
||||||
})
|
})
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type SSLVersion int
|
type SSLVersion int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SSLv3 SSLVersion = 0x02
|
SSLv3 SSLVersion = 0x02
|
||||||
TLSv1 SSLVersion = 0x03
|
TLSv1 SSLVersion = 0x03
|
||||||
TLSv1_1 SSLVersion = 0x04
|
TLSv1_1 SSLVersion = 0x04
|
||||||
TLSv1_2 SSLVersion = 0x05
|
TLSv1_2 SSLVersion = 0x05
|
||||||
AnyVersion SSLVersion = 0x06
|
AnyVersion SSLVersion = 0x06
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewCtxWithVersion creates an SSL context that is specific to the provided
|
// NewCtxWithVersion creates an SSL context that is specific to the provided
|
||||||
// SSL version. See http://www.openssl.org/docs/ssl/SSL_CTX_new.html for more.
|
// SSL version. See http://www.openssl.org/docs/ssl/SSL_CTX_new.html for more.
|
||||||
func NewCtxWithVersion(version SSLVersion) (*Ctx, error) {
|
func NewCtxWithVersion(version SSLVersion) (*Ctx, error) {
|
||||||
var method *C.SSL_METHOD
|
var method *C.SSL_METHOD
|
||||||
switch version {
|
switch version {
|
||||||
case SSLv3:
|
case SSLv3:
|
||||||
method = C.SSLv3_method()
|
method = C.SSLv3_method()
|
||||||
case TLSv1:
|
case TLSv1:
|
||||||
method = C.TLSv1_method()
|
method = C.TLSv1_method()
|
||||||
case TLSv1_1:
|
case TLSv1_1:
|
||||||
method = C.TLSv1_1_method()
|
method = C.TLSv1_1_method()
|
||||||
case TLSv1_2:
|
case TLSv1_2:
|
||||||
method = C.TLSv1_2_method()
|
method = C.TLSv1_2_method()
|
||||||
case AnyVersion:
|
case AnyVersion:
|
||||||
method = C.SSLv23_method()
|
method = C.SSLv23_method()
|
||||||
}
|
}
|
||||||
if method == nil {
|
if method == nil {
|
||||||
return nil, errors.New("unknown ssl/tls version")
|
return nil, errors.New("unknown ssl/tls version")
|
||||||
}
|
}
|
||||||
return newCtx(method)
|
return newCtx(method)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCtx creates a context that supports any TLS version 1.0 and newer.
|
// NewCtx creates a context that supports any TLS version 1.0 and newer.
|
||||||
func NewCtx() (*Ctx, error) {
|
func NewCtx() (*Ctx, error) {
|
||||||
c, err := NewCtxWithVersion(AnyVersion)
|
c, err := NewCtxWithVersion(AnyVersion)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.SetOptions(NoSSLv2 | NoSSLv3)
|
c.SetOptions(NoSSLv2 | NoSSLv3)
|
||||||
}
|
}
|
||||||
return c, err
|
return c, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCtxFromFiles calls NewCtx, loads the provided files, and configures the
|
// NewCtxFromFiles calls NewCtx, loads the provided files, and configures the
|
||||||
// context to use them.
|
// context to use them.
|
||||||
func NewCtxFromFiles(cert_file string, key_file string) (*Ctx, error) {
|
func NewCtxFromFiles(cert_file string, key_file string) (*Ctx, error) {
|
||||||
ctx, err := NewCtx()
|
ctx, err := NewCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cert_bytes, err := ioutil.ReadFile(cert_file)
|
cert_bytes, err := ioutil.ReadFile(cert_file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cert, err := LoadCertificate(cert_bytes)
|
cert, err := LoadCertificate(cert_bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ctx.UseCertificate(cert)
|
err = ctx.UseCertificate(cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
key_bytes, err := ioutil.ReadFile(key_file)
|
key_bytes, err := ioutil.ReadFile(key_file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := LoadPrivateKey(key_bytes)
|
key, err := LoadPrivateKey(key_bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ctx.UsePrivateKey(key)
|
err = ctx.UsePrivateKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return ctx, nil
|
return ctx, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UseCertificate configures the context to present the given certificate to
|
// UseCertificate configures the context to present the given certificate to
|
||||||
// peers.
|
// peers.
|
||||||
func (c *Ctx) UseCertificate(cert *Certificate) error {
|
func (c *Ctx) UseCertificate(cert *Certificate) error {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
if int(C.SSL_CTX_use_certificate(c.ctx, cert.x)) != 1 {
|
if int(C.SSL_CTX_use_certificate(c.ctx, cert.x)) != 1 {
|
||||||
return errorFromErrorQueue()
|
return errorFromErrorQueue()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsePrivateKey configures the context to use the given private key for SSL
|
// UsePrivateKey configures the context to use the given private key for SSL
|
||||||
// handshakes.
|
// handshakes.
|
||||||
func (c *Ctx) UsePrivateKey(key PrivateKey) error {
|
func (c *Ctx) UsePrivateKey(key PrivateKey) error {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
if int(C.SSL_CTX_use_PrivateKey(c.ctx, key.evpPKey())) != 1 {
|
if int(C.SSL_CTX_use_PrivateKey(c.ctx, key.evpPKey())) != 1 {
|
||||||
return errorFromErrorQueue()
|
return errorFromErrorQueue()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type CertificateStore struct {
|
type CertificateStore struct {
|
||||||
store *C.X509_STORE
|
store *C.X509_STORE
|
||||||
ctx *Ctx // for gc
|
ctx *Ctx // for gc
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCertificateStore returns the context's certificate store that will be
|
// GetCertificateStore returns the context's certificate store that will be
|
||||||
// used for peer validation.
|
// used for peer validation.
|
||||||
func (c *Ctx) GetCertificateStore() *CertificateStore {
|
func (c *Ctx) GetCertificateStore() *CertificateStore {
|
||||||
// we don't need to dealloc the cert store pointer here, because it points
|
// we don't need to dealloc the cert store pointer here, because it points
|
||||||
// to a ctx internal. so we do need to keep the ctx around
|
// to a ctx internal. so we do need to keep the ctx around
|
||||||
return &CertificateStore{
|
return &CertificateStore{
|
||||||
store: C.SSL_CTX_get_cert_store(c.ctx),
|
store: C.SSL_CTX_get_cert_store(c.ctx),
|
||||||
ctx: c}
|
ctx: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddCertificate marks the provided Certificate as a trusted certificate in
|
// AddCertificate marks the provided Certificate as a trusted certificate in
|
||||||
// the given CertificateStore.
|
// the given CertificateStore.
|
||||||
func (s *CertificateStore) AddCertificate(cert *Certificate) error {
|
func (s *CertificateStore) AddCertificate(cert *Certificate) error {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
if int(C.X509_STORE_add_cert(s.store, cert.x)) != 1 {
|
if int(C.X509_STORE_add_cert(s.store, cert.x)) != 1 {
|
||||||
return errorFromErrorQueue()
|
return errorFromErrorQueue()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadVerifyLocations tells the context to trust all certificate authorities
|
// LoadVerifyLocations tells the context to trust all certificate authorities
|
||||||
@ -192,120 +192,120 @@ func (s *CertificateStore) AddCertificate(cert *Certificate) error {
|
|||||||
// See http://www.openssl.org/docs/ssl/SSL_CTX_load_verify_locations.html for
|
// See http://www.openssl.org/docs/ssl/SSL_CTX_load_verify_locations.html for
|
||||||
// more.
|
// more.
|
||||||
func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error {
|
func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
var c_ca_file, c_ca_path *C.char
|
var c_ca_file, c_ca_path *C.char
|
||||||
if ca_file != "" {
|
if ca_file != "" {
|
||||||
c_ca_file = C.CString(ca_file)
|
c_ca_file = C.CString(ca_file)
|
||||||
defer C.free(unsafe.Pointer(c_ca_file))
|
defer C.free(unsafe.Pointer(c_ca_file))
|
||||||
}
|
}
|
||||||
if ca_path != "" {
|
if ca_path != "" {
|
||||||
c_ca_path = C.CString(ca_path)
|
c_ca_path = C.CString(ca_path)
|
||||||
defer C.free(unsafe.Pointer(c_ca_path))
|
defer C.free(unsafe.Pointer(c_ca_path))
|
||||||
}
|
}
|
||||||
if C.SSL_CTX_load_verify_locations(c.ctx, c_ca_file, c_ca_path) != 1 {
|
if C.SSL_CTX_load_verify_locations(c.ctx, c_ca_file, c_ca_path) != 1 {
|
||||||
return errorFromErrorQueue()
|
return errorFromErrorQueue()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Options int
|
type Options int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// NoCompression is only valid if you are using OpenSSL 1.0.1 or newer
|
// NoCompression is only valid if you are using OpenSSL 1.0.1 or newer
|
||||||
NoCompression Options = C.SSL_OP_NO_COMPRESSION
|
NoCompression Options = C.SSL_OP_NO_COMPRESSION
|
||||||
NoSSLv2 Options = C.SSL_OP_NO_SSLv2
|
NoSSLv2 Options = C.SSL_OP_NO_SSLv2
|
||||||
NoSSLv3 Options = C.SSL_OP_NO_SSLv3
|
NoSSLv3 Options = C.SSL_OP_NO_SSLv3
|
||||||
NoTLSv1 Options = C.SSL_OP_NO_TLSv1
|
NoTLSv1 Options = C.SSL_OP_NO_TLSv1
|
||||||
CipherServerPreference Options = C.SSL_OP_CIPHER_SERVER_PREFERENCE
|
CipherServerPreference Options = C.SSL_OP_CIPHER_SERVER_PREFERENCE
|
||||||
NoSessionResumptionOrRenegotiation Options = C.SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION
|
NoSessionResumptionOrRenegotiation Options = C.SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION
|
||||||
NoTicket Options = C.SSL_OP_NO_TICKET
|
NoTicket Options = C.SSL_OP_NO_TICKET
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetOptions sets context options. See
|
// SetOptions sets context options. See
|
||||||
// http://www.openssl.org/docs/ssl/SSL_CTX_set_options.html
|
// http://www.openssl.org/docs/ssl/SSL_CTX_set_options.html
|
||||||
func (c *Ctx) SetOptions(options Options) Options {
|
func (c *Ctx) SetOptions(options Options) Options {
|
||||||
return Options(C.SSL_CTX_set_options_not_a_macro(
|
return Options(C.SSL_CTX_set_options_not_a_macro(
|
||||||
c.ctx, C.long(options)))
|
c.ctx, C.long(options)))
|
||||||
}
|
}
|
||||||
|
|
||||||
type Modes int
|
type Modes int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ReleaseBuffers is only valid if you are using OpenSSL 1.0.1 or newer
|
// ReleaseBuffers is only valid if you are using OpenSSL 1.0.1 or newer
|
||||||
ReleaseBuffers Modes = C.SSL_MODE_RELEASE_BUFFERS
|
ReleaseBuffers Modes = C.SSL_MODE_RELEASE_BUFFERS
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetMode sets context modes. See
|
// SetMode sets context modes. See
|
||||||
// http://www.openssl.org/docs/ssl/SSL_CTX_set_mode.html
|
// http://www.openssl.org/docs/ssl/SSL_CTX_set_mode.html
|
||||||
func (c *Ctx) SetMode(modes Modes) Modes {
|
func (c *Ctx) SetMode(modes Modes) Modes {
|
||||||
return Modes(C.SSL_CTX_set_mode_not_a_macro(c.ctx, C.long(modes)))
|
return Modes(C.SSL_CTX_set_mode_not_a_macro(c.ctx, C.long(modes)))
|
||||||
}
|
}
|
||||||
|
|
||||||
type VerifyOptions int
|
type VerifyOptions int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
VerifyNone VerifyOptions = C.SSL_VERIFY_NONE
|
VerifyNone VerifyOptions = C.SSL_VERIFY_NONE
|
||||||
VerifyPeer VerifyOptions = C.SSL_VERIFY_PEER
|
VerifyPeer VerifyOptions = C.SSL_VERIFY_PEER
|
||||||
VerifyFailIfNoPeerCert VerifyOptions = C.SSL_VERIFY_FAIL_IF_NO_PEER_CERT
|
VerifyFailIfNoPeerCert VerifyOptions = C.SSL_VERIFY_FAIL_IF_NO_PEER_CERT
|
||||||
VerifyClientOnce VerifyOptions = C.SSL_VERIFY_CLIENT_ONCE
|
VerifyClientOnce VerifyOptions = C.SSL_VERIFY_CLIENT_ONCE
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetVerify controls peer verification settings. See
|
// SetVerify controls peer verification settings. See
|
||||||
// http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html
|
// http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html
|
||||||
func (c *Ctx) SetVerify(options VerifyOptions) {
|
func (c *Ctx) SetVerify(options VerifyOptions) {
|
||||||
// TODO: take a callback
|
// TODO: take a callback
|
||||||
C.SSL_CTX_set_verify(c.ctx, C.int(options), nil)
|
C.SSL_CTX_set_verify(c.ctx, C.int(options), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetVerifyDepth controls how many certificates deep the certificate
|
// SetVerifyDepth controls how many certificates deep the certificate
|
||||||
// verification logic is willing to follow a certificate chain. See
|
// verification logic is willing to follow a certificate chain. See
|
||||||
// https://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html
|
// https://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html
|
||||||
func (c *Ctx) SetVerifyDepth(depth int) {
|
func (c *Ctx) SetVerifyDepth(depth int) {
|
||||||
C.SSL_CTX_set_verify_depth(c.ctx, C.int(depth))
|
C.SSL_CTX_set_verify_depth(c.ctx, C.int(depth))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Ctx) SetSessionId(session_id []byte) error {
|
func (c *Ctx) SetSessionId(session_id []byte) error {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
if int(C.SSL_CTX_set_session_id_context(c.ctx,
|
if int(C.SSL_CTX_set_session_id_context(c.ctx,
|
||||||
(*C.uchar)(unsafe.Pointer(&session_id[0])),
|
(*C.uchar)(unsafe.Pointer(&session_id[0])),
|
||||||
C.uint(len(session_id)))) == 0 {
|
C.uint(len(session_id)))) == 0 {
|
||||||
return errorFromErrorQueue()
|
return errorFromErrorQueue()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCipherList sets the list of available ciphers. The format of the list is
|
// SetCipherList sets the list of available ciphers. The format of the list is
|
||||||
// described at http://www.openssl.org/docs/apps/ciphers.html, but see
|
// described at http://www.openssl.org/docs/apps/ciphers.html, but see
|
||||||
// http://www.openssl.org/docs/ssl/SSL_CTX_set_cipher_list.html for more.
|
// http://www.openssl.org/docs/ssl/SSL_CTX_set_cipher_list.html for more.
|
||||||
func (c *Ctx) SetCipherList(list string) error {
|
func (c *Ctx) SetCipherList(list string) error {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
clist := C.CString(list)
|
clist := C.CString(list)
|
||||||
defer C.free(unsafe.Pointer(clist))
|
defer C.free(unsafe.Pointer(clist))
|
||||||
if int(C.SSL_CTX_set_cipher_list(c.ctx, clist)) == 0 {
|
if int(C.SSL_CTX_set_cipher_list(c.ctx, clist)) == 0 {
|
||||||
return errorFromErrorQueue()
|
return errorFromErrorQueue()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionCacheModes int
|
type SessionCacheModes int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SessionCacheOff SessionCacheModes = C.SSL_SESS_CACHE_OFF
|
SessionCacheOff SessionCacheModes = C.SSL_SESS_CACHE_OFF
|
||||||
SessionCacheClient SessionCacheModes = C.SSL_SESS_CACHE_CLIENT
|
SessionCacheClient SessionCacheModes = C.SSL_SESS_CACHE_CLIENT
|
||||||
SessionCacheServer SessionCacheModes = C.SSL_SESS_CACHE_SERVER
|
SessionCacheServer SessionCacheModes = C.SSL_SESS_CACHE_SERVER
|
||||||
SessionCacheBoth SessionCacheModes = C.SSL_SESS_CACHE_BOTH
|
SessionCacheBoth SessionCacheModes = C.SSL_SESS_CACHE_BOTH
|
||||||
NoAutoClear SessionCacheModes = C.SSL_SESS_CACHE_NO_AUTO_CLEAR
|
NoAutoClear SessionCacheModes = C.SSL_SESS_CACHE_NO_AUTO_CLEAR
|
||||||
NoInternalLookup SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_LOOKUP
|
NoInternalLookup SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_LOOKUP
|
||||||
NoInternalStore SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_STORE
|
NoInternalStore SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL_STORE
|
||||||
NoInternal SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL
|
NoInternal SessionCacheModes = C.SSL_SESS_CACHE_NO_INTERNAL
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetSessionCacheMode enables or disables session caching. See
|
// SetSessionCacheMode enables or disables session caching. See
|
||||||
// http://www.openssl.org/docs/ssl/SSL_CTX_set_session_cache_mode.html
|
// http://www.openssl.org/docs/ssl/SSL_CTX_set_session_cache_mode.html
|
||||||
func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes {
|
func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes {
|
||||||
return SessionCacheModes(
|
return SessionCacheModes(
|
||||||
C.SSL_CTX_set_session_cache_mode_not_a_macro(c.ctx, C.long(modes)))
|
C.SSL_CTX_set_session_cache_mode_not_a_macro(c.ctx, C.long(modes)))
|
||||||
}
|
}
|
||||||
|
96
hostname.go
96
hostname.go
@ -23,20 +23,20 @@ extern int X509_check_ip(X509 *x, const unsigned char *chk, size_t chklen,
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ValidationError = errors.New("Host validation error")
|
ValidationError = errors.New("Host validation error")
|
||||||
)
|
)
|
||||||
|
|
||||||
type CheckFlags int
|
type CheckFlags int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AlwaysCheckSubject CheckFlags = C.X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT
|
AlwaysCheckSubject CheckFlags = C.X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT
|
||||||
NoWildcards CheckFlags = C.X509_CHECK_FLAG_NO_WILDCARDS
|
NoWildcards CheckFlags = C.X509_CHECK_FLAG_NO_WILDCARDS
|
||||||
)
|
)
|
||||||
|
|
||||||
// CheckHost checks that the X509 certificate is signed for the provided
|
// CheckHost checks that the X509 certificate is signed for the provided
|
||||||
@ -45,17 +45,17 @@ const (
|
|||||||
// Specifically returns ValidationError if the Certificate didn't match but
|
// Specifically returns ValidationError if the Certificate didn't match but
|
||||||
// there was no internal error.
|
// there was no internal error.
|
||||||
func (c *Certificate) CheckHost(host string, flags CheckFlags) error {
|
func (c *Certificate) CheckHost(host string, flags CheckFlags) error {
|
||||||
chost := unsafe.Pointer(C.CString(host))
|
chost := unsafe.Pointer(C.CString(host))
|
||||||
defer C.free(chost)
|
defer C.free(chost)
|
||||||
rv := C.X509_check_host(c.x, (*C.uchar)(chost), C.size_t(len(host)),
|
rv := C.X509_check_host(c.x, (*C.uchar)(chost), C.size_t(len(host)),
|
||||||
C.uint(flags))
|
C.uint(flags))
|
||||||
if rv > 0 {
|
if rv > 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if rv == 0 {
|
if rv == 0 {
|
||||||
return ValidationError
|
return ValidationError
|
||||||
}
|
}
|
||||||
return errors.New("hostname validation had an internal failure")
|
return errors.New("hostname validation had an internal failure")
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckEmail checks that the X509 certificate is signed for the provided
|
// CheckEmail checks that the X509 certificate is signed for the provided
|
||||||
@ -64,17 +64,17 @@ func (c *Certificate) CheckHost(host string, flags CheckFlags) error {
|
|||||||
// Specifically returns ValidationError if the Certificate didn't match but
|
// Specifically returns ValidationError if the Certificate didn't match but
|
||||||
// there was no internal error.
|
// there was no internal error.
|
||||||
func (c *Certificate) CheckEmail(email string, flags CheckFlags) error {
|
func (c *Certificate) CheckEmail(email string, flags CheckFlags) error {
|
||||||
cemail := unsafe.Pointer(C.CString(email))
|
cemail := unsafe.Pointer(C.CString(email))
|
||||||
defer C.free(cemail)
|
defer C.free(cemail)
|
||||||
rv := C.X509_check_email(c.x, (*C.uchar)(cemail), C.size_t(len(email)),
|
rv := C.X509_check_email(c.x, (*C.uchar)(cemail), C.size_t(len(email)),
|
||||||
C.uint(flags))
|
C.uint(flags))
|
||||||
if rv > 0 {
|
if rv > 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if rv == 0 {
|
if rv == 0 {
|
||||||
return ValidationError
|
return ValidationError
|
||||||
}
|
}
|
||||||
return errors.New("email validation had an internal failure")
|
return errors.New("email validation had an internal failure")
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckIP checks that the X509 certificate is signed for the provided
|
// CheckIP checks that the X509 certificate is signed for the provided
|
||||||
@ -83,16 +83,16 @@ func (c *Certificate) CheckEmail(email string, flags CheckFlags) error {
|
|||||||
// Specifically returns ValidationError if the Certificate didn't match but
|
// Specifically returns ValidationError if the Certificate didn't match but
|
||||||
// there was no internal error.
|
// there was no internal error.
|
||||||
func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error {
|
func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error {
|
||||||
cip := unsafe.Pointer(&ip[0])
|
cip := unsafe.Pointer(&ip[0])
|
||||||
rv := C.X509_check_ip(c.x, (*C.uchar)(cip), C.size_t(len(ip)),
|
rv := C.X509_check_ip(c.x, (*C.uchar)(cip), C.size_t(len(ip)),
|
||||||
C.uint(flags))
|
C.uint(flags))
|
||||||
if rv > 0 {
|
if rv > 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if rv == 0 {
|
if rv == 0 {
|
||||||
return ValidationError
|
return ValidationError
|
||||||
}
|
}
|
||||||
return errors.New("ip validation had an internal failure")
|
return errors.New("ip validation had an internal failure")
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyHostname is a combination of CheckHost and CheckIP. If the provided
|
// VerifyHostname is a combination of CheckHost and CheckIP. If the provided
|
||||||
@ -101,14 +101,14 @@ func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error {
|
|||||||
// Specifically returns ValidationError if the Certificate didn't match but
|
// Specifically returns ValidationError if the Certificate didn't match but
|
||||||
// there was no internal error.
|
// there was no internal error.
|
||||||
func (c *Certificate) VerifyHostname(host string) error {
|
func (c *Certificate) VerifyHostname(host string) error {
|
||||||
var ip net.IP
|
var ip net.IP
|
||||||
if len(host) >= 3 && host[0] == '[' && host[len(host)-1] == ']' {
|
if len(host) >= 3 && host[0] == '[' && host[len(host)-1] == ']' {
|
||||||
ip = net.ParseIP(host[1 : len(host)-1])
|
ip = net.ParseIP(host[1 : len(host)-1])
|
||||||
} else {
|
} else {
|
||||||
ip = net.ParseIP(host)
|
ip = net.ParseIP(host)
|
||||||
}
|
}
|
||||||
if ip != nil {
|
if ip != nil {
|
||||||
return c.CheckIP(ip, 0)
|
return c.CheckIP(ip, 0)
|
||||||
}
|
}
|
||||||
return c.CheckHost(host, 0)
|
return c.CheckHost(host, 0)
|
||||||
}
|
}
|
||||||
|
36
http.go
36
http.go
@ -3,37 +3,37 @@
|
|||||||
package openssl
|
package openssl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ListenAndServeTLS will take an http.Handler and serve it using OpenSSL over
|
// ListenAndServeTLS will take an http.Handler and serve it using OpenSSL over
|
||||||
// the given tcp address, configured to use the provided cert and key files.
|
// the given tcp address, configured to use the provided cert and key files.
|
||||||
func ListenAndServeTLS(addr string, cert_file string, key_file string,
|
func ListenAndServeTLS(addr string, cert_file string, key_file string,
|
||||||
handler http.Handler) error {
|
handler http.Handler) error {
|
||||||
return ServerListenAndServeTLS(
|
return ServerListenAndServeTLS(
|
||||||
&http.Server{Addr: addr, Handler: handler}, cert_file, key_file)
|
&http.Server{Addr: addr, Handler: handler}, cert_file, key_file)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerListenAndServeTLS will take an http.Server and serve it using OpenSSL
|
// ServerListenAndServeTLS will take an http.Server and serve it using OpenSSL
|
||||||
// configured to use the provided cert and key files.
|
// configured to use the provided cert and key files.
|
||||||
func ServerListenAndServeTLS(srv *http.Server,
|
func ServerListenAndServeTLS(srv *http.Server,
|
||||||
cert_file, key_file string) error {
|
cert_file, key_file string) error {
|
||||||
addr := srv.Addr
|
addr := srv.Addr
|
||||||
if addr == "" {
|
if addr == "" {
|
||||||
addr = ":https"
|
addr = ":https"
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, err := NewCtxFromFiles(cert_file, key_file)
|
ctx, err := NewCtxFromFiles(cert_file, key_file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
l, err := Listen("tcp", addr, ctx)
|
l, err := Listen("tcp", addr, ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return srv.Serve(l)
|
return srv.Serve(l)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: http client integration
|
// TODO: http client integration
|
||||||
|
66
init.go
66
init.go
@ -68,58 +68,58 @@ static void OpenSSL_add_all_algorithms_not_a_macro() {
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"code.spacemonkey.com/go/openssl/utils"
|
"code.spacemonkey.com/go/openssl/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
sslMutexes []sync.Mutex
|
sslMutexes []sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
C.OPENSSL_config(nil)
|
C.OPENSSL_config(nil)
|
||||||
C.ENGINE_load_builtin_engines()
|
C.ENGINE_load_builtin_engines()
|
||||||
C.SSL_load_error_strings()
|
C.SSL_load_error_strings()
|
||||||
C.SSL_library_init()
|
C.SSL_library_init()
|
||||||
C.OpenSSL_add_all_algorithms_not_a_macro()
|
C.OpenSSL_add_all_algorithms_not_a_macro()
|
||||||
sslMutexes = make([]sync.Mutex, int(C.CRYPTO_num_locks()))
|
sslMutexes = make([]sync.Mutex, int(C.CRYPTO_num_locks()))
|
||||||
C.CRYPTO_set_id_callback((*[0]byte)(C.sslThreadId))
|
C.CRYPTO_set_id_callback((*[0]byte)(C.sslThreadId))
|
||||||
C.CRYPTO_set_locking_callback((*[0]byte)(C.sslMutexOp))
|
C.CRYPTO_set_locking_callback((*[0]byte)(C.sslMutexOp))
|
||||||
|
|
||||||
// TODO: support dynlock callbacks
|
// TODO: support dynlock callbacks
|
||||||
}
|
}
|
||||||
|
|
||||||
// errorFromErrorQueue needs to run in the same OS thread as the operation
|
// errorFromErrorQueue needs to run in the same OS thread as the operation
|
||||||
// that caused the possible error
|
// that caused the possible error
|
||||||
func errorFromErrorQueue() error {
|
func errorFromErrorQueue() error {
|
||||||
var errs []string
|
var errs []string
|
||||||
for {
|
for {
|
||||||
err := C.ERR_get_error()
|
err := C.ERR_get_error()
|
||||||
if err == 0 {
|
if err == 0 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
errs = append(errs, fmt.Sprintf("%s:%s:%s",
|
errs = append(errs, fmt.Sprintf("%s:%s:%s",
|
||||||
C.GoString(C.ERR_lib_error_string(err)),
|
C.GoString(C.ERR_lib_error_string(err)),
|
||||||
C.GoString(C.ERR_func_error_string(err)),
|
C.GoString(C.ERR_func_error_string(err)),
|
||||||
C.GoString(C.ERR_reason_error_string(err))))
|
C.GoString(C.ERR_reason_error_string(err))))
|
||||||
}
|
}
|
||||||
return errors.New(fmt.Sprintf("SSL errors: %s", strings.Join(errs, "\n")))
|
return errors.New(fmt.Sprintf("SSL errors: %s", strings.Join(errs, "\n")))
|
||||||
}
|
}
|
||||||
|
|
||||||
//export sslMutexOp
|
//export sslMutexOp
|
||||||
func sslMutexOp(mode, n C.int, file *C.char, line C.int) {
|
func sslMutexOp(mode, n C.int, file *C.char, line C.int) {
|
||||||
if mode&C.CRYPTO_LOCK > 0 {
|
if mode&C.CRYPTO_LOCK > 0 {
|
||||||
sslMutexes[n].Lock()
|
sslMutexes[n].Lock()
|
||||||
} else {
|
} else {
|
||||||
sslMutexes[n].Unlock()
|
sslMutexes[n].Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//export sslThreadId
|
//export sslThreadId
|
||||||
func sslThreadId() C.ulong {
|
func sslThreadId() C.ulong {
|
||||||
return C.ulong(uintptr(utils.ThreadId()))
|
return C.ulong(uintptr(utils.ThreadId()))
|
||||||
}
|
}
|
||||||
|
112
net.go
112
net.go
@ -3,49 +3,49 @@
|
|||||||
package openssl
|
package openssl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type listener struct {
|
type listener struct {
|
||||||
net.Listener
|
net.Listener
|
||||||
ctx *Ctx
|
ctx *Ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *listener) Accept() (c net.Conn, err error) {
|
func (l *listener) Accept() (c net.Conn, err error) {
|
||||||
c, err = l.Listener.Accept()
|
c, err = l.Listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return Server(c, l.ctx)
|
return Server(c, l.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewListener wraps an existing net.Listener such that all accepted
|
// NewListener wraps an existing net.Listener such that all accepted
|
||||||
// connections are wrapped as OpenSSL server connections using the provided
|
// connections are wrapped as OpenSSL server connections using the provided
|
||||||
// context ctx.
|
// context ctx.
|
||||||
func NewListener(inner net.Listener, ctx *Ctx) net.Listener {
|
func NewListener(inner net.Listener, ctx *Ctx) net.Listener {
|
||||||
return &listener{
|
return &listener{
|
||||||
Listener: inner,
|
Listener: inner,
|
||||||
ctx: ctx}
|
ctx: ctx}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen is a wrapper around net.Listen that wraps incoming connections with
|
// Listen is a wrapper around net.Listen that wraps incoming connections with
|
||||||
// an OpenSSL server connection using the provided context ctx.
|
// an OpenSSL server connection using the provided context ctx.
|
||||||
func Listen(network, laddr string, ctx *Ctx) (net.Listener, error) {
|
func Listen(network, laddr string, ctx *Ctx) (net.Listener, error) {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return nil, errors.New("no ssl context provided")
|
return nil, errors.New("no ssl context provided")
|
||||||
}
|
}
|
||||||
l, err := net.Listen(network, laddr)
|
l, err := net.Listen(network, laddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewListener(l, ctx), nil
|
return NewListener(l, ctx), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type DialFlags int
|
type DialFlags int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
InsecureSkipHostVerification DialFlags = 0x01
|
InsecureSkipHostVerification DialFlags = 0x01
|
||||||
)
|
)
|
||||||
|
|
||||||
// Dial will connect to network/address and then wrap the corresponding
|
// Dial will connect to network/address and then wrap the corresponding
|
||||||
@ -59,39 +59,39 @@ const (
|
|||||||
// This library is not nice enough to use the system certificate store by
|
// This library is not nice enough to use the system certificate store by
|
||||||
// default for you yet.
|
// default for you yet.
|
||||||
func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
|
func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
var err error
|
var err error
|
||||||
ctx, err = NewCtx()
|
ctx, err = NewCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// TODO: use operating system default certificate chain?
|
// TODO: use operating system default certificate chain?
|
||||||
}
|
}
|
||||||
c, err := net.Dial(network, addr)
|
c, err := net.Dial(network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
conn, err := Client(c, ctx)
|
conn, err := Client(c, ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Close()
|
c.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = conn.Handshake()
|
err = conn.Handshake()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Close()
|
c.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if flags&InsecureSkipHostVerification == 0 {
|
if flags&InsecureSkipHostVerification == 0 {
|
||||||
host, _, err := net.SplitHostPort(addr)
|
host, _, err := net.SplitHostPort(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = conn.VerifyHostname(host)
|
err = conn.VerifyHostname(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
@ -4,13 +4,13 @@
|
|||||||
package openssl
|
package openssl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SSLRecordSize = 16 * 1024
|
SSLRecordSize = 16 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn struct{}
|
type Conn struct{}
|
||||||
@ -37,11 +37,11 @@ type Ctx struct{}
|
|||||||
type SSLVersion int
|
type SSLVersion int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SSLv3 SSLVersion = 0x02
|
SSLv3 SSLVersion = 0x02
|
||||||
TLSv1 SSLVersion = 0x03
|
TLSv1 SSLVersion = 0x03
|
||||||
TLSv1_1 SSLVersion = 0x04
|
TLSv1_1 SSLVersion = 0x04
|
||||||
TLSv1_2 SSLVersion = 0x05
|
TLSv1_2 SSLVersion = 0x05
|
||||||
AnyVersion SSLVersion = 0x06
|
AnyVersion SSLVersion = 0x06
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewCtxWithVersion(version SSLVersion) (*Ctx, error)
|
func NewCtxWithVersion(version SSLVersion) (*Ctx, error)
|
||||||
@ -61,13 +61,13 @@ func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error
|
|||||||
type Options int
|
type Options int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NoCompression Options = 0
|
NoCompression Options = 0
|
||||||
NoSSLv2 Options = 0
|
NoSSLv2 Options = 0
|
||||||
NoSSLv3 Options = 0
|
NoSSLv3 Options = 0
|
||||||
NoTLSv1 Options = 0
|
NoTLSv1 Options = 0
|
||||||
CipherServerPreference Options = 0
|
CipherServerPreference Options = 0
|
||||||
NoSessionResumptionOrRenegotiation Options = 0
|
NoSessionResumptionOrRenegotiation Options = 0
|
||||||
NoTicket Options = 0
|
NoTicket Options = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Ctx) SetOptions(options Options) Options
|
func (c *Ctx) SetOptions(options Options) Options
|
||||||
@ -75,7 +75,7 @@ func (c *Ctx) SetOptions(options Options) Options
|
|||||||
type Modes int
|
type Modes int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ReleaseBuffers Modes = 0
|
ReleaseBuffers Modes = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Ctx) SetMode(modes Modes) Modes
|
func (c *Ctx) SetMode(modes Modes) Modes
|
||||||
@ -83,10 +83,10 @@ func (c *Ctx) SetMode(modes Modes) Modes
|
|||||||
type VerifyOptions int
|
type VerifyOptions int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
VerifyNone VerifyOptions = 0
|
VerifyNone VerifyOptions = 0
|
||||||
VerifyPeer VerifyOptions = 0
|
VerifyPeer VerifyOptions = 0
|
||||||
VerifyFailIfNoPeerCert VerifyOptions = 0
|
VerifyFailIfNoPeerCert VerifyOptions = 0
|
||||||
VerifyClientOnce VerifyOptions = 0
|
VerifyClientOnce VerifyOptions = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Ctx) SetVerify(options VerifyOptions)
|
func (c *Ctx) SetVerify(options VerifyOptions)
|
||||||
@ -98,27 +98,27 @@ func (c *Ctx) SetCipherList(list string) error
|
|||||||
type SessionCacheModes int
|
type SessionCacheModes int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SessionCacheOff SessionCacheModes = 0
|
SessionCacheOff SessionCacheModes = 0
|
||||||
SessionCacheClient SessionCacheModes = 0
|
SessionCacheClient SessionCacheModes = 0
|
||||||
SessionCacheServer SessionCacheModes = 0
|
SessionCacheServer SessionCacheModes = 0
|
||||||
SessionCacheBoth SessionCacheModes = 0
|
SessionCacheBoth SessionCacheModes = 0
|
||||||
NoAutoClear SessionCacheModes = 0
|
NoAutoClear SessionCacheModes = 0
|
||||||
NoInternalLookup SessionCacheModes = 0
|
NoInternalLookup SessionCacheModes = 0
|
||||||
NoInternalStore SessionCacheModes = 0
|
NoInternalStore SessionCacheModes = 0
|
||||||
NoInternal SessionCacheModes = 0
|
NoInternal SessionCacheModes = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes
|
func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ValidationError = errors.New("Host validation error")
|
ValidationError = errors.New("Host validation error")
|
||||||
)
|
)
|
||||||
|
|
||||||
type CheckFlags int
|
type CheckFlags int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AlwaysCheckSubject CheckFlags = 0
|
AlwaysCheckSubject CheckFlags = 0
|
||||||
NoWildcards CheckFlags = 0
|
NoWildcards CheckFlags = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Certificate) CheckHost(host string, flags CheckFlags) error
|
func (c *Certificate) CheckHost(host string, flags CheckFlags) error
|
||||||
@ -127,15 +127,15 @@ func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error
|
|||||||
func (c *Certificate) VerifyHostname(host string) error
|
func (c *Certificate) VerifyHostname(host string) error
|
||||||
|
|
||||||
type PublicKey interface {
|
type PublicKey interface {
|
||||||
MarshalPKIXPublicKeyPEM() (pem_block []byte, err error)
|
MarshalPKIXPublicKeyPEM() (pem_block []byte, err error)
|
||||||
MarshalPKIXPublicKeyDER() (der_block []byte, err error)
|
MarshalPKIXPublicKeyDER() (der_block []byte, err error)
|
||||||
evpPKey() struct{}
|
evpPKey() struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type PrivateKey interface {
|
type PrivateKey interface {
|
||||||
PublicKey
|
PublicKey
|
||||||
MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error)
|
MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error)
|
||||||
MarshalPKCS1PrivateKeyDER() (der_block []byte, err error)
|
MarshalPKCS1PrivateKeyDER() (der_block []byte, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadPrivateKey(pem_block []byte) (PrivateKey, error)
|
func LoadPrivateKey(pem_block []byte) (PrivateKey, error)
|
||||||
|
276
pem.go
276
pem.go
@ -8,191 +8,191 @@ package openssl
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"runtime"
|
"runtime"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PublicKey interface {
|
type PublicKey interface {
|
||||||
// MarshalPKIXPublicKeyPEM converts the public key to PEM-encoded PKIX
|
// MarshalPKIXPublicKeyPEM converts the public key to PEM-encoded PKIX
|
||||||
// format
|
// format
|
||||||
MarshalPKIXPublicKeyPEM() (pem_block []byte, err error)
|
MarshalPKIXPublicKeyPEM() (pem_block []byte, err error)
|
||||||
|
|
||||||
// MarshalPKIXPublicKeyDER converts the public key to DER-encoded PKIX
|
// MarshalPKIXPublicKeyDER converts the public key to DER-encoded PKIX
|
||||||
// format
|
// format
|
||||||
MarshalPKIXPublicKeyDER() (der_block []byte, err error)
|
MarshalPKIXPublicKeyDER() (der_block []byte, err error)
|
||||||
|
|
||||||
evpPKey() *C.EVP_PKEY
|
evpPKey() *C.EVP_PKEY
|
||||||
}
|
}
|
||||||
|
|
||||||
type PrivateKey interface {
|
type PrivateKey interface {
|
||||||
PublicKey
|
PublicKey
|
||||||
|
|
||||||
// MarshalPKCS1PrivateKeyPEM converts the private key to PEM-encoded PKCS1
|
// MarshalPKCS1PrivateKeyPEM converts the private key to PEM-encoded PKCS1
|
||||||
// format
|
// format
|
||||||
MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error)
|
MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error)
|
||||||
|
|
||||||
// MarshalPKCS1PrivateKeyDER converts the private key to DER-encoded PKCS1
|
// MarshalPKCS1PrivateKeyDER converts the private key to DER-encoded PKCS1
|
||||||
// format
|
// format
|
||||||
MarshalPKCS1PrivateKeyDER() (der_block []byte, err error)
|
MarshalPKCS1PrivateKeyDER() (der_block []byte, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type pKey struct {
|
type pKey struct {
|
||||||
key *C.EVP_PKEY
|
key *C.EVP_PKEY
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *pKey) evpPKey() *C.EVP_PKEY { return key.key }
|
func (key *pKey) evpPKey() *C.EVP_PKEY { return key.key }
|
||||||
|
|
||||||
func (key *pKey) MarshalPKCS1PrivateKeyPEM() (pem_block []byte,
|
func (key *pKey) MarshalPKCS1PrivateKeyPEM() (pem_block []byte,
|
||||||
err error) {
|
err error) {
|
||||||
bio := C.BIO_new(C.BIO_s_mem())
|
bio := C.BIO_new(C.BIO_s_mem())
|
||||||
if bio == nil {
|
if bio == nil {
|
||||||
return nil, errors.New("failed to allocate memory BIO")
|
return nil, errors.New("failed to allocate memory BIO")
|
||||||
}
|
}
|
||||||
defer C.BIO_free(bio)
|
defer C.BIO_free(bio)
|
||||||
rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key))
|
rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key))
|
||||||
if rsa == nil {
|
if rsa == nil {
|
||||||
return nil, errors.New("failed getting rsa key")
|
return nil, errors.New("failed getting rsa key")
|
||||||
}
|
}
|
||||||
defer C.RSA_free(rsa)
|
defer C.RSA_free(rsa)
|
||||||
if int(C.PEM_write_bio_RSAPrivateKey(bio, rsa, nil, nil, C.int(0), nil,
|
if int(C.PEM_write_bio_RSAPrivateKey(bio, rsa, nil, nil, C.int(0), nil,
|
||||||
nil)) != 1 {
|
nil)) != 1 {
|
||||||
return nil, errors.New("failed dumping private key")
|
return nil, errors.New("failed dumping private key")
|
||||||
}
|
}
|
||||||
return ioutil.ReadAll(asAnyBio(bio))
|
return ioutil.ReadAll(asAnyBio(bio))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *pKey) MarshalPKCS1PrivateKeyDER() (der_block []byte,
|
func (key *pKey) MarshalPKCS1PrivateKeyDER() (der_block []byte,
|
||||||
err error) {
|
err error) {
|
||||||
bio := C.BIO_new(C.BIO_s_mem())
|
bio := C.BIO_new(C.BIO_s_mem())
|
||||||
if bio == nil {
|
if bio == nil {
|
||||||
return nil, errors.New("failed to allocate memory BIO")
|
return nil, errors.New("failed to allocate memory BIO")
|
||||||
}
|
}
|
||||||
defer C.BIO_free(bio)
|
defer C.BIO_free(bio)
|
||||||
rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key))
|
rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key))
|
||||||
if rsa == nil {
|
if rsa == nil {
|
||||||
return nil, errors.New("failed getting rsa key")
|
return nil, errors.New("failed getting rsa key")
|
||||||
}
|
}
|
||||||
defer C.RSA_free(rsa)
|
defer C.RSA_free(rsa)
|
||||||
if int(C.i2d_RSAPrivateKey_bio(bio, rsa)) != 1 {
|
if int(C.i2d_RSAPrivateKey_bio(bio, rsa)) != 1 {
|
||||||
return nil, errors.New("failed dumping private key der")
|
return nil, errors.New("failed dumping private key der")
|
||||||
}
|
}
|
||||||
return ioutil.ReadAll(asAnyBio(bio))
|
return ioutil.ReadAll(asAnyBio(bio))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *pKey) MarshalPKIXPublicKeyPEM() (pem_block []byte,
|
func (key *pKey) MarshalPKIXPublicKeyPEM() (pem_block []byte,
|
||||||
err error) {
|
err error) {
|
||||||
bio := C.BIO_new(C.BIO_s_mem())
|
bio := C.BIO_new(C.BIO_s_mem())
|
||||||
if bio == nil {
|
if bio == nil {
|
||||||
return nil, errors.New("failed to allocate memory BIO")
|
return nil, errors.New("failed to allocate memory BIO")
|
||||||
}
|
}
|
||||||
defer C.BIO_free(bio)
|
defer C.BIO_free(bio)
|
||||||
rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key))
|
rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key))
|
||||||
if rsa == nil {
|
if rsa == nil {
|
||||||
return nil, errors.New("failed getting rsa key")
|
return nil, errors.New("failed getting rsa key")
|
||||||
}
|
}
|
||||||
defer C.RSA_free(rsa)
|
defer C.RSA_free(rsa)
|
||||||
if int(C.PEM_write_bio_RSA_PUBKEY(bio, rsa)) != 1 {
|
if int(C.PEM_write_bio_RSA_PUBKEY(bio, rsa)) != 1 {
|
||||||
return nil, errors.New("failed dumping public key pem")
|
return nil, errors.New("failed dumping public key pem")
|
||||||
}
|
}
|
||||||
return ioutil.ReadAll(asAnyBio(bio))
|
return ioutil.ReadAll(asAnyBio(bio))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *pKey) MarshalPKIXPublicKeyDER() (der_block []byte,
|
func (key *pKey) MarshalPKIXPublicKeyDER() (der_block []byte,
|
||||||
err error) {
|
err error) {
|
||||||
bio := C.BIO_new(C.BIO_s_mem())
|
bio := C.BIO_new(C.BIO_s_mem())
|
||||||
if bio == nil {
|
if bio == nil {
|
||||||
return nil, errors.New("failed to allocate memory BIO")
|
return nil, errors.New("failed to allocate memory BIO")
|
||||||
}
|
}
|
||||||
defer C.BIO_free(bio)
|
defer C.BIO_free(bio)
|
||||||
rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key))
|
rsa := (*C.RSA)(C.EVP_PKEY_get1_RSA(key.key))
|
||||||
if rsa == nil {
|
if rsa == nil {
|
||||||
return nil, errors.New("failed getting rsa key")
|
return nil, errors.New("failed getting rsa key")
|
||||||
}
|
}
|
||||||
defer C.RSA_free(rsa)
|
defer C.RSA_free(rsa)
|
||||||
if int(C.i2d_RSA_PUBKEY_bio(bio, rsa)) != 1 {
|
if int(C.i2d_RSA_PUBKEY_bio(bio, rsa)) != 1 {
|
||||||
return nil, errors.New("failed dumping public key der")
|
return nil, errors.New("failed dumping public key der")
|
||||||
}
|
}
|
||||||
return ioutil.ReadAll(asAnyBio(bio))
|
return ioutil.ReadAll(asAnyBio(bio))
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadPrivateKey loads a private key from a PEM-encoded block.
|
// LoadPrivateKey loads a private key from a PEM-encoded block.
|
||||||
func LoadPrivateKey(pem_block []byte) (PrivateKey, error) {
|
func LoadPrivateKey(pem_block []byte) (PrivateKey, error) {
|
||||||
bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]),
|
bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]),
|
||||||
C.int(len(pem_block)))
|
C.int(len(pem_block)))
|
||||||
if bio == nil {
|
if bio == nil {
|
||||||
return nil, errors.New("failed creating bio")
|
return nil, errors.New("failed creating bio")
|
||||||
}
|
}
|
||||||
defer C.BIO_free(bio)
|
defer C.BIO_free(bio)
|
||||||
|
|
||||||
rsakey := C.PEM_read_bio_RSAPrivateKey(bio, nil, nil, nil)
|
rsakey := C.PEM_read_bio_RSAPrivateKey(bio, nil, nil, nil)
|
||||||
if rsakey == nil {
|
if rsakey == nil {
|
||||||
return nil, errors.New("failed reading rsa key")
|
return nil, errors.New("failed reading rsa key")
|
||||||
}
|
}
|
||||||
defer C.RSA_free(rsakey)
|
defer C.RSA_free(rsakey)
|
||||||
|
|
||||||
// convert to PKEY
|
// convert to PKEY
|
||||||
key := C.EVP_PKEY_new()
|
key := C.EVP_PKEY_new()
|
||||||
if key == nil {
|
if key == nil {
|
||||||
return nil, errors.New("failed converting to evp_pkey")
|
return nil, errors.New("failed converting to evp_pkey")
|
||||||
}
|
}
|
||||||
if C.EVP_PKEY_set1_RSA(key, (*C.struct_rsa_st)(rsakey)) != 1 {
|
if C.EVP_PKEY_set1_RSA(key, (*C.struct_rsa_st)(rsakey)) != 1 {
|
||||||
C.EVP_PKEY_free(key)
|
C.EVP_PKEY_free(key)
|
||||||
return nil, errors.New("failed converting to evp_pkey")
|
return nil, errors.New("failed converting to evp_pkey")
|
||||||
}
|
}
|
||||||
|
|
||||||
p := &pKey{key: key}
|
p := &pKey{key: key}
|
||||||
runtime.SetFinalizer(p, func(p *pKey) {
|
runtime.SetFinalizer(p, func(p *pKey) {
|
||||||
C.EVP_PKEY_free(p.key)
|
C.EVP_PKEY_free(p.key)
|
||||||
})
|
})
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Certificate struct {
|
type Certificate struct {
|
||||||
x *C.X509
|
x *C.X509
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadCertificate loads an X509 certificate from a PEM-encoded block.
|
// LoadCertificate loads an X509 certificate from a PEM-encoded block.
|
||||||
func LoadCertificate(pem_block []byte) (*Certificate, error) {
|
func LoadCertificate(pem_block []byte) (*Certificate, error) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]),
|
bio := C.BIO_new_mem_buf(unsafe.Pointer(&pem_block[0]),
|
||||||
C.int(len(pem_block)))
|
C.int(len(pem_block)))
|
||||||
cert := C.PEM_read_bio_X509(bio, nil, nil, nil)
|
cert := C.PEM_read_bio_X509(bio, nil, nil, nil)
|
||||||
C.BIO_free(bio)
|
C.BIO_free(bio)
|
||||||
if cert == nil {
|
if cert == nil {
|
||||||
return nil, errorFromErrorQueue()
|
return nil, errorFromErrorQueue()
|
||||||
}
|
}
|
||||||
x := &Certificate{x: cert}
|
x := &Certificate{x: cert}
|
||||||
runtime.SetFinalizer(x, func(x *Certificate) {
|
runtime.SetFinalizer(x, func(x *Certificate) {
|
||||||
C.X509_free(x.x)
|
C.X509_free(x.x)
|
||||||
})
|
})
|
||||||
return x, nil
|
return x, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalPEM converts the X509 certificate to PEM-encoded format
|
// MarshalPEM converts the X509 certificate to PEM-encoded format
|
||||||
func (c *Certificate) MarshalPEM() (pem_block []byte, err error) {
|
func (c *Certificate) MarshalPEM() (pem_block []byte, err error) {
|
||||||
bio := C.BIO_new(C.BIO_s_mem())
|
bio := C.BIO_new(C.BIO_s_mem())
|
||||||
if bio == nil {
|
if bio == nil {
|
||||||
return nil, errors.New("failed to allocate memory BIO")
|
return nil, errors.New("failed to allocate memory BIO")
|
||||||
}
|
}
|
||||||
defer C.BIO_free(bio)
|
defer C.BIO_free(bio)
|
||||||
if int(C.PEM_write_bio_X509(bio, c.x)) != 1 {
|
if int(C.PEM_write_bio_X509(bio, c.x)) != 1 {
|
||||||
return nil, errors.New("failed dumping certificate")
|
return nil, errors.New("failed dumping certificate")
|
||||||
}
|
}
|
||||||
return ioutil.ReadAll(asAnyBio(bio))
|
return ioutil.ReadAll(asAnyBio(bio))
|
||||||
}
|
}
|
||||||
|
|
||||||
// PublicKey returns the public key embedded in the X509 certificate.
|
// PublicKey returns the public key embedded in the X509 certificate.
|
||||||
func (c *Certificate) PublicKey() (PublicKey, error) {
|
func (c *Certificate) PublicKey() (PublicKey, error) {
|
||||||
pkey := C.X509_get_pubkey(c.x)
|
pkey := C.X509_get_pubkey(c.x)
|
||||||
if pkey == nil {
|
if pkey == nil {
|
||||||
return nil, errors.New("no public key found")
|
return nil, errors.New("no public key found")
|
||||||
}
|
}
|
||||||
key := &pKey{key: pkey}
|
key := &pKey{key: pkey}
|
||||||
runtime.SetFinalizer(key, func(key *pKey) {
|
runtime.SetFinalizer(key, func(key *pKey) {
|
||||||
C.EVP_PKEY_free(key.key)
|
C.EVP_PKEY_free(key.key)
|
||||||
})
|
})
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
126
pem_test.go
126
pem_test.go
@ -3,74 +3,74 @@
|
|||||||
package openssl
|
package openssl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMarshal(t *testing.T) {
|
func TestMarshal(t *testing.T) {
|
||||||
key, err := LoadPrivateKey(keyBytes)
|
key, err := LoadPrivateKey(keyBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
cert, err := LoadCertificate(certBytes)
|
cert, err := LoadCertificate(certBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pem, err := cert.MarshalPEM()
|
pem, err := cert.MarshalPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !bytes.Equal(pem, certBytes) {
|
if !bytes.Equal(pem, certBytes) {
|
||||||
ioutil.WriteFile("generated", pem, 0644)
|
ioutil.WriteFile("generated", pem, 0644)
|
||||||
ioutil.WriteFile("hardcoded", certBytes, 0644)
|
ioutil.WriteFile("hardcoded", certBytes, 0644)
|
||||||
t.Fatal("invalid cert pem bytes")
|
t.Fatal("invalid cert pem bytes")
|
||||||
}
|
}
|
||||||
|
|
||||||
pem, err = key.MarshalPKCS1PrivateKeyPEM()
|
pem, err = key.MarshalPKCS1PrivateKeyPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !bytes.Equal(pem, keyBytes) {
|
if !bytes.Equal(pem, keyBytes) {
|
||||||
ioutil.WriteFile("generated", pem, 0644)
|
ioutil.WriteFile("generated", pem, 0644)
|
||||||
ioutil.WriteFile("hardcoded", keyBytes, 0644)
|
ioutil.WriteFile("hardcoded", keyBytes, 0644)
|
||||||
t.Fatal("invalid private key pem bytes")
|
t.Fatal("invalid private key pem bytes")
|
||||||
}
|
}
|
||||||
tls_cert, err := tls.X509KeyPair(certBytes, keyBytes)
|
tls_cert, err := tls.X509KeyPair(certBytes, keyBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
tls_key, ok := tls_cert.PrivateKey.(*rsa.PrivateKey)
|
tls_key, ok := tls_cert.PrivateKey.(*rsa.PrivateKey)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("FASDFASDF")
|
t.Fatal("FASDFASDF")
|
||||||
}
|
}
|
||||||
_ = tls_key
|
_ = tls_key
|
||||||
|
|
||||||
der, err := key.MarshalPKCS1PrivateKeyDER()
|
der, err := key.MarshalPKCS1PrivateKeyDER()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
tls_der := x509.MarshalPKCS1PrivateKey(tls_key)
|
tls_der := x509.MarshalPKCS1PrivateKey(tls_key)
|
||||||
if !bytes.Equal(der, tls_der) {
|
if !bytes.Equal(der, tls_der) {
|
||||||
t.Fatal("invalid private key der bytes: %s\n v.s. %s\n", hex.Dump(der), hex.Dump(tls_der))
|
t.Fatal("invalid private key der bytes: %s\n v.s. %s\n", hex.Dump(der), hex.Dump(tls_der))
|
||||||
}
|
}
|
||||||
|
|
||||||
der, err = key.MarshalPKIXPublicKeyDER()
|
der, err = key.MarshalPKIXPublicKeyDER()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
tls_der, err = x509.MarshalPKIXPublicKey(&tls_key.PublicKey)
|
tls_der, err = x509.MarshalPKIXPublicKey(&tls_key.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !bytes.Equal(der, tls_der) {
|
if !bytes.Equal(der, tls_der) {
|
||||||
ioutil.WriteFile("generated", []byte(hex.Dump(der)), 0644)
|
ioutil.WriteFile("generated", []byte(hex.Dump(der)), 0644)
|
||||||
ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644)
|
ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644)
|
||||||
t.Fatal("invalid public key der bytes")
|
t.Fatal("invalid public key der bytes")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
856
ssl_test.go
856
ssl_test.go
@ -3,21 +3,21 @@
|
|||||||
package openssl
|
package openssl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"code.spacemonkey.com/go/openssl/utils"
|
"code.spacemonkey.com/go/openssl/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
certBytes = []byte(`-----BEGIN CERTIFICATE-----
|
certBytes = []byte(`-----BEGIN CERTIFICATE-----
|
||||||
MIIDxDCCAqygAwIBAgIVAMcK/0VWQr2O3MNfJCydqR7oVELcMA0GCSqGSIb3DQEB
|
MIIDxDCCAqygAwIBAgIVAMcK/0VWQr2O3MNfJCydqR7oVELcMA0GCSqGSIb3DQEB
|
||||||
BQUAMIGQMUkwRwYDVQQDE0A1NjdjZGRmYzRjOWZiNTYwZTk1M2ZlZjA1N2M0NGFm
|
BQUAMIGQMUkwRwYDVQQDE0A1NjdjZGRmYzRjOWZiNTYwZTk1M2ZlZjA1N2M0NGFm
|
||||||
MDdiYjc4MDIzODIxYTA5NThiY2RmMGMwNzJhOTdiMThhMQswCQYDVQQGEwJVUzEN
|
MDdiYjc4MDIzODIxYTA5NThiY2RmMGMwNzJhOTdiMThhMQswCQYDVQQGEwJVUzEN
|
||||||
@ -41,7 +41,7 @@ sRkg/uxcJf7wC5Y0BLlp1+aPwdmZD87T3a1uQ1Ij93jmHG+2T9U20MklHAePOl0q
|
|||||||
yTqdSPnSH1c=
|
yTqdSPnSH1c=
|
||||||
-----END CERTIFICATE-----
|
-----END CERTIFICATE-----
|
||||||
`)
|
`)
|
||||||
keyBytes = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
keyBytes = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||||
MIIEpQIBAAKCAQEA3X94nDbxbK5a5zS4vEqHLHKpUmxavqRL5oXEqKoAy6nm56rv
|
MIIEpQIBAAKCAQEA3X94nDbxbK5a5zS4vEqHLHKpUmxavqRL5oXEqKoAy6nm56rv
|
||||||
C3e9xySe+DBlxIEV/MWU+RYpzjC99QkerfRP493aleqfhn3ZRS3tyKrQtP2z1Zwg
|
C3e9xySe+DBlxIEV/MWU+RYpzjC99QkerfRP493aleqfhn3ZRS3tyKrQtP2z1Zwg
|
||||||
wYqwcoASOLgqzKvtVYQMT1nJaw6O5fUEWG7BMR/ZX5/kcr8XjTGYjgEmrL1WTZ3G
|
wYqwcoASOLgqzKvtVYQMT1nJaw6O5fUEWG7BMR/ZX5/kcr8XjTGYjgEmrL1WTZ3G
|
||||||
@ -72,534 +72,534 @@ qmgvgyRayemfO2zR0CPgC6wSoGBth+xW6g+WA8y0z76ZSaWpFi8lVM4=
|
|||||||
)
|
)
|
||||||
|
|
||||||
func NetPipe(t testing.TB) (net.Conn, net.Conn) {
|
func NetPipe(t testing.TB) (net.Conn, net.Conn) {
|
||||||
l, err := net.Listen("tcp", "localhost:0")
|
l, err := net.Listen("tcp", "localhost:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
client_future := utils.NewFuture()
|
client_future := utils.NewFuture()
|
||||||
go func() {
|
go func() {
|
||||||
client_future.Set(net.Dial(l.Addr().Network(), l.Addr().String()))
|
client_future.Set(net.Dial(l.Addr().Network(), l.Addr().String()))
|
||||||
}()
|
}()
|
||||||
var errs utils.ErrorGroup
|
var errs utils.ErrorGroup
|
||||||
server_conn, err := l.Accept()
|
server_conn, err := l.Accept()
|
||||||
errs.Add(err)
|
errs.Add(err)
|
||||||
client_conn, err := client_future.Get()
|
client_conn, err := client_future.Get()
|
||||||
errs.Add(err)
|
errs.Add(err)
|
||||||
err = errs.Finalize()
|
err = errs.Finalize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if server_conn != nil {
|
if server_conn != nil {
|
||||||
server_conn.Close()
|
server_conn.Close()
|
||||||
}
|
}
|
||||||
if client_conn != nil {
|
if client_conn != nil {
|
||||||
client_conn.(net.Conn).Close()
|
client_conn.(net.Conn).Close()
|
||||||
}
|
}
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
return server_conn, client_conn.(net.Conn)
|
return server_conn, client_conn.(net.Conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
type HandshakingConn interface {
|
type HandshakingConn interface {
|
||||||
net.Conn
|
net.Conn
|
||||||
Handshake() error
|
Handshake() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func SimpleConnTest(t testing.TB, constructor func(
|
func SimpleConnTest(t testing.TB, constructor func(
|
||||||
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
|
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
|
||||||
server_conn, client_conn := NetPipe(t)
|
server_conn, client_conn := NetPipe(t)
|
||||||
defer server_conn.Close()
|
defer server_conn.Close()
|
||||||
defer client_conn.Close()
|
defer client_conn.Close()
|
||||||
|
|
||||||
data := "first test string\n"
|
data := "first test string\n"
|
||||||
|
|
||||||
server, client := constructor(t, server_conn, client_conn)
|
server, client := constructor(t, server_conn, client_conn)
|
||||||
defer close_both(server, client)
|
defer close_both(server, client)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
err := client.Handshake()
|
err := client.Handshake()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = io.Copy(client, bytes.NewReader([]byte(data)))
|
_, err = io.Copy(client, bytes.NewReader([]byte(data)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = client.Close()
|
err = client.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
err := server.Handshake()
|
err := server.Handshake()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := bytes.NewBuffer(make([]byte, 0, len(data)))
|
buf := bytes.NewBuffer(make([]byte, 0, len(data)))
|
||||||
_, err = io.CopyN(buf, server, int64(len(data)))
|
_, err = io.CopyN(buf, server, int64(len(data)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if string(buf.Bytes()) != data {
|
if string(buf.Bytes()) != data {
|
||||||
t.Fatal("mismatched data")
|
t.Fatal("mismatched data")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = server.Close()
|
err = server.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func close_both(closer1, closer2 io.Closer) {
|
func close_both(closer1, closer2 io.Closer) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
closer1.Close()
|
closer1.Close()
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
closer2.Close()
|
closer2.Close()
|
||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClosingTest(t testing.TB, constructor func(
|
func ClosingTest(t testing.TB, constructor func(
|
||||||
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
|
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
|
||||||
|
|
||||||
run_test := func(close_tcp bool, server_writes bool) {
|
run_test := func(close_tcp bool, server_writes bool) {
|
||||||
server_conn, client_conn := NetPipe(t)
|
server_conn, client_conn := NetPipe(t)
|
||||||
defer server_conn.Close()
|
defer server_conn.Close()
|
||||||
defer client_conn.Close()
|
defer client_conn.Close()
|
||||||
server, client := constructor(t, server_conn, client_conn)
|
server, client := constructor(t, server_conn, client_conn)
|
||||||
defer close_both(server, client)
|
defer close_both(server, client)
|
||||||
|
|
||||||
var sslconn1, sslconn2 HandshakingConn
|
var sslconn1, sslconn2 HandshakingConn
|
||||||
var conn1 net.Conn
|
var conn1 net.Conn
|
||||||
if server_writes {
|
if server_writes {
|
||||||
sslconn1 = server
|
sslconn1 = server
|
||||||
conn1 = server_conn
|
conn1 = server_conn
|
||||||
sslconn2 = client
|
sslconn2 = client
|
||||||
} else {
|
} else {
|
||||||
sslconn1 = client
|
sslconn1 = client
|
||||||
conn1 = client_conn
|
conn1 = client_conn
|
||||||
sslconn2 = server
|
sslconn2 = server
|
||||||
}
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
_, err := sslconn1.Write([]byte("hello"))
|
_, err := sslconn1.Write([]byte("hello"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if close_tcp {
|
if close_tcp {
|
||||||
err = conn1.Close()
|
err = conn1.Close()
|
||||||
} else {
|
} else {
|
||||||
err = sslconn1.Close()
|
err = sslconn1.Close()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
data, err := ioutil.ReadAll(sslconn2)
|
data, err := ioutil.ReadAll(sslconn2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !bytes.Equal(data, []byte("hello")) {
|
if !bytes.Equal(data, []byte("hello")) {
|
||||||
t.Fatal("bytes don't match")
|
t.Fatal("bytes don't match")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
run_test(true, false)
|
run_test(true, false)
|
||||||
run_test(false, false)
|
run_test(false, false)
|
||||||
run_test(true, true)
|
run_test(true, true)
|
||||||
run_test(false, true)
|
run_test(false, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ThroughputBenchmark(b *testing.B, constructor func(
|
func ThroughputBenchmark(b *testing.B, constructor func(
|
||||||
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
|
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
|
||||||
server_conn, client_conn := NetPipe(b)
|
server_conn, client_conn := NetPipe(b)
|
||||||
defer server_conn.Close()
|
defer server_conn.Close()
|
||||||
defer client_conn.Close()
|
defer client_conn.Close()
|
||||||
|
|
||||||
server, client := constructor(b, server_conn, client_conn)
|
server, client := constructor(b, server_conn, client_conn)
|
||||||
defer close_both(server, client)
|
defer close_both(server, client)
|
||||||
|
|
||||||
b.SetBytes(1024)
|
b.SetBytes(1024)
|
||||||
data := make([]byte, b.N*1024)
|
data := make([]byte, b.N*1024)
|
||||||
_, err := io.ReadFull(rand.Reader, data[:])
|
_, err := io.ReadFull(rand.Reader, data[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
_, err = io.Copy(client, bytes.NewReader([]byte(data)))
|
_, err = io.Copy(client, bytes.NewReader([]byte(data)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
_, err = io.CopyN(buf, server, int64(len(data)))
|
_, err = io.CopyN(buf, server, int64(len(data)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
if !bytes.Equal(buf.Bytes(), data) {
|
if !bytes.Equal(buf.Bytes(), data) {
|
||||||
b.Fatal("mismatched data")
|
b.Fatal("mismatched data")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
b.StopTimer()
|
b.StopTimer()
|
||||||
}
|
}
|
||||||
|
|
||||||
func StdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) (
|
func StdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) (
|
||||||
server, client HandshakingConn) {
|
server, client HandshakingConn) {
|
||||||
cert, err := tls.X509KeyPair(certBytes, keyBytes)
|
cert, err := tls.X509KeyPair(certBytes, keyBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
config := &tls.Config{
|
config := &tls.Config{
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}}
|
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}}
|
||||||
server = tls.Server(server_conn, config)
|
server = tls.Server(server_conn, config)
|
||||||
client = tls.Client(client_conn, config)
|
client = tls.Client(client_conn, config)
|
||||||
return server, client
|
return server, client
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) (
|
func OpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) (
|
||||||
server, client HandshakingConn) {
|
server, client HandshakingConn) {
|
||||||
ctx, err := NewCtx()
|
ctx, err := NewCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
key, err := LoadPrivateKey(keyBytes)
|
key, err := LoadPrivateKey(keyBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
err = ctx.UsePrivateKey(key)
|
err = ctx.UsePrivateKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
cert, err := LoadCertificate(certBytes)
|
cert, err := LoadCertificate(certBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
err = ctx.UseCertificate(cert)
|
err = ctx.UseCertificate(cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
err = ctx.SetCipherList("AES128-SHA")
|
err = ctx.SetCipherList("AES128-SHA")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
server, err = Server(server_conn, ctx)
|
server, err = Server(server_conn, ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
client, err = Client(client_conn, ctx)
|
client, err = Client(client_conn, ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
return server, client
|
return server, client
|
||||||
}
|
}
|
||||||
|
|
||||||
func StdlibOpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) (
|
func StdlibOpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) (
|
||||||
server, client HandshakingConn) {
|
server, client HandshakingConn) {
|
||||||
server_std, _ := StdlibConstructor(t, server_conn, client_conn)
|
server_std, _ := StdlibConstructor(t, server_conn, client_conn)
|
||||||
_, client_ssl := OpenSSLConstructor(t, server_conn, client_conn)
|
_, client_ssl := OpenSSLConstructor(t, server_conn, client_conn)
|
||||||
return server_std, client_ssl
|
return server_std, client_ssl
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenSSLStdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) (
|
func OpenSSLStdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) (
|
||||||
server, client HandshakingConn) {
|
server, client HandshakingConn) {
|
||||||
_, client_std := StdlibConstructor(t, server_conn, client_conn)
|
_, client_std := StdlibConstructor(t, server_conn, client_conn)
|
||||||
server_ssl, _ := OpenSSLConstructor(t, server_conn, client_conn)
|
server_ssl, _ := OpenSSLConstructor(t, server_conn, client_conn)
|
||||||
return server_ssl, client_std
|
return server_ssl, client_std
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStdlibSimple(t *testing.T) {
|
func TestStdlibSimple(t *testing.T) {
|
||||||
SimpleConnTest(t, StdlibConstructor)
|
SimpleConnTest(t, StdlibConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenSSLSimple(t *testing.T) {
|
func TestOpenSSLSimple(t *testing.T) {
|
||||||
SimpleConnTest(t, OpenSSLConstructor)
|
SimpleConnTest(t, OpenSSLConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStdlibClosing(t *testing.T) {
|
func TestStdlibClosing(t *testing.T) {
|
||||||
ClosingTest(t, StdlibConstructor)
|
ClosingTest(t, StdlibConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenSSLClosing(t *testing.T) {
|
func TestOpenSSLClosing(t *testing.T) {
|
||||||
ClosingTest(t, OpenSSLConstructor)
|
ClosingTest(t, OpenSSLConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkStdlibThroughput(b *testing.B) {
|
func BenchmarkStdlibThroughput(b *testing.B) {
|
||||||
ThroughputBenchmark(b, StdlibConstructor)
|
ThroughputBenchmark(b, StdlibConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkOpenSSLThroughput(b *testing.B) {
|
func BenchmarkOpenSSLThroughput(b *testing.B) {
|
||||||
ThroughputBenchmark(b, OpenSSLConstructor)
|
ThroughputBenchmark(b, OpenSSLConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStdlibOpenSSLSimple(t *testing.T) {
|
func TestStdlibOpenSSLSimple(t *testing.T) {
|
||||||
SimpleConnTest(t, StdlibOpenSSLConstructor)
|
SimpleConnTest(t, StdlibOpenSSLConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenSSLStdlibSimple(t *testing.T) {
|
func TestOpenSSLStdlibSimple(t *testing.T) {
|
||||||
SimpleConnTest(t, OpenSSLStdlibConstructor)
|
SimpleConnTest(t, OpenSSLStdlibConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStdlibOpenSSLClosing(t *testing.T) {
|
func TestStdlibOpenSSLClosing(t *testing.T) {
|
||||||
ClosingTest(t, StdlibOpenSSLConstructor)
|
ClosingTest(t, StdlibOpenSSLConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenSSLStdlibClosing(t *testing.T) {
|
func TestOpenSSLStdlibClosing(t *testing.T) {
|
||||||
ClosingTest(t, OpenSSLStdlibConstructor)
|
ClosingTest(t, OpenSSLStdlibConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkStdlibOpenSSLThroughput(b *testing.B) {
|
func BenchmarkStdlibOpenSSLThroughput(b *testing.B) {
|
||||||
ThroughputBenchmark(b, StdlibOpenSSLConstructor)
|
ThroughputBenchmark(b, StdlibOpenSSLConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkOpenSSLStdlibThroughput(b *testing.B) {
|
func BenchmarkOpenSSLStdlibThroughput(b *testing.B) {
|
||||||
ThroughputBenchmark(b, OpenSSLStdlibConstructor)
|
ThroughputBenchmark(b, OpenSSLStdlibConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func FullDuplexRenegotiationTest(t testing.TB, constructor func(
|
func FullDuplexRenegotiationTest(t testing.TB, constructor func(
|
||||||
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
|
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
|
||||||
|
|
||||||
server_conn, client_conn := NetPipe(t)
|
server_conn, client_conn := NetPipe(t)
|
||||||
defer server_conn.Close()
|
defer server_conn.Close()
|
||||||
defer client_conn.Close()
|
defer client_conn.Close()
|
||||||
|
|
||||||
times := 256
|
times := 256
|
||||||
data_len := 4 * SSLRecordSize
|
data_len := 4 * SSLRecordSize
|
||||||
data1 := make([]byte, data_len)
|
data1 := make([]byte, data_len)
|
||||||
_, err := io.ReadFull(rand.Reader, data1[:])
|
_, err := io.ReadFull(rand.Reader, data1[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
data2 := make([]byte, data_len)
|
data2 := make([]byte, data_len)
|
||||||
_, err = io.ReadFull(rand.Reader, data1[:])
|
_, err = io.ReadFull(rand.Reader, data1[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server, client := constructor(t, server_conn, client_conn)
|
server, client := constructor(t, server_conn, client_conn)
|
||||||
defer close_both(server, client)
|
defer close_both(server, client)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
send_func := func(sender HandshakingConn, data []byte) {
|
send_func := func(sender HandshakingConn, data []byte) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for i := 0; i < times; i++ {
|
for i := 0; i < times; i++ {
|
||||||
if i == times/2 {
|
if i == times/2 {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err := sender.Handshake()
|
err := sender.Handshake()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
_, err := sender.Write(data)
|
_, err := sender.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
recv_func := func(receiver net.Conn, data []byte) {
|
recv_func := func(receiver net.Conn, data []byte) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
buf := make([]byte, len(data))
|
buf := make([]byte, len(data))
|
||||||
for i := 0; i < times; i++ {
|
for i := 0; i < times; i++ {
|
||||||
n, err := io.ReadFull(receiver, buf[:])
|
n, err := io.ReadFull(receiver, buf[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !bytes.Equal(buf[:n], data) {
|
if !bytes.Equal(buf[:n], data) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(4)
|
wg.Add(4)
|
||||||
go recv_func(server, data1)
|
go recv_func(server, data1)
|
||||||
go send_func(client, data1)
|
go send_func(client, data1)
|
||||||
go send_func(server, data2)
|
go send_func(server, data2)
|
||||||
go recv_func(client, data2)
|
go recv_func(client, data2)
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStdlibFullDuplexRenegotiation(t *testing.T) {
|
func TestStdlibFullDuplexRenegotiation(t *testing.T) {
|
||||||
FullDuplexRenegotiationTest(t, StdlibConstructor)
|
FullDuplexRenegotiationTest(t, StdlibConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenSSLFullDuplexRenegotiation(t *testing.T) {
|
func TestOpenSSLFullDuplexRenegotiation(t *testing.T) {
|
||||||
FullDuplexRenegotiationTest(t, OpenSSLConstructor)
|
FullDuplexRenegotiationTest(t, OpenSSLConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenSSLStdlibFullDuplexRenegotiation(t *testing.T) {
|
func TestOpenSSLStdlibFullDuplexRenegotiation(t *testing.T) {
|
||||||
FullDuplexRenegotiationTest(t, OpenSSLStdlibConstructor)
|
FullDuplexRenegotiationTest(t, OpenSSLStdlibConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStdlibOpenSSLFullDuplexRenegotiation(t *testing.T) {
|
func TestStdlibOpenSSLFullDuplexRenegotiation(t *testing.T) {
|
||||||
FullDuplexRenegotiationTest(t, StdlibOpenSSLConstructor)
|
FullDuplexRenegotiationTest(t, StdlibOpenSSLConstructor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LotsOfConns(t *testing.T, payload_size int64, loops, clients int,
|
func LotsOfConns(t *testing.T, payload_size int64, loops, clients int,
|
||||||
sleep time.Duration, newListener func(net.Listener) net.Listener,
|
sleep time.Duration, newListener func(net.Listener) net.Listener,
|
||||||
newClient func(net.Conn) (net.Conn, error)) {
|
newClient func(net.Conn) (net.Conn, error)) {
|
||||||
tcp_listener, err := net.Listen("tcp", "localhost:0")
|
tcp_listener, err := net.Listen("tcp", "localhost:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
ssl_listener := newListener(tcp_listener)
|
ssl_listener := newListener(tcp_listener)
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
conn, err := ssl_listener.Accept()
|
conn, err := ssl_listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed accept: %s", err)
|
t.Fatalf("failed accept: %s", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
err = conn.Close()
|
err = conn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed closing: %s", err)
|
t.Fatalf("failed closing: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
for i := 0; i < loops; i++ {
|
for i := 0; i < loops; i++ {
|
||||||
_, err := io.Copy(ioutil.Discard,
|
_, err := io.Copy(ioutil.Discard,
|
||||||
io.LimitReader(conn, payload_size))
|
io.LimitReader(conn, payload_size))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed reading: %s", err)
|
t.Fatalf("failed reading: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = io.Copy(conn, io.LimitReader(rand.Reader,
|
_, err = io.Copy(conn, io.LimitReader(rand.Reader,
|
||||||
payload_size))
|
payload_size))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed writing: %s", err)
|
t.Fatalf("failed writing: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep(sleep)
|
time.Sleep(sleep)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < clients; i++ {
|
for i := 0; i < clients; i++ {
|
||||||
tcp_client, err := net.Dial(tcp_listener.Addr().Network(),
|
tcp_client, err := net.Dial(tcp_listener.Addr().Network(),
|
||||||
tcp_listener.Addr().String())
|
tcp_listener.Addr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
ssl_client, err := newClient(tcp_client)
|
ssl_client, err := newClient(tcp_client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
defer func() {
|
defer func() {
|
||||||
err = ssl_client.Close()
|
err = ssl_client.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed closing: %s", err)
|
t.Fatalf("failed closing: %s", err)
|
||||||
}
|
}
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
for i := 0; i < loops; i++ {
|
for i := 0; i < loops; i++ {
|
||||||
_, err := io.Copy(ssl_client, io.LimitReader(rand.Reader,
|
_, err := io.Copy(ssl_client, io.LimitReader(rand.Reader,
|
||||||
payload_size))
|
payload_size))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed writing: %s", err)
|
t.Fatalf("failed writing: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = io.Copy(ioutil.Discard,
|
_, err = io.Copy(ioutil.Discard,
|
||||||
io.LimitReader(ssl_client, payload_size))
|
io.LimitReader(ssl_client, payload_size))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed reading: %s", err)
|
t.Fatalf("failed reading: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep(sleep)
|
time.Sleep(sleep)
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStdlibLotsOfConns(t *testing.T) {
|
func TestStdlibLotsOfConns(t *testing.T) {
|
||||||
tls_cert, err := tls.X509KeyPair(certBytes, keyBytes)
|
tls_cert, err := tls.X509KeyPair(certBytes, keyBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
tls_config := &tls.Config{
|
tls_config := &tls.Config{
|
||||||
Certificates: []tls.Certificate{tls_cert},
|
Certificates: []tls.Certificate{tls_cert},
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}}
|
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}}
|
||||||
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
|
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
|
||||||
func(l net.Listener) net.Listener {
|
func(l net.Listener) net.Listener {
|
||||||
return tls.NewListener(l, tls_config)
|
return tls.NewListener(l, tls_config)
|
||||||
}, func(c net.Conn) (net.Conn, error) {
|
}, func(c net.Conn) (net.Conn, error) {
|
||||||
return tls.Client(c, tls_config), nil
|
return tls.Client(c, tls_config), nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenSSLLotsOfConns(t *testing.T) {
|
func TestOpenSSLLotsOfConns(t *testing.T) {
|
||||||
ctx, err := NewCtx()
|
ctx, err := NewCtx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
key, err := LoadPrivateKey(keyBytes)
|
key, err := LoadPrivateKey(keyBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
err = ctx.UsePrivateKey(key)
|
err = ctx.UsePrivateKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
cert, err := LoadCertificate(certBytes)
|
cert, err := LoadCertificate(certBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
err = ctx.UseCertificate(cert)
|
err = ctx.UseCertificate(cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
err = ctx.SetCipherList("AES128-SHA")
|
err = ctx.SetCipherList("AES128-SHA")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
|
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
|
||||||
func(l net.Listener) net.Listener {
|
func(l net.Listener) net.Listener {
|
||||||
return NewListener(l, ctx)
|
return NewListener(l, ctx)
|
||||||
}, func(c net.Conn) (net.Conn, error) {
|
}, func(c net.Conn) (net.Conn, error) {
|
||||||
return Client(c, ctx)
|
return Client(c, ctx)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -3,20 +3,20 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrorGroup collates errors
|
// ErrorGroup collates errors
|
||||||
type ErrorGroup struct {
|
type ErrorGroup struct {
|
||||||
Errors []error
|
Errors []error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add adds an error to an existing error group
|
// Add adds an error to an existing error group
|
||||||
func (e *ErrorGroup) Add(err error) {
|
func (e *ErrorGroup) Add(err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.Errors = append(e.Errors, err)
|
e.Errors = append(e.Errors, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finalize returns an error corresponding to the ErrorGroup state. If there's
|
// Finalize returns an error corresponding to the ErrorGroup state. If there's
|
||||||
@ -24,15 +24,15 @@ func (e *ErrorGroup) Add(err error) {
|
|||||||
// Finalize returns that error. Otherwise, Finalize will make a new error
|
// Finalize returns that error. Otherwise, Finalize will make a new error
|
||||||
// consisting of the messages from the constituent errors.
|
// consisting of the messages from the constituent errors.
|
||||||
func (e *ErrorGroup) Finalize() error {
|
func (e *ErrorGroup) Finalize() error {
|
||||||
if len(e.Errors) == 0 {
|
if len(e.Errors) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if len(e.Errors) == 1 {
|
if len(e.Errors) == 1 {
|
||||||
return e.Errors[0]
|
return e.Errors[0]
|
||||||
}
|
}
|
||||||
msgs := make([]string, 0, len(e.Errors))
|
msgs := make([]string, 0, len(e.Errors))
|
||||||
for _, err := range e.Errors {
|
for _, err := range e.Errors {
|
||||||
msgs = append(msgs, err.Error())
|
msgs = append(msgs, err.Error())
|
||||||
}
|
}
|
||||||
return errors.New(strings.Join(msgs, "\n"))
|
return errors.New(strings.Join(msgs, "\n"))
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Future is a type that is essentially the inverse of a channel. With a
|
// Future is a type that is essentially the inverse of a channel. With a
|
||||||
@ -13,55 +13,55 @@ import (
|
|||||||
// results, we also capture and return error values as well. Use NewFuture
|
// results, we also capture and return error values as well. Use NewFuture
|
||||||
// to initialize.
|
// to initialize.
|
||||||
type Future struct {
|
type Future struct {
|
||||||
mutex *sync.Mutex
|
mutex *sync.Mutex
|
||||||
cond *sync.Cond
|
cond *sync.Cond
|
||||||
received bool
|
received bool
|
||||||
val interface{}
|
val interface{}
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFuture returns an initialized and ready Future.
|
// NewFuture returns an initialized and ready Future.
|
||||||
func NewFuture() *Future {
|
func NewFuture() *Future {
|
||||||
mutex := &sync.Mutex{}
|
mutex := &sync.Mutex{}
|
||||||
return &Future{
|
return &Future{
|
||||||
mutex: mutex,
|
mutex: mutex,
|
||||||
cond: sync.NewCond(mutex),
|
cond: sync.NewCond(mutex),
|
||||||
received: false,
|
received: false,
|
||||||
val: nil,
|
val: nil,
|
||||||
err: nil,
|
err: nil,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get blocks until the Future has a value set.
|
// Get blocks until the Future has a value set.
|
||||||
func (self *Future) Get() (interface{}, error) {
|
func (self *Future) Get() (interface{}, error) {
|
||||||
self.mutex.Lock()
|
self.mutex.Lock()
|
||||||
defer self.mutex.Unlock()
|
defer self.mutex.Unlock()
|
||||||
for {
|
for {
|
||||||
if self.received {
|
if self.received {
|
||||||
return self.val, self.err
|
return self.val, self.err
|
||||||
}
|
}
|
||||||
self.cond.Wait()
|
self.cond.Wait()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fired returns whether or not a value has been set. If Fired is true, Get
|
// Fired returns whether or not a value has been set. If Fired is true, Get
|
||||||
// won't block.
|
// won't block.
|
||||||
func (self *Future) Fired() bool {
|
func (self *Future) Fired() bool {
|
||||||
self.mutex.Lock()
|
self.mutex.Lock()
|
||||||
defer self.mutex.Unlock()
|
defer self.mutex.Unlock()
|
||||||
return self.received
|
return self.received
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set provides the value to present and future Get calls. If Set has already
|
// Set provides the value to present and future Get calls. If Set has already
|
||||||
// been called, this is a no-op.
|
// been called, this is a no-op.
|
||||||
func (self *Future) Set(val interface{}, err error) {
|
func (self *Future) Set(val interface{}, err error) {
|
||||||
self.mutex.Lock()
|
self.mutex.Lock()
|
||||||
defer self.mutex.Unlock()
|
defer self.mutex.Unlock()
|
||||||
if self.received {
|
if self.received {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
self.received = true
|
self.received = true
|
||||||
self.val = val
|
self.val = val
|
||||||
self.err = err
|
self.err = err
|
||||||
self.cond.Broadcast()
|
self.cond.Broadcast()
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ThreadId returns the current runtime's thread id. Thanks to Gustavo Niemeyer
|
// ThreadId returns the current runtime's thread id. Thanks to Gustavo Niemeyer
|
||||||
|
Loading…
Reference in New Issue
Block a user