更加完善协程本地变量

This commit is contained in:
tursom 2020-07-18 22:25:22 +08:00
parent e2f209ae56
commit 46cf3932f3
2 changed files with 27 additions and 63 deletions

View File

@ -1,3 +1,5 @@
@file:Suppress("unused")
package cn.tursom.utils.coroutine package cn.tursom.utils.coroutine
import cn.tursom.core.cast import cn.tursom.core.cast
@ -167,38 +169,31 @@ private val BaseContinuationImpl = Class.forName("kotlin.coroutines.jvm.internal
private val BaseContinuationImplCompletion = BaseContinuationImpl.getDeclaredField("completion").apply { isAccessible = true } private val BaseContinuationImplCompletion = BaseContinuationImpl.getDeclaredField("completion").apply { isAccessible = true }
fun Continuation<*>.injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext = CoroutineLocalContext()): Boolean { fun Continuation<*>.injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext = CoroutineLocalContext()): Boolean {
return if (context[CoroutineLocalContext] == null) { if (context[CoroutineLocalContext] != null) return true
if (BaseContinuationImpl.isInstance(this)) { if (BaseContinuationImpl.isInstance(this)) {
BaseContinuationImplCompletion.get(this).cast<Continuation<*>>().injectCoroutineLocalContext(coroutineLocalContext) BaseContinuationImplCompletion.get(this).cast<Continuation<*>>().injectCoroutineLocalContext(coroutineLocalContext)
}
combinedContext(context)
if (context[CoroutineLocalContext] == null) {
javaClass.forAllFields {
if (!it.type.isInheritanceFrom(CoroutineContext::class.java)) {
return@forAllFields
}
it.isAccessible = true
val coroutineContext = it.get(this).cast<CoroutineContext>()
it.set(this, coroutineContext + coroutineLocalContext)
}
context[CoroutineLocalContext] != null
} else {
true
}
} else {
true
} }
combinedContext(context)
if (context[CoroutineLocalContext] != null) return true
javaClass.forAllFields {
if (!it.type.isInheritanceFrom(CoroutineContext::class.java)) {
return@forAllFields
}
it.isAccessible = true
val coroutineContext = it.get(this).cast<CoroutineContext>()
it.set(this, coroutineContext + coroutineLocalContext)
}
return context[CoroutineLocalContext] != null
} }
private val combinedContextClass = Class.forName("kotlin.coroutines.CombinedContext") private val combinedContextClass = Class.forName("kotlin.coroutines.CombinedContext")
private val left = combinedContextClass.getDeclaredField("left").apply { isAccessible = true } private val left = combinedContextClass.getDeclaredField("left").apply { isAccessible = true }
fun combinedContext(coroutineContext: CoroutineContext): Boolean { fun combinedContext(coroutineContext: CoroutineContext): Boolean {
return if (coroutineContext.javaClass == combinedContextClass && coroutineContext[CoroutineLocalContext] == null) { if (!combinedContextClass.isInstance(coroutineContext)) return false
if (coroutineContext[CoroutineLocalContext] == null) {
val leftObj = left.get(coroutineContext).cast<CoroutineContext>() val leftObj = left.get(coroutineContext).cast<CoroutineContext>()
left.set(coroutineContext, leftObj + CoroutineLocalContext()) left.set(coroutineContext, leftObj + CoroutineLocalContext())
true
} else {
false
} }
return true
} }

View File

@ -1,9 +1,8 @@
package cn.tursom.utils.coroutine package cn.tursom.utils.coroutine
import cn.tursom.core.allFields
import cn.tursom.core.cast
import kotlin.coroutines.Continuation
import kotlin.coroutines.coroutineContext import kotlin.coroutines.coroutineContext
import kotlin.coroutines.resume
import kotlin.coroutines.suspendCoroutine
val testCoroutineLocal = CoroutineLocal<Int>() val testCoroutineLocal = CoroutineLocal<Int>()
@ -12,47 +11,17 @@ suspend fun testCustomContext() {
testInlineCustomContext() testInlineCustomContext()
} }
fun Any.printMsg() { suspend fun testInlineCustomContext() {
javaClass.allFields.forEach {
it.isAccessible = true
println("${it.type} ${it.name} = ${it.get(this)}")
val value: Any? = it.get(this)
println("${value?.javaClass} $value")
println(it.get(this) == this)
if (it.name == "completion") {
println((value as Continuation<*>).context)
}
println()
}
}
val BaseContinuationImpl = Class.forName("kotlin.coroutines.jvm.internal.BaseContinuationImpl")
val BaseContinuationImplCompletion = BaseContinuationImpl.getDeclaredField("completion").apply { isAccessible = true }
fun Continuation<*>.rootCompletion(): Continuation<*> {
var completion = this.javaClass.allFields.firstOrNull { it.name == "completion" }
val coroutineLocalContext = CoroutineLocalContext()
@Suppress("NAME_SHADOWING") var continuation = this
while (completion != null) {
continuation.injectCoroutineLocalContext(coroutineLocalContext)
completion.isAccessible = true
val newContinuation = completion.get(continuation)?.cast<Continuation<*>>() ?: return continuation
if (newContinuation == continuation) {
return continuation
}
completion = newContinuation.javaClass.allFields.firstOrNull { it.name == "completion" }
continuation = newContinuation
}
continuation.injectCoroutineLocalContext(coroutineLocalContext)
return continuation
}
suspend inline fun testInlineCustomContext() {
println(coroutineContext) println(coroutineContext)
println("===================") println("===================")
} }
suspend fun main() { suspend fun main() {
println(getContinuation())
suspendCoroutine<Int> { cont ->
println(cont)
cont.resume(0)
}
testCustomContext() testCustomContext()
println(testCoroutineLocal.get()) println(testCoroutineLocal.get())
testInlineCustomContext() testInlineCustomContext()