优化 SocketServer 结构

This commit is contained in:
tursom 2019-10-18 21:26:04 +08:00
parent 2dd2e8c5fb
commit 3b78a9c998
15 changed files with 690 additions and 754 deletions

View File

@ -21,22 +21,22 @@ data class Demo(
)
fun main() {
// 获取数据库访问协助对象
val helper = SQLiteHelper("demo.db")
// 获取数据库访问协助对象
val helper = SQLiteHelper("demo.db")
// 插入数据
helper.insert(Demo(name = "tursom"))
// 插入数据
helper.insert(Demo(name = "tursom"))
// 更新数据
helper.update(Demo(name = "tursom", money = 100.0), where = ClauseMaker.make {
// 更新数据
helper.update(Demo(name = "tursom", money = 100.0), where = ClauseMaker.make {
!Demo::name equal "tursom"
})
})
// 获取数据
val data = helper.select(Demo::class.java, where = ClauseMaker {
// 获取数据
val data = helper.select(Demo::class.java, where = ClauseMaker {
(!Demo::id greaterThan !0) and (!Demo::id lessThan !10)
})
})
// 删除数据
helper.delete(Demo::class.java.tableName, where = ClauseMaker.make { !Demo::name equal "tursom" })
// 删除数据
helper.delete(Demo::class.java.tableName, where = ClauseMaker.make { !Demo::name equal "tursom" })
}

View File

