From e5c7633f726e4ab49753bf452b452963ca6c1ab5 Mon Sep 17 00:00:00 2001 From: tursom Date: Sat, 18 Jul 2020 22:50:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E5=8A=A0=E5=AE=8C=E5=96=84=E5=8D=8F?= =?UTF-8?q?=E7=A8=8B=E6=9C=AC=E5=9C=B0=E5=8F=98=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kotlin/cn/tursom/utils/coroutine/utils.kt | 36 +++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/utils/src/main/kotlin/cn/tursom/utils/coroutine/utils.kt b/utils/src/main/kotlin/cn/tursom/utils/coroutine/utils.kt index ab791e2..3eeb81a 100644 --- a/utils/src/main/kotlin/cn/tursom/utils/coroutine/utils.kt +++ b/utils/src/main/kotlin/cn/tursom/utils/coroutine/utils.kt @@ -157,7 +157,18 @@ suspend inline fun getContinuation(): Continuation<*> { return (getContinuation.cast Continuation<*>>()).invoke() } -suspend inline fun injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext = CoroutineLocalContext()): Boolean { +suspend fun injectCoroutineContext( + coroutineContext: CoroutineContext, + key: CoroutineContext.Key? = null +): Boolean { + return if (key == null || coroutineContext[key] == null) { + getContinuation().injectCoroutineContext(coroutineContext, key) + } else { + true + } +} + +suspend fun injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext? = null): Boolean { return if (coroutineContext[CoroutineLocalContext] == null) { getContinuation().injectCoroutineLocalContext(coroutineLocalContext) } else { @@ -165,13 +176,25 @@ suspend inline fun injectCoroutineLocalContext(coroutineLocalContext: CoroutineL } } +fun Continuation<*>.injectCoroutineLocalContext( + coroutineLocalContext: CoroutineLocalContext? = null +): Boolean { + return if (context[CoroutineLocalContext] == null) { + injectCoroutineContext(coroutineLocalContext ?: CoroutineLocalContext(), CoroutineLocalContext) + } else { + true + } +} + private val BaseContinuationImpl = Class.forName("kotlin.coroutines.jvm.internal.BaseContinuationImpl") private val BaseContinuationImplCompletion = BaseContinuationImpl.getDeclaredField("completion").apply { isAccessible = true } - -fun Continuation<*>.injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext = CoroutineLocalContext()): Boolean { - if (context[CoroutineLocalContext] != null) return true +fun Continuation<*>.injectCoroutineContext( + coroutineContext: CoroutineContext, + key: CoroutineContext.Key? = null +): Boolean { + if (key != null && context[key] != null) return true if (BaseContinuationImpl.isInstance(this)) { - BaseContinuationImplCompletion.get(this).cast>().injectCoroutineLocalContext(coroutineLocalContext) + BaseContinuationImplCompletion.get(this).cast>().injectCoroutineContext(coroutineContext, key) } combinedContext(context) if (context[CoroutineLocalContext] != null) return true @@ -180,8 +203,7 @@ fun Continuation<*>.injectCoroutineLocalContext(coroutineLocalContext: Coroutine return@forAllFields } it.isAccessible = true - val coroutineContext = it.get(this).cast() - it.set(this, coroutineContext + coroutineLocalContext) + it.set(this, it.get(this).cast() + coroutineContext) } return context[CoroutineLocalContext] != null }