From 8c58b83ec211f3432064edbf96819968769fdd67 Mon Sep 17 00:00:00 2001 From: Him188 <Him188@mamoe.net> Date: Sat, 14 Dec 2019 21:25:22 +0800 Subject: [PATCH] Enhance LockFreeLinkedList --- .../utils/LockFreeLinkedList.kt | 239 +++++++++++++----- .../mirai/utils/LockFreeLinkedListTest.kt | 111 ++++++-- 2 files changed, 262 insertions(+), 88 deletions(-) diff --git a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/LockFreeLinkedList.kt b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/LockFreeLinkedList.kt index 05683a812..fb41bda83 100644 --- a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/LockFreeLinkedList.kt +++ b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/LockFreeLinkedList.kt @@ -4,7 +4,35 @@ package net.mamoe.mirai.utils import kotlinx.atomicfu.AtomicRef import kotlinx.atomicfu.atomic -import net.mamoe.mirai.utils.Node.Companion.equals +import kotlinx.atomicfu.loop + +fun <E> LockFreeLinkedList<E>.joinToString( + separator: CharSequence = ", ", + prefix: CharSequence = "[", + postfix: CharSequence = "]", + transform: ((E) -> CharSequence)? = null +): String = prefix.toString() + buildString { + this@joinToString.forEach { + if (transform != null) { + append(transform(it)) + } else append(it) + append(separator) + } +}.dropLast(2) + postfix + +/** + * Returns a [List] containing all the elements in [this] in the same order + */ +fun <E> LockFreeLinkedList<E>.toList(): List<E> = toMutableList() + +/** + * Returns a [MutableList] containing all the elements in [this] in the same order + */ +fun <E> LockFreeLinkedList<E>.toMutableList(): MutableList<E> { + val list = mutableListOf<E>() + this.forEach { list.add(it) } + return list +} /** * Implementation of lock-free LinkedList. @@ -12,12 +40,43 @@ import net.mamoe.mirai.utils.Node.Companion.equals * Modifying can be performed concurrently. * Iterating concurrency is guaranteed. */ -class LockFreeLinkedList<E> { - private val tail: Tail<E> = Tail() +open class LockFreeLinkedList<E> { + @PublishedApi + internal val tail: Tail<E> = Tail() - private val head: Head<E> = Head(tail) + @PublishedApi + internal val head: Head<E> = Head(tail) - fun add(element: E) { + fun removeFirst(): E { + while (true) { + val currentFirst = head.nextNode + if (!currentFirst.isValidElementNode()) { + throw NoSuchElementException() + } + if (head.compareAndSetNextNodeRef(currentFirst, currentFirst.nextNode)) { + return currentFirst.nodeValue + } + } + } + + fun peekFirst(): E = head.nextNode.letValueIfValid { return it } ?: throw NoSuchElementException() + + fun peekLast(): E = head.iterateBeforeFirst { it === tail }.letValueIfValid { return it } ?: throw NoSuchElementException() + + fun removeLast(): E { + while (true) { + val beforeLast = head.iterateBeforeFirst { it.nextNode === tail } + if (!beforeLast.isValidElementNode()) { + throw NoSuchElementException() + } + val last = beforeLast.nextNode + if (beforeLast.nextNodeRef.compareAndSet(last, last.nextNode)) { + return last.nodeValue + } + } + } + + fun addLast(element: E) { val node = element.asNode(tail) while (true) { @@ -28,21 +87,42 @@ class LockFreeLinkedList<E> { } } - override fun toString(): String = "[" + buildString { - this@LockFreeLinkedList.forEach { - append(it) - append(", ") + operator fun plusAssign(element: E) = this.addLast(element) + + inline fun filteringGetOrAdd(filter: (E) -> Boolean, noinline supplier: () -> E): E { + val node = LazyNode(tail, supplier) + + while (true) { + var current: Node<E> = head + + findLastNode@ while (true) { + if (current.isValidElementNode() && filter(current.nodeValue)) + return current.nodeValue + + if (current.nextNode === tail) { + if (current.compareAndSetNextNodeRef(tail, node)) { // ensure only one attempt can put the lazyNode in + return node.nodeValue + } + } + + current = current.nextNode + } } - }.dropLast(2) + "]" + } + + @PublishedApi // limitation by atomicfu + internal fun <E> Node<E>.compareAndSetNextNodeRef(expect: Node<E>, update: Node<E>) = this.nextNodeRef.compareAndSet(expect, update) + + override fun toString(): String = joinToString() @Suppress("unused") internal fun getLinkStructure(): String = buildString { head.childIterateReturnsLastSatisfying<Node<*>>({ append(it.toString()) - append("->") + append(" <- ") it.nextNode }, { it !is Tail }) - }.dropLast(2) + }.dropLast(4) fun remove(element: E): Boolean { while (true) { @@ -51,40 +131,45 @@ class LockFreeLinkedList<E> { if (toRemove === tail) { return false } - if (toRemove.nodeValue === null) { + if (toRemove.isRemoved()) { continue } - toRemove.nodeValue = null // logically remove: all the operations will recognize this node invalid + toRemove.removed.value = true // logically remove: all the operations will recognize this node invalid + + // physically remove: try to fix the link var next: Node<E> = toRemove.nextNode - while (next !== tail && next.nodeValue === null) { + while (next !== tail && next.isRemoved()) { next = next.nextNode } - if (before.nextNodeRef.compareAndSet(toRemove, next)) {// physically remove: try to fix the link + if (before.nextNodeRef.compareAndSet(toRemove, next)) { return true } } } - val size: Int get() = head.countChildIterate<Node<E>>({ it.nextNodeRef.value }, { it !is Tail }) - 1 // empty head is always included + val size: Int get() = head.countChildIterate<Node<E>>({ it.nextNode }, { it !is Tail }) - 1 // empty head is always included - operator fun contains(element: E): Boolean = head.iterateBeforeNodeValue(element) !== tail + operator fun contains(element: E): Boolean { + forEach { if (it == element) return true } + return false + } @Suppress("unused") fun containsAll(elements: Collection<E>): Boolean = elements.all { contains(it) } fun isEmpty(): Boolean = head.allMatching { it.isValidElementNode().not() } - fun forEach(block: (E) -> Unit) { + inline fun forEach(block: (E) -> Unit) { var node: Node<E> = head while (true) { - if (node === tail) return - node.letIfNotnull(block) + node.letValueIfValid(block) node = node.nextNode + if (node === tail) return } } - fun addAll(elements: Collection<E>) = elements.forEach { add(it) } + fun addAll(elements: Collection<E>) = elements.forEach { addLast(it) } @Suppress("unused") fun clear() { @@ -93,7 +178,7 @@ class LockFreeLinkedList<E> { first.childIterateReturnFirstUnsatisfying({ val n = it.nextNode it.nextNode = tail - it.nodeValue = null + it.removed.value = true n }, { it !== tail }) // clear the link structure, help GC. } @@ -417,31 +502,18 @@ class LockFreeLinkedList<E> { */ } -/** - * Returns a [List] containing all the elements in [this] in the same order - */ -fun <E> LockFreeLinkedList<E>.toList(): List<E> = toMutableList() - -/** - * Returns a [MutableList] containing all the elements in [this] in the same order - */ -fun <E> LockFreeLinkedList<E>.toMutableList(): MutableList<E> { - val list = mutableListOf<E>() - this.forEach { list.add(it) } - return list -} - // region internal @Suppress("NOTHING_TO_INLINE") -private inline fun <E> E.asNode(nextNode: Node<E>): Node<E> = Node(nextNode).apply { nodeValueRef.value = this@asNode } +private inline fun <E> E.asNode(nextNode: Node<E>): Node<E> = Node(nextNode, this) /** * Self-iterate using the [iterator], until [mustBeTrue] returns `false`. * Returns the element at the last time when the [mustBeTrue] returns `true` */ -private inline fun <N : Node<*>> N.childIterateReturnsLastSatisfying(iterator: (N) -> N, mustBeTrue: (N) -> Boolean): N { +@PublishedApi +internal inline fun <N : Node<*>> N.childIterateReturnsLastSatisfying(iterator: (N) -> N, mustBeTrue: (N) -> Boolean): N { if (!mustBeTrue(this)) return this var value: N = this @@ -497,38 +569,68 @@ private inline fun <E> E.countChildIterate(iterator: (E) -> E, mustBeTrue: (E) - } } -private class Head<E>( - nextNode: Node<E> -) : Node<E>(nextNode) +@PublishedApi +internal class LazyNode<E> @PublishedApi internal constructor( + nextNode: Node<E>, + private val valueComputer: () -> E +) : Node<E>(nextNode, null) { + private val initialized = atomic(false) -private open class Node<E>( - nextNode: Node<E>? + private val value: AtomicRef<E?> = atomic(null) + + override val nodeValue: E + get() { + @Suppress("BooleanLiteralArgument") // false positive warning + if (initialized.compareAndSet(false, true)) { // ensure only one lucky attempt can go into the if + val value = valueComputer() + this.value.value = value + return value // fast path + } + value.loop { + if (it != null) { + return it + } + } + } +} + +@PublishedApi +internal class Head<E>(nextNode: Node<E>) : Node<E>(nextNode, null) { + override fun toString(): String = "Head" + override val nodeValue: Nothing get() = error("Internal error: trying to get the value of a Head") +} + +@PublishedApi +internal open class Tail<E> : Node<E>(null, null) { + override fun toString(): String = "Tail" + override val nodeValue: Nothing get() = error("Internal error: trying to get the value of a Tail") +} + +@PublishedApi +internal open class Node<E>( + nextNode: Node<E>?, + private var initialNodeValue: E? ) { + /* internal val id: Int = nextId() - companion object { private val idCount = atomic(0) internal fun nextId() = idCount.getAndIncrement() - } + }*/ - override fun toString(): String = "Node#$id(${nodeValueRef.value})" + override fun toString(): String = "$nodeValue" + open val nodeValue: E get() = initialNodeValue ?: error("Internal error: nodeValue is not initialized") - val nodeValueRef: AtomicRef<E?> = atomic(null) - - /** - * Short cut for accessing [nodeValueRef] - */ - inline var nodeValue: E? - get() = nodeValueRef.value - set(value) { - nodeValueRef.value = value - } + val removed = atomic(false) @Suppress("LeakingThis") val nextNodeRef: AtomicRef<Node<E>> = atomic(nextNode ?: this) - inline fun <R> letIfNotnull(block: (E) -> R): R? { + inline fun <R> letValueIfValid(block: (E) -> R): R? { + if (!this.isValidElementNode()) { + return null + } val value = this.nodeValue return if (value !== null) block(value) else null } @@ -536,13 +638,12 @@ private open class Node<E>( /** * Short cut for accessing [nextNodeRef] */ - inline var nextNode: Node<E> + var nextNode: Node<E> get() = nextNodeRef.value set(value) { nextNodeRef.value = value } - /** * Returns the former node of the last node whence [filter] returns true */ @@ -563,21 +664,23 @@ private open class Node<E>( * E.g.: for `head <- 1 <- 2 <- 3 <- tail`, `iterateStopOnNodeValue(2)` returns the node whose value is 1 */ @Suppress("NOTHING_TO_INLINE") - inline fun iterateBeforeNodeValue(element: E): Node<E> = this.iterateBeforeFirst { it.nodeValueRef.value == element } + internal inline fun iterateBeforeNodeValue(element: E): Node<E> = this.iterateBeforeFirst { it.isValidElementNode() && it.nodeValue == element } + } -private open class Tail<E> : Node<E>(null) +@PublishedApi +internal fun <E> Node<E>.isRemoved() = this.removed.value +@PublishedApi @Suppress("NOTHING_TO_INLINE") -private inline fun Node<*>.isValidElementNode(): Boolean = !isHead() && !isTail() && !isRemoved() +internal inline fun Node<*>.isValidElementNode(): Boolean = !isHead() && !isTail() && !isRemoved() +@PublishedApi @Suppress("NOTHING_TO_INLINE") -private inline fun Node<*>.isHead(): Boolean = this is Head +internal inline fun Node<*>.isHead(): Boolean = this is Head +@PublishedApi @Suppress("NOTHING_TO_INLINE") -private inline fun Node<*>.isTail(): Boolean = this is Tail - -@Suppress("NOTHING_TO_INLINE") -private inline fun Node<*>.isRemoved(): Boolean = this.nodeValue == null +internal inline fun Node<*>.isTail(): Boolean = this is Tail // en dregion \ No newline at end of file diff --git a/mirai-core/src/jvmTest/kotlin/net/mamoe/mirai/utils/LockFreeLinkedListTest.kt b/mirai-core/src/jvmTest/kotlin/net/mamoe/mirai/utils/LockFreeLinkedListTest.kt index 976b65e32..eaa02fbea 100644 --- a/mirai-core/src/jvmTest/kotlin/net/mamoe/mirai/utils/LockFreeLinkedListTest.kt +++ b/mirai-core/src/jvmTest/kotlin/net/mamoe/mirai/utils/LockFreeLinkedListTest.kt @@ -23,10 +23,10 @@ internal class LockFreeLinkedListTest { @Test fun addAndGetSingleThreaded() { val list = LockFreeLinkedList<Int>() - list.add(1) - list.add(2) - list.add(3) - list.add(4) + list.addLast(1) + list.addLast(2) + list.addLast(3) + list.addLast(4) list.size shouldBeEqualTo 4 } @@ -36,7 +36,7 @@ internal class LockFreeLinkedListTest { //withContext(Dispatchers.Default){ val list = LockFreeLinkedList<Int>() - list.concurrentAdd(1000, 10, 1) + list.concurrentDo(1000, 10) { addLast(1) } list.size shouldBeEqualTo 1000 * 10 list.concurrentDo(100, 10) { @@ -51,18 +51,55 @@ internal class LockFreeLinkedListTest { fun addAndGetMassConcurrentAccess() = runBlocking { val list = LockFreeLinkedList<Int>() - val addJob = async { list.concurrentAdd(5000, 10, 1) } + val addJob = async { list.concurrentDo(2, 30000) { addLast(1) } } - delay(10) // let addJob fly + //delay(1) // let addJob fly if (addJob.isCompleted) { error("Number of elements are not enough") } - list.concurrentDo(1000, 10) { - remove(1).shouldBeTrue() + val foreachJob = async { + list.concurrentDo(1, 10000) { + forEach { it + it } + } } - addJob.join() + val removeLastJob = async { + list.concurrentDo(1, 15000) { + removeLast() shouldBeEqualTo 1 + } + } + val removeFirstJob = async { + list.concurrentDo(1, 10000) { + removeFirst() shouldBeEqualTo 1 + } + } + val addJob2 = async { + list.concurrentDo(1, 5000) { + addLast(1) + } + } + val removeExactJob = launch { + list.concurrentDo(3, 1000) { + remove(1).shouldBeTrue() + } + } + val filteringGetOrAddJob = launch { + list.concurrentDo(1, 10000) { + filteringGetOrAdd({ it == 2 }, { 1 }) + } + } + joinAll(addJob, addJob2, foreachJob, removeLastJob, removeFirstJob, removeExactJob, filteringGetOrAddJob) - list.size shouldBeEqualTo 5000 * 10 - 1000 * 10 + list.size shouldBeEqualTo 2 * 30000 - 1 * 15000 - 1 * 10000 + 1 * 5000 - 3 * 1000 + 1 * 10000 + } + + @Test + fun removeWhileForeach() { + val list = LockFreeLinkedList<Int>() + repeat(10) { list.addLast(it) } + list.forEach { + list.remove(it + 1) + } + list.peekFirst() shouldBeEqualTo 0 } @Test @@ -72,11 +109,11 @@ internal class LockFreeLinkedListTest { assertFalse { list.remove(1) } assertEquals(0, list.size) - list.add(1) + list.addLast(1) assertTrue { list.remove(1) } assertEquals(0, list.size) - list.add(2) + list.addLast(2) assertFalse { list.remove(1) } assertEquals(1, list.size) } @@ -107,6 +144,43 @@ internal class LockFreeLinkedListTest { list.toString() shouldBeEqualTo "[1, 2, 3, 4, 5]" } + @Test + fun `filteringGetOrAdd when add`() { + val list = LockFreeLinkedList<Int>() + list.addAll(listOf(1, 2, 3, 4, 5)) + val value = list.filteringGetOrAdd({ it == 6 }, { 6 }) + + println("Check value") + value shouldBeEqualTo 6 + println("Check size") + println(list.getLinkStructure()) + list.size shouldBeEqualTo 6 + } + + @Test + fun `filteringGetOrAdd when get`() { + val list = LockFreeLinkedList<Int>() + list.addAll(listOf(1, 2, 3, 4, 5)) + val value = list.filteringGetOrAdd({ it == 2 }, { 2 }) + + println("Check value") + value shouldBeEqualTo 2 + println("Check size") + println(list.getLinkStructure()) + list.size shouldBeEqualTo 5 + } + + @Test + fun `filteringGetOrAdd when empty`() { + val list = LockFreeLinkedList<Int>() + val value = list.filteringGetOrAdd({ it == 2 }, { 2 }) + + println("Check value") + value shouldBeEqualTo 2 + println("Check size") + println(list.getLinkStructure()) + list.size shouldBeEqualTo 1 + } /* @Test fun indexOf() { @@ -176,16 +250,13 @@ internal class LockFreeLinkedListTest { */ } +@UseExperimental(ExperimentalCoroutinesApi::class) @MiraiExperimentalAPI -internal suspend inline fun <E> LockFreeLinkedList<E>.concurrentAdd(numberOfCoroutines: Int, timesOfAdd: Int, element: E) = - concurrentDo(numberOfCoroutines, timesOfAdd) { add(element) } - -@MiraiExperimentalAPI -internal suspend inline fun <E : LockFreeLinkedList<*>> E.concurrentDo(numberOfCoroutines: Int, timesOfAdd: Int, crossinline todo: E.() -> Unit) = +internal suspend inline fun <E : LockFreeLinkedList<*>> E.concurrentDo(numberOfCoroutines: Int, times: Int, crossinline todo: E.() -> Unit) = coroutineScope { repeat(numberOfCoroutines) { - launch { - repeat(timesOfAdd) { + launch(start = CoroutineStart.UNDISPATCHED) { + repeat(times) { todo() } }