// Copyright (C) 2017. See AUTHORS.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package openssl

// #include "shim.h"
import "C"

import (
	"errors"
	"fmt"
	"io"
	"net"
	"runtime"
	"sync"
	"time"
	"unsafe"

	"github.com/libp2p/go-openssl/utils"
	"github.com/mattn/go-pointer"
)

var (
	errZeroReturn = errors.New("zero return")
	errWantRead   = errors.New("want read")
	errWantWrite  = errors.New("want write")
	errTryAgain   = errors.New("try again")
)

type Conn struct {
	*SSL

	conn             net.Conn
	ctx              *Ctx // for gc
	into_ssl         *readBio
	from_ssl         *writeBio
	is_shutdown      bool
	mtx              sync.Mutex
	want_read_future *utils.Future
}

type VerifyResult int

const (
	Ok                            VerifyResult = C.X509_V_OK
	UnableToGetIssuerCert         VerifyResult = C.X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT
	UnableToGetCrl                VerifyResult = C.X509_V_ERR_UNABLE_TO_GET_CRL
	UnableToDecryptCertSignature  VerifyResult = C.X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE
	UnableToDecryptCrlSignature   VerifyResult = C.X509_V_ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE
	UnableToDecodeIssuerPublicKey VerifyResult = C.X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY
	CertSignatureFailure          VerifyResult = C.X509_V_ERR_CERT_SIGNATURE_FAILURE
	CrlSignatureFailure           VerifyResult = C.X509_V_ERR_CRL_SIGNATURE_FAILURE
	CertNotYetValid               VerifyResult = C.X509_V_ERR_CERT_NOT_YET_VALID
	CertHasExpired                VerifyResult = C.X509_V_ERR_CERT_HAS_EXPIRED
	CrlNotYetValid                VerifyResult = C.X509_V_ERR_CRL_NOT_YET_VALID
	CrlHasExpired                 VerifyResult = C.X509_V_ERR_CRL_HAS_EXPIRED
	ErrorInCertNotBeforeField     VerifyResult = C.X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD
	ErrorInCertNotAfterField      VerifyResult = C.X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD
	ErrorInCrlLastUpdateField     VerifyResult = C.X509_V_ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD
	ErrorInCrlNextUpdateField     VerifyResult = C.X509_V_ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD
	OutOfMem                      VerifyResult = C.X509_V_ERR_OUT_OF_MEM
	DepthZeroSelfSignedCert       VerifyResult = C.X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT
	SelfSignedCertInChain         VerifyResult = C.X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN
	UnableToGetIssuerCertLocally  VerifyResult = C.X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY
	UnableToVerifyLeafSignature   VerifyResult = C.X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE
	CertChainTooLong              VerifyResult = C.X509_V_ERR_CERT_CHAIN_TOO_LONG
	CertRevoked                   VerifyResult = C.X509_V_ERR_CERT_REVOKED
	InvalidCa                     VerifyResult = C.X509_V_ERR_INVALID_CA
	PathLengthExceeded            VerifyResult = C.X509_V_ERR_PATH_LENGTH_EXCEEDED
	InvalidPurpose                VerifyResult = C.X509_V_ERR_INVALID_PURPOSE
	CertUntrusted                 VerifyResult = C.X509_V_ERR_CERT_UNTRUSTED
	CertRejected                  VerifyResult = C.X509_V_ERR_CERT_REJECTED
	SubjectIssuerMismatch         VerifyResult = C.X509_V_ERR_SUBJECT_ISSUER_MISMATCH
	AkidSkidMismatch              VerifyResult = C.X509_V_ERR_AKID_SKID_MISMATCH
	AkidIssuerSerialMismatch      VerifyResult = C.X509_V_ERR_AKID_ISSUER_SERIAL_MISMATCH
	KeyusageNoCertsign            VerifyResult = C.X509_V_ERR_KEYUSAGE_NO_CERTSIGN
	UnableToGetCrlIssuer          VerifyResult = C.X509_V_ERR_UNABLE_TO_GET_CRL_ISSUER
	UnhandledCriticalExtension    VerifyResult = C.X509_V_ERR_UNHANDLED_CRITICAL_EXTENSION
	KeyusageNoCrlSign             VerifyResult = C.X509_V_ERR_KEYUSAGE_NO_CRL_SIGN
	UnhandledCriticalCrlExtension VerifyResult = C.X509_V_ERR_UNHANDLED_CRITICAL_CRL_EXTENSION
	InvalidNonCa                  VerifyResult = C.X509_V_ERR_INVALID_NON_CA
	ProxyPathLengthExceeded       VerifyResult = C.X509_V_ERR_PROXY_PATH_LENGTH_EXCEEDED
	KeyusageNoDigitalSignature    VerifyResult = C.X509_V_ERR_KEYUSAGE_NO_DIGITAL_SIGNATURE
	ProxyCertificatesNotAllowed   VerifyResult = C.X509_V_ERR_PROXY_CERTIFICATES_NOT_ALLOWED
	InvalidExtension              VerifyResult = C.X509_V_ERR_INVALID_EXTENSION
	InvalidPolicyExtension        VerifyResult = C.X509_V_ERR_INVALID_POLICY_EXTENSION
	NoExplicitPolicy              VerifyResult = C.X509_V_ERR_NO_EXPLICIT_POLICY
	UnnestedResource              VerifyResult = C.X509_V_ERR_UNNESTED_RESOURCE
	ApplicationVerification       VerifyResult = C.X509_V_ERR_APPLICATION_VERIFICATION
)

