优化 SocketServer 结构,添加 README

This commit is contained in:
tursom 2019-10-20 15:48:16 +08:00
parent 3b78a9c998
commit 751b5dda42
19 changed files with 387 additions and 283 deletions

View File

@ -0,0 +1,35 @@
###异步套接字的协程封装
这个包实现了对异步的套接字的语句同步化封装,适用于 Kotlin 协程执行环境。
但是因为需要协程作为执行环境,所以无法在 Java 环境下正常创建。
其核心分别是对 AIO 进行封装的 AsyncAioSocket 和对 NIO 进行封装的
AsyncNioSocket。AsyncAioSocket 实现简单,但是可塑性较低,缺陷也较难解决;
AsyncNioSocket 虽然实现复杂,但是可塑性很高,优化空间大,缺陷一般也都可以解决。
---
AsyncAioSocket 和 AsyncNioSocket 分别通过对应的服务器与客户端创建。
创建一个异步服务器的形式和同步服务器的形式是完全一样的:
```kotlin
// 创建一个自带内存池的异步套接字服务器
val server = BufferedAsyncNioServer(port) { buffer->
// do any thing
// 这里都是用同步语法写出的异步套接字操作
read(buffer)
write(buffer)
}
// 异步服务器不需要创建新线程来执行
server.run()
// 异步套接字的创建既可以在普通环境下,也可以在协程环境下
val client = AsyncNioClient.connect("localhost", port)
runBlocking {
val buffer = ByteArrayAdvanceByteBuffer(1024)
// 向套接字内写数据
buffer.put("Hello!")
client.write(buffer)
// 从套接字内读数据
buffer.reset()
client.read(buffer)
log(buffer.getString())
client.close()
}
```

View File

@ -0,0 +1,38 @@
package cn.tursom.socket
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import java.net.InetSocketAddress
import java.net.SocketAddress
import java.nio.channels.AsynchronousSocketChannel
import java.nio.channels.CompletionHandler
import kotlin.coroutines.Continuation
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
import kotlin.coroutines.suspendCoroutine
object AsyncAioClient {
private val handler = object : CompletionHandler<Void, Continuation<Void?>> {
override fun completed(result: Void?, attachment: Continuation<Void?>) {
GlobalScope.launch {
attachment.resume(result)
}
}
override fun failed(exc: Throwable, attachment: Continuation<Void?>) {
GlobalScope.launch {
attachment.resumeWithException(exc)
}
}
}
suspend fun connect(host: String, port: Int): AsyncAioSocket {
@Suppress("BlockingMethodInNonBlockingContext")
return connect(AsynchronousSocketChannel.open()!!, host, port)
}
suspend fun connect(socketChannel: AsynchronousSocketChannel, host: String, port: Int): AsyncAioSocket {
suspendCoroutine<Void?> { cont -> socketChannel.connect(InetSocketAddress(host, port) as SocketAddress, cont, handler) }
return AsyncAioSocket(socketChannel)
}
}

View File

@ -1,42 +0,0 @@
package cn.tursom.socket
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import java.net.InetSocketAddress
import java.net.SocketAddress
import java.nio.channels.AsynchronousSocketChannel
import java.nio.channels.CompletionHandler
import kotlin.coroutines.Continuation
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
import kotlin.coroutines.suspendCoroutine
object AsyncClient {
private val handler = object : CompletionHandler<Void, Continuation<Void?>> {
override fun completed(result: Void?, attachment: Continuation<Void?>) {
GlobalScope.launch {
attachment.resume(result)
}
}
override fun failed(exc: Throwable, attachment: Continuation<Void?>) {
GlobalScope.launch {
attachment.resumeWithException(exc)
}
}
}
suspend fun connect(host: String, port: Int): AsyncAioSocket {
@Suppress("BlockingMethodInNonBlockingContext")
return connect(AsynchronousSocketChannel.open()!!, host, port)
}
suspend fun connect(socketChannel: AsynchronousSocketChannel, host: String, port: Int): AsyncAioSocket {
suspendCoroutine<Void?> { cont ->
socketChannel.connect(InetSocketAddress(host, port) as SocketAddress, cont,
handler
)
}
return AsyncAioSocket(socketChannel)
}
}

View File