@ -18,246 +18,204 @@ import kotlin.coroutines.suspendCoroutine
* 但是对于一般的应用而言是足够使用的
*/
class AsyncNioSocket(override val key: SelectionKey, override val nioThread: INioThread) : IAsyncNioSocket {
override val channel: SocketChannel = key.channel() as SocketChannel
override val channel: SocketChannel = key.channel() as SocketChannel
override suspend fun read(buffer: ByteBuffer): Int {
if (buffer.remaining() == 0) return -1
return try {
suspendCoroutine {
key.attach(SingleContext(buffer, it))
readMode()
nioThread.wakeup()
}
} catch (e: Exception) {
waitMode()
throw RuntimeException(e)
private suspend inline fun <T> operate(crossinline action: (Continuation<T>) -> Unit): T {
return try {
suspendCoroutine {
action(it)
}
} catch (e: Exception) {
waitMode()
throw RuntimeException(e)
}
}
override suspend fun read(buffer: ByteBuffer): Int {
if (buffer.remaining() == 0) return emptyBufferCode
return operate {
key.attach(SingleContext(buffer, it))
readMode()
nioThread.wakeup()
}
}
override suspend fun read(buffer: Array<out ByteBuffer>): Long {
if (buffer.isEmpty()) return emptyBufferLongCode
return operate {
key.attach(MultiContext(buffer, it))
readMode()
nioThread.wakeup()
}
}
override suspend fun write(buffer: ByteBuffer): Int {
if (buffer.remaining() == 0) return emptyBufferCode
return operate {
key.attach(SingleContext(buffer, it))
writeMode()
nioThread.wakeup()
}
}
override suspend fun write(buffer: Array<out ByteBuffer>): Long {
if (buffer.isEmpty()) return emptyBufferLongCode
return operate {
key.attach(MultiContext(buffer, it))
writeMode()
nioThread.wakeup()
}
}
override suspend fun read(buffer: ByteBuffer, timeout: Long): Int {
if (timeout <= 0) return read(buffer)
if (buffer.remaining() == 0) return emptyBufferCode
return operate {
key.attach(
SingleContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
})
)
readMode()
nioThread.wakeup()
}
}
override suspend fun read(buffer: Array<out ByteBuffer>, timeout: Long): Long {
if (timeout <= 0) return read(buffer)
if (buffer.isEmpty()) return emptyBufferLongCode
return operate {
key.attach(
MultiContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
})
)
readMode()
nioThread.wakeup()
}
}
override suspend fun write(buffer: ByteBuffer, timeout: Long): Int {
if (timeout <= 0) return write(buffer)
if (buffer.remaining() == 0) return emptyBufferCode
return operate {
key.attach(
SingleContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
})
)
writeMode()
nioThread.wakeup()
}
}
override suspend fun write(buffer: Array<out ByteBuffer>, timeout: Long): Long {
if (timeout <= 0) return write(buffer)
if (buffer.isEmpty()) return emptyBufferLongCode
return operate {
key.attach(
MultiContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
})
)
writeMode()
nioThread.wakeup()
}
}
override fun close() {
nioThread.execute {
channel.close()
key.cancel()
}
}
interface Context {
val cont: Continuation<*>
val timeoutTask: TimerTask? get() = null
}
class SingleContext(
val buffer: ByteBuffer,
override val cont: Continuation<Int>,
override val timeoutTask: TimerTask? = null
) : Context
class MultiContext(
val buffer: Array<out ByteBuffer>,
override val cont: Continuation<Long>,
override val timeoutTask: TimerTask? = null
) : Context
companion object {
val nioSocketProtocol = object : INioProtocol {
override fun handleConnect(key: SelectionKey, nioThread: INioThread) {}
override fun handleRead(key: SelectionKey, nioThread: INioThread) {
key.interestOps(0)
val context = key.attachment() as Context? ?: return
context.timeoutTask?.cancel()
if (context is SingleContext) {
val channel = key.channel() as SocketChannel
val readSize = channel.read(context.buffer)
context.cont.resume(readSize)
} else {
context as MultiContext
val channel = key.channel() as SocketChannel
val readSize = channel.read(context.buffer)
context.cont.resume(readSize)
}
}
}
override suspend fun read(buffer: Array<out ByteBuffer>): Long {
if (buffer.size == 0) return -1
return try {
suspendCoroutine {
key.attach(MultiContext(buffer, it))
readMode()
nioThread.wakeup()
}
} catch (e: Exception) {
waitMode()
throw RuntimeException(e)
override fun handleWrite(key: SelectionKey, nioThread: INioThread) {
key.interestOps(0)
val context = key.attachment() as Context? ?: return
context.timeoutTask?.cancel()
if (context is SingleContext) {
val channel = key.channel() as SocketChannel
val readSize = channel.write(context.buffer)
context.cont.resume(readSize)
} else {
context as MultiContext
val channel = key.channel() as SocketChannel
val readSize = channel.write(context.buffer)
context.cont.resume(readSize)
}
}
}
override suspend fun write(buffer: ByteBuffer): Int {
if (buffer.remaining() == 0) return -1
return try {
suspendCoroutine {
key.attach(SingleContext(buffer, it))
writeMode()
nioThread.wakeup()
}
} catch (e: Exception) {
waitMode()
throw Exception(e)
override fun exceptionCause(key: SelectionKey, nioThread: INioThread, e: Throwable) {
key.interestOps(0)
val context = key.attachment() as Context?
if (context != null)
context.cont.resumeWithException(e)
else {
key.cancel()
key.channel().close()
e.printStackTrace()
}
}
}
override suspend fun write(buffer: Array<out ByteBuffer>): Long {
if (buffer.isEmpty()) return -1
return try {
suspendCoroutine {
key.attach(MultiContext(buffer, it))
writeMode()
nioThread.wakeup()
}
} catch (e: Exception) {
waitMode()
throw Exception(e)
}
}
//val timer = StaticWheelTimer.timer
val timer = WheelTimer.timer
override suspend fun read(buffer: ByteBuffer, timeout: Long): Int {
if (timeout <= 0) return read(buffer)
if (buffer.remaining() == 0) return -1
return try {
val result: Int = suspendCoroutine {
key.attach(
SingleContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
try {
it.resumeWithException(TimeoutException())
} catch (e: Exception) {
}
})
)
readMode()
nioThread.wakeup()
}
result
} catch (e: Exception) {
waitMode()
throw RuntimeException(e)
}
}
override suspend fun read(buffer: Array<out ByteBuffer>, timeout: Long): Long {
if (timeout <= 0) return read(buffer)
if (buffer.isEmpty()) return -1
return try {
val result: Long = suspendCoroutine {
key.attach(
MultiContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
try {
it.resumeWithException(TimeoutException())
} catch (e: Exception) {
}
})
)
readMode()
nioThread.wakeup()
}
result
} catch (e: Exception) {
waitMode()
throw Exception(e)
}
}
override suspend fun write(buffer: ByteBuffer, timeout: Long): Int {
if (timeout <= 0) return write(buffer)
if (buffer.remaining() == 0) return -1
return try {
val result: Int = suspendCoroutine {
key.attach(
SingleContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
try {
it.resumeWithException(TimeoutException())
} catch (e: Exception) {
}
})
)
writeMode()
nioThread.wakeup()
}
result
} catch (e: Exception) {
waitMode()
throw Exception(e)
}
}
override suspend fun write(buffer: Array<out ByteBuffer>, timeout: Long): Long {
if (timeout <= 0) return write(buffer)
if (buffer.isEmpty()) return -1
return try {
val result: Long = suspendCoroutine {
key.attach(
MultiContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
try {
it.resumeWithException(TimeoutException())
} catch (e: Exception) {
}
})
)
writeMode()
nioThread.wakeup()
}
result
} catch (e: Exception) {
waitMode()
throw Exception(e)
}
}
override fun close() {
nioThread.execute {
channel.close()
key.cancel()
}
}
interface Context {
val cont: Continuation<*>
val timeoutTask: TimerTask? get() = null
}
class SingleContext(
val buffer: ByteBuffer,
override val cont: Continuation<Int>,
override val timeoutTask: TimerTask? = null
) : Context
class MultiContext(
val buffer: Array<out ByteBuffer>,
override val cont: Continuation<Long>,
override val timeoutTask: TimerTask? = null
) : Context
companion object {
val nioSocketProtocol = object : INioProtocol {
override fun handleConnect(key: SelectionKey, nioThread: INioThread) {}
override fun handleRead(key: SelectionKey, nioThread: INioThread) {
key.interestOps(0)
val context = key.attachment() as Context? ?: return
context.timeoutTask?.cancel()
if (context is SingleContext) {
val channel = key.channel() as SocketChannel
val readSize = channel.read(context.buffer)
context.cont.resume(readSize)
} else {
context as MultiContext
val channel = key.channel() as SocketChannel
val readSize = channel.read(context.buffer)
context.cont.resume(readSize)
}
}
override fun handleWrite(key: SelectionKey, nioThread: INioThread) {
key.interestOps(0)
val context = key.attachment() as Context? ?: return
context.timeoutTask?.cancel()
if (context is SingleContext) {
val channel = key.channel() as SocketChannel
val readSize = channel.write(context.buffer)
context.cont.resume(readSize)
} else {
context as MultiContext
val channel = key.channel() as SocketChannel
val readSize = channel.write(context.buffer)
context.cont.resume(readSize)
}
}
override fun exceptionCause(key: SelectionKey, nioThread: INioThread, e: Throwable) {
key.interestOps(0)
val context = key.attachment() as Context?
if (context != null)
context.cont.resumeWithException(e)
else {
key.cancel()
key.channel().close()
e.printStackTrace()
}
}
}
//val timer = StaticWheelTimer.timer
val timer = WheelTimer.timer
}
const val emptyBufferCode = 0
const val emptyBufferLongCode = 0L
}
}