func newSSL(ctx *C.SSL_CTX) (*C.SSL, error) {
	runtime.LockOSThread()
	defer runtime.UnlockOSThread()
	ssl := C.SSL_new(ctx)
	if ssl == nil {
		return nil, errorFromErrorQueue()
	}
	return ssl, nil
}

func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) {
	ssl, err := newSSL(ctx.ctx)
	if err != nil {
		return nil, err
	}

	into_ssl := &readBio{}
	from_ssl := &writeBio{}

	if ctx.GetMode()&ReleaseBuffers > 0 {
		into_ssl.release_buffers = true
		from_ssl.release_buffers = true
	}

	into_ssl_cbio := into_ssl.MakeCBIO()
	from_ssl_cbio := from_ssl.MakeCBIO()
	if into_ssl_cbio == nil || from_ssl_cbio == nil {
		// these frees are null safe
		C.BIO_free(into_ssl_cbio)
		C.BIO_free(from_ssl_cbio)
		C.SSL_free(ssl)
		return nil, errors.New("failed to allocate memory BIO")
	}

	// the ssl object takes ownership of these objects now
	C.SSL_set_bio(ssl, into_ssl_cbio, from_ssl_cbio)

	s := &SSL{ssl: ssl}
	C.SSL_set_ex_data(s.ssl, get_ssl_idx(), pointer.Save(s))

	c := &Conn{
		SSL: s,

		conn:     conn,
		ctx:      ctx,
		into_ssl: into_ssl,
		from_ssl: from_ssl}
	runtime.SetFinalizer(c, func(c *Conn) {
		c.into_ssl.Disconnect(into_ssl_cbio)
		c.from_ssl.Disconnect(from_ssl_cbio)
		C.SSL_free(c.ssl)
	})
	return c, nil
}

// Client wraps an existing stream connection and puts it in the connect state
// for any subsequent handshakes.
//
// IMPORTANT NOTE: if you use this method instead of Dial to construct an SSL
// connection, you are responsible for verifying the peer's hostname.
// Otherwise, you are vulnerable to MITM attacks.
//
// Client also does not set up SNI for you like Dial does.
//
// Client connections probably won't work for you unless you set a verify
// location or add some certs to the certificate store of the client context
// you're using. This library is not nice enough to use the system certificate
// store by default for you yet.
func Client(conn net.Conn, ctx *Ctx) (*Conn, error) {
	c, err := newConn(conn, ctx)
	if err != nil {
		return nil, err
	}
	C.SSL_set_connect_state(c.ssl)
	return c, nil
}

// Server wraps an existing stream connection and puts it in the accept state
// for any subsequent handshakes.
func Server(conn net.Conn, ctx *Ctx) (*Conn, error) {
	c, err := newConn(conn, ctx)
	if err != nil {
		return nil, err
	}
	C.SSL_set_accept_state(c.ssl)
	return c, nil
}

func (c *Conn) GetCtx() *Ctx { return c.ctx }

