优化 SocketServer 结构

This commit is contained in:
tursom 2019-10-18 21:26:04 +08:00
parent 2dd2e8c5fb
commit 3b78a9c998
15 changed files with 690 additions and 754 deletions

View File

@ -21,22 +21,22 @@ data class Demo(
) )
fun main() { fun main() {
// 获取数据库访问协助对象 // 获取数据库访问协助对象
val helper = SQLiteHelper("demo.db") val helper = SQLiteHelper("demo.db")
// 插入数据 // 插入数据
helper.insert(Demo(name = "tursom")) helper.insert(Demo(name = "tursom"))
// 更新数据 // 更新数据
helper.update(Demo(name = "tursom", money = 100.0), where = ClauseMaker.make { helper.update(Demo(name = "tursom", money = 100.0), where = ClauseMaker.make {
!Demo::name equal "tursom" !Demo::name equal "tursom"
}) })
// 获取数据 // 获取数据
val data = helper.select(Demo::class.java, where = ClauseMaker { val data = helper.select(Demo::class.java, where = ClauseMaker {
(!Demo::id greaterThan !0) and (!Demo::id lessThan !10) (!Demo::id greaterThan !0) and (!Demo::id lessThan !10)
}) })
// 删除数据 // 删除数据
helper.delete(Demo::class.java.tableName, where = ClauseMaker.make { !Demo::name equal "tursom" }) helper.delete(Demo::class.java.tableName, where = ClauseMaker.make { !Demo::name equal "tursom" })
} }

View File

