更加完善协程本地变量

This commit is contained in:
tursom 2020-07-18 22:05:23 +08:00
parent 3b8baada36
commit e2f209ae56
5 changed files with 150 additions and 84 deletions

View File

@ -377,4 +377,25 @@ inline fun <reified T : Any?> Any.assert(action: T.() -> Unit): Boolean {
} else {
false
}
}
val Class<*>.allFields: List<Field>
get() {
var clazz = this
val list = ArrayList<Field>()
while (clazz != Any::class.java) {
list.addAll(clazz.declaredFields)
clazz = clazz.superclass
}
list.addAll(clazz.declaredFields)
return list
}
fun Class<*>.forAllFields(action: (Field) -> Unit) {
var clazz = this
while (clazz != Any::class.java) {
clazz.declaredFields.forEach(action)
clazz = clazz.superclass
}
clazz.declaredFields.forEach(action)
}

View File

@ -10,8 +10,13 @@ open class CoroutineLocal<T> {
open suspend fun get(): T? {
var attach: MutableMap<CoroutineLocal<*>, Any?>? = coroutineContext[CoroutineLocalContext]
if (attach == null) {
val job = coroutineContext[Job] ?: return null
attach = attachMap[job]
if (injectCoroutineLocalContext()) {
attach = coroutineContext[CoroutineLocalContext]
}
if (attach == null) {
val job = coroutineContext[Job] ?: return null
attach = attachMap[job]
}
}
return attach?.get(this)?.cast()
}
@ -19,13 +24,18 @@ open class CoroutineLocal<T> {
open suspend infix fun set(value: T): Boolean {
var attach: MutableMap<CoroutineLocal<*>, Any?>? = coroutineContext[CoroutineLocalContext]
if (attach == null) {
val job = coroutineContext[Job] ?: return false
attach = attachMap[job]
if (injectCoroutineLocalContext()) {
attach = coroutineContext[CoroutineLocalContext]
}
if (attach == null) {
attach = HashMap()
attachMap[job] = attach
job.invokeOnCompletion {
attachMap.remove(job)
val job = coroutineContext[Job] ?: return false
attach = attachMap[job]
if (attach == null) {
attach = HashMap()
attachMap[job] = attach
job.invokeOnCompletion {
attachMap.remove(job)
}
}
}
}

View File

@ -3,9 +3,14 @@ package cn.tursom.utils.coroutine
import cn.tursom.core.cast
import kotlin.coroutines.Continuation
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
class CoroutineLocalContinuation(
private val completion: Continuation<*>
) : Continuation<Any?> by completion.cast() {
override val context: CoroutineContext = completion.context + CoroutineLocalContext()
override val context: CoroutineContext = completion.context + if (completion.context[CoroutineLocalContext] == null) {
CoroutineLocalContext()
} else {
EmptyCoroutineContext
}
}

View File

@ -1,6 +1,8 @@
package cn.tursom.utils.coroutine
import cn.tursom.core.cast
import cn.tursom.core.forAllFields
import cn.tursom.core.isInheritanceFrom
import kotlinx.coroutines.*
import kotlin.coroutines.Continuation
import kotlin.coroutines.CoroutineContext
@ -122,11 +124,21 @@ fun <T> runBlockingWithEnhanceContext(
}
}
suspend inline fun <T> runWithCoroutineLocalContext(block: () -> T): T {
return (block.cast<(Continuation<*>) -> T>()).invoke(CoroutineLocalContinuation(getContinuation()))
suspend fun <T> runWithCoroutineLocalContext(
block: suspend () -> T
): T {
val continuation: Any? = getContinuation()
val coroutineLocalContinuation = if (continuation is Continuation<*>) {
CoroutineLocalContinuation(continuation.cast())
} else {
return continuation.cast()
}
return (block.cast<(Any?) -> T>()).invoke(coroutineLocalContinuation)
}
suspend inline fun <T> runWithCoroutineLocal(block: () -> T): T {
suspend fun <T> runWithCoroutineLocal(
block: suspend () -> T
): T {
if (coroutineContext[CoroutineLocalContext] == null) {
return runWithCoroutineLocalContext(block)
}
@ -141,4 +153,52 @@ inline fun getContinuation(continuation: Continuation<*>): Continuation<*> {
suspend inline fun getContinuation(): Continuation<*> {
val getContinuation: (continuation: Continuation<*>) -> Continuation<*> = ::getContinuation
return (getContinuation.cast<suspend () -> Continuation<*>>()).invoke()
}
suspend inline fun injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext = CoroutineLocalContext()): Boolean {
return if (coroutineContext[CoroutineLocalContext] == null) {
getContinuation().injectCoroutineLocalContext(coroutineLocalContext)
} else {
true
}
}
private val BaseContinuationImpl = Class.forName("kotlin.coroutines.jvm.internal.BaseContinuationImpl")
private val BaseContinuationImplCompletion = BaseContinuationImpl.getDeclaredField("completion").apply { isAccessible = true }
fun Continuation<*>.injectCoroutineLocalContext(coroutineLocalContext: CoroutineLocalContext = CoroutineLocalContext()): Boolean {
return if (context[CoroutineLocalContext] == null) {
if (BaseContinuationImpl.isInstance(this)) {
BaseContinuationImplCompletion.get(this).cast<Continuation<*>>().injectCoroutineLocalContext(coroutineLocalContext)
}
combinedContext(context)
if (context[CoroutineLocalContext] == null) {
javaClass.forAllFields {
if (!it.type.isInheritanceFrom(CoroutineContext::class.java)) {
return@forAllFields
}
it.isAccessible = true
val coroutineContext = it.get(this).cast<CoroutineContext>()
it.set(this, coroutineContext + coroutineLocalContext)
}
context[CoroutineLocalContext] != null
} else {
true
}
} else {
true
}
}
private val combinedContextClass = Class.forName("kotlin.coroutines.CombinedContext")
private val left = combinedContextClass.getDeclaredField("left").apply { isAccessible = true }
fun combinedContext(coroutineContext: CoroutineContext): Boolean {
return if (coroutineContext.javaClass == combinedContextClass && coroutineContext[CoroutineLocalContext] == null) {
val leftObj = left.get(coroutineContext).cast<CoroutineContext>()
left.set(coroutineContext, leftObj + CoroutineLocalContext())
true
} else {
false
}
}

View File

@ -1,89 +1,59 @@
package cn.tursom.utils.coroutine
import cn.tursom.core.allFields
import cn.tursom.core.cast
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlin.coroutines.Continuation
import kotlin.coroutines.coroutineContext
val testCoroutineLocal = CoroutineLocal<Int>()
suspend fun test() {
println(coroutineContext)
println(coroutineContext[Job] is CoroutineScope)
println(CoroutineScopeContext.get())
println(Thread.currentThread().name)
suspend fun testCustomContext() {
testCoroutineLocal.set(1)
testInlineCustomContext()
}
@Suppress("NOTHING_TO_INLINE")
inline fun getContinuation(continuation: Continuation<*>): Continuation<*> {
fun Any.printMsg() {
javaClass.allFields.forEach {
it.isAccessible = true
println("${it.type} ${it.name} = ${it.get(this)}")
val value: Any? = it.get(this)
println("${value?.javaClass} $value")
println(it.get(this) == this)
if (it.name == "completion") {
println((value as Continuation<*>).context)
}
println()
}
}
val BaseContinuationImpl = Class.forName("kotlin.coroutines.jvm.internal.BaseContinuationImpl")
val BaseContinuationImplCompletion = BaseContinuationImpl.getDeclaredField("completion").apply { isAccessible = true }
fun Continuation<*>.rootCompletion(): Continuation<*> {
var completion = this.javaClass.allFields.firstOrNull { it.name == "completion" }
val coroutineLocalContext = CoroutineLocalContext()
@Suppress("NAME_SHADOWING") var continuation = this
while (completion != null) {
continuation.injectCoroutineLocalContext(coroutineLocalContext)
completion.isAccessible = true
val newContinuation = completion.get(continuation)?.cast<Continuation<*>>() ?: return continuation
if (newContinuation == continuation) {
return continuation
}
completion = newContinuation.javaClass.allFields.firstOrNull { it.name == "completion" }
continuation = newContinuation
}
continuation.injectCoroutineLocalContext(coroutineLocalContext)
return continuation
}
suspend inline fun getContinuation(): Continuation<*> {
val getContinuation: (continuation: Continuation<*>) -> Continuation<*> = ::getContinuation
return (getContinuation.cast<suspend () -> Continuation<*>>()).invoke()
}
suspend fun testCustomContext(): Int? = runWithCoroutineLocal {
suspend inline fun testInlineCustomContext() {
println(coroutineContext)
return testCoroutineLocal.get()
println("===================")
}
suspend fun main(): Unit = runWithCoroutineLocal {
repeat(100) {
testCoroutineLocal.set(it)
println(testCustomContext())
}
////println(::main.javaMethod?.parameters?.get(0))
//println(coroutineContext)
//CurrentThreadCoroutineScope.launch {
// println("Unconfined : I'm working in thread ${Thread.currentThread().name}")
// delay(50)
// println("Unconfined : After delay in thread ${Thread.currentThread().name}")
// delay(50)
// println("Unconfined : After delay in thread ${Thread.currentThread().name}")
//}
//GlobalScope.launch(Dispatchers.Unconfined) { // 非受限的——将和主线程一起工作
// println("Unconfined : I'm working in thread ${Thread.currentThread().name}")
// delay(50)
// println("Unconfined : After delay in thread ${Thread.currentThread().name}")
// delay(50)
// println("Unconfined : After delay in thread ${Thread.currentThread().name}")
//}
//GlobalScope.launch { // 父协程的上下文,主 runBlocking 协程
// println("main runBlocking: I'm working in thread ${Thread.currentThread().name}")
// delay(100)
// println("main runBlocking: After delay in thread ${Thread.currentThread().name}")
// delay(100)
// println("main runBlocking: After delay in thread ${Thread.currentThread().name}")
//}
//println("end")
//delay(1000)
//runBlockingWithEnhanceContext {
// println(coroutineContext)
// println(coroutineContext[Job] is CoroutineScope)
// println(CoroutineScopeContext.get())
// println(Thread.currentThread().name)
// CoroutineContextScope(coroutineContext).launch {
// println(coroutineContext)
// println(coroutineContext[Job] is CoroutineScope)
// println(CoroutineScopeContext.get())
// println(Thread.currentThread().name)
// }.join()
// //CoroutineScopeContext.get().launchWithEnhanceContext {
// // println(coroutineContext)
// // println(coroutineContext[Job] is CoroutineScope)
// // println(CoroutineScopeContext.get())
// // println(Thread.currentThread().name)
// // CoroutineScopeContext.get().launchWithEnhanceContext {
// // println(coroutineContext)
// // println(coroutineContext[Job] is CoroutineScope)
// // println(CoroutineScopeContext.get())
// // println(Thread.currentThread().name)
// // }
// //}.join()
// delay(1000)
// println(CoroutineLocal)
//}
suspend fun main() {
testCustomContext()
println(testCoroutineLocal.get())
testInlineCustomContext()
}