This commit is contained in:
tursom 2021-01-12 09:10:35 +08:00
parent 7a1371f16c
commit df633fc80d
3 changed files with 140 additions and 65 deletions

View File

@ -18,13 +18,24 @@ import io.netty.handler.codec.http.HttpClientCodec
import io.netty.handler.codec.http.HttpObjectAggregator
import io.netty.handler.codec.http.websocketx.*
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler
import io.netty.handler.logging.LoggingHandler
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.util.InsecureTrustManagerFactory
import java.net.URI
class WebSocketClient(uri: String, val handler: WebSocketHandler, val autoWrap: Boolean = true) {
private val uri: URI = URI.create(uri)
@Suppress("unused")
class WebSocketClient(
url: String,
val handler: WebSocketHandler,
val autoWrap: Boolean = true,
val log: Boolean = false,
val compressed: Boolean = true,
val maxContextLength: Int = 4096,
private val headers: Map<String, String>? = null,
private val handshakerUri: URI? = null,
) {
private val uri: URI = URI.create(url)
internal var ch: Channel? = null
fun open() {
@ -54,29 +65,39 @@ class WebSocketClient(uri: String, val handler: WebSocketHandler, val autoWrap:
} else {
null
}
val handler = WebSocketClientChannelHandler(
WebSocketClientHandshakerFactory.newHandshaker(
uri, WebSocketVersion.V13, null, false, DefaultHttpHeaders()
), this, handler
)
val httpHeaders = DefaultHttpHeaders()
headers?.forEach { (k, v) ->
httpHeaders[k] = v
}
val handshakerAdapter = WebSocketClientHandshakerAdapter(WebSocketClientHandshakerFactory.newHandshaker(
handshakerUri ?: uri, WebSocketVersion.V13, null, true, httpHeaders
), this, handler)
val handler = WebSocketClientChannelHandler(this, handler)
val bootstrap = Bootstrap()
bootstrap.group(group)
.channel(NioSocketChannel::class.java)
.handler(object : ChannelInitializer<SocketChannel>() {
override fun initChannel(ch: SocketChannel) {
val pipeline = ch.pipeline()
if (sslCtx != null) {
pipeline.addLast(sslCtx.newHandler(ch.alloc(), host, port))
}
pipeline.addLast(
HttpClientCodec(),
HttpObjectAggregator(4096),
WebSocketClientCompressionHandler.INSTANCE,
handler,
)
if (autoWrap) {
pipeline.addLast(WebSocketFrameWrapper)
ch.pipeline().apply {
if (log) {
addLast(LoggingHandler())
}
if (sslCtx != null) {
addLast(sslCtx.newHandler(ch.alloc(), host, port))
}
addLast(HttpClientCodec())
addLast(HttpObjectAggregator(maxContextLength))
if (compressed) {
addLast(WebSocketClientCompressionHandler.INSTANCE)
}
addLast(handshakerAdapter)
//if (log) {
// addLast(LoggingHandler())
//}
addLast(handler)
if (autoWrap) {
addLast(WebSocketFrameWrapper)
}
}
}
})
@ -84,8 +105,12 @@ class WebSocketClient(uri: String, val handler: WebSocketHandler, val autoWrap:
//handler.handshakeFuture().sync()
}
fun close() {
ch?.writeAndFlush(CloseWebSocketFrame())
fun close(reasonText: String? = null) {
if (reasonText == null) {
ch?.writeAndFlush(CloseWebSocketFrame())
} else {
ch?.writeAndFlush(CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, reasonText))
}
ch?.closeFuture()?.sync()
}
@ -127,6 +152,44 @@ class WebSocketClient(uri: String, val handler: WebSocketHandler, val autoWrap:
return ch!!.writeAndFlush(TextWebSocketFrame(data))
}
fun ping(data: ByteArray): ChannelFuture {
return ch!!.writeAndFlush(PingWebSocketFrame(Unpooled.wrappedBuffer(data)))
}
fun ping(data: ByteBuffer): ChannelFuture {
return ch!!.writeAndFlush(
PingWebSocketFrame(
when (data) {
is NettyByteBuffer -> data.byteBuf
else -> Unpooled.wrappedBuffer(data.getBytes())
}
)
)
}
fun ping(data: ByteBuf): ChannelFuture {
return ch!!.writeAndFlush(PingWebSocketFrame(data))
}
fun pong(data: ByteArray): ChannelFuture {
return ch!!.writeAndFlush(PongWebSocketFrame(Unpooled.wrappedBuffer(data)))
}
fun pong(data: ByteBuffer): ChannelFuture {
return ch!!.writeAndFlush(
PongWebSocketFrame(
when (data) {
is NettyByteBuffer -> data.byteBuf
else -> Unpooled.wrappedBuffer(data.getBytes())
}
)
)
}
fun pong(data: ByteBuf): ChannelFuture {
return ch!!.writeAndFlush(PongWebSocketFrame(data))
}
companion object {
private val group: EventLoopGroup = NioEventLoopGroup()
}

View File

@ -5,32 +5,14 @@ import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelPromise
import io.netty.channel.SimpleChannelInboundHandler
import io.netty.handler.codec.http.FullHttpResponse
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker
import io.netty.handler.codec.http.websocketx.*
import io.netty.util.CharsetUtil
class WebSocketClientChannelHandler(
private val handshaker: WebSocketClientHandshaker,
val client: WebSocketClient,
val handler: WebSocketHandler
) : SimpleChannelInboundHandler<Any>() {
private var handshakeFuture: ChannelPromise? = null
fun handshakeFuture(): ChannelFuture? {
return handshakeFuture
}
override fun handlerAdded(ctx: ChannelHandlerContext) {
handshakeFuture = ctx.newPromise()
}
override fun channelActive(ctx: ChannelHandlerContext) {
client.ch = ctx.channel()
handshaker.handshake(ctx.channel())
}
) : SimpleChannelInboundHandler<WebSocketFrame>() {
override fun channelInactive(ctx: ChannelHandlerContext) {
handler.onClose(client)
@ -39,34 +21,12 @@ class WebSocketClientChannelHandler(
}
}
override fun channelRead0(ctx: ChannelHandlerContext, msg: Any) {
override fun channelRead0(ctx: ChannelHandlerContext, msg: WebSocketFrame) {
val ch = ctx.channel()
if (!handshaker.isHandshakeComplete) {
// web socket client connected
handshaker.finishHandshake(ch, msg as FullHttpResponse)
handshakeFuture!!.setSuccess()
handler.onOpen(client)
return
}
if (msg is FullHttpResponse) {
throw Exception("Unexpected FullHttpResponse (getStatus=${msg.status()}, content=${msg.content().toString(CharsetUtil.UTF_8)})")
}
when (msg) {
is TextWebSocketFrame -> handler.readMessage(client, msg)
is BinaryWebSocketFrame -> handler.readMessage(client, msg)
is CloseWebSocketFrame -> ch.close()
}
}
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
try {
handler.onError(client, cause)
} catch (e: Exception) {
e.printStackTrace()
if (!handshakeFuture!!.isDone) {
handshakeFuture!!.setFailure(cause)
}
ctx.close()
}
}
}

View File

@ -0,0 +1,52 @@
package cn.tursom.ws
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelPromise
import io.netty.channel.SimpleChannelInboundHandler
import io.netty.handler.codec.http.FullHttpResponse
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker
import io.netty.util.CharsetUtil
class WebSocketClientHandshakerAdapter(
private val handshaker: WebSocketClientHandshaker,
private val client: WebSocketClient,
private val handler: WebSocketHandler,
) : SimpleChannelInboundHandler<FullHttpResponse>() {
private var handshakeFuture: ChannelPromise? = null
override fun handlerAdded(ctx: ChannelHandlerContext) {
handshakeFuture = ctx.newPromise()
}
override fun channelActive(ctx: ChannelHandlerContext) {
client.ch = ctx.channel()
handshaker.handshake(ctx.channel())
}
override fun channelRead0(ctx: ChannelHandlerContext, msg: FullHttpResponse) {
if (!handshaker.isHandshakeComplete) {
handshaker.finishHandshake(ctx.channel(), msg)
handshakeFuture!!.setSuccess()
msg.retain()
ctx.fireChannelRead(msg)
handler.onOpen(client)
return
} else {
throw Exception("Unexpected FullHttpResponse (getStatus=${msg.status()}, content=${
msg.content().toString(CharsetUtil.UTF_8)
})")
}
}
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
try {
handler.onError(client, cause)
} catch (e: Exception) {
e.printStackTrace()
if (!handshakeFuture!!.isDone) {
handshakeFuture!!.setFailure(cause)
}
ctx.close()
}
}
}