View File

@ -9,82 +9,87 @@ import java.nio.channels.SelectionKey
import java.nio.channels.SocketChannel
interface IAsyncNioSocket : AsyncSocket {
val channel: SocketChannel
val key: SelectionKey
val nioThread: INioThread
val channel: SocketChannel
val key: SelectionKey
val nioThread: INioThread
fun waitMode() {
if (Thread.currentThread() == nioThread.thread) {
if (key.isValid) key.interestOps(SelectionKey.OP_WRITE)
} else {
nioThread.execute { if (key.isValid) key.interestOps(0) }
nioThread.wakeup()
}
}
fun waitMode() {
if (Thread.currentThread() == nioThread.thread) {
if (key.isValid) key.interestOps(SelectionKey.OP_WRITE)
} else {
nioThread.execute { if (key.isValid) key.interestOps(0) }
nioThread.wakeup()
}
}
fun readMode() {
if (Thread.currentThread() == nioThread.thread) {
if (key.isValid) key.interestOps(SelectionKey.OP_WRITE)
} else {
nioThread.execute { if (key.isValid) key.interestOps(SelectionKey.OP_READ) }
nioThread.wakeup()
}
}
fun readMode() {
if (Thread.currentThread() == nioThread.thread) {
if (key.isValid) key.interestOps(SelectionKey.OP_WRITE)
} else {
nioThread.execute { if (key.isValid) key.interestOps(SelectionKey.OP_READ) }
nioThread.wakeup()
}
}
fun writeMode() {
if (Thread.currentThread() == nioThread.thread) {
if (key.isValid) key.interestOps(SelectionKey.OP_WRITE)
} else {
nioThread.execute { if (key.isValid) key.interestOps(SelectionKey.OP_WRITE) }
nioThread.wakeup()
}
}
fun writeMode() {
if (Thread.currentThread() == nioThread.thread) {
if (key.isValid) key.interestOps(SelectionKey.OP_WRITE)
} else {
nioThread.execute { if (key.isValid) key.interestOps(SelectionKey.OP_WRITE) }
nioThread.wakeup()
}
}
suspend fun read(buffer: ByteBuffer): Int = read(arrayOf(buffer)).toInt()
suspend fun write(buffer: ByteBuffer): Int = write(arrayOf(buffer)).toInt()
suspend fun read(buffer: Array<out ByteBuffer>): Long
suspend fun write(buffer: Array<out ByteBuffer>): Long
/**
* 如果通道已断开则会抛出异常
*/
suspend fun recv(buffer: ByteBuffer): Int {
if (buffer.remaining() == 0) return 0
val readSize = read(buffer)
if (readSize < 0) {
throw SocketException("channel closed")
}
return readSize
}
suspend fun read(buffer: ByteBuffer): Int = read(arrayOf(buffer)).toInt()
suspend fun write(buffer: ByteBuffer): Int = write(arrayOf(buffer)).toInt()
suspend fun read(buffer: Array<out ByteBuffer>): Long
suspend fun write(buffer: Array<out ByteBuffer>): Long
/**
* 如果通道已断开则会抛出异常
*/
suspend fun recv(buffer: ByteBuffer): Int {
if (buffer.remaining() == 0) return emptyBufferCode
val readSize = read(buffer)
if (readSize < 0) {
throw SocketException("channel closed")
}
return readSize
}
suspend fun recv(buffer: ByteBuffer, timeout: Long): Int {
if (buffer.remaining() == 0) return 0
val readSize = read(buffer, timeout)
if (readSize < 0) {
throw SocketException("channel closed")
}
return readSize
}
suspend fun recv(buffer: ByteBuffer, timeout: Long): Int {
if (buffer.remaining() == 0) return emptyBufferCode
val readSize = read(buffer, timeout)
if (readSize < 0) {
throw SocketException("channel closed")
}
return readSize
}
suspend fun recv(buffers: Array<out ByteBuffer>, timeout: Long): Long {
if (buffers.isEmpty()) return 0
val readSize = read(buffers, timeout)
if (readSize < 0) {
throw SocketException("channel closed")
}
return readSize
}
suspend fun recv(buffers: Array<out ByteBuffer>, timeout: Long): Long {
if (buffers.isEmpty()) return emptyBufferLongCode
val readSize = read(buffers, timeout)
if (readSize < 0) {
throw SocketException("channel closed")
}
return readSize
}
suspend fun recv(buffer: AdvanceByteBuffer, timeout: Long = 0): Int {
return if (buffer.bufferCount == 1) {
buffer.writeNioBuffer {
recv(it, timeout)
}
} else {
val readMode = buffer.readMode
buffer.resumeWriteMode()
val value = recv(buffer.nioBuffers, timeout).toInt()
if (readMode) buffer.readMode()
value
}
}
suspend fun recv(buffer: AdvanceByteBuffer, timeout: Long = 0): Int {
return if (buffer.bufferCount == 1) {
buffer.writeNioBuffer {
recv(it, timeout)
}
} else {
val readMode = buffer.readMode
buffer.resumeWriteMode()
val value = recv(buffer.nioBuffers, timeout).toInt()
if (readMode) buffer.readMode()
value
}
}
companion object {
const val emptyBufferCode = 0
const val emptyBufferLongCode = 0L
}
}