@ -18,246 +18,204 @@ import kotlin.coroutines.suspendCoroutine
* 但是对于一般的应用而言是足够使用的 * 但是对于一般的应用而言是足够使用的
*/ */
class AsyncNioSocket(override val key: SelectionKey, override val nioThread: INioThread) : IAsyncNioSocket { class AsyncNioSocket(override val key: SelectionKey, override val nioThread: INioThread) : IAsyncNioSocket {
override val channel: SocketChannel = key.channel() as SocketChannel override val channel: SocketChannel = key.channel() as SocketChannel
override suspend fun read(buffer: ByteBuffer): Int { private suspend inline fun <T> operate(crossinline action: (Continuation<T>) -> Unit): T {
if (buffer.remaining() == 0) return -1 return try {
return try { suspendCoroutine {
suspendCoroutine { action(it)
key.attach(SingleContext(buffer, it)) }
readMode() } catch (e: Exception) {
nioThread.wakeup() waitMode()
} throw RuntimeException(e)
} catch (e: Exception) { }
waitMode() }
throw RuntimeException(e)
override suspend fun read(buffer: ByteBuffer): Int {
if (buffer.remaining() == 0) return emptyBufferCode
return operate {
key.attach(SingleContext(buffer, it))
readMode()
nioThread.wakeup()
}
}
override suspend fun read(buffer: Array<out ByteBuffer>): Long {
if (buffer.isEmpty()) return emptyBufferLongCode
return operate {
key.attach(MultiContext(buffer, it))
readMode()
nioThread.wakeup()
}
}
override suspend fun write(buffer: ByteBuffer): Int {
if (buffer.remaining() == 0) return emptyBufferCode
return operate {
key.attach(SingleContext(buffer, it))
writeMode()
nioThread.wakeup()
}
}
override suspend fun write(buffer: Array<out ByteBuffer>): Long {
if (buffer.isEmpty()) return emptyBufferLongCode
return operate {
key.attach(MultiContext(buffer, it))
writeMode()
nioThread.wakeup()
}
}
override suspend fun read(buffer: ByteBuffer, timeout: Long): Int {
if (timeout <= 0) return read(buffer)
if (buffer.remaining() == 0) return emptyBufferCode
return operate {
key.attach(
SingleContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
})
)
readMode()
nioThread.wakeup()
}
}
override suspend fun read(buffer: Array<out ByteBuffer>, timeout: Long): Long {
if (timeout <= 0) return read(buffer)
if (buffer.isEmpty()) return emptyBufferLongCode
return operate {
key.attach(
MultiContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
})
)
readMode()
nioThread.wakeup()
}
}
override suspend fun write(buffer: ByteBuffer, timeout: Long): Int {
if (timeout <= 0) return write(buffer)
if (buffer.remaining() == 0) return emptyBufferCode
return operate {
key.attach(
SingleContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
})
)
writeMode()
nioThread.wakeup()
}
}
override suspend fun write(buffer: Array<out ByteBuffer>, timeout: Long): Long {
if (timeout <= 0) return write(buffer)
if (buffer.isEmpty()) return emptyBufferLongCode
return operate {
key.attach(
MultiContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
})
)
writeMode()
nioThread.wakeup()
}
}
override fun close() {
nioThread.execute {
channel.close()
key.cancel()
}
}
interface Context {
val cont: Continuation<*>
val timeoutTask: TimerTask? get() = null
}
class SingleContext(
val buffer: ByteBuffer,
override val cont: Continuation<Int>,
override val timeoutTask: TimerTask? = null
) : Context
class MultiContext(
val buffer: Array<out ByteBuffer>,
override val cont: Continuation<Long>,
override val timeoutTask: TimerTask? = null
) : Context
companion object {
val nioSocketProtocol = object : INioProtocol {
override fun handleConnect(key: SelectionKey, nioThread: INioThread) {}
override fun handleRead(key: SelectionKey, nioThread: INioThread) {
key.interestOps(0)
val context = key.attachment() as Context? ?: return
context.timeoutTask?.cancel()
if (context is SingleContext) {
val channel = key.channel() as SocketChannel
val readSize = channel.read(context.buffer)
context.cont.resume(readSize)
} else {
context as MultiContext
val channel = key.channel() as SocketChannel
val readSize = channel.read(context.buffer)
context.cont.resume(readSize)
} }
} }
override suspend fun read(buffer: Array<out ByteBuffer>): Long { override fun handleWrite(key: SelectionKey, nioThread: INioThread) {
if (buffer.size == 0) return -1 key.interestOps(0)
return try { val context = key.attachment() as Context? ?: return
suspendCoroutine { context.timeoutTask?.cancel()
key.attach(MultiContext(buffer, it)) if (context is SingleContext) {
readMode() val channel = key.channel() as SocketChannel
nioThread.wakeup() val readSize = channel.write(context.buffer)
} context.cont.resume(readSize)
} catch (e: Exception) { } else {
waitMode() context as MultiContext
throw RuntimeException(e) val channel = key.channel() as SocketChannel
val readSize = channel.write(context.buffer)
context.cont.resume(readSize)
} }
} }
override suspend fun write(buffer: ByteBuffer): Int { override fun exceptionCause(key: SelectionKey, nioThread: INioThread, e: Throwable) {
if (buffer.remaining() == 0) return -1 key.interestOps(0)
return try { val context = key.attachment() as Context?
suspendCoroutine { if (context != null)
key.attach(SingleContext(buffer, it)) context.cont.resumeWithException(e)
writeMode() else {
nioThread.wakeup() key.cancel()
} key.channel().close()
} catch (e: Exception) { e.printStackTrace()
waitMode()
throw Exception(e)
} }
}
} }
override suspend fun write(buffer: Array<out ByteBuffer>): Long { //val timer = StaticWheelTimer.timer
if (buffer.isEmpty()) return -1 val timer = WheelTimer.timer
return try {
suspendCoroutine {
key.attach(MultiContext(buffer, it))
writeMode()
nioThread.wakeup()
}
} catch (e: Exception) {
waitMode()
throw Exception(e)
}
}
override suspend fun read(buffer: ByteBuffer, timeout: Long): Int { const val emptyBufferCode = 0
if (timeout <= 0) return read(buffer) const val emptyBufferLongCode = 0L
if (buffer.remaining() == 0) return -1 }
return try {
val result: Int = suspendCoroutine {
key.attach(
SingleContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
try {
it.resumeWithException(TimeoutException())
} catch (e: Exception) {
}
})
)
readMode()
nioThread.wakeup()
}
result
} catch (e: Exception) {
waitMode()
throw RuntimeException(e)
}
}
override suspend fun read(buffer: Array<out ByteBuffer>, timeout: Long): Long {
if (timeout <= 0) return read(buffer)
if (buffer.isEmpty()) return -1
return try {
val result: Long = suspendCoroutine {
key.attach(
MultiContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
try {
it.resumeWithException(TimeoutException())
} catch (e: Exception) {
}
})
)
readMode()
nioThread.wakeup()
}
result
} catch (e: Exception) {
waitMode()
throw Exception(e)
}
}
override suspend fun write(buffer: ByteBuffer, timeout: Long): Int {
if (timeout <= 0) return write(buffer)
if (buffer.remaining() == 0) return -1
return try {
val result: Int = suspendCoroutine {
key.attach(
SingleContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
try {
it.resumeWithException(TimeoutException())
} catch (e: Exception) {
}
})
)
writeMode()
nioThread.wakeup()
}
result
} catch (e: Exception) {
waitMode()
throw Exception(e)
}
}
override suspend fun write(buffer: Array<out ByteBuffer>, timeout: Long): Long {
if (timeout <= 0) return write(buffer)
if (buffer.isEmpty()) return -1
return try {
val result: Long = suspendCoroutine {
key.attach(
MultiContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
try {
it.resumeWithException(TimeoutException())
} catch (e: Exception) {
}
})
)
writeMode()
nioThread.wakeup()
}
result
} catch (e: Exception) {
waitMode()
throw Exception(e)
}
}
override fun close() {
nioThread.execute {
channel.close()
key.cancel()
}
}
interface Context {
val cont: Continuation<*>
val timeoutTask: TimerTask? get() = null
}
class SingleContext(
val buffer: ByteBuffer,
override val cont: Continuation<Int>,
override val timeoutTask: TimerTask? = null
) : Context
class MultiContext(
val buffer: Array<out ByteBuffer>,
override val cont: Continuation<Long>,
override val timeoutTask: TimerTask? = null
) : Context
companion object {
val nioSocketProtocol = object : INioProtocol {
override fun handleConnect(key: SelectionKey, nioThread: INioThread) {}
override fun handleRead(key: SelectionKey, nioThread: INioThread) {
key.interestOps(0)
val context = key.attachment() as Context? ?: return
context.timeoutTask?.cancel()
if (context is SingleContext) {
val channel = key.channel() as SocketChannel
val readSize = channel.read(context.buffer)
context.cont.resume(readSize)
} else {
context as MultiContext
val channel = key.channel() as SocketChannel
val readSize = channel.read(context.buffer)
context.cont.resume(readSize)
}
}
override fun handleWrite(key: SelectionKey, nioThread: INioThread) {
key.interestOps(0)
val context = key.attachment() as Context? ?: return
context.timeoutTask?.cancel()
if (context is SingleContext) {
val channel = key.channel() as SocketChannel
val readSize = channel.write(context.buffer)
context.cont.resume(readSize)
} else {
context as MultiContext
val channel = key.channel() as SocketChannel
val readSize = channel.write(context.buffer)
context.cont.resume(readSize)
}
}
override fun exceptionCause(key: SelectionKey, nioThread: INioThread, e: Throwable) {
key.interestOps(0)
val context = key.attachment() as Context?
if (context != null)
context.cont.resumeWithException(e)
else {
key.cancel()
key.channel().close()
e.printStackTrace()
}
}
}
//val timer = StaticWheelTimer.timer
val timer = WheelTimer.timer
}
} }