func (c *Conn) CurrentCipher() (string, error) {
	p := C.X_SSL_get_cipher_name(c.ssl)
	if p == nil {
		return "", errors.New("session not established")
	}

	return C.GoString(p), nil
}

func (c *Conn) fillInputBuffer() error {
	for {
		n, err := c.into_ssl.ReadFromOnce(c.conn)
		if n == 0 && err == nil {
			continue
		}
		if err == io.EOF {
			c.into_ssl.MarkEOF()
			return c.Close()
		}
		return err
	}
}

func (c *Conn) flushOutputBuffer() error {
	_, err := c.from_ssl.WriteTo(c.conn)
	return err
}

func (c *Conn) getErrorHandler(rv C.int, errno error) func() error {
	errcode := C.SSL_get_error(c.ssl, rv)
	switch errcode {
	case C.SSL_ERROR_ZERO_RETURN:
		return func() error {
			c.Close()
			return io.ErrUnexpectedEOF
		}
	case C.SSL_ERROR_WANT_READ:
		go c.flushOutputBuffer()
		if c.want_read_future != nil {
			want_read_future := c.want_read_future
			return func() error {
				_, err := want_read_future.Get()
				return err
			}
		}
		c.want_read_future = utils.NewFuture()
		want_read_future := c.want_read_future
		return func() (err error) {
			defer func() {
				c.mtx.Lock()
				c.want_read_future = nil
				c.mtx.Unlock()
				want_read_future.Set(nil, err)
			}()
			err = c.fillInputBuffer()
			if err != nil {
				return err
			}
			return errTryAgain
		}
	case C.SSL_ERROR_WANT_WRITE:
		return func() error {
			err := c.flushOutputBuffer()
			if err != nil {
				return err
			}
			return errTryAgain
		}
	case C.SSL_ERROR_SYSCALL:
		var err error
		if C.ERR_peek_error() == 0 {
			switch rv {
			case 0:
				err = errors.New("protocol-violating EOF")
			case -1:
				err = errno
			default:
				err = errorFromErrorQueue()
			}
		} else {
			err = errorFromErrorQueue()
		}
		return func() error { return err }
	default:
		err := errorFromErrorQueue()
		return func() error { return err }
	}
}

func (c *Conn) handleError(errcb func() error) error {
	if errcb != nil {
		return errcb()
	}
	return nil
}

func (c *Conn) handshake() func() error {
	c.mtx.Lock()
	defer c.mtx.Unlock()
	if c.is_shutdown {
		return func() error { return io.ErrUnexpectedEOF }
	}
	runtime.LockOSThread()
	defer runtime.UnlockOSThread()
	rv, errno := C.SSL_do_handshake(c.ssl)
	if rv > 0 {
		return nil
	}
	return c.getErrorHandler(rv, errno)
}

// Handshake performs an SSL handshake. If a handshake is not manually
// triggered, it will run before the first I/O on the encrypted stream.
func (c *Conn) Handshake() error {
	err := errTryAgain
	for err == errTryAgain {
		err = c.handleError(c.handshake())
	}
	go c.flushOutputBuffer()
	return err
}

// PeerCertificate returns the Certificate of the peer with which you're
// communicating. Only valid after a handshake.
func (c *Conn) PeerCertificate() (*Certificate, error) {
	c.mtx.Lock()
	defer c.mtx.Unlock()
	if c.is_shutdown {
		return nil, errors.New("connection closed")
	}
	x := C.SSL_get_peer_certificate(c.ssl)
	if x == nil {
		return nil, errors.New("no peer certificate found")
	}
	cert := &Certificate{x: x}
	runtime.SetFinalizer(cert, func(cert *Certificate) {
		C.X509_free(cert.x)
	})
	return cert, nil
}

// loadCertificateStack loads up a stack of x509 certificates and returns them,
// handling memory ownership.
func (c *Conn) loadCertificateStack(sk *C.struct_stack_st_X509) (
	rv []*Certificate) {

	sk_num := int(C.X_sk_X509_num(sk))
	rv = make([]*Certificate, 0, sk_num)
	for i := 0; i < sk_num; i++ {
		x := C.X_sk_X509_value(sk, C.int(i))
		// ref holds on to the underlying connection memory so we don't need to
		// worry about incrementing refcounts manually or freeing the X509
		rv = append(rv, &Certificate{x: x, ref: c})
	}
	return rv
}

