diff --git a/utils/src/main/kotlin/cn/tursom/utils/coroutine/utils.kt b/utils/src/main/kotlin/cn/tursom/utils/coroutine/utils.kt index c112554..ab791e2 100644 --- a/utils/src/main/kotlin/cn/tursom/utils/coroutine/utils.kt +++ b/utils/src/main/kotlin/cn/tursom/utils/coroutine/utils.kt @@ -1,3 +1,5 @@ +@file:Suppress("unused") + package cn.tursom.utils.coroutine import cn.tursom.core.cast @@ -167,38 +169,31 @@ private val BaseContinuationImpl = Class.forName("kotlin.coroutines.jvm.internal private val BaseContinuationImplCompletion = BaseContinuationImpl.getDeclaredField("completion").apply { isAccessible = true } fun Continuation<*>.injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext = CoroutineLocalContext()): Boolean { - return if (context[CoroutineLocalContext] == null) { - if (BaseContinuationImpl.isInstance(this)) { - BaseContinuationImplCompletion.get(this).cast>().injectCoroutineLocalContext(coroutineLocalContext) - } - combinedContext(context) - if (context[CoroutineLocalContext] == null) { - javaClass.forAllFields { - if (!it.type.isInheritanceFrom(CoroutineContext::class.java)) { - return@forAllFields - } - it.isAccessible = true - val coroutineContext = it.get(this).cast() - it.set(this, coroutineContext + coroutineLocalContext) - } - context[CoroutineLocalContext] != null - } else { - true - } - } else { - true + if (context[CoroutineLocalContext] != null) return true + if (BaseContinuationImpl.isInstance(this)) { + BaseContinuationImplCompletion.get(this).cast>().injectCoroutineLocalContext(coroutineLocalContext) } + combinedContext(context) + if (context[CoroutineLocalContext] != null) return true + javaClass.forAllFields { + if (!it.type.isInheritanceFrom(CoroutineContext::class.java)) { + return@forAllFields + } + it.isAccessible = true + val coroutineContext = it.get(this).cast() + it.set(this, coroutineContext + coroutineLocalContext) + } + return context[CoroutineLocalContext] != null } private val combinedContextClass = Class.forName("kotlin.coroutines.CombinedContext") private val left = combinedContextClass.getDeclaredField("left").apply { isAccessible = true } fun combinedContext(coroutineContext: CoroutineContext): Boolean { - return if (coroutineContext.javaClass == combinedContextClass && coroutineContext[CoroutineLocalContext] == null) { + if (!combinedContextClass.isInstance(coroutineContext)) return false + if (coroutineContext[CoroutineLocalContext] == null) { val leftObj = left.get(coroutineContext).cast() left.set(coroutineContext, leftObj + CoroutineLocalContext()) - true - } else { - false } + return true } \ No newline at end of file diff --git a/utils/src/test/kotlin/cn/tursom/utils/coroutine/CoroutineLocalTest.kt b/utils/src/test/kotlin/cn/tursom/utils/coroutine/CoroutineLocalTest.kt index a7eca6f..c80f630 100644 --- a/utils/src/test/kotlin/cn/tursom/utils/coroutine/CoroutineLocalTest.kt +++ b/utils/src/test/kotlin/cn/tursom/utils/coroutine/CoroutineLocalTest.kt @@ -1,9 +1,8 @@ package cn.tursom.utils.coroutine -import cn.tursom.core.allFields -import cn.tursom.core.cast -import kotlin.coroutines.Continuation import kotlin.coroutines.coroutineContext +import kotlin.coroutines.resume +import kotlin.coroutines.suspendCoroutine val testCoroutineLocal = CoroutineLocal() @@ -12,47 +11,17 @@ suspend fun testCustomContext() { testInlineCustomContext() } -fun Any.printMsg() { - javaClass.allFields.forEach { - it.isAccessible = true - println("${it.type} ${it.name} = ${it.get(this)}") - val value: Any? = it.get(this) - println("${value?.javaClass} $value") - println(it.get(this) == this) - if (it.name == "completion") { - println((value as Continuation<*>).context) - } - println() - } -} - -val BaseContinuationImpl = Class.forName("kotlin.coroutines.jvm.internal.BaseContinuationImpl") -val BaseContinuationImplCompletion = BaseContinuationImpl.getDeclaredField("completion").apply { isAccessible = true } - -fun Continuation<*>.rootCompletion(): Continuation<*> { - var completion = this.javaClass.allFields.firstOrNull { it.name == "completion" } - val coroutineLocalContext = CoroutineLocalContext() - @Suppress("NAME_SHADOWING") var continuation = this - while (completion != null) { - continuation.injectCoroutineLocalContext(coroutineLocalContext) - completion.isAccessible = true - val newContinuation = completion.get(continuation)?.cast>() ?: return continuation - if (newContinuation == continuation) { - return continuation - } - completion = newContinuation.javaClass.allFields.firstOrNull { it.name == "completion" } - continuation = newContinuation - } - continuation.injectCoroutineLocalContext(coroutineLocalContext) - return continuation -} - -suspend inline fun testInlineCustomContext() { +suspend fun testInlineCustomContext() { println(coroutineContext) println("===================") } suspend fun main() { + println(getContinuation()) + suspendCoroutine { cont -> + println(cont) + cont.resume(0) + } testCustomContext() println(testCoroutineLocal.get()) testInlineCustomContext()