View File

@ -9,33 +9,33 @@ import java.nio.channels.SelectionKey
/**
* 有多个工作线程的协程套接字服务器
* 不过因为结构复杂所以性能实际上比多线程的 ProtocolAsyncNioServer
* 不过因为结构复杂所以性能一般比单个工作线程的 AsyncNioServer
*/
@Suppress("MemberVisibilityCanBePrivate")
class AsyncGroupNioServer(
val port: Int,
val threads: Int = Runtime.getRuntime().availableProcessors(),
backlog: Int = 50,
val handler: suspend AsyncNioSocket.() -> Unit
override val port: Int,
val threads: Int = Runtime.getRuntime().availableProcessors(),
backlog: Int = 50,
val handler: suspend AsyncNioSocket.() -> Unit
) : ISocketServer by GroupNioServer(
port,
threads,
object : INioProtocol by AsyncNioSocket.nioSocketProtocol {
override fun handleConnect(key: SelectionKey, nioThread: INioThread) {
GlobalScope.launch {
val socket = AsyncNioSocket(key, nioThread)
try {
socket.handler()
} catch (e: Exception) {
e.printStackTrace()
} finally {
try {
nioThread.execute { socket.close() }
} catch (e: Exception) {
}
}
}
}
},
backlog
port,
threads,
object : INioProtocol by AsyncNioSocket.nioSocketProtocol {
override fun handleConnect(key: SelectionKey, nioThread: INioThread) {
GlobalScope.launch {
val socket = AsyncNioSocket(key, nioThread)
try {
socket.handler()
} catch (e: Exception) {
e.printStackTrace()
} finally {
try {
nioThread.execute { socket.close() }
} catch (e: Exception) {
}
}
}
}
},
backlog
)

View File

