This commit is contained in:
tursom 2022-05-13 23:16:40 +08:00
parent 5a89be88df
commit 78b79bedec
18 changed files with 249 additions and 246 deletions

View File

@ -17,6 +17,13 @@ class ArrayContextEnv : ContextEnv {
private val idGenerator: AtomicInteger,
) : Context {
private var array = arrayOfNulls<Any?>(idGenerator.get())
override fun get(id: Int): Any? {
return if (array.size > id) {
array[id]
} else {
null
}
}
override operator fun <T> get(key: ContextKey<T>): T? {
checkEnv(key)
@ -42,7 +49,8 @@ class ArrayContextEnv : ContextEnv {
private val idGenerator: AtomicInteger,
) : Context {
override fun <T> get(key: ContextKey<T>): T? = null
override fun <T> get(key: DefaultContextKey<T>): T = key.provider()
override fun get(id: Int): Any? = null
override fun <T> get(key: DefaultContextKey<T>): T = key.provider(this)
override fun <T> set(key: ContextKey<T>, value: T) =
ArrayContext(envId, idGenerator).set(key, value)
}

View File

@ -3,12 +3,14 @@ package cn.tursom.core.context
interface Context {
val envId: Int
operator fun get(id: Int): Any?
operator fun <T> get(key: DefaultContextKey<T>): T {
var value = get(key as ContextKey<T>)
return if (value != null) {
value
} else {
value = key.provider()
value = key.provider(this)
set(key, value)
value
}

View File

@ -1,10 +1,21 @@
package cn.tursom.core.context
import cn.tursom.core.uncheckedCast
open class ContextKey<T>(
val envId: Int,
val id: Int,
) {
fun withDefault(provider: () -> T) = DefaultContextKey(envId, id, provider)
fun withDefault(provider: ContextKey<T>.(Context) -> T) = DefaultContextKey(envId, id, provider)
fun withSynchronizedDefault(provider: ContextKey<T>.(Context) -> T) = DefaultContextKey<T>(envId, id) { context ->
synchronized(context) {
context[id]?.uncheckedCast() ?: run {
val value = provider(context)
context[this] = value
value
}
}
}
override fun hashCode(): Int {
return id

View File

@ -1,3 +1,11 @@
package cn.tursom.core.context
class DefaultContextKey<T>(envId: Int, id: Int, val provider: () -> T) : ContextKey<T>(envId, id)
class DefaultContextKey<T>(
envId: Int,
id: Int,
private val realProvider: ContextKey<T>.(Context) -> T,
) : ContextKey<T>(envId, id) {
val provider: (Context) -> T = {
realProvider(it)
}
}

View File

@ -14,6 +14,9 @@ class HashMapContextEnv : ContextEnv {
override val envId: Int,
) : Context {
private var map = HashMap<Int, Any?>()
override fun get(id: Int): Any? {
return map[id]
}
override fun <T> get(key: ContextKey<T>): T? {
checkEnv(key)

View File

@ -468,6 +468,21 @@ inline fun <T> Any.notifyAll(action: () -> T) = synchronized(this) {
t
}
fun Any.wait() {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
(this as Object).wait()
}
fun Any.notify() {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
(this as Object).notify()
}
fun Any.notifyAll() {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
(this as Object).notifyAll()
}
inline val KClass<*>.companionObjectInstanceOrNull: Any?
get() = try {
companionObjectInstance
@ -669,11 +684,13 @@ fun File.ungz(): File {
if (!isFile) {
throw UnsupportedOperationException("$name is not file")
}
val ungzFile = File(if (name.endsWith(".gz")) {
val ungzFile = File(
if (name.endsWith(".gz")) {
name.dropLast(3)
} else {
"$name.ungz"
})
}
)
ungzFile.outputStream().use { os ->
GZIPInputStream(inputStream()).use { inputStream ->
inputStream.copyTo(os)
@ -700,3 +717,8 @@ fun StringBuilder(vararg strings: String): StringBuilder {
builder.append(value = strings)
return builder
}
@Suppress("NOTHING_TO_INLINE")
inline fun Throwable.throws(): Nothing {
throw this
}

View File

@ -0,0 +1,29 @@
package cn.tursom.core.coroutine
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
class SingletonCoroutine(
private val scope: CoroutineScope = GlobalScope,
private val handler: suspend CoroutineScope.() -> Unit
) : Runnable {
private val run = AtomicBoolean()
override fun run() = run(EmptyCoroutineContext)
fun run(context: CoroutineContext) {
if (!run.compareAndSet(false, true)) {
return
}
scope.launch(context) {
try {
handler()
} finally {
run.set(false)
}
}
}
}

View File

@ -9,6 +9,9 @@ dependencies {
implementation(project(":ts-core:ts-buffer"))
implementation(project(":ts-core:ts-pool"))
implementation(project(":ts-core:ts-datastruct"))
testApi(group = "com.google.code.gson", name = "gson", version = "2.8.9")
testApi(group = "junit", name = "junit", version = "4.13.2")
}

View File

@ -9,6 +9,8 @@ import javax.crypto.spec.SecretKeySpec
class AES(
@Suppress("CanBeParameter") val secKey: SecretKey,
) : Encrypt {
override val algorithm: String get() = "AES"
private val decryptCipher = Cipher.getInstance("AES")!!
private val encryptCipher = Cipher.getInstance("AES")!!

View File

@ -1,19 +1,16 @@
package cn.tursom.core.encrypt
import cn.tursom.core.toUTF8String
import java.security.*
import java.security.spec.X509EncodedKeySpec
import javax.crypto.Cipher
import kotlin.experimental.xor
import kotlin.math.min
import kotlin.random.Random
@Suppress("unused", "MemberVisibilityCanBePrivate")
abstract class AbstractPublicKeyEncrypt(
val algorithm: String,
final override val algorithm: String,
final override val publicKey: PublicKey,
final override val privateKey: PrivateKey? = null,
val modeOfOperation: BlockCipherModeOfOperation = BlockCipherModeOfOperation.ECB,
) : PublicKeyEncrypt {
val publicKeyEncoded get() = publicKey.encoded!!
val privateKeyEncoded get() = privateKey?.encoded
@ -40,63 +37,54 @@ abstract class AbstractPublicKeyEncrypt(
} else object : AbstractPublicKeyEncrypt(algorithm, publicKey) {
override val decryptMaxLen: Int get() = this@AbstractPublicKeyEncrypt.decryptMaxLen
override val encryptMaxLen: Int get() = this@AbstractPublicKeyEncrypt.encryptMaxLen
override val cipherAlgorithm: String get() = this@AbstractPublicKeyEncrypt.cipherAlgorithm
override fun signature(digest: String): String = this@AbstractPublicKeyEncrypt.signature(digest)
}
private val blockCipher: Encrypt = when (modeOfOperation) {
BlockCipherModeOfOperation.ECB -> ECBBlockCipher()
BlockCipherModeOfOperation.CBC -> CBCBlockCipher()
else -> TODO()
}
override var encryptInitVector: ByteArray?
get() = blockCipher.encryptInitVector
set(value) {
blockCipher.encryptInitVector = value
}
override var decryptInitVector: ByteArray?
get() = blockCipher.decryptInitVector
set(value) {
blockCipher.decryptInitVector = value
}
constructor(
algorithm: String,
keyPair: KeyPair,
modeOfOperation: BlockCipherModeOfOperation = BlockCipherModeOfOperation.ECB,
) : this(algorithm, keyPair.public as PublicKey, keyPair.private as PrivateKey, modeOfOperation = modeOfOperation)
) : this(algorithm, keyPair.public as PublicKey, keyPair.private as PrivateKey)
constructor(
algorithm: String,
keySize: Int = 1024,
modeOfOperation: BlockCipherModeOfOperation = BlockCipherModeOfOperation.ECB,
) : this(
algorithm,
KeyPairGenerator.getInstance(algorithm).let {
it.initialize(keySize)
it.generateKeyPair()
},
modeOfOperation = modeOfOperation
)
constructor(
algorithm: String,
publicKey: ByteArray,
modeOfOperation: BlockCipherModeOfOperation = BlockCipherModeOfOperation.ECB,
) : this(
algorithm,
KeyFactory.getInstance(algorithm).generatePublic(X509EncodedKeySpec(publicKey)) as PublicKey,
modeOfOperation = modeOfOperation
)
override fun encrypt(data: ByteArray, offset: Int, size: Int): ByteArray = blockCipher.encrypt(data, offset, size)
override fun decrypt(data: ByteArray, offset: Int, size: Int): ByteArray = blockCipher.decrypt(data, offset, size)
override fun encrypt(data: ByteArray, offset: Int, size: Int): ByteArray {
if (size < encryptMaxLen) {
return encryptCipher.doFinal(data, offset, size)
}
var encrypted = 0
while (encrypted < size) {
val enc = min(size - encrypted, encryptMaxLen)
encryptCipher.update(data, offset + encrypted, enc)
encrypted += enc
}
return encryptCipher.doFinal()
}
override fun decrypt(data: ByteArray, offset: Int, size: Int): ByteArray = decryptCipher.doFinal(data, offset, size)
override fun encrypt(data: ByteArray, buffer: ByteArray, bufferOffset: Int, offset: Int, size: Int): Int =
blockCipher.encrypt(data, buffer, bufferOffset, offset, size)
encryptCipher.doFinal(data, offset, size, buffer, bufferOffset)
override fun decrypt(data: ByteArray, buffer: ByteArray, bufferOffset: Int, offset: Int, size: Int): Int =
blockCipher.decrypt(data, buffer, bufferOffset, offset, size)
decryptCipher.doFinal(data, offset, size, buffer, bufferOffset)
protected open fun signature(digest: String) = "${digest}with$algorithm"
@ -132,161 +120,7 @@ abstract class AbstractPublicKeyEncrypt(
return result
}
protected inner class ECBBlockCipher : Encrypt {
override fun encrypt(data: ByteArray, offset: Int, size: Int): ByteArray {
return if (size < encryptMaxLen) {
encryptCipher.doFinal(data, offset, size)
} else {
val buffer = ByteArray(((size - 1) / encryptMaxLen + 1) * decryptMaxLen)
buffer.copyOf(doFinal(data, offset, size, buffer, encryptCipher, encryptMaxLen))
}
}
override fun decrypt(data: ByteArray, offset: Int, size: Int): ByteArray {
return if (data.size < decryptMaxLen) {
decryptCipher.doFinal(data, offset, size)
} else {
val buffer = ByteArray(size / decryptMaxLen * encryptMaxLen + 11)
buffer.copyOf(doFinal(data, offset, size, buffer, decryptCipher, decryptMaxLen))
}
}
override fun encrypt(data: ByteArray, buffer: ByteArray, bufferOffset: Int, offset: Int, size: Int): Int {
return if (data.size < decryptMaxLen) {
encryptCipher.doFinal(data, offset, size, buffer, bufferOffset)
} else {
doFinal(data, offset, size, buffer, encryptCipher, decryptMaxLen, bufferOffset)
}
}
override fun decrypt(data: ByteArray, buffer: ByteArray, bufferOffset: Int, offset: Int, size: Int): Int {
return if (data.size < decryptMaxLen) {
decryptCipher.doFinal(data, offset, size, buffer, bufferOffset)
} else {
doFinal(data, offset, size, buffer, decryptCipher, decryptMaxLen, bufferOffset)
}
}
private fun doFinal(
data: ByteArray,
offset: Int,
size: Int,
buffer: ByteArray,
cipher: Cipher,
blockSize: Int,
bufferOffset: Int = 0,
): Int {
var readPosition = offset
var writeIndex = bufferOffset
while (readPosition + blockSize < size) {
writeIndex += cipher.doFinal(data, readPosition, blockSize, buffer, writeIndex)
readPosition += blockSize
}
writeIndex += cipher.doFinal(data, readPosition, size - readPosition, buffer, writeIndex)
return writeIndex - bufferOffset
}
}
protected inner class CBCBlockCipher : Encrypt {
override var encryptInitVector: ByteArray? = Random.nextBytes(encryptMaxLen)
set(value) {
value ?: return
field = value
encBuf = value
}
override var decryptInitVector: ByteArray? = null
set(value) {
field = value
decBuf = value
}
private var encBuf = encryptInitVector!!
private var decBuf: ByteArray? = decryptInitVector
override fun encrypt(data: ByteArray, offset: Int, size: Int): ByteArray {
val buffer = ByteArray(((size - 1) / encryptMaxLen + 1) * decryptMaxLen)
//return buffer.copyOf(encrypt(data, buffer, 0, offset, size))
encrypt(data, buffer, 0, offset, size)
return buffer
}
override fun encrypt(data: ByteArray, buffer: ByteArray, bufferOffset: Int, offset: Int, size: Int): Int {
var end = offset
var start: Int
var writeIndex = bufferOffset
do {
start = end
end += encryptMaxLen
end = min(data.size, end)
(0 until end - start).forEach { index ->
encBuf[index] = encBuf[index] xor data[start + index]
}
writeIndex += encryptCipher.doFinal(encBuf, 0, encBuf.size, buffer, writeIndex)
//println("${data.size} $start->$end $writeIndex")
} while (end < offset + size)
return writeIndex - bufferOffset
}
override fun decrypt(data: ByteArray, offset: Int, size: Int): ByteArray {
val decryptInitVector = decBuf!!
var start: Int
var end = offset
val buffer = ByteArray(((size - 1) / decryptMaxLen + 1) * encryptMaxLen + 11)
var writeIndex = 0
do {
start = end
end += decryptMaxLen
end = min(data.size, end)
println("${data.size}, $start->$end, ${buffer.size}, $writeIndex")
val writeIndexBefore = writeIndex
writeIndex += decryptCipher.doFinal(data, start, end - start, buffer, writeIndex)
if (start == 0) {
repeat(encryptMaxLen) {
buffer[it] = buffer[it] xor decryptInitVector[it]
}
} else {
repeat(writeIndex - writeIndexBefore) {
buffer[writeIndexBefore + it] = buffer[writeIndexBefore + it] xor data[start + it]
}
}
} while (end < offset + size)
decBuf = buffer.copyOfRange(buffer.size - encryptMaxLen, buffer.size)
return buffer.copyOf(writeIndex)
}
override fun decrypt(data: ByteArray, buffer: ByteArray, bufferOffset: Int, offset: Int, size: Int): Int {
TODO("Not yet implemented")
}
//private fun doFinal(data: ByteArray, buffer: ByteArray, bufferOffset: Int, offset: Int, size: Int, cipher: Cipher): Int {
// var start = offset
// var end = offset
// var writeIndex = bufferOffset
// do {
// end += decryptMaxLen
// end = min(data.size, end)
// encBuf.indices.forEach { index ->
// encBuf[index] = encBuf[index] xor data[start + index]
// }
// writeIndex += cipher.doFinal(encBuf, 0, encBuf.size, buffer, writeIndex)
// start += decryptMaxLen
// } while (end < offset + size)
// return writeIndex - bufferOffset
//}
}
companion object {
private val random = Random(System.currentTimeMillis())
}
}
fun main() {
val source = "HelloWorld".repeat(100).toByteArray()
val rsa = RSA()
val decodeRsa = rsa.public
decodeRsa.decryptInitVector = rsa.encryptInitVector
val encrypt = rsa.encrypt(source)
//println(encrypt.toHexString())
println(decodeRsa.decrypt(encrypt).toUTF8String())
}

View File

@ -11,8 +11,8 @@ import java.security.spec.X509EncodedKeySpec
class DSA(
publicKey: DSAPublicKey,
privateKey: DSAPrivateKey? = null,
) : AbstractPublicKeyEncrypt("DSA", publicKey, privateKey) {
algorithm: String = "DSA",
) : AbstractPublicKeyEncrypt(algorithm, publicKey, privateKey) {
override val decryptMaxLen = Int.MAX_VALUE
override val encryptMaxLen = Int.MAX_VALUE
@ -24,14 +24,23 @@ class DSA(
}
}
constructor(keyPair: KeyPair) : this(keyPair.public as DSAPublicKey, keyPair.private as DSAPrivateKey)
constructor(keyPair: KeyPair, algorithm: String = "DSA") : this(
keyPair.public as DSAPublicKey,
keyPair.private as DSAPrivateKey,
algorithm,
)
constructor(keySize: Int = 1024) : this(KeyPairGenerator.getInstance("DSA").let {
constructor(keySize: Int = 1024, algorithm: String = "DSA") : this(
KeyPairGenerator.getInstance("DSA").let {
it.initialize(keySize)
it.generateKeyPair()
})
},
algorithm,
)
constructor(publicKey: ByteArray) : this(
KeyFactory.getInstance("DSA").generatePublic(X509EncodedKeySpec(publicKey)) as DSAPublicKey
constructor(publicKey: ByteArray, algorithm: String = "DSA") : this(
KeyFactory.getInstance("DSA").generatePublic(X509EncodedKeySpec(publicKey)) as DSAPublicKey,
null,
algorithm,
)
}

View File

@ -11,14 +11,16 @@ import java.security.interfaces.ECPublicKey
import java.security.spec.AlgorithmParameterSpec
import java.security.spec.ECGenParameterSpec
import java.security.spec.X509EncodedKeySpec
import javax.crypto.Cipher
import javax.crypto.NullCipher
@Suppress("unused", "MemberVisibilityCanBePrivate")
class ECC(
publicKey: ECPublicKey,
privateKey: ECPrivateKey? = null,
) : AbstractPublicKeyEncrypt("EC", publicKey, privateKey) {
algorithm: String = "EC",
) : AbstractPublicKeyEncrypt(algorithm, publicKey, privateKey) {
override val decryptMaxLen = Int.MAX_VALUE
override val encryptMaxLen = Int.MAX_VALUE
@ -30,25 +32,36 @@ class ECC(
}
}
constructor(keyPair: KeyPair) : this(keyPair.public as ECPublicKey, keyPair.private as ECPrivateKey)
constructor(keySize: Int = 256, spec: AlgorithmParameterSpec) : this(KeyPairGenerator.getInstance("EC").let {
constructor(
keyPair: KeyPair,
algorithm: String = "EC",
) : this(keyPair.public as ECPublicKey, keyPair.private as ECPrivateKey, algorithm)
constructor(
keySize: Int = 256,
spec: AlgorithmParameterSpec,
algorithm: String = "EC",
) : this(KeyPairGenerator.getInstance("EC").let {
val generator = KeyPairGenerator.getInstance("EC")
generator.initialize(spec, SecureRandom())
generator.initialize(keySize)
generator.generateKeyPair()
})
}, algorithm)
constructor(
keySize: Int = 256,
standardCurveLine: String = StandardCurveLine.secp256k1.name.replace('_', ' '),
algorithm: String = "EC",
) : this(
keySize,
ECGenParameterSpec(standardCurveLine)
ECGenParameterSpec(standardCurveLine),
algorithm,
)
constructor(keySize: Int = 256, standardCurveLine: StandardCurveLine) : this(
constructor(keySize: Int = 256, standardCurveLine: StandardCurveLine, algorithm: String = "EC") : this(
keySize,
standardCurveLine.name.replace('_', ' ')
standardCurveLine.name.replace('_', ' '),
algorithm,
)
constructor(publicKey: ByteArray) : this(
@ -57,6 +70,17 @@ class ECC(
override fun signature(digest: String): String = "${digest}withECDSA"
override val encryptCipher by lazy {
val cipher = NullCipher()
cipher.init(Cipher.ENCRYPT_MODE, privateKey ?: publicKey)
cipher
}
override val decryptCipher by lazy {
val cipher = NullCipher()
cipher.init(Cipher.DECRYPT_MODE, privateKey ?: publicKey)
cipher
}
@Suppress("EnumEntryName", "SpellCheckingInspection")
enum class StandardCurveLine {
secp224r1, `NIST_B-233`, secp160r1, secp160r2, `NIST_K-233`, sect163r2, secp128r1, sect163r1, `NIST_P-256`,
@ -74,7 +98,7 @@ class ECC(
Unsafe {
Class.forName("sun.security.ec.CurveDB").getField("nameMap").uncheckedCast<Map<String, Any>>().keys
}
} catch (e: Exception) {
} catch (e: Throwable) {
emptySet()
}
}

View File

@ -3,12 +3,7 @@ package cn.tursom.core.encrypt
import cn.tursom.core.buffer.ByteBuffer
interface Encrypt {
var encryptInitVector: ByteArray?
get() = null
set(_) {}
var decryptInitVector: ByteArray?
get() = null
set(_) {}
val algorithm: String
fun encrypt(data: ByteArray, offset: Int = 0, size: Int = data.size - offset): ByteArray
fun decrypt(data: ByteArray, offset: Int = 0, size: Int = data.size - offset): ByteArray
@ -44,3 +39,4 @@ interface Encrypt {
}
}
}

View File

@ -11,10 +11,10 @@ import java.security.spec.X509EncodedKeySpec
class RSA(
publicKey: RSAPublicKey,
privateKey: RSAPrivateKey? = null,
modeOfOperation: BlockCipherModeOfOperation = BlockCipherModeOfOperation.ECB,
) : AbstractPublicKeyEncrypt("RSA", publicKey, privateKey, modeOfOperation = modeOfOperation) {
algorithm: String = "RSA",
) : AbstractPublicKeyEncrypt(algorithm, publicKey, privateKey) {
val keySize get() = (publicKey as RSAPublicKey).modulus.bitLength()
override val decryptMaxLen get() = keySize / 8
override val encryptMaxLen get() = decryptMaxLen - 11
@ -22,37 +22,37 @@ class RSA(
if (privateKey == null) {
this
} else {
RSA(publicKey, modeOfOperation = modeOfOperation)
RSA(publicKey)
}
}
constructor(
keyPair: KeyPair,
modeOfOperation: BlockCipherModeOfOperation = BlockCipherModeOfOperation.ECB,
algorithm: String = "RSA",
) : this(
keyPair.public as RSAPublicKey,
keyPair.private as RSAPrivateKey,
modeOfOperation
algorithm,
)
constructor(
keySize: Int = 1024,
modeOfOperation: BlockCipherModeOfOperation = BlockCipherModeOfOperation.ECB,
algorithm: String = "RSA",
) : this(
KeyPairGenerator.getInstance("RSA").let {
it.initialize(keySize)
it.generateKeyPair()
},
modeOfOperation
algorithm,
)
constructor(
publicKey: ByteArray,
modeOfOperation: BlockCipherModeOfOperation = BlockCipherModeOfOperation.ECB,
algorithm: String = "RSA",
) : this(
KeyFactory.getInstance("RSA").generatePublic(X509EncodedKeySpec(publicKey)) as RSAPublicKey,
null,
modeOfOperation
algorithm,
)
}

View File

@ -0,0 +1,2 @@
package cn.tursom.core.encrypt

View File

@ -0,0 +1,18 @@
package cn.tursom.core.encrypt
import org.junit.Test
class CBCBlockCipherTest {
private val encrypt = ECC()
@Test
fun test() {
//println(Security.getProviders().map { provider ->
// provider.services.map { service ->
// service.algorithm
// }
//}.toPrettyJson())
val bytes = encrypt.sign("test".repeat(1000).toByteArray())
assert(encrypt.public.verify("test".repeat(1000).toByteArray(), bytes))
}
}

View File

@ -23,10 +23,11 @@ import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.util.InsecureTrustManagerFactory
import java.net.URI
import java.util.concurrent.ThreadFactory
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
@Suppress("unused")
@Suppress("unused", "MemberVisibilityCanBePrivate")
open class WebSocketClient<in T : WebSocketClient<T, H>, H : WebSocketHandler<T, H>>(
url: String,
open val handler: H,
@ -35,15 +36,26 @@ open class WebSocketClient<in T : WebSocketClient<T, H>, H : WebSocketHandler<T,
val compressed: Boolean = true,
val maxContextLength: Int = 4096,
private val headers: Map<String, String>? = null,
private val handshakerUri: URI? = null,
private val handshakeUri: URI? = null,
val autoRelease: Boolean = true,
var initChannel: ((ch: SocketChannel) -> Unit)? = null,
) {
companion object {
private val threadId = AtomicInteger()
private val group: EventLoopGroup = NioEventLoopGroup(0, ThreadFactory {
val thread = Thread(it, "WebSocketClient-${threadId.incrementAndGet()}")
thread.isDaemon = true
thread
})
}
private val uri: URI = URI.create(url)
var ch: Channel? = null
internal set
var closed: Boolean = false
private set
private val onOpenLock = AtomicBoolean()
val onOpen get() = onOpenLock.get()
private val hook = ShutdownHook.addHook(true) {
close()
@ -54,7 +66,18 @@ open class WebSocketClient<in T : WebSocketClient<T, H>, H : WebSocketHandler<T,
}
fun open(): ChannelFuture? {
if (!onOpenLock.compareAndSet(false, true)) {
return null
}
try {
close()
return open1()
} finally {
onOpenLock.set(false)
}
}
private fun open1(): ChannelFuture? {
val scheme = if (uri.scheme == null) "ws" else uri.scheme
val host = if (uri.host == null) "127.0.0.1" else uri.host
val port = if (uri.port == -1) {
@ -85,7 +108,7 @@ open class WebSocketClient<in T : WebSocketClient<T, H>, H : WebSocketHandler<T,
}
val handshakerAdapter = WebSocketClientHandshakerAdapter(
WebSocketClientHandshakerFactory.newHandshaker(
handshakerUri ?: uri, WebSocketVersion.V13, null, true, httpHeaders
handshakeUri ?: uri, WebSocketVersion.V13, null, true, httpHeaders
), uncheckedCast(), handler
)
val handler = WebSocketClientChannelHandler(uncheckedCast(), handler, autoRelease)
@ -186,20 +209,18 @@ open class WebSocketClient<in T : WebSocketClient<T, H>, H : WebSocketHandler<T,
}
open fun onClose() {
synchronized(this) {
closed = true
notifyAll { }
notifyAll()
}
}
fun waitClose() {
if (!closed) wait { }
if (!closed) synchronized(this) {
if (closed) {
return
}
wait()
}
companion object {
private val threadId = AtomicInteger()
val group: EventLoopGroup = NioEventLoopGroup(0, ThreadFactory {
val thread = Thread(it, "WebSocketClient-${threadId.incrementAndGet()}")
thread.isDaemon = true
thread
})
}
}

View File

@ -0,0 +1,11 @@
package cn.tursom.core.ws
import kotlin.coroutines.CoroutineContext
class WebSocketClientContainer<C : WebSocketClient<C, H>, H : WebSocketHandler<C, H>>(
val client: C,
) : CoroutineContext.Element {
companion object Key : CoroutineContext.Key<WebSocketClientContainer<*, *>>
override val key = Key
}