View File

@ -9,82 +9,87 @@ import java.nio.channels.SelectionKey
import java.nio.channels.SocketChannel import java.nio.channels.SocketChannel
interface IAsyncNioSocket : AsyncSocket { interface IAsyncNioSocket : AsyncSocket {
val channel: SocketChannel val channel: SocketChannel
val key: SelectionKey val key: SelectionKey
val nioThread: INioThread val nioThread: INioThread
fun waitMode() { fun waitMode() {
if (Thread.currentThread() == nioThread.thread) { if (Thread.currentThread() == nioThread.thread) {
if (key.isValid) key.interestOps(SelectionKey.OP_WRITE) if (key.isValid) key.interestOps(SelectionKey.OP_WRITE)
} else { } else {
nioThread.execute { if (key.isValid) key.interestOps(0) } nioThread.execute { if (key.isValid) key.interestOps(0) }
nioThread.wakeup() nioThread.wakeup()
} }
} }
fun readMode() { fun readMode() {
if (Thread.currentThread() == nioThread.thread) { if (Thread.currentThread() == nioThread.thread) {
if (key.isValid) key.interestOps(SelectionKey.OP_WRITE) if (key.isValid) key.interestOps(SelectionKey.OP_WRITE)
} else { } else {
nioThread.execute { if (key.isValid) key.interestOps(SelectionKey.OP_READ) } nioThread.execute { if (key.isValid) key.interestOps(SelectionKey.OP_READ) }
nioThread.wakeup() nioThread.wakeup()
} }
} }
fun writeMode() { fun writeMode() {
if (Thread.currentThread() == nioThread.thread) { if (Thread.currentThread() == nioThread.thread) {
if (key.isValid) key.interestOps(SelectionKey.OP_WRITE) if (key.isValid) key.interestOps(SelectionKey.OP_WRITE)
} else { } else {
nioThread.execute { if (key.isValid) key.interestOps(SelectionKey.OP_WRITE) } nioThread.execute { if (key.isValid) key.interestOps(SelectionKey.OP_WRITE) }
nioThread.wakeup() nioThread.wakeup()
} }
} }
suspend fun read(buffer: ByteBuffer): Int = read(arrayOf(buffer)).toInt() suspend fun read(buffer: ByteBuffer): Int = read(arrayOf(buffer)).toInt()
suspend fun write(buffer: ByteBuffer): Int = write(arrayOf(buffer)).toInt() suspend fun write(buffer: ByteBuffer): Int = write(arrayOf(buffer)).toInt()
suspend fun read(buffer: Array<out ByteBuffer>): Long suspend fun read(buffer: Array<out ByteBuffer>): Long
suspend fun write(buffer: Array<out ByteBuffer>): Long suspend fun write(buffer: Array<out ByteBuffer>): Long
/** /**
* 如果通道已断开则会抛出异常 * 如果通道已断开则会抛出异常
*/ */
suspend fun recv(buffer: ByteBuffer): Int { suspend fun recv(buffer: ByteBuffer): Int {
if (buffer.remaining() == 0) return 0 if (buffer.remaining() == 0) return emptyBufferCode
val readSize = read(buffer) val readSize = read(buffer)
if (readSize < 0) { if (readSize < 0) {
throw SocketException("channel closed") throw SocketException("channel closed")
} }
return readSize return readSize
} }
suspend fun recv(buffer: ByteBuffer, timeout: Long): Int { suspend fun recv(buffer: ByteBuffer, timeout: Long): Int {
if (buffer.remaining() == 0) return 0 if (buffer.remaining() == 0) return emptyBufferCode
val readSize = read(buffer, timeout) val readSize = read(buffer, timeout)
if (readSize < 0) { if (readSize < 0) {
throw SocketException("channel closed") throw SocketException("channel closed")
} }
return readSize return readSize
} }
suspend fun recv(buffers: Array<out ByteBuffer>, timeout: Long): Long { suspend fun recv(buffers: Array<out ByteBuffer>, timeout: Long): Long {
if (buffers.isEmpty()) return 0 if (buffers.isEmpty()) return emptyBufferLongCode
val readSize = read(buffers, timeout) val readSize = read(buffers, timeout)
if (readSize < 0) { if (readSize < 0) {
throw SocketException("channel closed") throw SocketException("channel closed")
} }
return readSize return readSize
} }
suspend fun recv(buffer: AdvanceByteBuffer, timeout: Long = 0): Int { suspend fun recv(buffer: AdvanceByteBuffer, timeout: Long = 0): Int {
return if (buffer.bufferCount == 1) { return if (buffer.bufferCount == 1) {
buffer.writeNioBuffer { buffer.writeNioBuffer {
recv(it, timeout) recv(it, timeout)
} }
} else { } else {
val readMode = buffer.readMode val readMode = buffer.readMode
buffer.resumeWriteMode() buffer.resumeWriteMode()
val value = recv(buffer.nioBuffers, timeout).toInt() val value = recv(buffer.nioBuffers, timeout).toInt()
if (readMode) buffer.readMode() if (readMode) buffer.readMode()
value value
} }
} }
companion object {
const val emptyBufferCode = 0
const val emptyBufferLongCode = 0L
}
} }

