mirror of
https://github.com/libp2p/go-openssl.git
synced 2025-04-25 17:50:23 +08:00
space monkey internal commit export
[katamari commit: b70e599dbbed2ba3cd5ba278a0278c8ff8c553cb]
This commit is contained in:
parent
053d794fe5
commit
88870b4b4c
27
bio.go
27
bio.go
@ -148,25 +148,18 @@ func (b *writeBio) WriteTo(w io.Writer) (rv int64, err error) {
|
||||
b.data_mtx.Lock()
|
||||
data := b.buf
|
||||
b.data_mtx.Unlock()
|
||||
total := int64(0)
|
||||
|
||||
for {
|
||||
if len(data) == 0 {
|
||||
return total, nil
|
||||
}
|
||||
written, err := w.Write(data)
|
||||
total += int64(written)
|
||||
|
||||
// subtract however much data we wrote from the buffer
|
||||
b.data_mtx.Lock()
|
||||
n := copy(b.buf, b.buf[written:])
|
||||
b.buf = b.buf[:n]
|
||||
data = b.buf
|
||||
b.data_mtx.Unlock()
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
n, err := w.Write(data)
|
||||
|
||||
// subtract however much data we wrote from the buffer
|
||||
b.data_mtx.Lock()
|
||||
b.buf = b.buf[:copy(b.buf, b.buf[n:])]
|
||||
b.data_mtx.Unlock()
|
||||
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
func (self *writeBio) Disconnect(b *C.BIO) {
|
||||
|
28
conn.go
28
conn.go
@ -141,6 +141,7 @@ func (c *Conn) getErrorHandler(rv C.int, errno error) func() error {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
case C.SSL_ERROR_WANT_READ:
|
||||
go c.flushOutputBuffer()
|
||||
if c.want_read_future != nil {
|
||||
want_read_future := c.want_read_future
|
||||
return func() error {
|
||||
@ -157,10 +158,6 @@ func (c *Conn) getErrorHandler(rv C.int, errno error) func() error {
|
||||
c.mtx.Unlock()
|
||||
want_read_future.Set(nil, err)
|
||||
}()
|
||||
err = c.flushOutputBuffer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.fillInputBuffer()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -204,13 +201,13 @@ func (c *Conn) handleError(errcb func() error) error {
|
||||
}
|
||||
|
||||
func (c *Conn) handshake() func() error {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
c.mtx.Lock()
|
||||
defer c.mtx.Unlock()
|
||||
if c.is_shutdown {
|
||||
return func() error { return io.ErrUnexpectedEOF }
|
||||
}
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
rv, errno := C.SSL_do_handshake(c.ssl)
|
||||
if rv > 0 {
|
||||
return nil
|
||||
@ -224,10 +221,8 @@ func (c *Conn) Handshake() error {
|
||||
err := tryAgain
|
||||
for err == tryAgain {
|
||||
err = c.handleError(c.handshake())
|
||||
if err == nil {
|
||||
return c.flushOutputBuffer()
|
||||
}
|
||||
}
|
||||
go c.flushOutputBuffer()
|
||||
return err
|
||||
}
|
||||
|
||||
@ -251,10 +246,10 @@ func (c *Conn) PeerCertificate() (*Certificate, error) {
|
||||
}
|
||||
|
||||
func (c *Conn) shutdown() func() error {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
c.mtx.Lock()
|
||||
defer c.mtx.Unlock()
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
rv, errno := C.SSL_shutdown(c.ssl)
|
||||
if rv > 0 {
|
||||
return nil
|
||||
@ -311,13 +306,13 @@ func (c *Conn) Close() error {
|
||||
}
|
||||
|
||||
func (c *Conn) read(b []byte) (int, func() error) {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
c.mtx.Lock()
|
||||
defer c.mtx.Unlock()
|
||||
if c.is_shutdown {
|
||||
return 0, func() error { return io.EOF }
|
||||
}
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
rv, errno := C.SSL_read(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
|
||||
if rv > 0 {
|
||||
return int(rv), nil
|
||||
@ -337,7 +332,8 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
n, errcb := c.read(b)
|
||||
err = c.handleError(errcb)
|
||||
if err == nil {
|
||||
return n, c.flushOutputBuffer()
|
||||
go c.flushOutputBuffer()
|
||||
return n, nil
|
||||
}
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
err = io.EOF
|
||||
@ -347,14 +343,14 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *Conn) write(b []byte) (int, func() error) {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
c.mtx.Lock()
|
||||
defer c.mtx.Unlock()
|
||||
if c.is_shutdown {
|
||||
err := errors.New("connection closed")
|
||||
return 0, func() error { return err }
|
||||
}
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
|
||||
if rv > 0 {
|
||||
return int(rv), nil
|
||||
|
3
init.go
3
init.go
@ -46,7 +46,8 @@
|
||||
// }
|
||||
// conn, err := openssl.Dial("tcp", "localhost:7777", ctx, 0)
|
||||
//
|
||||
// TODO/Help wanted: make an easy interface to the net/http client library
|
||||
// TODO/Help wanted: make an easy interface to the net/http client library that
|
||||
// supports all the fiddly bits like proxies and connection pools and what-not.
|
||||
package openssl
|
||||
|
||||
/*
|
||||
|
84
ssl_test.go
84
ssl_test.go
@ -392,6 +392,90 @@ func BenchmarkOpenSSLStdlibThroughput(b *testing.B) {
|
||||
ThroughputBenchmark(b, OpenSSLStdlibConstructor)
|
||||
}
|
||||
|
||||
func FullDuplexRenegotiationTest(t testing.TB, constructor func(
|
||||
t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) {
|
||||
|
||||
server_conn, client_conn := NetPipe(t)
|
||||
defer server_conn.Close()
|
||||
defer client_conn.Close()
|
||||
|
||||
times := 256
|
||||
data_len := 4 * SSLRecordSize
|
||||
data1 := make([]byte, data_len)
|
||||
_, err := io.ReadFull(rand.Reader, data1[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data2 := make([]byte, data_len)
|
||||
_, err = io.ReadFull(rand.Reader, data1[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
server, client := constructor(t, server_conn, client_conn)
|
||||
defer close_both(server, client)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
send_func := func(sender HandshakingConn, data []byte) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < times; i++ {
|
||||
if i == times/2 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := sender.Handshake()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
_, err := sender.Write(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
recv_func := func(receiver net.Conn, data []byte) {
|
||||
defer wg.Done()
|
||||
|
||||
buf := make([]byte, len(data))
|
||||
for i := 0; i < times; i++ {
|
||||
n, err := io.ReadFull(receiver, buf[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(buf[:n], data) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(4)
|
||||
go recv_func(server, data1)
|
||||
go send_func(client, data1)
|
||||
go send_func(server, data2)
|
||||
go recv_func(client, data2)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestStdlibFullDuplexRenegotiation(t *testing.T) {
|
||||
FullDuplexRenegotiationTest(t, StdlibConstructor)
|
||||
}
|
||||
|
||||
func TestOpenSSLFullDuplexRenegotiation(t *testing.T) {
|
||||
FullDuplexRenegotiationTest(t, OpenSSLConstructor)
|
||||
}
|
||||
|
||||
func TestOpenSSLStdlibFullDuplexRenegotiation(t *testing.T) {
|
||||
FullDuplexRenegotiationTest(t, OpenSSLStdlibConstructor)
|
||||
}
|
||||
|
||||
func TestStdlibOpenSSLFullDuplexRenegotiation(t *testing.T) {
|
||||
FullDuplexRenegotiationTest(t, StdlibOpenSSLConstructor)
|
||||
}
|
||||
|
||||
func LotsOfConns(t *testing.T, payload_size int64, loops, clients int,
|
||||
sleep time.Duration, newListener func(net.Listener) net.Listener,
|
||||
newClient func(net.Conn) (net.Conn, error)) {
|
||||
|
Loading…
Reference in New Issue
Block a user