diff --git a/utils/src/main/kotlin/cn/tursom/utils/coroutine/utils.kt b/utils/src/main/kotlin/cn/tursom/utils/coroutine/utils.kt index ab791e2..3eeb81a 100644 --- a/utils/src/main/kotlin/cn/tursom/utils/coroutine/utils.kt +++ b/utils/src/main/kotlin/cn/tursom/utils/coroutine/utils.kt @@ -157,7 +157,18 @@ suspend inline fun getContinuation(): Continuation<*> { return (getContinuation.cast Continuation<*>>()).invoke() } -suspend inline fun injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext = CoroutineLocalContext()): Boolean { +suspend fun injectCoroutineContext( + coroutineContext: CoroutineContext, + key: CoroutineContext.Key? = null +): Boolean { + return if (key == null || coroutineContext[key] == null) { + getContinuation().injectCoroutineContext(coroutineContext, key) + } else { + true + } +} + +suspend fun injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext? = null): Boolean { return if (coroutineContext[CoroutineLocalContext] == null) { getContinuation().injectCoroutineLocalContext(coroutineLocalContext) } else { @@ -165,13 +176,25 @@ suspend inline fun injectCoroutineLocalContext(coroutineLocalContext: CoroutineL } } +fun Continuation<*>.injectCoroutineLocalContext( + coroutineLocalContext: CoroutineLocalContext? = null +): Boolean { + return if (context[CoroutineLocalContext] == null) { + injectCoroutineContext(coroutineLocalContext ?: CoroutineLocalContext(), CoroutineLocalContext) + } else { + true + } +} + private val BaseContinuationImpl = Class.forName("kotlin.coroutines.jvm.internal.BaseContinuationImpl") private val BaseContinuationImplCompletion = BaseContinuationImpl.getDeclaredField("completion").apply { isAccessible = true } - -fun Continuation<*>.injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext = CoroutineLocalContext()): Boolean { - if (context[CoroutineLocalContext] != null) return true +fun Continuation<*>.injectCoroutineContext( + coroutineContext: CoroutineContext, + key: CoroutineContext.Key? = null +): Boolean { + if (key != null && context[key] != null) return true if (BaseContinuationImpl.isInstance(this)) { - BaseContinuationImplCompletion.get(this).cast>().injectCoroutineLocalContext(coroutineLocalContext) + BaseContinuationImplCompletion.get(this).cast>().injectCoroutineContext(coroutineContext, key) } combinedContext(context) if (context[CoroutineLocalContext] != null) return true @@ -180,8 +203,7 @@ fun Continuation<*>.injectCoroutineLocalContext(coroutineLocalContext: Coroutine return@forAllFields } it.isAccessible = true - val coroutineContext = it.get(this).cast() - it.set(this, coroutineContext + coroutineLocalContext) + it.set(this, it.get(this).cast() + coroutineContext) } return context[CoroutineLocalContext] != null }