View File

@ -9,33 +9,33 @@ import java.nio.channels.SelectionKey
/** /**
* 有多个工作线程的协程套接字服务器 * 有多个工作线程的协程套接字服务器
* 不过因为结构复杂所以性能实际上比多线程的 ProtocolAsyncNioServer * 不过因为结构复杂所以性能一般比单个工作线程的 AsyncNioServer
*/ */
@Suppress("MemberVisibilityCanBePrivate") @Suppress("MemberVisibilityCanBePrivate")
class AsyncGroupNioServer( class AsyncGroupNioServer(
val port: Int, override val port: Int,
val threads: Int = Runtime.getRuntime().availableProcessors(), val threads: Int = Runtime.getRuntime().availableProcessors(),
backlog: Int = 50, backlog: Int = 50,
val handler: suspend AsyncNioSocket.() -> Unit val handler: suspend AsyncNioSocket.() -> Unit
) : ISocketServer by GroupNioServer( ) : ISocketServer by GroupNioServer(
port, port,
threads, threads,
object : INioProtocol by AsyncNioSocket.nioSocketProtocol { object : INioProtocol by AsyncNioSocket.nioSocketProtocol {
override fun handleConnect(key: SelectionKey, nioThread: INioThread) { override fun handleConnect(key: SelectionKey, nioThread: INioThread) {
GlobalScope.launch { GlobalScope.launch {
val socket = AsyncNioSocket(key, nioThread) val socket = AsyncNioSocket(key, nioThread)
try { try {
socket.handler() socket.handler()
} catch (e: Exception) { } catch (e: Exception) {
e.printStackTrace() e.printStackTrace()
} finally { } finally {
try { try {
nioThread.execute { socket.close() } nioThread.execute { socket.close() }
} catch (e: Exception) { } catch (e: Exception) {
} }
} }
} }
} }
}, },
backlog backlog
) )

View File

@ -9,43 +9,41 @@ import java.nio.channels.SelectionKey
/** /**
* 只有一个工作线程的协程套接字服务器 * 只有一个工作线程的协程套接字服务器
* 不过因为结构更加简单所以性能实际上比多线程的 ProtocolGroupAsyncNioServer * 不过因为结构更加简单所以性能一般比多个工作线程的 ProtocolGroupAsyncNioServer
* 而且协程是天生多线程并不需要太多的接受线程来处理所以一般只需要用本服务器即可 * 而且协程是天生多线程并不需要太多的接受线程来处理所以一般只需要用本服务器即可
*/ */
class AsyncNioServer( class AsyncNioServer(
val port: Int, override val port: Int,
backlog: Int = 50, backlog: Int = 50,
val handler: suspend AsyncNioSocket.() -> Unit val handler: suspend AsyncNioSocket.() -> Unit
) : ISocketServer by NioServer(port, object : INioProtocol by AsyncNioSocket.nioSocketProtocol { ) : ISocketServer by NioServer(port, object : INioProtocol by AsyncNioSocket.nioSocketProtocol {
override fun handleConnect(key: SelectionKey, nioThread: INioThread) { override fun handleConnect(key: SelectionKey, nioThread: INioThread) {
GlobalScope.launch { GlobalScope.launch {
val socket = AsyncNioSocket(key, nioThread) val socket = AsyncNioSocket(key, nioThread)
try { try {
socket.handler() socket.handler()
} catch (e: Exception) { } catch (e: Exception) {
e.printStackTrace() e.printStackTrace()
} finally { } finally {
try { try {
socket.close() socket.close()
} catch (e: Exception) { } catch (e: Exception) {
}
}
} }
}
} }
}
}, backlog) { }, backlog) {
/** /**
* 次要构造方法为使用Spring的同学们准备的 * 次要构造方法为使用Spring的同学们准备的
*/ */
constructor( constructor(
port: Int, port: Int,
backlog: Int = 50, backlog: Int = 50,
handler: Handler handler: Handler
) : this(port, backlog, { ) : this(port, backlog, { handler.handle(this) })
handler.handle(this)
})
interface Handler { interface Handler {
fun handle(socket: AsyncNioSocket) fun handle(socket: AsyncNioSocket)
} }
} }

View File

@ -0,0 +1,51 @@
package cn.tursom.socket.server
import cn.tursom.socket.AsyncAioSocket
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import java.io.Closeable
import java.net.InetSocketAddress
import java.nio.channels.AsynchronousCloseException
import java.nio.channels.AsynchronousServerSocketChannel
import java.nio.channels.AsynchronousSocketChannel
import java.nio.channels.CompletionHandler
class AsyncSocketServer(
override val port: Int,
host: String = "0.0.0.0",
private val handler: suspend AsyncAioSocket.() -> Unit
) : ISocketServer {
private val server = AsynchronousServerSocketChannel
.open()
.bind(InetSocketAddress(host, port))
override fun run() {
server.accept(0, object : CompletionHandler<AsynchronousSocketChannel, Int> {
override fun completed(result: AsynchronousSocketChannel?, attachment: Int) {
try {
server.accept(attachment + 1, this)
} catch (e: Throwable) {
e.printStackTrace()
}
result ?: return
GlobalScope.launch {
AsyncAioSocket(result).handler()
}
}
override fun failed(exc: Throwable?, attachment: Int?) {
when (exc) {
is AsynchronousCloseException -> {
}
else -> exc?.printStackTrace()
}
}
})
}
override fun close() {
server.close()
}
}

