mirror of
https://github.com/libp2p/go-openssl.git
synced 2024-12-26 23:40:07 +08:00
space monkey internal commit export
[katamari commit: 66d3bf715795d3696ca37003fba5dba1af7ffacf]
This commit is contained in:
parent
a9b372afa5
commit
053d794fe5
11
bio.go
11
bio.go
@ -56,6 +56,7 @@ static void BIO_set_retry_read_not_a_macro(BIO *b) { BIO_set_retry_read(b); }
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
@ -63,7 +64,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
sslMaxRecord = 16 * 1024
|
SSLRecordSize = 16 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
func nonCopyGoBytes(ptr uintptr, length int) []byte {
|
func nonCopyGoBytes(ptr uintptr, length int) []byte {
|
||||||
@ -243,8 +244,8 @@ func (b *readBio) ReadFromOnce(r io.Reader) (n int, err error) {
|
|||||||
|
|
||||||
// make sure we have a destination that fits at least one SSL record
|
// make sure we have a destination that fits at least one SSL record
|
||||||
b.data_mtx.Lock()
|
b.data_mtx.Lock()
|
||||||
if cap(b.buf) < len(b.buf)+sslMaxRecord {
|
if cap(b.buf) < len(b.buf)+SSLRecordSize {
|
||||||
new_buf := make([]byte, len(b.buf), len(b.buf)+sslMaxRecord)
|
new_buf := make([]byte, len(b.buf), len(b.buf)+SSLRecordSize)
|
||||||
copy(new_buf, b.buf)
|
copy(new_buf, b.buf)
|
||||||
b.buf = new_buf
|
b.buf = new_buf
|
||||||
}
|
}
|
||||||
@ -257,7 +258,7 @@ func (b *readBio) ReadFromOnce(r io.Reader) (n int, err error) {
|
|||||||
defer b.data_mtx.Unlock()
|
defer b.data_mtx.Unlock()
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
if len(dst_slice) != len(b.buf) {
|
if len(dst_slice) != len(b.buf) {
|
||||||
// someone shrunk the buffer, so we read in to far ahead and we
|
// someone shrunk the buffer, so we read in too far ahead and we
|
||||||
// need to slide backwards
|
// need to slide backwards
|
||||||
copy(b.buf[len(b.buf):len(b.buf)+n], dst)
|
copy(b.buf[len(b.buf):len(b.buf)+n], dst)
|
||||||
}
|
}
|
||||||
@ -300,7 +301,7 @@ func (b *anyBio) Write(buf []byte) (written int, err error) {
|
|||||||
n := int(C.BIO_write((*C.BIO)(b), unsafe.Pointer(&buf[0]),
|
n := int(C.BIO_write((*C.BIO)(b), unsafe.Pointer(&buf[0]),
|
||||||
C.int(len(buf))))
|
C.int(len(buf))))
|
||||||
if n != len(buf) {
|
if n != len(buf) {
|
||||||
return n, SSLError.New("BIO write failed")
|
return n, errors.New("BIO write failed")
|
||||||
}
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
67
conn.go
67
conn.go
@ -9,6 +9,7 @@ package openssl
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -16,18 +17,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"code.spacemonkey.com/go/errors"
|
"code.spacemonkey.com/go/openssl/utils"
|
||||||
space_sync "code.spacemonkey.com/go/space/sync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrnoError = errors.New(SSLError, "Errno")
|
zeroReturn = errors.New("zero return")
|
||||||
|
wantRead = errors.New("want read")
|
||||||
internalConnError = errors.New(SSLError, "Unhandled internal error")
|
wantWrite = errors.New("want write")
|
||||||
zeroReturn = internalConnError.New("zero return")
|
tryAgain = errors.New("try again")
|
||||||
wantRead = internalConnError.New("want read")
|
|
||||||
wantWrite = internalConnError.New("want write")
|
|
||||||
tryAgain = internalConnError.New("try again")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
@ -37,7 +34,7 @@ type Conn struct {
|
|||||||
from_ssl *writeBio
|
from_ssl *writeBio
|
||||||
is_shutdown bool
|
is_shutdown bool
|
||||||
mtx sync.Mutex
|
mtx sync.Mutex
|
||||||
want_read_future *space_sync.Future
|
want_read_future *utils.Future
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSSL(ctx *C.SSL_CTX) (*C.SSL, error) {
|
func newSSL(ctx *C.SSL_CTX) (*C.SSL, error) {
|
||||||
@ -66,7 +63,7 @@ func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
|||||||
C.BIO_free(into_ssl_cbio)
|
C.BIO_free(into_ssl_cbio)
|
||||||
C.BIO_free(from_ssl_cbio)
|
C.BIO_free(from_ssl_cbio)
|
||||||
C.SSL_free(ssl)
|
C.SSL_free(ssl)
|
||||||
return nil, SSLError.New("failed to allocate memory BIO")
|
return nil, errors.New("failed to allocate memory BIO")
|
||||||
}
|
}
|
||||||
|
|
||||||
// the ssl object takes ownership of these objects now
|
// the ssl object takes ownership of these objects now
|
||||||
@ -85,6 +82,17 @@ func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Client wraps an existing stream connection and puts it in the connect state
|
||||||
|
// for any subsequent handshakes.
|
||||||
|
//
|
||||||
|
// IMPORTANT NOTE: if you use this method instead of Dial to construct an SSL
|
||||||
|
// connection, you are responsible for verifying the peer's hostname.
|
||||||
|
// Otherwise, you are vulnerable to MITM attacks.
|
||||||
|
//
|
||||||
|
// Client connections probably won't work for you unless you set a verify
|
||||||
|
// location or add some certs to the certificate store of the client context
|
||||||
|
// you're using. This library is not nice enough to use the system certificate
|
||||||
|
// store by default for you yet.
|
||||||
func Client(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
func Client(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
||||||
c, err := newConn(conn, ctx)
|
c, err := newConn(conn, ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -94,6 +102,8 @@ func Client(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Server wraps an existing stream connection and puts it in the accept state
|
||||||
|
// for any subsequent handshakes.
|
||||||
func Server(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
func Server(conn net.Conn, ctx *Ctx) (*Conn, error) {
|
||||||
c, err := newConn(conn, ctx)
|
c, err := newConn(conn, ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -138,7 +148,7 @@ func (c *Conn) getErrorHandler(rv C.int, errno error) func() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.want_read_future = space_sync.NewFuture()
|
c.want_read_future = utils.NewFuture()
|
||||||
want_read_future := c.want_read_future
|
want_read_future := c.want_read_future
|
||||||
return func() (err error) {
|
return func() (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -170,9 +180,9 @@ func (c *Conn) getErrorHandler(rv C.int, errno error) func() error {
|
|||||||
if C.ERR_peek_error() == 0 {
|
if C.ERR_peek_error() == 0 {
|
||||||
switch rv {
|
switch rv {
|
||||||
case 0:
|
case 0:
|
||||||
err = SSLError.New("Unexpected EOF")
|
err = errors.New("protocol-violating EOF")
|
||||||
case -1:
|
case -1:
|
||||||
err = ErrnoError.Wrap(errno)
|
err = errno
|
||||||
default:
|
default:
|
||||||
err = errorFromErrorQueue()
|
err = errorFromErrorQueue()
|
||||||
}
|
}
|
||||||
@ -208,6 +218,8 @@ func (c *Conn) handshake() func() error {
|
|||||||
return c.getErrorHandler(rv, errno)
|
return c.getErrorHandler(rv, errno)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handshake performs an SSL handshake. If a handshake is not manually
|
||||||
|
// triggered, it will run before the first I/O on the encrypted stream.
|
||||||
func (c *Conn) Handshake() error {
|
func (c *Conn) Handshake() error {
|
||||||
err := tryAgain
|
err := tryAgain
|
||||||
for err == tryAgain {
|
for err == tryAgain {
|
||||||
@ -219,15 +231,17 @@ func (c *Conn) Handshake() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeerCertificate returns the Certificate of the peer with which you're
|
||||||
|
// communicating. Only valid after a handshake.
|
||||||
func (c *Conn) PeerCertificate() (*Certificate, error) {
|
func (c *Conn) PeerCertificate() (*Certificate, error) {
|
||||||
c.mtx.Lock()
|
c.mtx.Lock()
|
||||||
if c.is_shutdown {
|
if c.is_shutdown {
|
||||||
return nil, SSLError.New("connection closed")
|
return nil, errors.New("connection closed")
|
||||||
}
|
}
|
||||||
x := C.SSL_get_peer_certificate(c.ssl)
|
x := C.SSL_get_peer_certificate(c.ssl)
|
||||||
c.mtx.Unlock()
|
c.mtx.Unlock()
|
||||||
if x == nil {
|
if x == nil {
|
||||||
return nil, SSLError.New("no peer certificate found")
|
return nil, errors.New("no peer certificate found")
|
||||||
}
|
}
|
||||||
cert := &Certificate{x: x}
|
cert := &Certificate{x: x}
|
||||||
runtime.SetFinalizer(cert, func(cert *Certificate) {
|
runtime.SetFinalizer(cert, func(cert *Certificate) {
|
||||||
@ -271,7 +285,7 @@ func (c *Conn) shutdownLoop() error {
|
|||||||
return c.flushOutputBuffer()
|
return c.flushOutputBuffer()
|
||||||
}
|
}
|
||||||
if err == tryAgain && shutdown_tries >= 2 {
|
if err == tryAgain && shutdown_tries >= 2 {
|
||||||
return SSLError.New("shutdown requested a third time?")
|
return errors.New("shutdown requested a third time?")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err == io.ErrUnexpectedEOF {
|
if err == io.ErrUnexpectedEOF {
|
||||||
@ -280,6 +294,8 @@ func (c *Conn) shutdownLoop() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close shuts down the SSL connection and closes the underlying wrapped
|
||||||
|
// connection.
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
c.mtx.Lock()
|
c.mtx.Lock()
|
||||||
if c.is_shutdown {
|
if c.is_shutdown {
|
||||||
@ -288,7 +304,7 @@ func (c *Conn) Close() error {
|
|||||||
}
|
}
|
||||||
c.is_shutdown = true
|
c.is_shutdown = true
|
||||||
c.mtx.Unlock()
|
c.mtx.Unlock()
|
||||||
errs := errors.NewErrorGroup()
|
var errs utils.ErrorGroup
|
||||||
errs.Add(c.shutdownLoop())
|
errs.Add(c.shutdownLoop())
|
||||||
errs.Add(c.conn.Close())
|
errs.Add(c.conn.Close())
|
||||||
return errs.Finalize()
|
return errs.Finalize()
|
||||||
@ -309,6 +325,9 @@ func (c *Conn) read(b []byte) (int, func() error) {
|
|||||||
return 0, c.getErrorHandler(rv, errno)
|
return 0, c.getErrorHandler(rv, errno)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read reads up to len(b) bytes into b. It returns the number of bytes read
|
||||||
|
// and an error if applicable. io.EOF is returned when the caller can expect
|
||||||
|
// to see no more data.
|
||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
@ -333,7 +352,7 @@ func (c *Conn) write(b []byte) (int, func() error) {
|
|||||||
c.mtx.Lock()
|
c.mtx.Lock()
|
||||||
defer c.mtx.Unlock()
|
defer c.mtx.Unlock()
|
||||||
if c.is_shutdown {
|
if c.is_shutdown {
|
||||||
err := SSLError.New("connection closed")
|
err := errors.New("connection closed")
|
||||||
return 0, func() error { return err }
|
return 0, func() error { return err }
|
||||||
}
|
}
|
||||||
rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
|
rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
|
||||||
@ -343,6 +362,9 @@ func (c *Conn) write(b []byte) (int, func() error) {
|
|||||||
return 0, c.getErrorHandler(rv, errno)
|
return 0, c.getErrorHandler(rv, errno)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write will encrypt the contents of b and write it to the underlying stream.
|
||||||
|
// Performance will be vastly improved if the size of b is a multiple of
|
||||||
|
// SSLRecordSize.
|
||||||
func (c *Conn) Write(b []byte) (written int, err error) {
|
func (c *Conn) Write(b []byte) (written int, err error) {
|
||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
@ -358,6 +380,8 @@ func (c *Conn) Write(b []byte) (written int, err error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// VerifyHostname pulls the PeerCertificate and calls VerifyHostname on the
|
||||||
|
// certificate.
|
||||||
func (c *Conn) VerifyHostname(host string) error {
|
func (c *Conn) VerifyHostname(host string) error {
|
||||||
cert, err := c.PeerCertificate()
|
cert, err := c.PeerCertificate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -366,22 +390,27 @@ func (c *Conn) VerifyHostname(host string) error {
|
|||||||
return cert.VerifyHostname(host)
|
return cert.VerifyHostname(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LocalAddr returns the underlying connection's local address
|
||||||
func (c *Conn) LocalAddr() net.Addr {
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
return c.conn.LocalAddr()
|
return c.conn.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoteAddr returns the underlying connection's remote address
|
||||||
func (c *Conn) RemoteAddr() net.Addr {
|
func (c *Conn) RemoteAddr() net.Addr {
|
||||||
return c.conn.RemoteAddr()
|
return c.conn.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetDeadline calls SetDeadline on the underlying connection.
|
||||||
func (c *Conn) SetDeadline(t time.Time) error {
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
return c.conn.SetDeadline(t)
|
return c.conn.SetDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline calls SetReadDeadline on the underlying connection.
|
||||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
return c.conn.SetReadDeadline(t)
|
return c.conn.SetReadDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline calls SetWriteDeadline on the underlying connection.
|
||||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
return c.conn.SetWriteDeadline(t)
|
return c.conn.SetWriteDeadline(t)
|
||||||
}
|
}
|
||||||
|
82
ctx.go
82
ctx.go
@ -21,6 +21,8 @@ package openssl
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"io/ioutil"
|
||||||
"runtime"
|
"runtime"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
@ -47,12 +49,14 @@ type SSLVersion int
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
SSLv3 SSLVersion = 0x02
|
SSLv3 SSLVersion = 0x02
|
||||||
TLSv1 SSLVersion = 0x03
|
TLSv1 = 0x03
|
||||||
TLSv1_1 SSLVersion = 0x04
|
TLSv1_1 = 0x04
|
||||||
TLSv1_2 SSLVersion = 0x05
|
TLSv1_2 = 0x05
|
||||||
AnyVersion SSLVersion = 0x06
|
AnyVersion = 0x06
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NewCtxWithVersion creates an SSL context that is specific to the provided
|
||||||
|
// SSL version. See http://www.openssl.org/docs/ssl/SSL_CTX_new.html for more.
|
||||||
func NewCtxWithVersion(version SSLVersion) (*Ctx, error) {
|
func NewCtxWithVersion(version SSLVersion) (*Ctx, error) {
|
||||||
switch version {
|
switch version {
|
||||||
case SSLv3:
|
case SSLv3:
|
||||||
@ -66,10 +70,11 @@ func NewCtxWithVersion(version SSLVersion) (*Ctx, error) {
|
|||||||
case AnyVersion:
|
case AnyVersion:
|
||||||
return newCtx(C.SSLv23_method())
|
return newCtx(C.SSLv23_method())
|
||||||
default:
|
default:
|
||||||
return nil, SSLError.New("unknown version")
|
return nil, errors.New("unknown ssl/tls version")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewCtx creates a context that supports any TLS version 1.0 and newer.
|
||||||
func NewCtx() (*Ctx, error) {
|
func NewCtx() (*Ctx, error) {
|
||||||
c, err := NewCtxWithVersion(AnyVersion)
|
c, err := NewCtxWithVersion(AnyVersion)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -78,6 +83,49 @@ func NewCtx() (*Ctx, error) {
|
|||||||
return c, err
|
return c, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewCtxFromFiles calls NewCtx, loads the provided files, and configures the
|
||||||
|
// context to use them.
|
||||||
|
func NewCtxFromFiles(cert_file string, key_file string) (*Ctx, error) {
|
||||||
|
ctx, err := NewCtx()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cert_bytes, err := ioutil.ReadFile(cert_file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := LoadCertificate(cert_bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ctx.UseCertificate(cert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
key_bytes, err := ioutil.ReadFile(key_file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := LoadPrivateKey(key_bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ctx.UsePrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseCertificate configures the context to present the given certificate to
|
||||||
|
// peers.
|
||||||
func (c *Ctx) UseCertificate(cert *Certificate) error {
|
func (c *Ctx) UseCertificate(cert *Certificate) error {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
@ -87,6 +135,8 @@ func (c *Ctx) UseCertificate(cert *Certificate) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UsePrivateKey configures the context to use the given private key for SSL
|
||||||
|
// handshakes.
|
||||||
func (c *Ctx) UsePrivateKey(key PrivateKey) error {
|
func (c *Ctx) UsePrivateKey(key PrivateKey) error {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
@ -101,6 +151,8 @@ type CertificateStore struct {
|
|||||||
ctx *Ctx // for gc
|
ctx *Ctx // for gc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCertificateStore returns the context's certificate store that will be
|
||||||
|
// used for peer validation.
|
||||||
func (c *Ctx) GetCertificateStore() *CertificateStore {
|
func (c *Ctx) GetCertificateStore() *CertificateStore {
|
||||||
// we don't need to dealloc the cert store pointer here, because it points
|
// we don't need to dealloc the cert store pointer here, because it points
|
||||||
// to a ctx internal. so we do need to keep the ctx around
|
// to a ctx internal. so we do need to keep the ctx around
|
||||||
@ -109,6 +161,8 @@ func (c *Ctx) GetCertificateStore() *CertificateStore {
|
|||||||
ctx: c}
|
ctx: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddCertificate marks the provided Certificate as a trusted certificate in
|
||||||
|
// the given CertificateStore.
|
||||||
func (s *CertificateStore) AddCertificate(cert *Certificate) error {
|
func (s *CertificateStore) AddCertificate(cert *Certificate) error {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
@ -118,6 +172,10 @@ func (s *CertificateStore) AddCertificate(cert *Certificate) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadVerifyLocations tells the context to trust all certificate authorities
|
||||||
|
// provided in either the ca_file or the ca_path.
|
||||||
|
// See http://www.openssl.org/docs/ssl/SSL_CTX_load_verify_locations.html for
|
||||||
|
// more.
|
||||||
func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error {
|
func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
@ -148,6 +206,8 @@ const (
|
|||||||
NoTicket = C.SSL_OP_NO_TICKET
|
NoTicket = C.SSL_OP_NO_TICKET
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SetOptions sets context options. See
|
||||||
|
// http://www.openssl.org/docs/ssl/SSL_CTX_set_options.html
|
||||||
func (c *Ctx) SetOptions(options Options) Options {
|
func (c *Ctx) SetOptions(options Options) Options {
|
||||||
return Options(C.SSL_CTX_set_options_not_a_macro(
|
return Options(C.SSL_CTX_set_options_not_a_macro(
|
||||||
c.ctx, C.long(options)))
|
c.ctx, C.long(options)))
|
||||||
@ -159,6 +219,8 @@ const (
|
|||||||
ReleaseBuffers Modes = C.SSL_MODE_RELEASE_BUFFERS
|
ReleaseBuffers Modes = C.SSL_MODE_RELEASE_BUFFERS
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SetMode sets context modes. See
|
||||||
|
// http://www.openssl.org/docs/ssl/SSL_CTX_set_mode.html
|
||||||
func (c *Ctx) SetMode(modes Modes) Modes {
|
func (c *Ctx) SetMode(modes Modes) Modes {
|
||||||
return Modes(C.SSL_CTX_set_mode_not_a_macro(c.ctx, C.long(modes)))
|
return Modes(C.SSL_CTX_set_mode_not_a_macro(c.ctx, C.long(modes)))
|
||||||
}
|
}
|
||||||
@ -172,11 +234,16 @@ const (
|
|||||||
VerifyClientOnce = C.SSL_VERIFY_CLIENT_ONCE
|
VerifyClientOnce = C.SSL_VERIFY_CLIENT_ONCE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SetVerify controls peer verification settings. See
|
||||||
|
// http://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html
|
||||||
func (c *Ctx) SetVerify(options VerifyOptions) {
|
func (c *Ctx) SetVerify(options VerifyOptions) {
|
||||||
// TODO: take a callback
|
// TODO: take a callback
|
||||||
C.SSL_CTX_set_verify(c.ctx, C.int(options), nil)
|
C.SSL_CTX_set_verify(c.ctx, C.int(options), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetVerifyDepth controls how many certificates deep the certificate
|
||||||
|
// verification logic is willing to follow a certificate chain. See
|
||||||
|
// https://www.openssl.org/docs/ssl/SSL_CTX_set_verify.html
|
||||||
func (c *Ctx) SetVerifyDepth(depth int) {
|
func (c *Ctx) SetVerifyDepth(depth int) {
|
||||||
C.SSL_CTX_set_verify_depth(c.ctx, C.int(depth))
|
C.SSL_CTX_set_verify_depth(c.ctx, C.int(depth))
|
||||||
}
|
}
|
||||||
@ -192,6 +259,9 @@ func (c *Ctx) SetSessionId(session_id []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCipherList sets the list of available ciphers. The format of the list is
|
||||||
|
// described at http://www.openssl.org/docs/apps/ciphers.html, but see
|
||||||
|
// http://www.openssl.org/docs/ssl/SSL_CTX_set_cipher_list.html for more.
|
||||||
func (c *Ctx) SetCipherList(list string) error {
|
func (c *Ctx) SetCipherList(list string) error {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
@ -216,6 +286,8 @@ const (
|
|||||||
NoInternal = C.SSL_SESS_CACHE_NO_INTERNAL
|
NoInternal = C.SSL_SESS_CACHE_NO_INTERNAL
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SetSessionCacheMode enables or disables session caching. See
|
||||||
|
// http://www.openssl.org/docs/ssl/SSL_CTX_set_session_cache_mode.html
|
||||||
func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes {
|
func (c *Ctx) SetSessionCacheMode(modes SessionCacheModes) SessionCacheModes {
|
||||||
return SessionCacheModes(
|
return SessionCacheModes(
|
||||||
C.SSL_CTX_set_session_cache_mode_not_a_macro(c.ctx, C.long(modes)))
|
C.SSL_CTX_set_session_cache_mode_not_a_macro(c.ctx, C.long(modes)))
|
||||||
|
40
hostname.go
40
hostname.go
@ -23,14 +23,13 @@ extern int X509_check_ip(X509 *x, const unsigned char *chk, size_t chklen,
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"code.spacemonkey.com/go/errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ValidationError = errors.New(SSLError, "Host validation error")
|
ValidationError = errors.New("Host validation error")
|
||||||
)
|
)
|
||||||
|
|
||||||
type CheckFlags int
|
type CheckFlags int
|
||||||
@ -40,6 +39,11 @@ const (
|
|||||||
NoWildcards CheckFlags = C.X509_CHECK_FLAG_NO_WILDCARDS
|
NoWildcards CheckFlags = C.X509_CHECK_FLAG_NO_WILDCARDS
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// CheckHost checks that the X509 certificate is signed for the provided
|
||||||
|
// host name. See http://www.openssl.org/docs/crypto/X509_check_host.html for
|
||||||
|
// more. Note that CheckHost does not check the IP field. See VerifyHostname.
|
||||||
|
// Specifically returns ValidationError if the Certificate didn't match but
|
||||||
|
// there was no internal error.
|
||||||
func (c *Certificate) CheckHost(host string, flags CheckFlags) error {
|
func (c *Certificate) CheckHost(host string, flags CheckFlags) error {
|
||||||
chost := unsafe.Pointer(C.CString(host))
|
chost := unsafe.Pointer(C.CString(host))
|
||||||
defer C.free(chost)
|
defer C.free(chost)
|
||||||
@ -49,12 +53,16 @@ func (c *Certificate) CheckHost(host string, flags CheckFlags) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if rv == 0 {
|
if rv == 0 {
|
||||||
return ValidationError.New(
|
return ValidationError
|
||||||
"cert failed validation for host %s", host)
|
|
||||||
}
|
}
|
||||||
return SSLError.New("hostname validation failed")
|
return errors.New("hostname validation had an internal failure")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckEmail checks that the X509 certificate is signed for the provided
|
||||||
|
// email address. See http://www.openssl.org/docs/crypto/X509_check_host.html
|
||||||
|
// for more.
|
||||||
|
// Specifically returns ValidationError if the Certificate didn't match but
|
||||||
|
// there was no internal error.
|
||||||
func (c *Certificate) CheckEmail(email string, flags CheckFlags) error {
|
func (c *Certificate) CheckEmail(email string, flags CheckFlags) error {
|
||||||
cemail := unsafe.Pointer(C.CString(email))
|
cemail := unsafe.Pointer(C.CString(email))
|
||||||
defer C.free(cemail)
|
defer C.free(cemail)
|
||||||
@ -64,12 +72,16 @@ func (c *Certificate) CheckEmail(email string, flags CheckFlags) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if rv == 0 {
|
if rv == 0 {
|
||||||
return ValidationError.New(
|
return ValidationError
|
||||||
"cert failed validation for email %s", email)
|
|
||||||
}
|
}
|
||||||
return SSLError.New("email validation failed")
|
return errors.New("email validation had an internal failure")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckIP checks that the X509 certificate is signed for the provided
|
||||||
|
// IP address. See http://www.openssl.org/docs/crypto/X509_check_host.html
|
||||||
|
// for more.
|
||||||
|
// Specifically returns ValidationError if the Certificate didn't match but
|
||||||
|
// there was no internal error.
|
||||||
func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error {
|
func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error {
|
||||||
cip := unsafe.Pointer(&ip[0])
|
cip := unsafe.Pointer(&ip[0])
|
||||||
rv := C.X509_check_ip(c.x, (*C.uchar)(cip), C.size_t(len(ip)),
|
rv := C.X509_check_ip(c.x, (*C.uchar)(cip), C.size_t(len(ip)),
|
||||||
@ -78,12 +90,16 @@ func (c *Certificate) CheckIP(ip net.IP, flags CheckFlags) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if rv == 0 {
|
if rv == 0 {
|
||||||
return ValidationError.New(
|
return ValidationError
|
||||||
"cert failed validation for ip %s", ip.String())
|
|
||||||
}
|
}
|
||||||
return SSLError.New("ip validation failed")
|
return errors.New("ip validation had an internal failure")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// VerifyHostname is a combination of CheckHost and CheckIP. If the provided
|
||||||
|
// hostname looks like an IP address, it will be checked as an IP address,
|
||||||
|
// otherwise it will be checked as a hostname.
|
||||||
|
// Specifically returns ValidationError if the Certificate didn't match but
|
||||||
|
// there was no internal error.
|
||||||
func (c *Certificate) VerifyHostname(host string) error {
|
func (c *Certificate) VerifyHostname(host string) error {
|
||||||
var ip net.IP
|
var ip net.IP
|
||||||
if len(host) >= 3 && host[0] == '[' && host[len(host)-1] == ']' {
|
if len(host) >= 3 && host[0] == '[' && host[len(host)-1] == ']' {
|
||||||
|
37
http.go
37
http.go
@ -3,16 +3,19 @@
|
|||||||
package openssl
|
package openssl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ListenAndServeTLS will take an http.Handler and serve it using OpenSSL over
|
||||||
|
// the given tcp address, configured to use the provided cert and key files.
|
||||||
func ListenAndServeTLS(addr string, cert_file string, key_file string,
|
func ListenAndServeTLS(addr string, cert_file string, key_file string,
|
||||||
handler http.Handler) error {
|
handler http.Handler) error {
|
||||||
return ServerListenAndServeTLS(
|
return ServerListenAndServeTLS(
|
||||||
&http.Server{Addr: addr, Handler: handler}, cert_file, key_file)
|
&http.Server{Addr: addr, Handler: handler}, cert_file, key_file)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ServerListenAndServeTLS will take an http.Server and serve it using OpenSSL
|
||||||
|
// configured to use the provided cert and key files.
|
||||||
func ServerListenAndServeTLS(srv *http.Server,
|
func ServerListenAndServeTLS(srv *http.Server,
|
||||||
cert_file, key_file string) error {
|
cert_file, key_file string) error {
|
||||||
addr := srv.Addr
|
addr := srv.Addr
|
||||||
@ -20,37 +23,7 @@ func ServerListenAndServeTLS(srv *http.Server,
|
|||||||
addr = ":https"
|
addr = ":https"
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, err := NewCtx()
|
ctx, err := NewCtxFromFiles(cert_file, key_file)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
key_bytes, err := ioutil.ReadFile(key_file)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := LoadPrivateKey(key_bytes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = ctx.UsePrivateKey(key)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cert_bytes, err := ioutil.ReadFile(cert_file)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := LoadCertificate(cert_bytes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = ctx.UseCertificate(cert)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
60
init.go
60
init.go
@ -1,5 +1,52 @@
|
|||||||
// Copyright (C) 2014 Space Monkey, Inc.
|
// Copyright (C) 2014 Space Monkey, Inc.
|
||||||
|
|
||||||
|
// Package openssl is a light wrapper around OpenSSL for Go.
|
||||||
|
// It strives to provide a near-drop-in replacement for the Go standard library
|
||||||
|
// tls package, while allowing for:
|
||||||
|
// * Performance - OpenSSL is battle-tested and optimized C. While Go's built-
|
||||||
|
// in library shows great promise, it is still young and in some places,
|
||||||
|
// inefficient. This simple OpenSSL wrapper can often do at least 2x with
|
||||||
|
// the same cipher and protocol.
|
||||||
|
//
|
||||||
|
// On my lappytop, I get the following speeds for AES128-SHA
|
||||||
|
// BenchmarkStdlibThroughput 50000 58685 ns/op 17.45 MB/s
|
||||||
|
// BenchmarkOpenSSLThroughput 100000 20772 ns/op 49.30 MB/s
|
||||||
|
//
|
||||||
|
// * Interoperability - many systems support OpenSSL with a variety of plugins
|
||||||
|
// and modules for things, such as hardware acceleration in embedded devices
|
||||||
|
//
|
||||||
|
// * Greater flexibility and configuration - OpenSSL allows for far greater
|
||||||
|
// configuration of corner cases and backwards compatibility (such as
|
||||||
|
// support of SSLv2)
|
||||||
|
//
|
||||||
|
// * Security - OpenSSL has been reviewed by security experts thoroughly.
|
||||||
|
// According to its author, the same can not be said of the standard
|
||||||
|
// library. Though this wrapper has not received equal scrutiny, it is very
|
||||||
|
// small and easy to check.
|
||||||
|
//
|
||||||
|
// Starting an HTTP server that uses OpenSSL is very easy. It's as simple as:
|
||||||
|
// log.Fatal(openssl.ListenAndServeTLS(
|
||||||
|
// ":8443", "my_server.crt", "my_server.key", myHandler))
|
||||||
|
//
|
||||||
|
// Getting a net.Listener that uses OpenSSL is also easy:
|
||||||
|
// ctx, err := openssl.NewCtxFromFiles("my_server.crt", "my_server.key")
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
// l, err := openssl.Listen("tcp", ":7777", ctx)
|
||||||
|
//
|
||||||
|
// Making a client connection is straightforward too:
|
||||||
|
// ctx, err := NewCtx()
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
// err = ctx.LoadVerifyLocations("/etc/ssl/certs/ca-certificates.crt", "")
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
// conn, err := openssl.Dial("tcp", "localhost:7777", ctx, 0)
|
||||||
|
//
|
||||||
|
// TODO/Help wanted: make an easy interface to the net/http client library
|
||||||
package openssl
|
package openssl
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -20,16 +67,15 @@ static void OpenSSL_add_all_algorithms_not_a_macro() {
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"code.spacemonkey.com/go/errors"
|
"code.spacemonkey.com/go/openssl/utils"
|
||||||
"code.spacemonkey.com/go/openssl/thread_id"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
SSLError = errors.New(nil, "SSL Error")
|
|
||||||
sslMutexes []sync.Mutex
|
sslMutexes []sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -49,18 +95,18 @@ func init() {
|
|||||||
// errorFromErrorQueue needs to run in the same OS thread as the operation
|
// errorFromErrorQueue needs to run in the same OS thread as the operation
|
||||||
// that caused the possible error
|
// that caused the possible error
|
||||||
func errorFromErrorQueue() error {
|
func errorFromErrorQueue() error {
|
||||||
var errors []string
|
var errs []string
|
||||||
for {
|
for {
|
||||||
err := C.ERR_get_error()
|
err := C.ERR_get_error()
|
||||||
if err == 0 {
|
if err == 0 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
errors = append(errors, fmt.Sprintf("%s:%s:%s",
|
errs = append(errs, fmt.Sprintf("%s:%s:%s",
|
||||||
C.GoString(C.ERR_lib_error_string(err)),
|
C.GoString(C.ERR_lib_error_string(err)),
|
||||||
C.GoString(C.ERR_func_error_string(err)),
|
C.GoString(C.ERR_func_error_string(err)),
|
||||||
C.GoString(C.ERR_reason_error_string(err))))
|
C.GoString(C.ERR_reason_error_string(err))))
|
||||||
}
|
}
|
||||||
return SSLError.New("errors: %s", strings.Join(errors, "\n"))
|
return errors.New(fmt.Sprintf("SSL errors: %s", strings.Join(errs, "\n")))
|
||||||
}
|
}
|
||||||
|
|
||||||
//export sslMutexOp
|
//export sslMutexOp
|
||||||
@ -74,5 +120,5 @@ func sslMutexOp(mode, n C.int, file *C.char, line C.int) {
|
|||||||
|
|
||||||
//export sslThreadId
|
//export sslThreadId
|
||||||
func sslThreadId(id *C.CRYPTO_THREADID) {
|
func sslThreadId(id *C.CRYPTO_THREADID) {
|
||||||
C.CRYPTO_THREADID_set_pointer(id, thread_id.Id())
|
C.CRYPTO_THREADID_set_pointer(id, utils.ThreadId())
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
package openssl
|
package openssl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -19,15 +20,20 @@ func (l *listener) Accept() (c net.Conn, err error) {
|
|||||||
return Server(c, l.ctx)
|
return Server(c, l.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewListener wraps an existing net.Listener such that all accepted
|
||||||
|
// connections are wrapped as OpenSSL server connections using the provided
|
||||||
|
// context ctx.
|
||||||
func NewListener(inner net.Listener, ctx *Ctx) net.Listener {
|
func NewListener(inner net.Listener, ctx *Ctx) net.Listener {
|
||||||
return &listener{
|
return &listener{
|
||||||
Listener: inner,
|
Listener: inner,
|
||||||
ctx: ctx}
|
ctx: ctx}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Listen is a wrapper around net.Listen that wraps incoming connections with
|
||||||
|
// an OpenSSL server connection using the provided context ctx.
|
||||||
func Listen(network, laddr string, ctx *Ctx) (net.Listener, error) {
|
func Listen(network, laddr string, ctx *Ctx) (net.Listener, error) {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return nil, SSLError.New("no ssl context provided")
|
return nil, errors.New("no ssl context provided")
|
||||||
}
|
}
|
||||||
l, err := net.Listen(network, laddr)
|
l, err := net.Listen(network, laddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -39,9 +45,19 @@ func Listen(network, laddr string, ctx *Ctx) (net.Listener, error) {
|
|||||||
type DialFlags int
|
type DialFlags int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
InsecureSkipHostVerification DialFlags = 0
|
InsecureSkipHostVerification DialFlags = 0x01
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Dial will connect to network/address and then wrap the corresponding
|
||||||
|
// underlying connection with an OpenSSL client connection using context ctx.
|
||||||
|
// If flags includes InsecureSkipHostVerification, the server certificate's
|
||||||
|
// hostname will not be checked to match the hostname in addr. Otherwise, flags
|
||||||
|
// should be 0.
|
||||||
|
//
|
||||||
|
// Dial probably won't work for you unless you set a verify location or add
|
||||||
|
// some certs to the certificate store of the client context you're using.
|
||||||
|
// This library is not nice enough to use the system certificate store by
|
||||||
|
// default for you yet.
|
||||||
func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
|
func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
var err error
|
var err error
|
||||||
@ -66,7 +82,12 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if flags&InsecureSkipHostVerification == 0 {
|
if flags&InsecureSkipHostVerification == 0 {
|
||||||
err = conn.VerifyHostname(addr)
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = conn.VerifyHostname(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
101
pem.go
101
pem.go
@ -8,19 +8,21 @@ package openssl
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"runtime"
|
"runtime"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PublicKey interface {
|
type PublicKey interface {
|
||||||
|
// MarshalPKIXPublicKeyPEM converts the public key to PEM-encoded PKIX
|
||||||
|
// format
|
||||||
MarshalPKIXPublicKeyPEM() (pem_block []byte, err error)
|
MarshalPKIXPublicKeyPEM() (pem_block []byte, err error)
|
||||||
|
|
||||||
|
// MarshalPKIXPublicKeyDER converts the public key to DER-encoded PKIX
|
||||||
|
// format
|
||||||
MarshalPKIXPublicKeyDER() (der_block []byte, err error)
|
MarshalPKIXPublicKeyDER() (der_block []byte, err error)
|
||||||
StdlibPublicKey() (*rsa.PublicKey, error)
|
|
||||||
|
|
||||||
evpPKey() *C.EVP_PKEY
|
evpPKey() *C.EVP_PKEY
|
||||||
}
|
}
|
||||||
@ -28,9 +30,13 @@ type PublicKey interface {
|
|||||||
type PrivateKey interface {
|
type PrivateKey interface {
|
||||||
PublicKey
|
PublicKey
|
||||||
|
|
||||||
|
// MarshalPKCS1PrivateKeyPEM converts the private key to PEM-encoded PKCS1
|
||||||
|
// format
|
||||||
MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error)
|
MarshalPKCS1PrivateKeyPEM() (pem_block []byte, err error)
|
||||||
|
|
||||||
|
// MarshalPKCS1PrivateKeyDER converts the private key to DER-encoded PKCS1
|
||||||
|
// format
|
||||||
MarshalPKCS1PrivateKeyDER() (der_block []byte, err error)
|
MarshalPKCS1PrivateKeyDER() (der_block []byte, err error)
|
||||||
StdlibPrivateKey() (*rsa.PrivateKey, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type pKey struct {
|
type pKey struct {
|
||||||
@ -43,25 +49,12 @@ func (key *pKey) MarshalPKCS1PrivateKeyPEM() (pem_block []byte,
|
|||||||
err error) {
|
err error) {
|
||||||
bio := C.BIO_new(C.BIO_s_mem())
|
bio := C.BIO_new(C.BIO_s_mem())
|
||||||
if bio == nil {
|
if bio == nil {
|
||||||
return nil, SSLError.New("failed to allocate memory BIO")
|
return nil, errors.New("failed to allocate memory BIO")
|
||||||
}
|
}
|
||||||
defer C.BIO_free(bio)
|
defer C.BIO_free(bio)
|
||||||
if int(C.PEM_write_bio_PrivateKey(bio, key.key, nil, nil, C.int(0), nil,
|
if int(C.PEM_write_bio_PrivateKey(bio, key.key, nil, nil, C.int(0), nil,
|
||||||
nil)) != 1 {
|
nil)) != 1 {
|
||||||
return nil, SSLError.New("failed dumping private key")
|
return nil, errors.New("failed dumping private key")
|
||||||
}
|
|
||||||
return ioutil.ReadAll(asAnyBio(bio))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (key *pKey) MarshalPKIXPublicKeyPEM() (pem_block []byte,
|
|
||||||
err error) {
|
|
||||||
bio := C.BIO_new(C.BIO_s_mem())
|
|
||||||
if bio == nil {
|
|
||||||
return nil, SSLError.New("failed to allocate memory BIO")
|
|
||||||
}
|
|
||||||
defer C.BIO_free(bio)
|
|
||||||
if int(C.PEM_write_bio_PUBKEY(bio, key.key)) != 1 {
|
|
||||||
return nil, SSLError.New("failed dumping public key")
|
|
||||||
}
|
}
|
||||||
return ioutil.ReadAll(asAnyBio(bio))
|
return ioutil.ReadAll(asAnyBio(bio))
|
||||||
}
|
}
|
||||||
@ -78,11 +71,24 @@ func (key *pKey) MarshalPKCS1PrivateKeyDER() (der_block []byte,
|
|||||||
var p *pem.Block
|
var p *pem.Block
|
||||||
p, pem_block = pem.Decode(pem_block)
|
p, pem_block = pem.Decode(pem_block)
|
||||||
if len(pem_block) > 0 || p == nil {
|
if len(pem_block) > 0 || p == nil {
|
||||||
return nil, SSLError.New("something went wrong with PEM generation")
|
return nil, errors.New("something went wrong with PEM generation")
|
||||||
}
|
}
|
||||||
return p.Bytes, nil
|
return p.Bytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (key *pKey) MarshalPKIXPublicKeyPEM() (pem_block []byte,
|
||||||
|
err error) {
|
||||||
|
bio := C.BIO_new(C.BIO_s_mem())
|
||||||
|
if bio == nil {
|
||||||
|
return nil, errors.New("failed to allocate memory BIO")
|
||||||
|
}
|
||||||
|
defer C.BIO_free(bio)
|
||||||
|
if int(C.PEM_write_bio_PUBKEY(bio, key.key)) != 1 {
|
||||||
|
return nil, errors.New("failed dumping public key")
|
||||||
|
}
|
||||||
|
return ioutil.ReadAll(asAnyBio(bio))
|
||||||
|
}
|
||||||
|
|
||||||
func (key *pKey) MarshalPKIXPublicKeyDER() (der_block []byte,
|
func (key *pKey) MarshalPKIXPublicKeyDER() (der_block []byte,
|
||||||
err error) {
|
err error) {
|
||||||
// TODO: i can't decipher how to get a generic PKIX Public Key in DER
|
// TODO: i can't decipher how to get a generic PKIX Public Key in DER
|
||||||
@ -95,35 +101,12 @@ func (key *pKey) MarshalPKIXPublicKeyDER() (der_block []byte,
|
|||||||
var p *pem.Block
|
var p *pem.Block
|
||||||
p, pem_block = pem.Decode(pem_block)
|
p, pem_block = pem.Decode(pem_block)
|
||||||
if len(pem_block) > 0 || p == nil {
|
if len(pem_block) > 0 || p == nil {
|
||||||
return nil, SSLError.New("something went wrong with PEM generation")
|
return nil, errors.New("something went wrong with PEM generation")
|
||||||
}
|
}
|
||||||
return p.Bytes, nil
|
return p.Bytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *pKey) StdlibPrivateKey() (*rsa.PrivateKey, error) {
|
// LoadPrivateKey loads a private key from a PEM-encoded block.
|
||||||
der_block, err := key.MarshalPKCS1PrivateKeyDER()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return x509.ParsePKCS1PrivateKey(der_block)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (key *pKey) StdlibPublicKey() (*rsa.PublicKey, error) {
|
|
||||||
der_block, err := key.MarshalPKIXPublicKeyDER()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
k, err := x509.ParsePKIXPublicKey(der_block)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
rk, ok := k.(*rsa.PublicKey)
|
|
||||||
if !ok {
|
|
||||||
return nil, SSLError.New("not an rsa public key")
|
|
||||||
}
|
|
||||||
return rk, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadPrivateKey(pem_block []byte) (PrivateKey, error) {
|
func LoadPrivateKey(pem_block []byte) (PrivateKey, error) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
@ -145,6 +128,7 @@ type Certificate struct {
|
|||||||
x *C.X509
|
x *C.X509
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadCertificate loads an X509 certificate from a PEM-encoded block.
|
||||||
func LoadCertificate(pem_block []byte) (*Certificate, error) {
|
func LoadCertificate(pem_block []byte) (*Certificate, error) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
defer runtime.UnlockOSThread()
|
defer runtime.UnlockOSThread()
|
||||||
@ -162,22 +146,24 @@ func LoadCertificate(pem_block []byte) (*Certificate, error) {
|
|||||||
return x, nil
|
return x, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalPEM converts the X509 certificate to PEM-encoded format
|
||||||
func (c *Certificate) MarshalPEM() (pem_block []byte, err error) {
|
func (c *Certificate) MarshalPEM() (pem_block []byte, err error) {
|
||||||
bio := C.BIO_new(C.BIO_s_mem())
|
bio := C.BIO_new(C.BIO_s_mem())
|
||||||
if bio == nil {
|
if bio == nil {
|
||||||
return nil, SSLError.New("failed to allocate memory BIO")
|
return nil, errors.New("failed to allocate memory BIO")
|
||||||
}
|
}
|
||||||
defer C.BIO_free(bio)
|
defer C.BIO_free(bio)
|
||||||
if int(C.PEM_write_bio_X509(bio, c.x)) != 1 {
|
if int(C.PEM_write_bio_X509(bio, c.x)) != 1 {
|
||||||
return nil, SSLError.New("failed dumping certificate")
|
return nil, errors.New("failed dumping certificate")
|
||||||
}
|
}
|
||||||
return ioutil.ReadAll(asAnyBio(bio))
|
return ioutil.ReadAll(asAnyBio(bio))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PublicKey returns the public key embedded in the X509 certificate.
|
||||||
func (c *Certificate) PublicKey() (PublicKey, error) {
|
func (c *Certificate) PublicKey() (PublicKey, error) {
|
||||||
pkey := C.X509_get_pubkey(c.x)
|
pkey := C.X509_get_pubkey(c.x)
|
||||||
if pkey == nil {
|
if pkey == nil {
|
||||||
return nil, SSLError.New("no public key found")
|
return nil, errors.New("no public key found")
|
||||||
}
|
}
|
||||||
key := &pKey{key: pkey}
|
key := &pKey{key: pkey}
|
||||||
runtime.SetFinalizer(key, func(key *pKey) {
|
runtime.SetFinalizer(key, func(key *pKey) {
|
||||||
@ -185,20 +171,3 @@ func (c *Certificate) PublicKey() (PublicKey, error) {
|
|||||||
})
|
})
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type KeyPair struct {
|
|
||||||
Certificate *Certificate
|
|
||||||
PrivateKey PrivateKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func X509KeyPair(key PrivateKey, cert *Certificate) (tls.Certificate, error) {
|
|
||||||
key_pem_bytes, err := key.MarshalPKCS1PrivateKeyPEM()
|
|
||||||
if err != nil {
|
|
||||||
return tls.Certificate{}, err
|
|
||||||
}
|
|
||||||
cert_pem_bytes, err := cert.MarshalPEM()
|
|
||||||
if err != nil {
|
|
||||||
return tls.Certificate{}, err
|
|
||||||
}
|
|
||||||
return tls.X509KeyPair(cert_pem_bytes, key_pem_bytes)
|
|
||||||
}
|
|
||||||
|
22
ssl_test.go
22
ssl_test.go
@ -13,8 +13,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"code.spacemonkey.com/go/errors"
|
"code.spacemonkey.com/go/openssl/utils"
|
||||||
space_sync "code.spacemonkey.com/go/space/sync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -80,11 +79,11 @@ func NetPipe(t testing.TB) (net.Conn, net.Conn) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
client_future := space_sync.NewFuture()
|
client_future := utils.NewFuture()
|
||||||
go func() {
|
go func() {
|
||||||
client_future.Set(net.Dial(l.Addr().Network(), l.Addr().String()))
|
client_future.Set(net.Dial(l.Addr().Network(), l.Addr().String()))
|
||||||
}()
|
}()
|
||||||
errs := errors.NewErrorGroup()
|
var errs utils.ErrorGroup
|
||||||
server_conn, err := l.Accept()
|
server_conn, err := l.Accept()
|
||||||
errs.Add(err)
|
errs.Add(err)
|
||||||
client_conn, err := client_future.Get()
|
client_conn, err := client_future.Get()
|
||||||
@ -287,7 +286,8 @@ func StdlibConstructor(t testing.TB, server_conn, client_conn net.Conn) (
|
|||||||
}
|
}
|
||||||
config := &tls.Config{
|
config := &tls.Config{
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
InsecureSkipVerify: true}
|
InsecureSkipVerify: true,
|
||||||
|
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}}
|
||||||
server = tls.Server(server_conn, config)
|
server = tls.Server(server_conn, config)
|
||||||
client = tls.Client(client_conn, config)
|
client = tls.Client(client_conn, config)
|
||||||
return server, client
|
return server, client
|
||||||
@ -315,6 +315,10 @@ func OpenSSLConstructor(t testing.TB, server_conn, client_conn net.Conn) (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
err = ctx.SetCipherList("AES128-SHA")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
server, err = Server(server_conn, ctx)
|
server, err = Server(server_conn, ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -475,7 +479,8 @@ func TestStdlibLotsOfConns(t *testing.T) {
|
|||||||
}
|
}
|
||||||
tls_config := &tls.Config{
|
tls_config := &tls.Config{
|
||||||
Certificates: []tls.Certificate{tls_cert},
|
Certificates: []tls.Certificate{tls_cert},
|
||||||
InsecureSkipVerify: true}
|
InsecureSkipVerify: true,
|
||||||
|
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}}
|
||||||
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
|
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
|
||||||
func(l net.Listener) net.Listener {
|
func(l net.Listener) net.Listener {
|
||||||
return tls.NewListener(l, tls_config)
|
return tls.NewListener(l, tls_config)
|
||||||
@ -505,7 +510,10 @@ func TestOpenSSLLotsOfConns(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
err = ctx.SetCipherList("AES128-SHA")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
|
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
|
||||||
func(l net.Listener) net.Listener {
|
func(l net.Listener) net.Listener {
|
||||||
return NewListener(l, ctx)
|
return NewListener(l, ctx)
|
||||||
|
@ -1,9 +0,0 @@
|
|||||||
// Copyright (C) 2014 Space Monkey, Inc.
|
|
||||||
|
|
||||||
package thread_id
|
|
||||||
|
|
||||||
import (
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Id() unsafe.Pointer
|
|
38
utils/errors.go
Normal file
38
utils/errors.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
// Copyright (C) 2014 Space Monkey, Inc.
|
||||||
|
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrorGroup collates errors
|
||||||
|
type ErrorGroup struct {
|
||||||
|
Errors []error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds an error to an existing error group
|
||||||
|
func (e *ErrorGroup) Add(err error) {
|
||||||
|
if err != nil {
|
||||||
|
e.Errors = append(e.Errors, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize returns an error corresponding to the ErrorGroup state. If there's
|
||||||
|
// no errors in the group, finalize returns nil. If there's only one error,
|
||||||
|
// Finalize returns that error. Otherwise, Finalize will make a new error
|
||||||
|
// consisting of the messages from the constituent errors.
|
||||||
|
func (e *ErrorGroup) Finalize() error {
|
||||||
|
if len(e.Errors) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(e.Errors) == 1 {
|
||||||
|
return e.Errors[0]
|
||||||
|
}
|
||||||
|
msgs := make([]string, 0, len(e.Errors))
|
||||||
|
for _, err := range e.Errors {
|
||||||
|
msgs = append(msgs, err.Error())
|
||||||
|
}
|
||||||
|
return errors.New(strings.Join(msgs, "\n"))
|
||||||
|
}
|
67
utils/future.go
Normal file
67
utils/future.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
// Copyright (C) 2014 Space Monkey, Inc.
|
||||||
|
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Future is a type that is essentially the inverse of a channel. With a
|
||||||
|
// channel, you have multiple senders and one receiver. With a future, you can
|
||||||
|
// have multiple receivers and one sender. Additionally, a future protects
|
||||||
|
// against double-sends. Since this is usually used for returning function
|
||||||
|
// results, we also capture and return error values as well. Use NewFuture
|
||||||
|
// to initialize.
|
||||||
|
type Future struct {
|
||||||
|
mutex *sync.Mutex
|
||||||
|
cond *sync.Cond
|
||||||
|
received bool
|
||||||
|
val interface{}
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFuture returns an initialized and ready Future.
|
||||||
|
func NewFuture() *Future {
|
||||||
|
mutex := &sync.Mutex{}
|
||||||
|
return &Future{
|
||||||
|
mutex: mutex,
|
||||||
|
cond: sync.NewCond(mutex),
|
||||||
|
received: false,
|
||||||
|
val: nil,
|
||||||
|
err: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get blocks until the Future has a value set.
|
||||||
|
func (self *Future) Get() (interface{}, error) {
|
||||||
|
self.mutex.Lock()
|
||||||
|
defer self.mutex.Unlock()
|
||||||
|
for {
|
||||||
|
if self.received {
|
||||||
|
return self.val, self.err
|
||||||
|
}
|
||||||
|
self.cond.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fired returns whether or not a value has been set. If Fired is true, Get
|
||||||
|
// won't block.
|
||||||
|
func (self *Future) Fired() bool {
|
||||||
|
self.mutex.Lock()
|
||||||
|
defer self.mutex.Unlock()
|
||||||
|
return self.received
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set provides the value to present and future Get calls. If Set has already
|
||||||
|
// been called, this is a no-op.
|
||||||
|
func (self *Future) Set(val interface{}, err error) {
|
||||||
|
self.mutex.Lock()
|
||||||
|
defer self.mutex.Unlock()
|
||||||
|
if self.received {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
self.received = true
|
||||||
|
self.val = val
|
||||||
|
self.err = err
|
||||||
|
self.cond.Broadcast()
|
||||||
|
}
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include "runtime.h"
|
#include "runtime.h"
|
||||||
|
|
||||||
void ·Id(void *ref) {
|
void ·ThreadId(void *id) {
|
||||||
ref = (void *)m;
|
id = (void *)m;
|
||||||
FLUSH(&ref);
|
FLUSH(&id);
|
||||||
}
|
}
|
13
utils/thread_id.go
Normal file
13
utils/thread_id.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
// Copyright (C) 2014 Space Monkey, Inc.
|
||||||
|
|
||||||
|
// Package utils provides some small things that implementation of the OpenSSL
|
||||||
|
// wrapper library needed.
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ThreadId returns the current runtime's thread id. Thanks to Gustavo Niemeyer
|
||||||
|
// for this. https://github.com/niemeyer/qml/blob/master/tref/tref.go
|
||||||
|
func ThreadId() unsafe.Pointer
|
Loading…
Reference in New Issue
Block a user