1
0
mirror of https://github.com/libp2p/go-openssl.git synced 2025-04-25 17:50:23 +08:00

space monkey internal commit export

[katamari commit: b70e599dbbed2ba3cd5ba278a0278c8ff8c553cb]
This commit is contained in:
JT Olds 2014-01-19 13:43:40 -07:00
parent 053d794fe5
commit 88870b4b4c
4 changed files with 108 additions and 34 deletions

27
bio.go
View File

@ -148,25 +148,18 @@ func (b *writeBio) WriteTo(w io.Writer) (rv int64, err error) {
b.data_mtx.Lock()
data := b.buf
b.data_mtx.Unlock()
total := int64(0)
for {
if len(data) == 0 {
return total, nil
}
written, err := w.Write(data)
total += int64(written)
// subtract however much data we wrote from the buffer
b.data_mtx.Lock()
n := copy(b.buf, b.buf[written:])
b.buf = b.buf[:n]
data = b.buf
b.data_mtx.Unlock()
if err != nil {
return total, err
}
if len(data) == 0 {
return 0, nil
}
n, err := w.Write(data)
// subtract however much data we wrote from the buffer
b.data_mtx.Lock()
b.buf = b.buf[:copy(b.buf, b.buf[n:])]
b.data_mtx.Unlock()
return int64(n), err
}
func (self *writeBio) Disconnect(b *C.BIO) {

28
conn.go
View File

@ -141,6 +141,7 @@ func (c *Conn) getErrorHandler(rv C.int, errno error) func() error {
return io.ErrUnexpectedEOF
}
case C.SSL_ERROR_WANT_READ:
go c.flushOutputBuffer()
if c.want_read_future != nil {
want_read_future := c.want_read_future
return func() error {
@ -157,10 +158,6 @@ func (c *Conn) getErrorHandler(rv C.int, errno error) func() error {
c.mtx.Unlock()
want_read_future.Set(nil, err)
}()
err = c.flushOutputBuffer()
if err != nil {
return err
}
err = c.fillInputBuffer()
if err != nil {
return err
@ -204,13 +201,13 @@ func (c *Conn) handleError(errcb func() error) error {
}
func (c *Conn) handshake() func() error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
c.mtx.Lock()
defer c.mtx.Unlock()
if c.is_shutdown {
return func() error { return io.ErrUnexpectedEOF }
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
rv, errno := C.SSL_do_handshake(c.ssl)
if rv > 0 {
return nil
@ -224,10 +221,8 @@ func (c *Conn) Handshake() error {
err := tryAgain
for err == tryAgain {
err = c.handleError(c.handshake())
if err == nil {
return c.flushOutputBuffer()
}
}
go c.flushOutputBuffer()
return err
}
@ -251,10 +246,10 @@ func (c *Conn) PeerCertificate() (*Certificate, error) {
}
func (c *Conn) shutdown() func() error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
c.mtx.Lock()
defer c.mtx.Unlock()
runtime.LockOSThread()
defer runtime.UnlockOSThread()
rv, errno := C.SSL_shutdown(c.ssl)
if rv > 0 {
return nil
@ -311,13 +306,13 @@ func (c *Conn) Close() error {
}
func (c *Conn) read(b []byte) (int, func() error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
c.mtx.Lock()
defer c.mtx.Unlock()
if c.is_shutdown {
return 0, func() error { return io.EOF }
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
rv, errno := C.SSL_read(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
if rv > 0 {
return int(rv), nil
@ -337,7 +332,8 @@ func (c *Conn) Read(b []byte) (n int, err error) {
n, errcb := c.read(b)
err = c.handleError(errcb)
if err == nil {
return n, c.flushOutputBuffer()
go c.flushOutputBuffer()
return n, nil
}
if err == io.ErrUnexpectedEOF {
err = io.EOF
@ -347,14 +343,14 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}
func (c *Conn) write(b []byte) (int, func() error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
c.mtx.Lock()
defer c.mtx.Unlock()
if c.is_shutdown {
err := errors.New("connection closed")
return 0, func() error { return err }
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
if rv > 0 {
return int(rv), nil

View File

@ -46,7 +46,8 @@
// }
// conn, err := openssl.Dial("tcp", "localhost:7777", ctx, 0)
//
// TODO/Help wanted: make an easy interface to the net/http client library
// TODO/Help wanted: make an easy interface to the net/http client library that
// supports all the fiddly bits like proxies and connection pools and what-not.
package openssl
/*

View File

@ -392,6 +392,90 @@ func BenchmarkOpenSSLStdlibThroughput(b *testing.B) {
ThroughputBenchmark(b, OpenSSLStdlibConstructor)
}
func FullDuplexRenegotiationTest(t testing.TB, constructor func(
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
server_conn, client_conn := NetPipe(t)
defer server_conn.Close()
defer client_conn.Close()
times := 256
data_len := 4 * SSLRecordSize
data1 := make([]byte, data_len)
_, err := io.ReadFull(rand.Reader, data1[:])
if err != nil {
t.Fatal(err)
}
data2 := make([]byte, data_len)
_, err = io.ReadFull(rand.Reader, data1[:])
if err != nil {
t.Fatal(err)
}
server, client := constructor(t, server_conn, client_conn)
defer close_both(server, client)
var wg sync.WaitGroup
send_func := func(sender HandshakingConn, data []byte) {
defer wg.Done()
for i := 0; i < times; i++ {
if i == times/2 {
wg.Add(1)
go func() {
defer wg.Done()
err := sender.Handshake()
if err != nil {
t.Fatal(err)
}
}()
}
_, err := sender.Write(data)
if err != nil {
t.Fatal(err)
}
}
}
recv_func := func(receiver net.Conn, data []byte) {
defer wg.Done()
buf := make([]byte, len(data))
for i := 0; i < times; i++ {
n, err := io.ReadFull(receiver, buf[:])
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf[:n], data) {
t.Fatal(err)
}
}
}
wg.Add(4)
go recv_func(server, data1)
go send_func(client, data1)
go send_func(server, data2)
go recv_func(client, data2)
wg.Wait()
}
func TestStdlibFullDuplexRenegotiation(t *testing.T) {
FullDuplexRenegotiationTest(t, StdlibConstructor)
}
func TestOpenSSLFullDuplexRenegotiation(t *testing.T) {
FullDuplexRenegotiationTest(t, OpenSSLConstructor)
}
func TestOpenSSLStdlibFullDuplexRenegotiation(t *testing.T) {
FullDuplexRenegotiationTest(t, OpenSSLStdlibConstructor)
}
func TestStdlibOpenSSLFullDuplexRenegotiation(t *testing.T) {
FullDuplexRenegotiationTest(t, StdlibOpenSSLConstructor)
}
func LotsOfConns(t *testing.T, payload_size int64, loops, clients int,
sleep time.Duration, newListener func(net.Listener) net.Listener,
newClient func(net.Conn) (net.Conn, error)) {