View File

@ -1,7 +0,0 @@
package cn.tursom.socket.server.async
import java.io.Closeable
interface AsyncServer : Runnable, Closeable {
val port: Int
}

View File

@ -1,51 +0,0 @@
package cn.tursom.socket.server.async
import cn.tursom.socket.AsyncAioSocket
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import java.io.Closeable
import java.net.InetSocketAddress
import java.nio.channels.AsynchronousCloseException
import java.nio.channels.AsynchronousServerSocketChannel
import java.nio.channels.AsynchronousSocketChannel
import java.nio.channels.CompletionHandler
class AsyncSocketServer(
port: Int,
host: String = "0.0.0.0",
private val handler: suspend AsyncAioSocket.() -> Unit
) : Runnable, Closeable {
private val server = AsynchronousServerSocketChannel
.open()
.bind(InetSocketAddress(host, port))
override fun run() {
server.accept(0, object : CompletionHandler<AsynchronousSocketChannel, Int> {
override fun completed(result: AsynchronousSocketChannel?, attachment: Int) {
try {
server.accept(attachment + 1, this)
} catch (e: Throwable) {
e.printStackTrace()
}
result ?: return
GlobalScope.launch {
AsyncAioSocket(result).handler()
}
}
override fun failed(exc: Throwable?, attachment: Int?) {
when (exc) {
is AsynchronousCloseException -> {
}
else -> exc?.printStackTrace()
}
}
})
}
override fun close() {
server.close()
}
}

View File

@ -16,89 +16,89 @@ import java.util.concurrent.LinkedBlockingDeque
*/ */
@Suppress("MemberVisibilityCanBePrivate") @Suppress("MemberVisibilityCanBePrivate")
class GroupNioServer( class GroupNioServer(
val port: Int, override val port: Int,
val threads: Int = Runtime.getRuntime().availableProcessors(), val threads: Int = Runtime.getRuntime().availableProcessors(),
private val protocol: INioProtocol, private val protocol: INioProtocol,
backlog: Int = 50, backlog: Int = 50,
val nioThreadGenerator: ( val nioThreadGenerator: (
threadName: String, threadName: String,
threads: Int, threads: Int,
worker: (thread: INioThread) -> Unit worker: (thread: INioThread) -> Unit
) -> IWorkerGroup = { name, _, worker -> ) -> IWorkerGroup = { name, _, worker ->
ThreadPoolWorkerGroup(threads, name, false, worker) ThreadPoolWorkerGroup(threads, name, false, worker)
} }
) : ISocketServer { ) : ISocketServer {
private val listenChannel = ServerSocketChannel.open() private val listenChannel = ServerSocketChannel.open()
private val listenThreads = LinkedBlockingDeque<INioThread>() private val listenThreads = LinkedBlockingDeque<INioThread>()
private val workerGroupList = LinkedBlockingDeque<IWorkerGroup>() private val workerGroupList = LinkedBlockingDeque<IWorkerGroup>()
init { init {
listenChannel.socket().bind(InetSocketAddress(port), backlog) listenChannel.socket().bind(InetSocketAddress(port), backlog)
listenChannel.configureBlocking(false) listenChannel.configureBlocking(false)
} }
override fun run() { override fun run() {
val workerGroup = nioThreadGenerator( val workerGroup = nioThreadGenerator(
"nioWorkerGroup", threads, "nioWorkerGroup", threads,
NioServer.LoopHandler(protocol)::handle NioServer.LoopHandler(protocol)::handle
) )
workerGroupList.add(workerGroup) workerGroupList.add(workerGroup)
val nioThread = ThreadPoolNioThread("nioAccepter") { nioThread -> val nioThread = ThreadPoolNioThread("nioAccepter") { nioThread ->
val selector = nioThread.selector val selector = nioThread.selector
if (selector.isOpen) { if (selector.isOpen) {
forEachKey(selector) { key -> forEachKey(selector) { key ->
try { try {
when { when {
key.isAcceptable -> { key.isAcceptable -> {
val serverChannel = key.channel() as ServerSocketChannel val serverChannel = key.channel() as ServerSocketChannel
var channel = serverChannel.accept() var channel = serverChannel.accept()
while (channel != null) { while (channel != null) {
channel.configureBlocking(false) channel.configureBlocking(false)
workerGroup.register(channel) { (key, thread) -> workerGroup.register(channel) { (key, thread) ->
protocol.handleConnect(key, thread) protocol.handleConnect(key, thread)
} }
channel = serverChannel.accept() channel = serverChannel.accept()
}
}
}
} catch (e: Throwable) {
try {
protocol.exceptionCause(key, nioThread, e)
} catch (e1: Throwable) {
e.printStackTrace()
e1.printStackTrace()
key.cancel()
key.channel().close()
}
}
nioThread.execute(this)
} }
} }
} }
listenThreads.add(nioThread) } catch (e: Throwable) {
listenChannel.register(nioThread.selector, SelectionKey.OP_ACCEPT) try {
nioThread.wakeup() protocol.exceptionCause(key, nioThread, e)
} } catch (e1: Throwable) {
e.printStackTrace()
e1.printStackTrace()
key.cancel()
key.channel().close()
}
}
nioThread.execute(this)
}
}
}
listenThreads.add(nioThread)
listenChannel.register(nioThread.selector, SelectionKey.OP_ACCEPT)
nioThread.wakeup()
}
override fun close() { override fun close() {
listenChannel.close() listenChannel.close()
listenThreads.forEach { it.close() } listenThreads.forEach { it.close() }
workerGroupList.forEach { it.close() } workerGroupList.forEach { it.close() }
} }
companion object { companion object {
const val TIMEOUT = 3000L const val TIMEOUT = 3000L
inline fun forEachKey(selector: Selector, action: (key: SelectionKey) -> Unit) { inline fun forEachKey(selector: Selector, action: (key: SelectionKey) -> Unit) {
if (selector.select(TIMEOUT) != 0) { if (selector.select(TIMEOUT) != 0) {
val keyIter = selector.selectedKeys().iterator() val keyIter = selector.selectedKeys().iterator()
while (keyIter.hasNext()) run whileBlock@{ while (keyIter.hasNext()) run whileBlock@{
val key = keyIter.next() val key = keyIter.next()
keyIter.remove() keyIter.remove()
action(key) action(key)
} }
} }
} }
} }
} }