// PeerCertificateChain returns the certificate chain of the peer. If called on
// the client side, the stack also contains the peer's certificate; if called
// on the server side, the peer's certificate must be obtained separately using
// PeerCertificate.
func (c *Conn) PeerCertificateChain() (rv []*Certificate, err error) {
	c.mtx.Lock()
	defer c.mtx.Unlock()
	if c.is_shutdown {
		return nil, errors.New("connection closed")
	}
	sk := C.SSL_get_peer_cert_chain(c.ssl)
	if sk == nil {
		return nil, errors.New("no peer certificates found")
	}
	return c.loadCertificateStack(sk), nil
}

type ConnectionState struct {
	Certificate           *Certificate
	CertificateError      error
	CertificateChain      []*Certificate
	CertificateChainError error
	SessionReused         bool
}

func (c *Conn) ConnectionState() (rv ConnectionState) {
	rv.Certificate, rv.CertificateError = c.PeerCertificate()
	rv.CertificateChain, rv.CertificateChainError = c.PeerCertificateChain()
	rv.SessionReused = c.SessionReused()
	return
}

func (c *Conn) shutdown() func() error {
	c.mtx.Lock()
	defer c.mtx.Unlock()
	runtime.LockOSThread()
	defer runtime.UnlockOSThread()
	rv, errno := C.SSL_shutdown(c.ssl)
	if rv > 0 {
		return nil
	}
	if rv == 0 {
		// The OpenSSL docs say that in this case, the shutdown is not
		// finished, and we should call SSL_shutdown() a second time, if a
		// bidirectional shutdown is going to be performed. Further, the
		// output of SSL_get_error may be misleading, as an erroneous
		// SSL_ERROR_SYSCALL may be flagged even though no error occurred.
		// So, TODO: revisit bidrectional shutdown, possibly trying again.
		// Note: some broken clients won't engage in bidirectional shutdown
		// without tickling them to close by sending a TCP_FIN packet, or
		// shutting down the write-side of the connection.
		return nil
	} else {
		return c.getErrorHandler(rv, errno)
	}
}

func (c *Conn) shutdownLoop() error {
	err := errTryAgain
	shutdown_tries := 0
	for err == errTryAgain {
		shutdown_tries = shutdown_tries + 1
		err = c.handleError(c.shutdown())
		if err == nil {
			return c.flushOutputBuffer()
		}
		if err == errTryAgain && shutdown_tries >= 2 {
			return errors.New("shutdown requested a third time?")
		}
	}
	if err == io.ErrUnexpectedEOF {
		err = nil
	}
	return err
}

// Close shuts down the SSL connection and closes the underlying wrapped
// connection.
func (c *Conn) Close() error {
	c.mtx.Lock()
	if c.is_shutdown {
		c.mtx.Unlock()
		return nil
	}
	c.is_shutdown = true
	c.mtx.Unlock()
	var errs utils.ErrorGroup
	errs.Add(c.shutdownLoop())
	errs.Add(c.conn.Close())
	return errs.Finalize()
}

