更加完善协程本地变量

This commit is contained in:
tursom 2020-07-18 22:50:12 +08:00
parent 46cf3932f3
commit e5c7633f72

View File

@ -157,7 +157,18 @@ suspend inline fun getContinuation(): Continuation<*> {
return (getContinuation.cast<suspend () -> Continuation<*>>()).invoke() return (getContinuation.cast<suspend () -> Continuation<*>>()).invoke()
} }
suspend inline fun injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext = CoroutineLocalContext()): Boolean { suspend fun injectCoroutineContext(
coroutineContext: CoroutineContext,
key: CoroutineContext.Key<out CoroutineContext.Element>? = null
): Boolean {
return if (key == null || coroutineContext[key] == null) {
getContinuation().injectCoroutineContext(coroutineContext, key)
} else {
true
}
}
suspend fun injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext? = null): Boolean {
return if (coroutineContext[CoroutineLocalContext] == null) { return if (coroutineContext[CoroutineLocalContext] == null) {
getContinuation().injectCoroutineLocalContext(coroutineLocalContext) getContinuation().injectCoroutineLocalContext(coroutineLocalContext)
} else { } else {
@ -165,13 +176,25 @@ suspend inline fun injectCoroutineLocalContext(coroutineLocalContext: CoroutineL
} }
} }
fun Continuation<*>.injectCoroutineLocalContext(
coroutineLocalContext: CoroutineLocalContext? = null
): Boolean {
return if (context[CoroutineLocalContext] == null) {
injectCoroutineContext(coroutineLocalContext ?: CoroutineLocalContext(), CoroutineLocalContext)
} else {
true
}
}
private val BaseContinuationImpl = Class.forName("kotlin.coroutines.jvm.internal.BaseContinuationImpl") private val BaseContinuationImpl = Class.forName("kotlin.coroutines.jvm.internal.BaseContinuationImpl")
private val BaseContinuationImplCompletion = BaseContinuationImpl.getDeclaredField("completion").apply { isAccessible = true } private val BaseContinuationImplCompletion = BaseContinuationImpl.getDeclaredField("completion").apply { isAccessible = true }
fun Continuation<*>.injectCoroutineContext(
fun Continuation<*>.injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext = CoroutineLocalContext()): Boolean { coroutineContext: CoroutineContext,
if (context[CoroutineLocalContext] != null) return true key: CoroutineContext.Key<out CoroutineContext.Element>? = null
): Boolean {
if (key != null && context[key] != null) return true
if (BaseContinuationImpl.isInstance(this)) { if (BaseContinuationImpl.isInstance(this)) {
BaseContinuationImplCompletion.get(this).cast<Continuation<*>>().injectCoroutineLocalContext(coroutineLocalContext) BaseContinuationImplCompletion.get(this).cast<Continuation<*>>().injectCoroutineContext(coroutineContext, key)
} }
combinedContext(context) combinedContext(context)
if (context[CoroutineLocalContext] != null) return true if (context[CoroutineLocalContext] != null) return true
@ -180,8 +203,7 @@ fun Continuation<*>.injectCoroutineLocalContext(coroutineLocalContext: Coroutine
return@forAllFields return@forAllFields
} }
it.isAccessible = true it.isAccessible = true
val coroutineContext = it.get(this).cast<CoroutineContext>() it.set(this, it.get(this).cast<CoroutineContext>() + coroutineContext)
it.set(this, coroutineContext + coroutineLocalContext)
} }
return context[CoroutineLocalContext] != null return context[CoroutineLocalContext] != null
} }