View File

@ -2,4 +2,6 @@ package cn.tursom.socket.server
import java.io.Closeable import java.io.Closeable
interface ISocketServer : Runnable, Closeable interface ISocketServer : Runnable, Closeable {
val port: Int
}

View File

@ -1,5 +1,6 @@
package cn.tursom.socket.server package cn.tursom.socket.server
import cn.tursom.core.cpuNumber
import cn.tursom.socket.BaseSocket import cn.tursom.socket.BaseSocket
import java.net.ServerSocket import java.net.ServerSocket
@ -7,46 +8,47 @@ class MultithreadingSocketServer(
private val serverSocket: ServerSocket, private val serverSocket: ServerSocket,
private val threadNumber: Int = cpuNumber, private val threadNumber: Int = cpuNumber,
val exception: Exception.() -> Unit = { val exception: Exception.() -> Unit = {
printStackTrace() printStackTrace()
}, },
handler: BaseSocket.() -> Unit override val handler: BaseSocket.() -> Unit
) : SocketServer(handler) { ) : SocketServer {
override val port = serverSocket.localPort
constructor( constructor(
port: Int, port: Int,
threadNumber: Int = cpuNumber, threadNumber: Int = cpuNumber,
exception: Exception.() -> Unit = { exception: Exception.() -> Unit = {
printStackTrace() printStackTrace()
}, },
handler: BaseSocket.() -> Unit handler: BaseSocket.() -> Unit
) : this(ServerSocket(port), threadNumber, exception, handler) ) : this(ServerSocket(port), threadNumber, exception, handler)
constructor( constructor(
port: Int, port: Int,
handler: BaseSocket.() -> Unit handler: BaseSocket.() -> Unit
) : this(port, cpuNumber, { printStackTrace() }, handler) ) : this(port, cpuNumber, { printStackTrace() }, handler)
private val threadList = ArrayList<Thread>() private val threadList = ArrayList<Thread>()
override fun run() { override fun run() {
for (i in 1..threadNumber) { for (i in 1..threadNumber) {
val thread = Thread { val thread = Thread {
while (true) { while (true) {
serverSocket.accept().use { serverSocket.accept().use {
try { try {
BaseSocket(it).handler() BaseSocket(it).handler()
} catch (e: Exception) { } catch (e: Exception) {
e.exception() e.exception()
}
}
}
} }
thread.start() }
threadList.add(thread)
} }
}
thread.start()
threadList.add(thread)
} }
}
override fun close() { override fun close() {
serverSocket.close() serverSocket.close()
} }
} }

View File