func (c *Conn) read(b []byte) (int, func() error) {
	if len(b) == 0 {
		return 0, nil
	}
	c.mtx.Lock()
	defer c.mtx.Unlock()
	if c.is_shutdown {
		return 0, func() error { return io.EOF }
	}
	runtime.LockOSThread()
	defer runtime.UnlockOSThread()
	rv, errno := C.SSL_read(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
	if rv > 0 {
		return int(rv), nil
	}
	return 0, c.getErrorHandler(rv, errno)
}

// Read reads up to len(b) bytes into b. It returns the number of bytes read
// and an error if applicable. io.EOF is returned when the caller can expect
// to see no more data.
func (c *Conn) Read(b []byte) (n int, err error) {
	if len(b) == 0 {
		return 0, nil
	}
	err = errTryAgain
	for err == errTryAgain {
		n, errcb := c.read(b)
		err = c.handleError(errcb)
		if err == nil {
			go c.flushOutputBuffer()
			return n, nil
		}
		if err == io.ErrUnexpectedEOF {
			err = io.EOF
		}
	}
	return 0, err
}

func (c *Conn) write(b []byte) (int, func() error) {
	if len(b) == 0 {
		return 0, nil
	}
	c.mtx.Lock()
	defer c.mtx.Unlock()
	if c.is_shutdown {
		err := errors.New("connection closed")
		return 0, func() error { return err }
	}
	runtime.LockOSThread()
	defer runtime.UnlockOSThread()
	rv, errno := C.SSL_write(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
	if rv > 0 {
		return int(rv), nil
	}
	return 0, c.getErrorHandler(rv, errno)
}

// Write will encrypt the contents of b and write it to the underlying stream.
// Performance will be vastly improved if the size of b is a multiple of
// SSLRecordSize.
func (c *Conn) Write(b []byte) (written int, err error) {
	if len(b) == 0 {
		return 0, nil
	}
	err = errTryAgain
	for err == errTryAgain {
		n, errcb := c.write(b)
		err = c.handleError(errcb)
		if err == nil {
			return n, c.flushOutputBuffer()
		}
	}
	return 0, err
}

// VerifyHostname pulls the PeerCertificate and calls VerifyHostname on the
// certificate.
func (c *Conn) VerifyHostname(host string) error {
	cert, err := c.PeerCertificate()
	if err != nil {
		return err
	}
	return cert.VerifyHostname(host)
}

// LocalAddr returns the underlying connection's local address
func (c *Conn) LocalAddr() net.Addr {
	return c.conn.LocalAddr()
}

// RemoteAddr returns the underlying connection's remote address
func (c *Conn) RemoteAddr() net.Addr {
	return c.conn.RemoteAddr()
}

// SetDeadline calls SetDeadline on the underlying connection.
func (c *Conn) SetDeadline(t time.Time) error {
	return c.conn.SetDeadline(t)
}

// SetReadDeadline calls SetReadDeadline on the underlying connection.
func (c *Conn) SetReadDeadline(t time.Time) error {
	return c.conn.SetReadDeadline(t)
}

// SetWriteDeadline calls SetWriteDeadline on the underlying connection.
func (c *Conn) SetWriteDeadline(t time.Time) error {
	return c.conn.SetWriteDeadline(t)
}

func (c *Conn) UnderlyingConn() net.Conn {
	return c.conn
}

func (c *Conn) SetTlsExtHostName(name string) error {
	cname := C.CString(name)
	defer C.free(unsafe.Pointer(cname))
	runtime.LockOSThread()
	defer runtime.UnlockOSThread()
	if C.X_SSL_set_tlsext_host_name(c.ssl, cname) == 0 {
		return errorFromErrorQueue()
	}
	return nil
}

func (c *Conn) VerifyResult() VerifyResult {
	return VerifyResult(C.SSL_get_verify_result(c.ssl))
}

func (c *Conn) SessionReused() bool {
	return C.X_SSL_session_reused(c.ssl) == 1
}

func (c *Conn) GetSession() ([]byte, error) {
	runtime.LockOSThread()
	defer runtime.UnlockOSThread()

	// get1 increases the refcount of the session, so we have to free it.
	session := (*C.SSL_SESSION)(C.SSL_get1_session(c.ssl))
	if session == nil {
		return nil, errors.New("failed to get session")
	}
	defer C.SSL_SESSION_free(session)

	// get the size of the encoding
	slen := C.i2d_SSL_SESSION(session, nil)

	buf := (*C.uchar)(C.malloc(C.size_t(slen)))
	defer C.free(unsafe.Pointer(buf))

	// this modifies the value of buf (seriously), so we have to pass in a temp
	// var so that we can actually read the bytes from buf.
	tmp := buf
	slen2 := C.i2d_SSL_SESSION(session, &tmp)
	if slen != slen2 {
		return nil, errors.New("session had different lengths")
	}

	return C.GoBytes(unsafe.Pointer(buf), slen), nil
}

func (c *Conn) setSession(session []byte) error {
	runtime.LockOSThread()
	defer runtime.UnlockOSThread()

	ptr := (*C.uchar)(&session[0])
	s := C.d2i_SSL_SESSION(nil, &ptr, C.long(len(session)))
	if s == nil {
		return fmt.Errorf("unable to load session: %s", errorFromErrorQueue())
	}
	defer C.SSL_SESSION_free(s)

	ret := C.SSL_set_session(c.ssl, s)
	if ret != 1 {
		return fmt.Errorf("unable to set session: %s", errorFromErrorQueue())
	}
	return nil
}