@ -9,43 +9,41 @@ import java.nio.channels.SelectionKey
/**
* 只有一个工作线程的协程套接字服务器
* 不过因为结构更加简单所以性能实际上比多线程的 ProtocolGroupAsyncNioServer
* 不过因为结构更加简单所以性能一般比多个工作线程的 ProtocolGroupAsyncNioServer
* 而且协程是天生多线程并不需要太多的接受线程来处理所以一般只需要用本服务器即可
*/
class AsyncNioServer(
val port: Int,
backlog: Int = 50,
val handler: suspend AsyncNioSocket.() -> Unit
override val port: Int,
backlog: Int = 50,
val handler: suspend AsyncNioSocket.() -> Unit
) : ISocketServer by NioServer(port, object : INioProtocol by AsyncNioSocket.nioSocketProtocol {
override fun handleConnect(key: SelectionKey, nioThread: INioThread) {
GlobalScope.launch {
val socket = AsyncNioSocket(key, nioThread)
try {
socket.handler()
} catch (e: Exception) {
e.printStackTrace()
} finally {
try {
socket.close()
} catch (e: Exception) {
}
}
override fun handleConnect(key: SelectionKey, nioThread: INioThread) {
GlobalScope.launch {
val socket = AsyncNioSocket(key, nioThread)
try {
socket.handler()
} catch (e: Exception) {
e.printStackTrace()
} finally {
try {
socket.close()
} catch (e: Exception) {
}
}
}
}
}, backlog) {
/**
* 次要构造方法为使用Spring的同学们准备的
*/
constructor(
port: Int,
backlog: Int = 50,
handler: Handler
) : this(port, backlog, {
handler.handle(this)
})
/**
* 次要构造方法为使用Spring的同学们准备的
*/
constructor(
port: Int,
backlog: Int = 50,
handler: Handler
) : this(port, backlog, { handler.handle(this) })
interface Handler {
fun handle(socket: AsyncNioSocket)
}
interface Handler {
fun handle(socket: AsyncNioSocket)
}
}

View File

@ -0,0 +1,51 @@
package cn.tursom.socket.server
import cn.tursom.socket.AsyncAioSocket
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import java.io.Closeable
import java.net.InetSocketAddress
import java.nio.channels.AsynchronousCloseException
import java.nio.channels.AsynchronousServerSocketChannel
import java.nio.channels.AsynchronousSocketChannel
import java.nio.channels.CompletionHandler
class AsyncSocketServer(
override val port: Int,
host: String = "0.0.0.0",
private val handler: suspend AsyncAioSocket.() -> Unit
) : ISocketServer {
private val server = AsynchronousServerSocketChannel
.open()
.bind(InetSocketAddress(host, port))
override fun run() {
server.accept(0, object : CompletionHandler<AsynchronousSocketChannel, Int> {
override fun completed(result: AsynchronousSocketChannel?, attachment: Int) {
try {
server.accept(attachment + 1, this)
} catch (e: Throwable) {
e.printStackTrace()
}
result ?: return
GlobalScope.launch {
AsyncAioSocket(result).handler()
}
}
override fun failed(exc: Throwable?, attachment: Int?) {
when (exc) {
is AsynchronousCloseException -> {
}
else -> exc?.printStackTrace()
}
}
})
}
override fun close() {
server.close()
}
}

View File

@ -1,7 +0,0 @@
package cn.tursom.socket.server.async
import java.io.Closeable
interface AsyncServer : Runnable, Closeable {
val port: Int
}

View File

@ -1,51 +0,0 @@
package cn.tursom.socket.server.async
import cn.tursom.socket.AsyncAioSocket
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import java.io.Closeable
import java.net.InetSocketAddress
import java.nio.channels.AsynchronousCloseException
import java.nio.channels.AsynchronousServerSocketChannel
import java.nio.channels.AsynchronousSocketChannel
import java.nio.channels.CompletionHandler
class AsyncSocketServer(
port: Int,
host: String = "0.0.0.0",
private val handler: suspend AsyncAioSocket.() -> Unit
) : Runnable, Closeable {
private val server = AsynchronousServerSocketChannel
.open()
.bind(InetSocketAddress(host, port))
override fun run() {
server.accept(0, object : CompletionHandler<AsynchronousSocketChannel, Int> {
override fun completed(result: AsynchronousSocketChannel?, attachment: Int) {
try {
server.accept(attachment + 1, this)
} catch (e: Throwable) {
e.printStackTrace()
}
result ?: return
GlobalScope.launch {
AsyncAioSocket(result).handler()
}
}
override fun failed(exc: Throwable?, attachment: Int?) {
when (exc) {
is AsynchronousCloseException -> {
}
else -> exc?.printStackTrace()
}
}
})
}
override fun close() {
server.close()
}
}

View File

@ -16,89 +16,89 @@ import java.util.concurrent.LinkedBlockingDeque
*/
@Suppress("MemberVisibilityCanBePrivate")
class GroupNioServer(
val port: Int,
val threads: Int = Runtime.getRuntime().availableProcessors(),
private val protocol: INioProtocol,
backlog: Int = 50,
val nioThreadGenerator: (
threadName: String,
threads: Int,
worker: (thread: INioThread) -> Unit
) -> IWorkerGroup = { name, _, worker ->
ThreadPoolWorkerGroup(threads, name, false, worker)
}
override val port: Int,
val threads: Int = Runtime.getRuntime().availableProcessors(),
private val protocol: INioProtocol,
backlog: Int = 50,
val nioThreadGenerator: (
threadName: String,
threads: Int,
worker: (thread: INioThread) -> Unit
) -> IWorkerGroup = { name, _, worker ->
ThreadPoolWorkerGroup(threads, name, false, worker)
}
) : ISocketServer {
private val listenChannel = ServerSocketChannel.open()
private val listenThreads = LinkedBlockingDeque<INioThread>()
private val workerGroupList = LinkedBlockingDeque<IWorkerGroup>()
private val listenChannel = ServerSocketChannel.open()
private val listenThreads = LinkedBlockingDeque<INioThread>()
private val workerGroupList = LinkedBlockingDeque<IWorkerGroup>()
init {
listenChannel.socket().bind(InetSocketAddress(port), backlog)
listenChannel.configureBlocking(false)
}
init {
listenChannel.socket().bind(InetSocketAddress(port), backlog)
listenChannel.configureBlocking(false)
}
override fun run() {
val workerGroup = nioThreadGenerator(
"nioWorkerGroup", threads,
NioServer.LoopHandler(protocol)::handle
)
workerGroupList.add(workerGroup)
override fun run() {
val workerGroup = nioThreadGenerator(
"nioWorkerGroup", threads,
NioServer.LoopHandler(protocol)::handle
)
workerGroupList.add(workerGroup)
val nioThread = ThreadPoolNioThread("nioAccepter") { nioThread ->
val selector = nioThread.selector
if (selector.isOpen) {
forEachKey(selector) { key ->
try {
when {
key.isAcceptable -> {
val serverChannel = key.channel() as ServerSocketChannel
var channel = serverChannel.accept()
while (channel != null) {
channel.configureBlocking(false)
workerGroup.register(channel) { (key, thread) ->
protocol.handleConnect(key, thread)
}
channel = serverChannel.accept()
}
}
}
} catch (e: Throwable) {
try {
protocol.exceptionCause(key, nioThread, e)
} catch (e1: Throwable) {
e.printStackTrace()
e1.printStackTrace()
key.cancel()
key.channel().close()
}
}
nioThread.execute(this)
val nioThread = ThreadPoolNioThread("nioAccepter") { nioThread ->
val selector = nioThread.selector
if (selector.isOpen) {
forEachKey(selector) { key ->
try {
when {
key.isAcceptable -> {
val serverChannel = key.channel() as ServerSocketChannel
var channel = serverChannel.accept()
while (channel != null) {
channel.configureBlocking(false)
workerGroup.register(channel) { (key, thread) ->
protocol.handleConnect(key, thread)
}
channel = serverChannel.accept()
}
}
}
listenThreads.add(nioThread)
listenChannel.register(nioThread.selector, SelectionKey.OP_ACCEPT)
nioThread.wakeup()
}
}
}
} catch (e: Throwable) {
try {
protocol.exceptionCause(key, nioThread, e)
} catch (e1: Throwable) {
e.printStackTrace()
e1.printStackTrace()
key.cancel()
key.channel().close()
}
}
nioThread.execute(this)
}
}
}
listenThreads.add(nioThread)
listenChannel.register(nioThread.selector, SelectionKey.OP_ACCEPT)
nioThread.wakeup()
}
override fun close() {
listenChannel.close()
listenThreads.forEach { it.close() }
workerGroupList.forEach { it.close() }
}
override fun close() {
listenChannel.close()
listenThreads.forEach { it.close() }
workerGroupList.forEach { it.close() }
}
companion object {
const val TIMEOUT = 3000L
companion object {
const val TIMEOUT = 3000L
inline fun forEachKey(selector: Selector, action: (key: SelectionKey) -> Unit) {
if (selector.select(TIMEOUT) != 0) {
val keyIter = selector.selectedKeys().iterator()
while (keyIter.hasNext()) run whileBlock@{
val key = keyIter.next()
keyIter.remove()
action(key)
}
}
}
}
inline fun forEachKey(selector: Selector, action: (key: SelectionKey) -> Unit) {
if (selector.select(TIMEOUT) != 0) {
val keyIter = selector.selectedKeys().iterator()
while (keyIter.hasNext()) run whileBlock@{
val key = keyIter.next()
keyIter.remove()
action(key)
}
}
}
}
}

View File

@ -2,4 +2,6 @@ package cn.tursom.socket.server
import java.io.Closeable
interface ISocketServer : Runnable, Closeable
interface ISocketServer : Runnable, Closeable {
val port: Int
}

View File

@ -1,5 +1,6 @@
package cn.tursom.socket.server
import cn.tursom.core.cpuNumber
import cn.tursom.socket.BaseSocket
import java.net.ServerSocket
@ -7,46 +8,47 @@ class MultithreadingSocketServer(
private val serverSocket: ServerSocket,
private val threadNumber: Int = cpuNumber,
val exception: Exception.() -> Unit = {
printStackTrace()
printStackTrace()
},
handler: BaseSocket.() -> Unit
) : SocketServer(handler) {
override val handler: BaseSocket.() -> Unit
) : SocketServer {
override val port = serverSocket.localPort
constructor(
port: Int,
threadNumber: Int = cpuNumber,
exception: Exception.() -> Unit = {
printStackTrace()
},
handler: BaseSocket.() -> Unit
) : this(ServerSocket(port), threadNumber, exception, handler)
constructor(
port: Int,
threadNumber: Int = cpuNumber,
exception: Exception.() -> Unit = {
printStackTrace()
},
handler: BaseSocket.() -> Unit
) : this(ServerSocket(port), threadNumber, exception, handler)
constructor(
port: Int,
handler: BaseSocket.() -> Unit
) : this(port, cpuNumber, { printStackTrace() }, handler)
constructor(
port: Int,
handler: BaseSocket.() -> Unit
) : this(port, cpuNumber, { printStackTrace() }, handler)
private val threadList = ArrayList<Thread>()
private val threadList = ArrayList<Thread>()
override fun run() {
for (i in 1..threadNumber) {
val thread = Thread {
while (true) {
serverSocket.accept().use {
try {
BaseSocket(it).handler()
} catch (e: Exception) {
e.exception()
}
}
}
override fun run() {
for (i in 1..threadNumber) {
val thread = Thread {
while (true) {
serverSocket.accept().use {
try {
BaseSocket(it).handler()
} catch (e: Exception) {
e.exception()
}
thread.start()
threadList.add(thread)
}
}
}
thread.start()
threadList.add(thread)
}
}
override fun close() {
serverSocket.close()
}
override fun close() {
serverSocket.close()
}
}

View File

@ -12,86 +12,86 @@ import java.util.concurrent.ConcurrentLinkedDeque
* 工作在单线程上的 Nio 服务器
*/
class NioServer(
val port: Int,
override val port: Int,
private val protocol: INioProtocol,
backLog: Int = 50,
val nioThreadGenerator: (threadName: String, workLoop: (thread: INioThread) -> Unit) -> INioThread
) : ISocketServer {
private val listenChannel = ServerSocketChannel.open()
private val threadList = ConcurrentLinkedDeque<INioThread>()
private val listenChannel = ServerSocketChannel.open()
private val threadList = ConcurrentLinkedDeque<INioThread>()
init {
listenChannel.socket().bind(InetSocketAddress(port), backLog)
listenChannel.configureBlocking(false)
init {
listenChannel.socket().bind(InetSocketAddress(port), backLog)
listenChannel.configureBlocking(false)
}
constructor(
port: Int,
protocol: INioProtocol,
backLog: Int = 50
) : this(port, protocol, backLog, { name, workLoop ->
WorkerLoopNioThread(name, workLoop = workLoop, isDaemon = false)
})
override fun run() {
val nioThread = nioThreadGenerator("nio worker", LoopHandler(protocol)::handle)
nioThread.register(listenChannel, SelectionKey.OP_ACCEPT) {}
threadList.add(nioThread)
}
override fun close() {
listenChannel.close()
threadList.forEach {
it.close()
}
}
constructor(
port: Int,
protocol: INioProtocol,
backLog: Int = 50
) : this(port, protocol, backLog, { name, workLoop ->
WorkerLoopNioThread(name, workLoop = workLoop, isDaemon = false)
})
override fun run() {
val nioThread = nioThreadGenerator("nio worker", LoopHandler(protocol)::handle)
nioThread.register(listenChannel, SelectionKey.OP_ACCEPT) {}
threadList.add(nioThread)
}
override fun close() {
listenChannel.close()
threadList.forEach {
it.close()
}
}
class LoopHandler(val protocol: INioProtocol) {
fun handle(nioThread: INioThread) {
val selector = nioThread.selector
if (selector.isOpen) {
if (selector.select(TIMEOUT) != 0) {
val keyIter = selector.selectedKeys().iterator()
while (keyIter.hasNext()) run whileBlock@{
val key = keyIter.next()
keyIter.remove()
try {
when {
key.isAcceptable -> {
val serverChannel = key.channel() as ServerSocketChannel
var channel = serverChannel.accept()
while (channel != null) {
channel.configureBlocking(false)
nioThread.register(channel, 0) {
protocol.handleConnect(it, nioThread)
}
channel = serverChannel.accept()
}
}
key.isReadable -> {
protocol.handleRead(key, nioThread)
}
key.isWritable -> {
protocol.handleWrite(key, nioThread)
}
}
} catch (e: Throwable) {
try {
protocol.exceptionCause(key, nioThread, e)
} catch (e1: Throwable) {
e.printStackTrace()
e1.printStackTrace()
key.cancel()
key.channel().close()
}
}
class LoopHandler(val protocol: INioProtocol) {
fun handle(nioThread: INioThread) {
val selector = nioThread.selector
if (selector.isOpen) {
if (selector.select(TIMEOUT) != 0) {
val keyIter = selector.selectedKeys().iterator()
while (keyIter.hasNext()) run whileBlock@{
val key = keyIter.next()
keyIter.remove()
try {
when {
key.isAcceptable -> {
val serverChannel = key.channel() as ServerSocketChannel
var channel = serverChannel.accept()
while (channel != null) {
channel.configureBlocking(false)
nioThread.register(channel, 0) {
protocol.handleConnect(it, nioThread)
}
channel = serverChannel.accept()
}
}
key.isReadable -> {
protocol.handleRead(key, nioThread)
}
key.isWritable -> {
protocol.handleWrite(key, nioThread)
}
}
} catch (e: Throwable) {
try {
protocol.exceptionCause(key, nioThread, e)
} catch (e1: Throwable) {
e.printStackTrace()
e1.printStackTrace()
key.cancel()
key.channel().close()
}
}
}
}
}
}
}
companion object {
private const val TIMEOUT = 1000L
}
companion object {
private const val TIMEOUT = 1000L
}
}

View File

@ -7,45 +7,46 @@ import java.net.SocketException
class SingleThreadSocketServer(
private val serverSocket: ServerSocket,
val exception: Exception.() -> Unit = { printStackTrace() },
handler: BaseSocket.() -> Unit
) : SocketServer(handler) {
override val handler: BaseSocket.() -> Unit
) : SocketServer {
override val port = serverSocket.localPort
constructor(
port: Int,
exception: Exception.() -> Unit = { printStackTrace() },
handler: BaseSocket.() -> Unit
) : this(ServerSocket(port), exception, handler)
constructor(
port: Int,
exception: Exception.() -> Unit = { printStackTrace() },
handler: BaseSocket.() -> Unit
) : this(ServerSocket(port), exception, handler)
constructor(
port: Int,
handler: BaseSocket.() -> Unit
) : this(port, { printStackTrace() }, handler)
constructor(
port: Int,
handler: BaseSocket.() -> Unit
) : this(port, { printStackTrace() }, handler)
override fun run() {
while (!serverSocket.isClosed) {
try {
serverSocket.accept().use {
try {
BaseSocket(it).handler()
} catch (e: Exception) {
e.exception()
}
}
} catch (e: SocketException) {
if (e.message == "Socket closed" || e.message == "cn.tursom.socket closed") {
break
} else {
e.exception()
}
}
override fun run() {
while (!serverSocket.isClosed) {
try {
serverSocket.accept().use {
try {
BaseSocket(it).handler()
} catch (e: Exception) {
e.exception()
}
}
}
override fun close() {
try {
serverSocket.close()
} catch (e: Exception) {
e.printStackTrace()
} catch (e: SocketException) {
if (e.message == "Socket closed" || e.message == "cn.tursom.socket closed") {
break
} else {
e.exception()
}
}
}
}
override fun close() {
try {
serverSocket.close()
} catch (e: Exception) {
e.printStackTrace()
}
}
}

View File

@ -2,8 +2,10 @@ package cn.tursom.socket.server
import cn.tursom.socket.BaseSocket
abstract class SocketServer(val handler: BaseSocket.() -> Unit) : ISocketServer {
companion object {
val cpuNumber = Runtime.getRuntime().availableProcessors() //CPU处理器的个数
}
interface SocketServer : ISocketServer {
val handler: BaseSocket.() -> Unit
companion object {
val cpuNumber = Runtime.getRuntime().availableProcessors() //CPU处理器的个数
}
}

View File

@ -30,7 +30,7 @@ import java.util.concurrent.TimeUnit
* }
*
*/
open class ThreadPoolSocketServer
class ThreadPoolSocketServer
/**
* 使用代码而不是配置文件的构造函数
*
@ -39,125 +39,100 @@ open class ThreadPoolSocketServer
* @param queueSize 线程池任务队列大小
* @param keepAliveTime 线程最长存活时间
* @param timeUnit timeout的单位默认毫秒
* @param startImmediately 是否立即启动
* @param handler 对套接字处理的业务逻辑
*/(
port: Int,
threads: Int = 1,
queueSize: Int = 1,
keepAliveTime: Long = 60_000L,
timeUnit: TimeUnit = TimeUnit.MILLISECONDS,
handler: BaseSocket.() -> Unit
) : SocketServer(handler) {
override val port: Int,
threads: Int = 1,
queueSize: Int = 1,
keepAliveTime: Long = 60_000L,
timeUnit: TimeUnit = TimeUnit.MILLISECONDS,
override val handler: BaseSocket.() -> Unit
) : SocketServer {
constructor(
port: Int,
handler: BaseSocket.() -> Unit
) : this(port, 1, 1, 60_000L, TimeUnit.MILLISECONDS, handler)
constructor(
port: Int,
handler: BaseSocket.() -> Unit
) : this(port, 1, 1, 60_000L, TimeUnit.MILLISECONDS, handler)
var socket = Socket()
private val pool: ThreadPoolExecutor =
ThreadPoolExecutor(threads, threads, keepAliveTime, timeUnit, LinkedBlockingQueue(queueSize))
private var serverSocket: ServerSocket = ServerSocket(port)
var socket = Socket()
private val pool: ThreadPoolExecutor =
ThreadPoolExecutor(threads, threads, keepAliveTime, timeUnit, LinkedBlockingQueue(queueSize))
private var serverSocket: ServerSocket = ServerSocket(port)
/**
* 为了在构造函数中自动启动服务我们需要封闭start()防止用户重载start()
*/
private fun start() {
Thread(this).start()
}
/**
* 主要作用
* 循环接受连接请求
* 讲接收的连接交给handler处理
* 连接初期异常处理
* 自动关闭套接字服务器与线程池
*/
final override fun run() {
while (!serverSocket.isClosed) {
try {
socket = serverSocket.accept()
println("$TAG: run(): get connect: $socket")
pool.execute {
socket.use {
BaseSocket(it).handler()
}
}
} catch (e: IOException) {
if (pool.isShutdown || serverSocket.isClosed) {
System.err.println("server closed")
break
}
e.printStackTrace()
} catch (e: SocketException) {
e.printStackTrace()
break
} catch (e: RejectedExecutionException) {
socket.getOutputStream()?.write(poolIsFull)
} catch (e: Exception) {
e.printStackTrace()
break
}
/**
* 主要作用
* 循环接受连接请求
* 讲接收的连接交给handler处理
* 连接初期异常处理
* 自动关闭套接字服务器与线程池
*/
override fun run() {
while (!serverSocket.isClosed) {
try {
socket = serverSocket.accept()
println("$TAG: run(): get connect: $socket")
pool.execute {
socket.use {
BaseSocket(it).handler()
}
}
whenClose()
close()
System.err.println("server closed")
}
/**
* 关闭服务器套接字
*/
private fun closeServer() {
if (!serverSocket.isClosed) {
serverSocket.close()
} catch (e: IOException) {
if (pool.isShutdown || serverSocket.isClosed) {
System.err.println("server closed")
break
}
e.printStackTrace()
} catch (e: SocketException) {
e.printStackTrace()
break
} catch (e: RejectedExecutionException) {
socket.getOutputStream()?.write(poolIsFull)
} catch (e: Exception) {
e.printStackTrace()
break
}
}
close()
System.err.println("server closed")
}
/**
* 关闭线程池
*/
private fun shutdownPool() {
if (!pool.isShutdown) {
pool.shutdown()
}
/**
* 关闭服务器套接字
*/
private fun closeServer() {
if (!serverSocket.isClosed) {
serverSocket.close()
}
}
/**
* 服务器是否已经关闭
*/
@Suppress("unused")
fun isClosed() = pool.isShutdown || serverSocket.isClosed
/**
* 关闭服务器
*/
override fun close() {
shutdownPool()
closeServer()
/**
* 关闭线程池
*/
private fun shutdownPool() {
if (!pool.isShutdown) {
pool.shutdown()
}
}
/**
* 关闭服务器时执行
*/
open fun whenClose() {
}
/**
* 服务器是否已经关闭
*/
@Suppress("unused")
fun isClosed() = pool.isShutdown || serverSocket.isClosed
/**
* 关闭服务器
*/
override fun close() {
shutdownPool()
closeServer()
}
companion object {
val TAG = getTAG(this::class.java)
/**
* 线程池满时返回给客户端的信息
*/
open val poolIsFull
get() = Companion.poolIsFull
private data class ServerConfigData(
val port: Int = 0,
val threads: Int = 1,
val queueSize: Int = 1,
val timeout: Long = 0L,
val startImmediately: Boolean = false
)
companion object {
val TAG = getTAG(this::class.java)
val poolIsFull = "server pool is full".toByteArray()
}
val poolIsFull = "server pool is full".toByteArray()
}
}