@ -12,86 +12,86 @@ import java.util.concurrent.ConcurrentLinkedDeque
* 工作在单线程上的 Nio 服务器 * 工作在单线程上的 Nio 服务器
*/ */
class NioServer( class NioServer(
val port: Int, override val port: Int,
private val protocol: INioProtocol, private val protocol: INioProtocol,
backLog: Int = 50, backLog: Int = 50,
val nioThreadGenerator: (threadName: String, workLoop: (thread: INioThread) -> Unit) -> INioThread val nioThreadGenerator: (threadName: String, workLoop: (thread: INioThread) -> Unit) -> INioThread
) : ISocketServer { ) : ISocketServer {
private val listenChannel = ServerSocketChannel.open() private val listenChannel = ServerSocketChannel.open()
private val threadList = ConcurrentLinkedDeque<INioThread>() private val threadList = ConcurrentLinkedDeque<INioThread>()
init { init {
listenChannel.socket().bind(InetSocketAddress(port), backLog) listenChannel.socket().bind(InetSocketAddress(port), backLog)
listenChannel.configureBlocking(false) listenChannel.configureBlocking(false)
}
constructor(
port: Int,
protocol: INioProtocol,
backLog: Int = 50
) : this(port, protocol, backLog, { name, workLoop ->
WorkerLoopNioThread(name, workLoop = workLoop, isDaemon = false)
})
override fun run() {
val nioThread = nioThreadGenerator("nio worker", LoopHandler(protocol)::handle)
nioThread.register(listenChannel, SelectionKey.OP_ACCEPT) {}
threadList.add(nioThread)
}
override fun close() {
listenChannel.close()
threadList.forEach {
it.close()
} }
}
constructor( class LoopHandler(val protocol: INioProtocol) {
port: Int, fun handle(nioThread: INioThread) {
protocol: INioProtocol, val selector = nioThread.selector
backLog: Int = 50 if (selector.isOpen) {
) : this(port, protocol, backLog, { name, workLoop -> if (selector.select(TIMEOUT) != 0) {
WorkerLoopNioThread(name, workLoop = workLoop, isDaemon = false) val keyIter = selector.selectedKeys().iterator()
}) while (keyIter.hasNext()) run whileBlock@{
val key = keyIter.next()
override fun run() { keyIter.remove()
val nioThread = nioThreadGenerator("nio worker", LoopHandler(protocol)::handle) try {
nioThread.register(listenChannel, SelectionKey.OP_ACCEPT) {} when {
threadList.add(nioThread) key.isAcceptable -> {
} val serverChannel = key.channel() as ServerSocketChannel
var channel = serverChannel.accept()
override fun close() { while (channel != null) {
listenChannel.close() channel.configureBlocking(false)
threadList.forEach { nioThread.register(channel, 0) {
it.close() protocol.handleConnect(it, nioThread)
}
}
class LoopHandler(val protocol: INioProtocol) {
fun handle(nioThread: INioThread) {
val selector = nioThread.selector
if (selector.isOpen) {
if (selector.select(TIMEOUT) != 0) {
val keyIter = selector.selectedKeys().iterator()
while (keyIter.hasNext()) run whileBlock@{
val key = keyIter.next()
keyIter.remove()
try {
when {
key.isAcceptable -> {
val serverChannel = key.channel() as ServerSocketChannel
var channel = serverChannel.accept()
while (channel != null) {
channel.configureBlocking(false)
nioThread.register(channel, 0) {
protocol.handleConnect(it, nioThread)
}
channel = serverChannel.accept()
}
}
key.isReadable -> {
protocol.handleRead(key, nioThread)
}
key.isWritable -> {
protocol.handleWrite(key, nioThread)
}
}
} catch (e: Throwable) {
try {
protocol.exceptionCause(key, nioThread, e)
} catch (e1: Throwable) {
e.printStackTrace()
e1.printStackTrace()
key.cancel()
key.channel().close()
}
}
} }
channel = serverChannel.accept()
}
} }
key.isReadable -> {
protocol.handleRead(key, nioThread)
}
key.isWritable -> {
protocol.handleWrite(key, nioThread)
}
}
} catch (e: Throwable) {
try {
protocol.exceptionCause(key, nioThread, e)
} catch (e1: Throwable) {
e.printStackTrace()
e1.printStackTrace()
key.cancel()
key.channel().close()
}
} }
}
} }
}
} }
}
companion object { companion object {
private const val TIMEOUT = 1000L private const val TIMEOUT = 1000L
} }
} }

View File

@ -7,45 +7,46 @@ import java.net.SocketException
class SingleThreadSocketServer( class SingleThreadSocketServer(
private val serverSocket: ServerSocket, private val serverSocket: ServerSocket,
val exception: Exception.() -> Unit = { printStackTrace() }, val exception: Exception.() -> Unit = { printStackTrace() },
handler: BaseSocket.() -> Unit override val handler: BaseSocket.() -> Unit
) : SocketServer(handler) { ) : SocketServer {
override val port = serverSocket.localPort
constructor( constructor(
port: Int, port: Int,
exception: Exception.() -> Unit = { printStackTrace() }, exception: Exception.() -> Unit = { printStackTrace() },
handler: BaseSocket.() -> Unit handler: BaseSocket.() -> Unit
) : this(ServerSocket(port), exception, handler) ) : this(ServerSocket(port), exception, handler)
constructor( constructor(
port: Int, port: Int,
handler: BaseSocket.() -> Unit handler: BaseSocket.() -> Unit
) : this(port, { printStackTrace() }, handler) ) : this(port, { printStackTrace() }, handler)
override fun run() { override fun run() {
while (!serverSocket.isClosed) { while (!serverSocket.isClosed) {
try { try {
serverSocket.accept().use { serverSocket.accept().use {
try { try {
BaseSocket(it).handler() BaseSocket(it).handler()
} catch (e: Exception) { } catch (e: Exception) {
e.exception() e.exception()
} }
}
} catch (e: SocketException) {
if (e.message == "Socket closed" || e.message == "cn.tursom.socket closed") {
break
} else {
e.exception()
}
}
} }
} } catch (e: SocketException) {
if (e.message == "Socket closed" || e.message == "cn.tursom.socket closed") {
override fun close() { break
try { } else {
serverSocket.close() e.exception()
} catch (e: Exception) {
e.printStackTrace()
} }
}
} }
}
override fun close() {
try {
serverSocket.close()
} catch (e: Exception) {
e.printStackTrace()
}
}
} }

View File

@ -2,8 +2,10 @@ package cn.tursom.socket.server
import cn.tursom.socket.BaseSocket import cn.tursom.socket.BaseSocket
abstract class SocketServer(val handler: BaseSocket.() -> Unit) : ISocketServer { interface SocketServer : ISocketServer {
companion object { val handler: BaseSocket.() -> Unit
val cpuNumber = Runtime.getRuntime().availableProcessors() //CPU处理器的个数
} companion object {
val cpuNumber = Runtime.getRuntime().availableProcessors() //CPU处理器的个数
}
} }

View File

