添加注释,优化 AsyncNioSocket 结构

This commit is contained in:
tursom 2019-10-22 13:59:40 +08:00
parent 751b5dda42
commit 9703056095
8 changed files with 226 additions and 276 deletions

View File

@ -1,7 +1,5 @@
package cn.tursom.socket
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import java.net.InetSocketAddress
import java.net.SocketAddress
import java.nio.channels.AsynchronousSocketChannel
@ -14,15 +12,11 @@ import kotlin.coroutines.suspendCoroutine
object AsyncAioClient {
private val handler = object : CompletionHandler<Void, Continuation<Void?>> {
override fun completed(result: Void?, attachment: Continuation<Void?>) {
GlobalScope.launch {
attachment.resume(result)
}
attachment.resume(result)
}
override fun failed(exc: Throwable, attachment: Continuation<Void?>) {
GlobalScope.launch {
attachment.resumeWithException(exc)
}
attachment.resumeWithException(exc)
}
}

View File

@ -20,50 +20,35 @@ import kotlin.coroutines.suspendCoroutine
class AsyncNioSocket(override val key: SelectionKey, override val nioThread: INioThread) : IAsyncNioSocket {
override val channel: SocketChannel = key.channel() as SocketChannel
private suspend inline fun <T> operate(crossinline action: (Continuation<T>) -> Unit): T {
return try {
suspendCoroutine {
action(it)
}
} catch (e: Exception) {
waitMode()
throw RuntimeException(e)
}
}
override suspend fun read(buffer: ByteBuffer): Int {
if (buffer.remaining() == 0) return emptyBufferCode
return operate {
key.attach(SingleContext(buffer, it))
readMode()
nioThread.wakeup()
waitRead()
channel.read(buffer)
}
}
override suspend fun read(buffer: Array<out ByteBuffer>): Long {
if (buffer.isEmpty()) return emptyBufferLongCode
return operate {
key.attach(MultiContext(buffer, it))
readMode()
nioThread.wakeup()
waitRead()
channel.read(buffer)
}
}
override suspend fun write(buffer: ByteBuffer): Int {
if (buffer.remaining() == 0) return emptyBufferCode
return operate {
key.attach(SingleContext(buffer, it))
writeMode()
nioThread.wakeup()
waitWrite()
channel.write(buffer)
}
}
override suspend fun write(buffer: Array<out ByteBuffer>): Long {
if (buffer.isEmpty()) return emptyBufferLongCode
return operate {
key.attach(MultiContext(buffer, it))
writeMode()
nioThread.wakeup()
waitWrite()
channel.write(buffer)
}
}
@ -71,17 +56,8 @@ class AsyncNioSocket(override val key: SelectionKey, override val nioThread: INi
if (timeout <= 0) return read(buffer)
if (buffer.remaining() == 0) return emptyBufferCode
return operate {
key.attach(
SingleContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
})
)
readMode()
nioThread.wakeup()
waitRead(timeout)
channel.read(buffer)
}
}
@ -89,17 +65,8 @@ class AsyncNioSocket(override val key: SelectionKey, override val nioThread: INi
if (timeout <= 0) return read(buffer)
if (buffer.isEmpty()) return emptyBufferLongCode
return operate {
key.attach(
MultiContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
})
)
readMode()
nioThread.wakeup()
waitRead(timeout)
channel.read(buffer)
}
}
@ -107,17 +74,8 @@ class AsyncNioSocket(override val key: SelectionKey, override val nioThread: INi
if (timeout <= 0) return write(buffer)
if (buffer.remaining() == 0) return emptyBufferCode
return operate {
key.attach(
SingleContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
})
)
writeMode()
nioThread.wakeup()
waitWrite(timeout)
channel.write(buffer)
}
}
@ -125,17 +83,8 @@ class AsyncNioSocket(override val key: SelectionKey, override val nioThread: INi
if (timeout <= 0) return write(buffer)
if (buffer.isEmpty()) return emptyBufferLongCode
return operate {
key.attach(
MultiContext(
buffer,
it,
timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
})
)
writeMode()
nioThread.wakeup()
waitWrite(timeout)
channel.write(buffer)
}
}
@ -144,24 +93,57 @@ class AsyncNioSocket(override val key: SelectionKey, override val nioThread: INi
channel.close()
key.cancel()
}
nioThread.wakeup()
}
interface Context {
val cont: Continuation<*>
val timeoutTask: TimerTask? get() = null
private inline fun <T> operate(action: () -> T): T {
return try {
action()
} catch (e: Exception) {
waitMode()
throw RuntimeException(e)
}
}
class SingleContext(
val buffer: ByteBuffer,
override val cont: Continuation<Int>,
override val timeoutTask: TimerTask? = null
) : Context
private suspend inline fun waitRead(timeout: Long) {
suspendCoroutine<Int> {
key.attach(Context(it, timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
readMode()
nioThread.wakeup()
}))
}
}
class MultiContext(
val buffer: Array<out ByteBuffer>,
override val cont: Continuation<Long>,
override val timeoutTask: TimerTask? = null
) : Context
private suspend inline fun waitWrite(timeout: Long) {
suspendCoroutine<Int> {
key.attach(Context(it, timer.exec(timeout) {
key.attach(null)
it.resumeWithException(TimeoutException())
writeMode()
nioThread.wakeup()
}))
}
}
private suspend inline fun waitRead() {
suspendCoroutine<Int> {
key.attach(Context(it))
readMode()
nioThread.wakeup()
}
}
private suspend inline fun waitWrite() {
suspendCoroutine<Int> {
key.attach(Context(it))
writeMode()
nioThread.wakeup()
}
}
data class Context(val cont: Continuation<Int>, val timeoutTask: TimerTask? = null)
companion object {
val nioSocketProtocol = object : INioProtocol {
@ -171,32 +153,13 @@ class AsyncNioSocket(override val key: SelectionKey, override val nioThread: INi
key.interestOps(0)
val context = key.attachment() as Context? ?: return
context.timeoutTask?.cancel()
if (context is SingleContext) {
val channel = key.channel() as SocketChannel
val readSize = channel.read(context.buffer)
context.cont.resume(readSize)
} else {
context as MultiContext
val channel = key.channel() as SocketChannel
val readSize = channel.read(context.buffer)
context.cont.resume(readSize)
}
context.cont.resume(0)
}
override fun handleWrite(key: SelectionKey, nioThread: INioThread) {
key.interestOps(0)
val context = key.attachment() as Context? ?: return
context.timeoutTask?.cancel()
if (context is SingleContext) {
val channel = key.channel() as SocketChannel
val readSize = channel.write(context.buffer)
context.cont.resume(readSize)
} else {
context as MultiContext
val channel = key.channel() as SocketChannel
val readSize = channel.write(context.buffer)
context.cont.resume(readSize)
}
context.cont.resume(0)
}
override fun exceptionCause(key: SelectionKey, nioThread: INioThread, e: Throwable) {

View File

@ -7,10 +7,15 @@ import cn.tursom.core.pool.MemoryPool
import cn.tursom.core.pool.usingAdvanceByteBuffer
import cn.tursom.socket.AsyncNioSocket
/**
* 带内存池的 NIO 套接字服务器<br />
* 其构造函数是标准写法的改造会向 handler 方法传入一个 AdvanceByteBuffer默认是 DirectAdvanceByteBuffer
* 当内存池用完之后会换为 ByteArrayAdvanceByteBuffer
*/
class BuffedAsyncNioServer(
port: Int,
backlog: Int = 50,
memoryPool: MemoryPool,
backlog: Int = 50,
handler: suspend AsyncNioSocket.(buffer: AdvanceByteBuffer) -> Unit
) : IAsyncNioServer by AsyncNioServer(port, backlog, {
memoryPool.usingAdvanceByteBuffer {
@ -23,5 +28,5 @@ class BuffedAsyncNioServer(
blockCount: Int = 128,
backlog: Int = 50,
handler: suspend AsyncNioSocket.(buffer: AdvanceByteBuffer) -> Unit
) : this(port, backlog, DirectMemoryPool(blockSize, blockCount), handler)
) : this(port, DirectMemoryPool(blockSize, blockCount), backlog, handler)
}

View File

@ -7,7 +7,8 @@ import java.nio.channels.Selector
import java.util.concurrent.Callable
/**
* 一个 nio 工作线程一个线程只有一个 Selector
* 一个 nio 工作线程
* 一个线程对应一个 Selector 选择器
*/
interface INioThread : Closeable {
val selector: Selector
@ -20,6 +21,9 @@ interface INioThread : Closeable {
if (Thread.currentThread() != thread) selector.wakeup()
}
/**
* 将通道注册到线程对应的选择器上
*/
fun register(channel: SelectableChannel, ops: Int, onComplete: (key: SelectionKey) -> Unit) {
if (Thread.currentThread() == thread) {
val key = channel.register(selector, ops)

View File

@ -1,86 +1,68 @@
package cn.tursom.socket.niothread
import cn.tursom.core.timer.WheelTimer
import java.nio.channels.SelectableChannel
import java.nio.channels.SelectionKey
import java.nio.channels.Selector
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicBoolean
@Suppress("MemberVisibilityCanBePrivate")
class ThreadPoolNioThread(
val threadName: String = "",
override val selector: Selector = Selector.open(),
override val isDaemon: Boolean = false,
override val workLoop: (thread: INioThread) -> Unit
val threadName: String = "",
override val selector: Selector = Selector.open(),
override val isDaemon: Boolean = false,
override val workLoop: (thread: INioThread) -> Unit
) : INioThread {
private var onWakeup: AtomicBoolean = AtomicBoolean(false)
override lateinit var thread: Thread
//val threadPool: ExecutorService = Executors.newSingleThreadExecutor {
// val thread = Thread(it)
// workerThread = thread
// thread.isDaemon = true
// thread.name = threadName
// thread
//}
val threadPool: ExecutorService = ThreadPoolExecutor(1, 1,
0L, TimeUnit.MILLISECONDS,
LinkedBlockingQueue<Runnable>(),
ThreadFactory {
val thread = Thread(it)
this.thread = thread
thread.isDaemon = isDaemon
thread.name = threadName
thread
})
override lateinit var thread: Thread
val threadPool: ExecutorService = ThreadPoolExecutor(1, 1,
0L, TimeUnit.MILLISECONDS,
LinkedBlockingQueue<Runnable>(),
ThreadFactory {
val thread = Thread(it)
this.thread = thread
thread.isDaemon = isDaemon
thread.name = threadName
thread
})
override var closed: Boolean = false
init {
threadPool.execute(object : Runnable {
override fun run() {
workLoop(this@ThreadPoolNioThread)
if (!threadPool.isShutdown) threadPool.execute(this)
}
})
}
init {
threadPool.execute(object : Runnable {
override fun run() {
workLoop(this@ThreadPoolNioThread)
if (!threadPool.isShutdown) threadPool.execute(this)
}
})
}
override var closed: Boolean = false
override fun wakeup() {
if (Thread.currentThread() != thread) {
selector.wakeup()
}
}
override fun wakeup() {
if (Thread.currentThread() != thread && onWakeup.compareAndSet(false, true)) {
timer.exec(50) {
onWakeup.set(false)
selector.wakeup()
}
}
}
override fun register(channel: SelectableChannel, ops: Int, onComplete: (key: SelectionKey) -> Unit) {
if (Thread.currentThread() == thread) {
onComplete(channel.register(selector, ops))
} else {
threadPool.execute { register(channel, ops, onComplete) }
wakeup()
}
}
override fun register(channel: SelectableChannel, ops: Int, onComplete: (key: SelectionKey) -> Unit) {
if (Thread.currentThread() == thread) {
onComplete(channel.register(selector, ops))
} else {
threadPool.execute { register(channel, ops, onComplete) }
wakeup()
}
}
override fun execute(command: Runnable) = threadPool.execute(command)
override fun <T> call(task: Callable<T>): T = threadPool.submit(task).get()
override fun <T> submit(task: Callable<T>): NioThreadFuture<T> = ThreadPoolFuture(threadPool.submit(task))
override fun execute(command: Runnable) = threadPool.execute(command)
override fun <T> call(task: Callable<T>): T = threadPool.submit(task).get()
override fun <T> submit(task: Callable<T>): NioThreadFuture<T> = ThreadPoolFuture(threadPool.submit(task))
override fun close() {
closed = true
threadPool.shutdown()
}
override fun close() {
closed = true
threadPool.shutdown()
}
class ThreadPoolFuture<T>(val future: Future<T>) : NioThreadFuture<T> {
override fun get(): T = future.get()
}
class ThreadPoolFuture<T>(val future: Future<T>) : NioThreadFuture<T> {
override fun get(): T = future.get()
}
override fun toString(): String {
return "SingleThreadNioThread($threadName)"
}
companion object {
val timer = WheelTimer.smoothTimer
}
override fun toString(): String {
return "SingleThreadNioThread($threadName)"
}
}

View File

@ -8,112 +8,103 @@ import java.util.concurrent.atomic.AtomicBoolean
@Suppress("MemberVisibilityCanBePrivate", "CanBeParameter")
class WorkerLoopNioThread(
val threadName: String = "nioLoopThread",
override val selector: Selector = Selector.open(),
override val isDaemon: Boolean = false,
override val workLoop: (thread: INioThread) -> Unit
val threadName: String = "nioLoopThread",
override val selector: Selector = Selector.open(),
override val isDaemon: Boolean = false,
override val workLoop: (thread: INioThread) -> Unit
) : INioThread {
private var onWakeup: AtomicBoolean = AtomicBoolean(false)
override var closed: Boolean = false
override var closed: Boolean = false
val waitQueue = LinkedBlockingDeque<Runnable>()
val taskQueue = LinkedBlockingDeque<Future<Any?>>()
val waitQueue = LinkedBlockingDeque<Runnable>()
val taskQueue = LinkedBlockingDeque<Future<Any?>>()
override val thread = Thread {
while (!closed) {
try {
workLoop(this)
} catch (e: Exception) {
e.printStackTrace()
}
//System.err.println("$threadName worker loop finish once")
while (waitQueue.isNotEmpty()) try {
waitQueue.poll().run()
} catch (e: Exception) {
e.printStackTrace()
}
while (taskQueue.isNotEmpty()) try {
val task = taskQueue.poll()
try {
task.resume(task.task.call())
} catch (e: Throwable) {
task.resumeWithException(e)
}
} catch (e: Exception) {
e.printStackTrace()
}
}
}
override val thread = Thread {
while (!closed) {
try {
workLoop(this)
} catch (e: Exception) {
e.printStackTrace()
}
//System.err.println("$threadName worker loop finish once")
while (waitQueue.isNotEmpty()) try {
waitQueue.poll().run()
} catch (e: Exception) {
e.printStackTrace()
}
while (taskQueue.isNotEmpty()) try {
val task = taskQueue.poll()
try {
task.resume(task.task.call())
} catch (e: Throwable) {
task.resumeWithException(e)
}
} catch (e: Exception) {
e.printStackTrace()
}
}
}
init {
thread.name = threadName
thread.isDaemon = isDaemon
thread.start()
}
init {
thread.name = threadName
thread.isDaemon = isDaemon
thread.start()
}
override fun execute(command: Runnable) {
waitQueue.add(command)
}
override fun execute(command: Runnable) {
waitQueue.add(command)
}
override fun <T> submit(task: Callable<T>): NioThreadFuture<T> {
val f = Future(task)
@Suppress("UNCHECKED_CAST")
taskQueue.add(f as Future<Any?>)
return f
}
override fun <T> submit(task: Callable<T>): NioThreadFuture<T> {
val f = Future(task)
@Suppress("UNCHECKED_CAST")
taskQueue.add(f as Future<Any?>)
return f
}
override fun close() {
closed = true
}
override fun close() {
closed = true
}
override fun wakeup() {
if (Thread.currentThread() != thread) {
selector.wakeup()
}
}
override fun wakeup() {
if (Thread.currentThread() != thread && onWakeup.compareAndSet(false, true)) {
timer.exec(50) {
onWakeup.set(false)
selector.wakeup()
}
}
}
class Future<T>(val task: Callable<T>) : NioThreadFuture<T> {
private val lock = Object()
private var exception: Throwable? = null
private var result: Pair<T, Boolean>? = null
class Future<T>(val task: Callable<T>) : NioThreadFuture<T> {
private val lock = Object()
private var exception: Throwable? = null
private var result: Pair<T, Boolean>? = null
override fun get(): T {
val result = this.result
return when {
exception != null -> throw RuntimeException(exception)
result != null -> result.first
else -> synchronized(lock) {
lock.wait()
val exception = this.exception
if (exception != null) {
throw RuntimeException(exception)
} else {
this.result!!.first
}
}
}
}
override fun get(): T {
val result = this.result
return when {
exception != null -> throw RuntimeException(exception)
result != null -> result.first
else -> synchronized(lock) {
lock.wait()
val exception = this.exception
if (exception != null) {
throw RuntimeException(exception)
} else {
this.result!!.first
}
}
}
}
fun resume(value: T) {
result = value to true
synchronized(lock) {
lock.notifyAll()
}
}
fun resume(value: T) {
result = value to true
synchronized(lock) {
lock.notifyAll()
}
}
fun resumeWithException(e: Throwable) {
exception = e
synchronized(lock) {
lock.notifyAll()
}
}
}
companion object {
val timer = WheelTimer.smoothTimer
}
fun resumeWithException(e: Throwable) {
exception = e
synchronized(lock) {
lock.notifyAll()
}
}
}
}

View File

@ -2,6 +2,13 @@ package cn.tursom.socket.server
import java.io.Closeable
/**
* 套接字服务器的基本形式提供运行关闭的基本操作
* 其应支持最基本的创建形式
* XXXServer(port) {
* // 业务逻辑
* }
*/
interface ISocketServer : Runnable, Closeable {
val port: Int
}

View File

@ -4,6 +4,10 @@ import cn.tursom.core.bytebuffer.AdvanceByteBuffer
import cn.tursom.core.bytebuffer.NioAdvanceByteBuffer
import java.nio.ByteBuffer
/**
* 内存池提供批量的等大小的 ByteBuffer
* 使用 allocate 分配内存使用 getMemory getAdvanceByteBuffer 获得内存使用 free 释放内存
*/
interface MemoryPool {
val blockSize: Int
val blockCount: Int