diff --git a/utils/ws-client/src/main/kotlin/cn/tursom/ws/WebSocketClient.kt b/utils/ws-client/src/main/kotlin/cn/tursom/ws/WebSocketClient.kt index c0eae91..b6a4c68 100644 --- a/utils/ws-client/src/main/kotlin/cn/tursom/ws/WebSocketClient.kt +++ b/utils/ws-client/src/main/kotlin/cn/tursom/ws/WebSocketClient.kt @@ -18,13 +18,24 @@ import io.netty.handler.codec.http.HttpClientCodec import io.netty.handler.codec.http.HttpObjectAggregator import io.netty.handler.codec.http.websocketx.* import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler +import io.netty.handler.logging.LoggingHandler import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.ssl.util.InsecureTrustManagerFactory import java.net.URI -class WebSocketClient(uri: String, val handler: WebSocketHandler, val autoWrap: Boolean = true) { - private val uri: URI = URI.create(uri) +@Suppress("unused") +class WebSocketClient( + url: String, + val handler: WebSocketHandler, + val autoWrap: Boolean = true, + val log: Boolean = false, + val compressed: Boolean = true, + val maxContextLength: Int = 4096, + private val headers: Map? = null, + private val handshakerUri: URI? = null, +) { + private val uri: URI = URI.create(url) internal var ch: Channel? = null fun open() { @@ -54,29 +65,39 @@ class WebSocketClient(uri: String, val handler: WebSocketHandler, val autoWrap: } else { null } - - val handler = WebSocketClientChannelHandler( - WebSocketClientHandshakerFactory.newHandshaker( - uri, WebSocketVersion.V13, null, false, DefaultHttpHeaders() - ), this, handler - ) + val httpHeaders = DefaultHttpHeaders() + headers?.forEach { (k, v) -> + httpHeaders[k] = v + } + val handshakerAdapter = WebSocketClientHandshakerAdapter(WebSocketClientHandshakerFactory.newHandshaker( + handshakerUri ?: uri, WebSocketVersion.V13, null, true, httpHeaders + ), this, handler) + val handler = WebSocketClientChannelHandler(this, handler) val bootstrap = Bootstrap() bootstrap.group(group) .channel(NioSocketChannel::class.java) .handler(object : ChannelInitializer() { override fun initChannel(ch: SocketChannel) { - val pipeline = ch.pipeline() - if (sslCtx != null) { - pipeline.addLast(sslCtx.newHandler(ch.alloc(), host, port)) - } - pipeline.addLast( - HttpClientCodec(), - HttpObjectAggregator(4096), - WebSocketClientCompressionHandler.INSTANCE, - handler, - ) - if (autoWrap) { - pipeline.addLast(WebSocketFrameWrapper) + ch.pipeline().apply { + if (log) { + addLast(LoggingHandler()) + } + if (sslCtx != null) { + addLast(sslCtx.newHandler(ch.alloc(), host, port)) + } + addLast(HttpClientCodec()) + addLast(HttpObjectAggregator(maxContextLength)) + if (compressed) { + addLast(WebSocketClientCompressionHandler.INSTANCE) + } + addLast(handshakerAdapter) + //if (log) { + // addLast(LoggingHandler()) + //} + addLast(handler) + if (autoWrap) { + addLast(WebSocketFrameWrapper) + } } } }) @@ -84,8 +105,12 @@ class WebSocketClient(uri: String, val handler: WebSocketHandler, val autoWrap: //handler.handshakeFuture().sync() } - fun close() { - ch?.writeAndFlush(CloseWebSocketFrame()) + fun close(reasonText: String? = null) { + if (reasonText == null) { + ch?.writeAndFlush(CloseWebSocketFrame()) + } else { + ch?.writeAndFlush(CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, reasonText)) + } ch?.closeFuture()?.sync() } @@ -127,6 +152,44 @@ class WebSocketClient(uri: String, val handler: WebSocketHandler, val autoWrap: return ch!!.writeAndFlush(TextWebSocketFrame(data)) } + fun ping(data: ByteArray): ChannelFuture { + return ch!!.writeAndFlush(PingWebSocketFrame(Unpooled.wrappedBuffer(data))) + } + + fun ping(data: ByteBuffer): ChannelFuture { + return ch!!.writeAndFlush( + PingWebSocketFrame( + when (data) { + is NettyByteBuffer -> data.byteBuf + else -> Unpooled.wrappedBuffer(data.getBytes()) + } + ) + ) + } + + fun ping(data: ByteBuf): ChannelFuture { + return ch!!.writeAndFlush(PingWebSocketFrame(data)) + } + + fun pong(data: ByteArray): ChannelFuture { + return ch!!.writeAndFlush(PongWebSocketFrame(Unpooled.wrappedBuffer(data))) + } + + fun pong(data: ByteBuffer): ChannelFuture { + return ch!!.writeAndFlush( + PongWebSocketFrame( + when (data) { + is NettyByteBuffer -> data.byteBuf + else -> Unpooled.wrappedBuffer(data.getBytes()) + } + ) + ) + } + + fun pong(data: ByteBuf): ChannelFuture { + return ch!!.writeAndFlush(PongWebSocketFrame(data)) + } + companion object { private val group: EventLoopGroup = NioEventLoopGroup() } diff --git a/utils/ws-client/src/main/kotlin/cn/tursom/ws/WebSocketClientChannelHandler.kt b/utils/ws-client/src/main/kotlin/cn/tursom/ws/WebSocketClientChannelHandler.kt index c4e1fe2..106edc0 100644 --- a/utils/ws-client/src/main/kotlin/cn/tursom/ws/WebSocketClientChannelHandler.kt +++ b/utils/ws-client/src/main/kotlin/cn/tursom/ws/WebSocketClientChannelHandler.kt @@ -5,32 +5,14 @@ import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelPromise import io.netty.channel.SimpleChannelInboundHandler import io.netty.handler.codec.http.FullHttpResponse -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame -import io.netty.handler.codec.http.websocketx.TextWebSocketFrame -import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker +import io.netty.handler.codec.http.websocketx.* import io.netty.util.CharsetUtil class WebSocketClientChannelHandler( - private val handshaker: WebSocketClientHandshaker, val client: WebSocketClient, val handler: WebSocketHandler -) : SimpleChannelInboundHandler() { - private var handshakeFuture: ChannelPromise? = null - - fun handshakeFuture(): ChannelFuture? { - return handshakeFuture - } - - override fun handlerAdded(ctx: ChannelHandlerContext) { - handshakeFuture = ctx.newPromise() - } - - override fun channelActive(ctx: ChannelHandlerContext) { - client.ch = ctx.channel() - handshaker.handshake(ctx.channel()) - } +) : SimpleChannelInboundHandler() { override fun channelInactive(ctx: ChannelHandlerContext) { handler.onClose(client) @@ -39,34 +21,12 @@ class WebSocketClientChannelHandler( } } - override fun channelRead0(ctx: ChannelHandlerContext, msg: Any) { + override fun channelRead0(ctx: ChannelHandlerContext, msg: WebSocketFrame) { val ch = ctx.channel() - if (!handshaker.isHandshakeComplete) { - // web socket client connected - handshaker.finishHandshake(ch, msg as FullHttpResponse) - handshakeFuture!!.setSuccess() - handler.onOpen(client) - return - } - if (msg is FullHttpResponse) { - throw Exception("Unexpected FullHttpResponse (getStatus=${msg.status()}, content=${msg.content().toString(CharsetUtil.UTF_8)})") - } when (msg) { is TextWebSocketFrame -> handler.readMessage(client, msg) is BinaryWebSocketFrame -> handler.readMessage(client, msg) is CloseWebSocketFrame -> ch.close() } } - - override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { - try { - handler.onError(client, cause) - } catch (e: Exception) { - e.printStackTrace() - if (!handshakeFuture!!.isDone) { - handshakeFuture!!.setFailure(cause) - } - ctx.close() - } - } } \ No newline at end of file diff --git a/utils/ws-client/src/main/kotlin/cn/tursom/ws/WebSocketClientHandshakerAdapter.kt b/utils/ws-client/src/main/kotlin/cn/tursom/ws/WebSocketClientHandshakerAdapter.kt new file mode 100644 index 0000000..acc560a --- /dev/null +++ b/utils/ws-client/src/main/kotlin/cn/tursom/ws/WebSocketClientHandshakerAdapter.kt @@ -0,0 +1,52 @@ +package cn.tursom.ws + +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelPromise +import io.netty.channel.SimpleChannelInboundHandler +import io.netty.handler.codec.http.FullHttpResponse +import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker +import io.netty.util.CharsetUtil + +class WebSocketClientHandshakerAdapter( + private val handshaker: WebSocketClientHandshaker, + private val client: WebSocketClient, + private val handler: WebSocketHandler, +) : SimpleChannelInboundHandler() { + private var handshakeFuture: ChannelPromise? = null + + override fun handlerAdded(ctx: ChannelHandlerContext) { + handshakeFuture = ctx.newPromise() + } + + override fun channelActive(ctx: ChannelHandlerContext) { + client.ch = ctx.channel() + handshaker.handshake(ctx.channel()) + } + + override fun channelRead0(ctx: ChannelHandlerContext, msg: FullHttpResponse) { + if (!handshaker.isHandshakeComplete) { + handshaker.finishHandshake(ctx.channel(), msg) + handshakeFuture!!.setSuccess() + msg.retain() + ctx.fireChannelRead(msg) + handler.onOpen(client) + return + } else { + throw Exception("Unexpected FullHttpResponse (getStatus=${msg.status()}, content=${ + msg.content().toString(CharsetUtil.UTF_8) + })") + } + } + + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + try { + handler.onError(client, cause) + } catch (e: Exception) { + e.printStackTrace() + if (!handshakeFuture!!.isDone) { + handshakeFuture!!.setFailure(cause) + } + ctx.close() + } + } +} \ No newline at end of file