@ -30,7 +30,7 @@ import java.util.concurrent.TimeUnit
* } * }
* *
*/ */
open class ThreadPoolSocketServer class ThreadPoolSocketServer
/** /**
* 使用代码而不是配置文件的构造函数 * 使用代码而不是配置文件的构造函数
* *
@ -39,125 +39,100 @@ open class ThreadPoolSocketServer
* @param queueSize 线程池任务队列大小 * @param queueSize 线程池任务队列大小
* @param keepAliveTime 线程最长存活时间 * @param keepAliveTime 线程最长存活时间
* @param timeUnit timeout的单位默认毫秒 * @param timeUnit timeout的单位默认毫秒
* @param startImmediately 是否立即启动 * @param handler 对套接字处理的业务逻辑
*/( */(
port: Int, override val port: Int,
threads: Int = 1, threads: Int = 1,
queueSize: Int = 1, queueSize: Int = 1,
keepAliveTime: Long = 60_000L, keepAliveTime: Long = 60_000L,
timeUnit: TimeUnit = TimeUnit.MILLISECONDS, timeUnit: TimeUnit = TimeUnit.MILLISECONDS,
handler: BaseSocket.() -> Unit override val handler: BaseSocket.() -> Unit
) : SocketServer(handler) { ) : SocketServer {
constructor( constructor(
port: Int, port: Int,
handler: BaseSocket.() -> Unit handler: BaseSocket.() -> Unit
) : this(port, 1, 1, 60_000L, TimeUnit.MILLISECONDS, handler) ) : this(port, 1, 1, 60_000L, TimeUnit.MILLISECONDS, handler)
var socket = Socket() var socket = Socket()
private val pool: ThreadPoolExecutor = private val pool: ThreadPoolExecutor =
ThreadPoolExecutor(threads, threads, keepAliveTime, timeUnit, LinkedBlockingQueue(queueSize)) ThreadPoolExecutor(threads, threads, keepAliveTime, timeUnit, LinkedBlockingQueue(queueSize))
private var serverSocket: ServerSocket = ServerSocket(port) private var serverSocket: ServerSocket = ServerSocket(port)
/** /**
* 为了在构造函数中自动启动服务我们需要封闭start()防止用户重载start() * 主要作用
*/ * 循环接受连接请求
private fun start() { * 讲接收的连接交给handler处理
Thread(this).start() * 连接初期异常处理
} * 自动关闭套接字服务器与线程池
*/
/** override fun run() {
* 主要作用 while (!serverSocket.isClosed) {
* 循环接受连接请求 try {
* 讲接收的连接交给handler处理 socket = serverSocket.accept()
* 连接初期异常处理 println("$TAG: run(): get connect: $socket")
* 自动关闭套接字服务器与线程池 pool.execute {
*/ socket.use {
final override fun run() { BaseSocket(it).handler()
while (!serverSocket.isClosed) { }
try {
socket = serverSocket.accept()
println("$TAG: run(): get connect: $socket")
pool.execute {
socket.use {
BaseSocket(it).handler()
}
}
} catch (e: IOException) {
if (pool.isShutdown || serverSocket.isClosed) {
System.err.println("server closed")
break
}
e.printStackTrace()
} catch (e: SocketException) {
e.printStackTrace()
break
} catch (e: RejectedExecutionException) {
socket.getOutputStream()?.write(poolIsFull)
} catch (e: Exception) {
e.printStackTrace()
break
}
} }
whenClose() } catch (e: IOException) {
close() if (pool.isShutdown || serverSocket.isClosed) {
System.err.println("server closed") System.err.println("server closed")
} break
/**
* 关闭服务器套接字
*/
private fun closeServer() {
if (!serverSocket.isClosed) {
serverSocket.close()
} }
e.printStackTrace()
} catch (e: SocketException) {
e.printStackTrace()
break
} catch (e: RejectedExecutionException) {
socket.getOutputStream()?.write(poolIsFull)
} catch (e: Exception) {
e.printStackTrace()
break
}
} }
close()
System.err.println("server closed")
}
/** /**
* 关闭线程池 * 关闭服务器套接字
*/ */
private fun shutdownPool() { private fun closeServer() {
if (!pool.isShutdown) { if (!serverSocket.isClosed) {
pool.shutdown() serverSocket.close()
}
} }
}
/** /**
* 服务器是否已经关闭 * 关闭线程池
*/ */
@Suppress("unused") private fun shutdownPool() {
fun isClosed() = pool.isShutdown || serverSocket.isClosed if (!pool.isShutdown) {
pool.shutdown()
/**
* 关闭服务器
*/
override fun close() {
shutdownPool()
closeServer()
} }
}
/** /**
* 关闭服务器时执行 * 服务器是否已经关闭
*/ */
open fun whenClose() { @Suppress("unused")
} fun isClosed() = pool.isShutdown || serverSocket.isClosed
/**
* 关闭服务器
*/
override fun close() {
shutdownPool()
closeServer()
}
companion object {
val TAG = getTAG(this::class.java)
/** /**
* 线程池满时返回给客户端的信息 * 线程池满时返回给客户端的信息
*/ */
open val poolIsFull val poolIsFull = "server pool is full".toByteArray()
get() = Companion.poolIsFull }
private data class ServerConfigData(
val port: Int = 0,
val threads: Int = 1,
val queueSize: Int = 1,
val timeout: Long = 0L,
val startImmediately: Boolean = false
)
companion object {
val TAG = getTAG(this::class.java)
val poolIsFull = "server pool is full".toByteArray()
}
} }