@ -16,8 +16,8 @@ class AsyncGroupNioServer(
override val port: Int,
val threads: Int = Runtime.getRuntime().availableProcessors(),
backlog: Int = 50,
val handler: suspend AsyncNioSocket.() -> Unit
) : ISocketServer by GroupNioServer(
override val handler: suspend AsyncNioSocket.() -> Unit
) : IAsyncNioServer, ISocketServer by GroupNioServer(
port,
threads,
object : INioProtocol by AsyncNioSocket.nioSocketProtocol {

View File

@ -15,8 +15,8 @@ import java.nio.channels.SelectionKey
class AsyncNioServer(
override val port: Int,
backlog: Int = 50,
val handler: suspend AsyncNioSocket.() -> Unit
) : ISocketServer by NioServer(port, object : INioProtocol by AsyncNioSocket.nioSocketProtocol {
override val handler: suspend AsyncNioSocket.() -> Unit
) : IAsyncNioServer, ISocketServer by NioServer(port, object : INioProtocol by AsyncNioSocket.nioSocketProtocol {
override fun handleConnect(key: SelectionKey, nioThread: INioThread) {
GlobalScope.launch {
val socket = AsyncNioSocket(key, nioThread)

View File

@ -3,7 +3,6 @@ package cn.tursom.socket.server
import cn.tursom.socket.AsyncAioSocket
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch
import java.io.Closeable
import java.net.InetSocketAddress
import java.nio.channels.AsynchronousCloseException
import java.nio.channels.AsynchronousServerSocketChannel

View File

@ -0,0 +1,27 @@
package cn.tursom.socket.server
import cn.tursom.core.bytebuffer.AdvanceByteBuffer
import cn.tursom.core.bytebuffer.ByteArrayAdvanceByteBuffer
import cn.tursom.core.pool.DirectMemoryPool
import cn.tursom.core.pool.MemoryPool
import cn.tursom.core.pool.usingAdvanceByteBuffer
import cn.tursom.socket.AsyncNioSocket
class BuffedAsyncNioServer(
port: Int,
backlog: Int = 50,
memoryPool: MemoryPool,
handler: suspend AsyncNioSocket.(buffer: AdvanceByteBuffer) -> Unit
) : IAsyncNioServer by AsyncNioServer(port, backlog, {
memoryPool.usingAdvanceByteBuffer {
handler(it ?: ByteArrayAdvanceByteBuffer(memoryPool.blockSize))
}
}) {
constructor(
port: Int,
blockSize: Int = 1024,
blockCount: Int = 128,
backlog: Int = 50,
handler: suspend AsyncNioSocket.(buffer: AdvanceByteBuffer) -> Unit
) : this(port, backlog, DirectMemoryPool(blockSize, blockCount), handler)
}

View File

@ -0,0 +1,7 @@
package cn.tursom.socket.server
import cn.tursom.socket.AsyncNioSocket
interface IAsyncNioServer : ISocketServer {
val handler: suspend AsyncNioSocket.() -> Unit
}

View File

@ -1,9 +1,8 @@
package cn.tursom.socket
import cn.tursom.core.*
import java.io.*
import cn.tursom.core.put
import java.io.Closeable
import java.net.Socket
import java.net.SocketTimeoutException
/**
* 对基础的Socket做了些许封装

View File

@ -0,0 +1,11 @@
package cn.tursom.socket.server
import cn.tursom.socket.BaseSocket
interface ISimpleSocketServer : ISocketServer {
val handler: BaseSocket.() -> Unit
interface Handler {
fun handle(socket: BaseSocket)
}
}

View File

@ -4,29 +4,37 @@ import cn.tursom.core.cpuNumber
import cn.tursom.socket.BaseSocket
import java.net.ServerSocket
/**
* 这是一个自动启用多个线程来处理请求的套接字服务器
*/
class MultithreadingSocketServer(
private val serverSocket: ServerSocket,
private val threadNumber: Int = cpuNumber,
val exception: Exception.() -> Unit = {
printStackTrace()
},
override val handler: BaseSocket.() -> Unit
) : SocketServer {
) : ISimpleSocketServer {
override val port = serverSocket.localPort
constructor(
port: Int,
threadNumber: Int = cpuNumber,
exception: Exception.() -> Unit = {
printStackTrace()
},
handler: BaseSocket.() -> Unit
) : this(ServerSocket(port), threadNumber, exception, handler)
) : this(ServerSocket(port), threadNumber, handler)
constructor(
port: Int,
handler: BaseSocket.() -> Unit
) : this(port, cpuNumber, { printStackTrace() }, handler)
) : this(port, cpuNumber, handler)
constructor(
port: Int,
threadNumber: Int = cpuNumber,
handler: ISimpleSocketServer.Handler
) : this(ServerSocket(port), threadNumber, handler::handle)
constructor(
port: Int,
handler: ISimpleSocketServer.Handler
) : this(port, cpuNumber, handler::handle)
private val threadList = ArrayList<Thread>()
@ -38,7 +46,7 @@ class MultithreadingSocketServer(
try {
BaseSocket(it).handler()
} catch (e: Exception) {
e.exception()
e.printStackTrace()
}
}
}

View File

@ -4,23 +4,25 @@ import cn.tursom.socket.BaseSocket
import java.net.ServerSocket
import java.net.SocketException
/**
* 单线程套接字服务器
* 可以用多个线程同时运行该服务器可以正常工作
*/
class SingleThreadSocketServer(
private val serverSocket: ServerSocket,
val exception: Exception.() -> Unit = { printStackTrace() },
override val handler: BaseSocket.() -> Unit
) : SocketServer {
) : ISimpleSocketServer {
override val port = serverSocket.localPort
constructor(
port: Int,
exception: Exception.() -> Unit = { printStackTrace() },
handler: BaseSocket.() -> Unit
) : this(ServerSocket(port), exception, handler)
) : this(ServerSocket(port), handler)
constructor(
port: Int,
handler: BaseSocket.() -> Unit
) : this(port, { printStackTrace() }, handler)
handler: ISimpleSocketServer.Handler
) : this(ServerSocket(port), handler::handle)
override fun run() {
while (!serverSocket.isClosed) {
@ -29,14 +31,14 @@ class SingleThreadSocketServer(
try {
BaseSocket(it).handler()
} catch (e: Exception) {
e.exception()
e.printStackTrace()
}
}
} catch (e: SocketException) {
if (e.message == "Socket closed" || e.message == "cn.tursom.socket closed") {
break
} else {
e.exception()
e.printStackTrace()
}
}
}

View File

@ -1,11 +0,0 @@
package cn.tursom.socket.server
import cn.tursom.socket.BaseSocket
interface SocketServer : ISocketServer {
val handler: BaseSocket.() -> Unit
companion object {
val cpuNumber = Runtime.getRuntime().availableProcessors() //CPU处理器的个数
}
}

View File

@ -47,7 +47,7 @@ class ThreadPoolSocketServer
keepAliveTime: Long = 60_000L,
timeUnit: TimeUnit = TimeUnit.MILLISECONDS,
override val handler: BaseSocket.() -> Unit
) : SocketServer {
) : ISimpleSocketServer {
constructor(
port: Int,

View File

@ -0,0 +1,21 @@
package cn.tursom.socket.server
import cn.tursom.socket.BaseSocket
import org.junit.Test
class SingleThreadSocketServerTest {
@Test
fun testCreateServer() {
val port = 12345
// Kotlin 写法
SingleThreadSocketServer(port) {
}.close()
// Java 写法
SingleThreadSocketServer(port, object : ISimpleSocketServer.Handler {
override fun handle(socket: BaseSocket) {
}
}).close()
}
}

View File

@ -1,26 +1,35 @@
@file:Suppress("unused")
package cn.tursom.core
import sun.misc.Unsafe
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import java.net.URLDecoder
import java.net.URLEncoder
import java.security.MessageDigest
import java.security.NoSuchAlgorithmException
import java.util.*
import java.util.jar.JarFile
import kotlin.collections.ArrayList
inline fun <reified T> Array<out T?>.excludeNull(): List<T> {
val list = ArrayList<T>()
forEach { if (it != null) list.add(it) }
return list
}
fun printNonDaemonThread() {
val currentGroup = Thread.currentThread().threadGroup
val noThreads = currentGroup.activeCount()
val lstThreads = arrayOfNulls<Thread>(noThreads)
currentGroup.enumerate(lstThreads)
for (i in 0 until noThreads) {
val t = lstThreads[i]
if (t?.isDaemon != true) {
println("${System.currentTimeMillis()}: ${t?.name}")
}
}
println()
val currentGroup = Thread.currentThread().threadGroup
val noThreads = currentGroup.activeCount()
val lstThreads = arrayOfNulls<Thread>(noThreads)
currentGroup.enumerate(lstThreads)
lstThreads.excludeNull().forEach { t ->
if (!t.isDaemon) {
println("${System.currentTimeMillis()}: ${t.name}")
}
}
println()
}
fun log(log: String) = println("${System.currentTimeMillis()}: $log")
@ -31,184 +40,182 @@ val String.urlDecode: String get() = URLDecoder.decode(this, "utf-8")
val String.urlEncode: String get() = URLEncoder.encode(this, "utf-8")
inline fun <T> usingTime(action: () -> T): Long {
val t1 = System.currentTimeMillis()
action()
val t2 = System.currentTimeMillis()
return t2 - t1
val t1 = System.currentTimeMillis()
action()
val t2 = System.currentTimeMillis()
return t2 - t1
}
inline fun <T> Collection<T>.doEach(block: (T) -> Any): String {
val iterator = iterator()
if (!iterator.hasNext()) return "[]"
val sb = StringBuilder("[${block(iterator.next())}")
iterator.forEach {
sb.append(", ")
sb.append(block(it))
}
sb.append("]")
return sb.toString()
inline fun <T> Collection<T>.toString(action: (T) -> Any): String {
val iterator = iterator()
if (!iterator.hasNext()) return "[]"
val sb = StringBuilder("[${action(iterator.next())}")
iterator.forEach {
sb.append(", ")
sb.append(action(it))
}
sb.append("]")
return sb.toString()
}
//利用Unsafe绕过构造函数获取变量
val unsafe by lazy {
val field = Unsafe::class.java.getDeclaredField("theUnsafe")
//允许通过反射设置属性的值
field.isAccessible = true
field.get(null) as Unsafe
val field = Unsafe::class.java.getDeclaredField("theUnsafe")
field.isAccessible = true
field.get(null) as Unsafe
}
@Suppress("UNCHECKED_CAST")
fun <T> Class<T>.unsafeInstance() = unsafe.allocateInstance(this) as T
val Class<*>.actualTypeArguments
val Class<*>.actualTypeArguments: Array<out Type>
get() = (genericSuperclass as ParameterizedType).actualTypeArguments
fun Class<*>.isInheritanceFrom(parent: Class<*>) = parent.isAssignableFrom(this)
fun getClassName(jarPath: String): List<String> {
val myClassName = ArrayList<String>()
for (entry in JarFile(jarPath).entries()) {
val entryName = entry.name
if (entryName.endsWith(".class")) {
myClassName.add(entryName.replace("/", ".").substring(0, entryName.lastIndexOf(".")))
}
}
return myClassName
val myClassName = ArrayList<String>()
for (entry in JarFile(jarPath).entries()) {
val entryName = entry.name
if (entryName.endsWith(".class")) {
myClassName.add(entryName.replace("/", ".").substring(0, entryName.lastIndexOf(".")))
}
}
return myClassName
}
fun <T> List<T>.binarySearch(comparison: (T) -> Int): T? {
val index = binarySearch(0, size, comparison)
return if (index < 0) null
else get(index)
val index = binarySearch(0, size, comparison)
return if (index < 0) null
else get(index)
}
val cpuNumber = Runtime.getRuntime().availableProcessors()
fun String.simplifyPath(): String {
if (isEmpty()) {
return "/"
}
val strs = split("/".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
val list = LinkedList<String>()
for (str in strs) {
if (str.isEmpty() || "." == str) {
continue
}
if (".." == str) {
list.pollLast()
continue
}
list.addLast(str)
}
var result = ""
while (list.size > 0) {
result += "/" + list.pollFirst()!!
}
return if (result.isNotEmpty()) result else "/"
if (isEmpty()) {
return "."
}
val pathList = split(java.io.File.separator).dropLastWhile { it.isEmpty() }
val list = LinkedList<String>()
for (path in pathList) {
if (path.isEmpty() || "." == path) {
continue
}
if (".." == path) {
list.pollLast()
continue
}
list.addLast(path)
}
var result = ""
while (list.size > 0) {
result += java.io.File.separator + list.pollFirst()!!
}
return if (result.isNotEmpty()) result else "."
}
fun ByteArray.md5(): ByteArray? {
return try {
//获取md5加密对象
val instance = MessageDigest.getInstance("MD5")
//加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
}
return try {
//获取md5加密对象
val instance = MessageDigest.getInstance("MD5")
//加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
}
}
fun String.md5(): String? {
return toByteArray().md5()?.toHexString()
return toByteArray().md5()?.toHexString()
}
fun ByteArray.sha256(): ByteArray? {
return try {
//获取md5加密对象
val instance = MessageDigest.getInstance("SHA-256")
//加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
}
return try {
//获取md5加密对象
val instance = MessageDigest.getInstance("SHA-256")
//加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
}
}
fun String.sha256(): String? {
return toByteArray().sha256()?.toHexString()
return toByteArray().sha256()?.toHexString()
}
fun ByteArray.sha(): ByteArray? {
return try {
//获取md5加密对象
val instance = MessageDigest.getInstance("SHA")
//对字符串加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
}
return try {
//获取md5加密对象
val instance = MessageDigest.getInstance("SHA")
//对字符串加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
}
}
fun String.sha(): String? = toByteArray().sha()?.toHexString()
fun ByteArray.sha1(): ByteArray? {
return try {
//获取md5加密对象
val instance = MessageDigest.getInstance("SHA-1")
//对字符串加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
}
return try {
//获取md5加密对象
val instance = MessageDigest.getInstance("SHA-1")
//对字符串加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
}
}
fun String.sha1(): String? = toByteArray().sha1()?.toHexString()
fun ByteArray.sha384(): ByteArray? {
return try {
//获取md5加密对象
val instance = MessageDigest.getInstance("SHA-384")
//对字符串加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
}
return try {
//获取md5加密对象
val instance = MessageDigest.getInstance("SHA-384")
//对字符串加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
}
}
fun String.sha384(): String? = toByteArray().sha384()?.toHexString()
fun ByteArray.sha512(): ByteArray? {
return try {
//获取md5加密对象
val instance = MessageDigest.getInstance("SHA-512")
//对字符串加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
}
return try {
//获取md5加密对象
val instance = MessageDigest.getInstance("SHA-512")
//对字符串加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
}
}
fun String.sha512(): String? = toByteArray().sha512()?.toHexString()
fun ByteArray.toHexString(): String? {
val sb = StringBuilder()
forEach {
//获取低八位有效值+
val i: Int = it.toInt() and 0xff
//将整数转化为16进制
var hexString = Integer.toHexString(i)
if (hexString.length < 2) {
//如果是一位的话补0
hexString = "0$hexString"
}
sb.append(hexString)
}
return sb.toString()
val sb = StringBuilder()
forEach {
//获取低八位有效值+
val i: Int = it.toInt() and 0xff
//将整数转化为16进制
var hexString = Integer.toHexString(i)
if (hexString.length < 2) {
//如果是一位的话补0
hexString = "0$hexString"
}
sb.append(hexString)
}
return sb.toString()
}
fun ByteArray.toUTF8String() = String(this, Charsets.UTF_8)
@ -216,7 +223,7 @@ fun ByteArray.toUTF8String() = String(this, Charsets.UTF_8)
fun String.base64() = this.toByteArray().base64().toUTF8String()
fun ByteArray.base64(): ByteArray {
return Base64.getEncoder().encode(this)
return Base64.getEncoder().encode(this)
}
fun String.base64decode() = Base64.getDecoder().decode(this).toUTF8String()
@ -226,18 +233,18 @@ fun ByteArray.base64decode(): ByteArray = Base64.getDecoder().decode(this)
fun String.digest(type: String) = toByteArray().digest(type)?.toHexString()
fun ByteArray.digest(type: String) = try {
//获取加密对象
val instance = MessageDigest.getInstance(type)
//加密,返回字节数组
instance.digest(this)
//获取加密对象
val instance = MessageDigest.getInstance(type)
//加密,返回字节数组
instance.digest(this)
} catch (e: NoSuchAlgorithmException) {
e.printStackTrace()
null
e.printStackTrace()
null
}
fun randomInt(min: Int, max: Int) = Random().nextInt(max) % (max - min + 1) + min
fun getTAG(cls: Class<*>): String {
return cls.name.split(".").last().dropLast(10)
return cls.name.split(".").last().dropLast(10)
}

View File

@ -5,46 +5,46 @@ import cn.tursom.core.bytebuffer.NioAdvanceByteBuffer
import cn.tursom.core.datastruct.ArrayBitSet
import java.nio.ByteBuffer
class DirectMemoryPool(val blockSize: Int = 1024, val blockCount: Int = 16) : MemoryPool {
private val memoryPool = ByteBuffer.allocateDirect(blockSize * blockCount)
private val bitMap = ArrayBitSet(blockCount.toLong())
class DirectMemoryPool(override val blockSize: Int = 1024, override val blockCount: Int = 16) : MemoryPool {
private val memoryPool = ByteBuffer.allocateDirect(blockSize * blockCount)
private val bitMap = ArrayBitSet(blockCount.toLong())
/**
* @return token
*/
override fun allocate(): Int = synchronized(this) {
val index = bitMap.firstDown()
if (index in 0 until blockCount) {
bitMap.up(index)
index.toInt()
} else {
-1
}
}
/**
* @return token
*/
override fun allocate(): Int = synchronized(this) {
val index = bitMap.firstDown()
if (index in 0 until blockCount) {
bitMap.up(index)
index.toInt()
} else {
-1
}
}
override fun free(token: Int) {
if (token in 0 until blockCount) synchronized(this) {
bitMap.down(token.toLong())
}
}
override fun free(token: Int) {
if (token in 0 until blockCount) synchronized(this) {
bitMap.down(token.toLong())
}
}
override fun getMemory(token: Int): ByteBuffer? = if (token in 0 until blockCount) {
synchronized(this) {
memoryPool.limit((token + 1) * blockSize)
memoryPool.position(token * blockSize)
return memoryPool.slice()
}
} else {
null
}
override fun getMemory(token: Int): ByteBuffer? = if (token in 0 until blockCount) {
synchronized(this) {
memoryPool.limit((token + 1) * blockSize)
memoryPool.position(token * blockSize)
return memoryPool.slice()
}
} else {
null
}
override fun getAdvanceByteBuffer(token: Int): AdvanceByteBuffer? = if (token in 0 until blockCount) {
synchronized(this) {
memoryPool.limit((token + 1) * blockSize)
memoryPool.position(token * blockSize)
return NioAdvanceByteBuffer(memoryPool.slice())
}
} else {
null
}
override fun getAdvanceByteBuffer(token: Int): AdvanceByteBuffer? = if (token in 0 until blockCount) {
synchronized(this) {
memoryPool.limit((token + 1) * blockSize)
memoryPool.position(token * blockSize)
return NioAdvanceByteBuffer(memoryPool.slice())
}
} else {
null
}
}

View File

@ -6,7 +6,7 @@ import cn.tursom.core.datastruct.ArrayBitSet
import java.nio.ByteBuffer
@Suppress("MemberVisibilityCanBePrivate")
class HeapMemoryPool(val blockSize: Int = 1024, val blockCount: Int = 16) : MemoryPool {
class HeapMemoryPool(override val blockSize: Int = 1024, override val blockCount: Int = 16) : MemoryPool {
private val memoryPool = ByteBuffer.allocate(blockSize * blockCount)
private val bitMap = ArrayBitSet(blockCount.toLong())

View File

@ -5,34 +5,37 @@ import cn.tursom.core.bytebuffer.NioAdvanceByteBuffer
import java.nio.ByteBuffer
interface MemoryPool {
fun allocate(): Int
fun free(token: Int)
fun getMemory(token: Int): ByteBuffer?
fun getAdvanceByteBuffer(token: Int): AdvanceByteBuffer? {
val buffer = getMemory(token)
return if (buffer != null) {
NioAdvanceByteBuffer(buffer)
} else {
null
}
}
val blockSize: Int
val blockCount: Int
fun allocate(): Int
fun free(token: Int)
fun getMemory(token: Int): ByteBuffer?
fun getAdvanceByteBuffer(token: Int): AdvanceByteBuffer? {
val buffer = getMemory(token)
return if (buffer != null) {
NioAdvanceByteBuffer(buffer)
} else {
null
}
}
}
inline fun MemoryPool.usingMemory(action: (ByteBuffer?) -> Unit) {
val token = allocate()
try {
action(getMemory(token))
} finally {
free(token)
}
inline fun <T> MemoryPool.usingMemory(action: (ByteBuffer?) -> T): T {
val token = allocate()
return try {
action(getMemory(token))
} finally {
free(token)
}
}
inline fun MemoryPool.usingAdvanceByteBuffer(action: (AdvanceByteBuffer?) -> Unit) {
val token = allocate()
try {
action(getAdvanceByteBuffer(token))
} finally {
free(token)
}
inline fun <T> MemoryPool.usingAdvanceByteBuffer(action: (AdvanceByteBuffer?) -> T): T {
val token = allocate()
return try {
action(getAdvanceByteBuffer(token))
} finally {
free(token)
}
}