diff --git a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/LockFreeLinkedList.kt b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/LockFreeLinkedList.kt new file mode 100644 index 000000000..d71df904c --- /dev/null +++ b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/LockFreeLinkedList.kt @@ -0,0 +1,389 @@ +@file:Suppress("NOTHING_TO_INLINE") + +package net.mamoe.mirai.utils + +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic + + +@MiraiExperimentalAPI +inline fun lockFreeLinkedListOf(vararg elements: E): LockFreeLinkedList = LockFreeLinkedList().apply { + addAll(elements) +} + +@MiraiExperimentalAPI +inline fun lockFreeLinkedListOf(): LockFreeLinkedList = LockFreeLinkedList() + +/** + * 无锁链表实现. 元素值不能为 null + */ +@MiraiExperimentalAPI +class LockFreeLinkedList : MutableList, RandomAccess { + private val tail: Tail = Tail() + + private val head: Head = Head(tail) + + override fun add(element: E): Boolean { + val node = element.asNode(tail) + + while (true) { + val tail = head.iterateBeforeFirst { it === tail } // find the last node. + if (tail.nextNodeRef.compareAndSet(this.tail, node)) { // ensure the last node is the last node + return true + } + } + + + } + + internal fun getLinkStucture(): String = buildString { + head.childIterate>({ + append(it.toString()) + append("->") + it.nextNode + }, { it !is Tail }) + }.let { + if (it.lastIndex > 0) { + it.substring(0..it.lastIndex - 2) + } else it + } + + override fun remove(element: E): Boolean { + while (true) { + val before = head.iterateBeforeNodeValue(element) + val toRemove = before.nextNode + val next = toRemove.nextNode + if (toRemove === tail) { + return false + } + + if (before.nextNodeRef.compareAndSet(toRemove, next)) { + return true + } + } + } + + private fun removeNode(node: Node): Boolean { + if (node == tail) { + return false + } + while (true) { + val before = head.iterateBeforeFirst { it === node } + val toRemove = before.nextNode + val next = toRemove.nextNode + if (toRemove == tail) { // This + return true + } + toRemove.nodeValue = null // logaically remove first, then all the operations will recognize this node invalid + + if (before.nextNodeRef.compareAndSet(toRemove, next)) { // physically remove: try to fix the link + return true + } + } + } + + override val size: Int + get() = head.countChildIterate>({ it.nextNodeRef.value }, { it !is Tail }) - 1 // empty head is always included + + override operator fun contains(element: E): Boolean = head.iterateBeforeNodeValue(element) !== tail + + override fun containsAll(elements: Collection): Boolean = elements.all { contains(it) } + + override operator fun get(index: Int): E { + require(index >= 0) { "Index must be >= 0" } + var i = index + 1 // 1 for head + return head.iterateStopOnFirst { i-- == 0 }.nodeValueRef.value ?: noSuchElement() + } + + override fun indexOf(element: E): Int { + var i = -1 // head + if (!head.iterateStopOnFirst { + i++ + it.nodeValueRef.value == element + }.isValidElementNode()) { + return -1 + } + return i - 1 // iteration is stopped at the next node + } + + override fun isEmpty(): Boolean = head.allMatching { it.nodeValueRef.value == null } + + /** + * Create a concurrent-unsafe iterator + */ + override operator fun iterator(): MutableIterator = object : MutableIterator { + var currentNode: Node + get() = currentNoderef.value + set(value) { + currentNoderef.value = value + } + + private var currentNoderef: AtomicRef> = atomic(head) // concurrent compatibility + + /** + * Check if + * + * **Notice That:** + * if `hasNext` returned `true`, then the last remaining element is removed concurrently, + * [next] will produce a [NoSuchElementException] + */ + override fun hasNext(): Boolean = !currentNode.iterateStopOnFirst { it.isValidElementNode() }.isTail() + + /** + * Iterate until the next node is not + */ + override fun next(): E { + while (true) { + val next = currentNode.nextNode + if (next.isTail()) noSuchElement() + + currentNode = next + + val nodeValue = next.nodeValue + if (nodeValue != null) { // next node is not removed, that's what we want + return nodeValue + } // or try again + } + } + + override fun remove() { + if (!removeNode(currentNode)) { // search from head onto the node, concurrent compatibility + noSuchElement() + } + } + } + + /** + * Find the last index of the element in the list that is [equals] to [element], with concurrent compatibility. + * + * For a typical list, say `head <- Node#1(1) <- Node#2(2) <- Node#3(3) <- Node#4(4) <- Node#5(2) <- tail`, + * the procedures of `lastIndexOf(2)` is: + * + * 1. Iterate each element, until 2 is found, accumulate the index found, which is 1 + * 2. Search again from the first matching element, which is Node#2 + * 3. Accumulate the index found. + * 4. Repeat 2,3 until the `tail` is reached. + * + * Concurrent changes may influence the result. + * While searching, + * + */ + override fun lastIndexOf(element: E): Int { + var lastMatching: Node = head + var searchStartingFrom: Node = lastMatching + var index = 0 // accumulated index from each search + + findTheLastMatchingElement@ while (true) { // iterate to find the last matching element. + var timesOnThisTurn = if (searchStartingFrom === head) -1 else 0 // ignore the head + val got = searchStartingFrom.nextNode.iterateBeforeFirst { timesOnThisTurn++; it.nodeValue == element } + // find the first match starting from `searchStartingFrom` + + if (got.isTail()) break@findTheLastMatchingElement // no further elements + check(timesOnThisTurn >= 0) { "Internal check failed: too many times ran: $timesOnThisTurn" } + + searchStartingFrom = got.nextNode + index += timesOnThisTurn + + if (!got.isRemoved()) lastMatching = got // only record the lastMatching if got is not removed. + } + + if (!lastMatching.isValidElementNode()) { + // found is invalid means not found + return -1 + } + + return index + } + + override fun listIterator(): MutableListIterator = listIterator0(0) + override fun listIterator(index: Int): MutableListIterator = listIterator0(index) + + @Suppress("NOTHING_TO_INLINE") + internal inline fun listIterator0(index: Int): MutableListIterator { + TODO() + } + + override fun subList(fromIndex: Int, toIndex: Int): MutableList { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + + override fun add(index: Int, element: E) { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + override fun addAll(index: Int, elements: Collection): Boolean { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + override fun addAll(elements: Collection): Boolean { + elements.forEach { add(it) } + return true + } + + override fun clear() { + head.nextNode = tail + + // TODO: 2019/12/13 check ? + } + + override fun removeAll(elements: Collection): Boolean { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + override fun removeAt(index: Int): E { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + override fun retainAll(elements: Collection): Boolean { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + override operator fun set(index: Int, element: E): E { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + + // NO INLINE: currently exceptions thrown in a inline function cannot be traced + private fun noSuchElement(): Nothing = throw NoSuchElementException() + +} + +@Suppress("NOTHING_TO_INLINE") +private inline fun E.asNode(nextNode: Node): Node = Node(nextNode).apply { nodeValueRef.value = this@asNode } + +/** + * 使用 [iterator] 进行自我迭代, 直到 [mustBeTrue] 返回 false 时停止迭代. 返回最后一个满足条件的元素 + */ +private inline fun > N.childIterate(iterator: (N) -> N, mustBeTrue: (N) -> Boolean): N { + if (!mustBeTrue(this)) return this + var value: N = this + + while (true) { + val newValue = iterator(value) + if (mustBeTrue(newValue)) { + value = newValue + } else { + return value + } + + if (newValue is Tail<*>) return newValue + } +} + +/** + * 使用 [iterator] 进行自我迭代, 直到 [mustBeTrue] 返回 false 时停止迭代. 返回第一个不满足条件的元素 + */ +private inline fun E.childIterateReturnFirstUnsitisfying(iterator: (E) -> E, mustBeTrue: (E) -> Boolean): E { + if (!mustBeTrue(this)) return this + var value: E = this + + while (true) { + val newValue = iterator(value) + if (mustBeTrue(newValue)) { + value = newValue + } else { + return newValue + } + + if (newValue is Tail<*>) return newValue + } +} + +/** + * 使用 [iterator] 进行自我迭代, 直到 [mustBeTrue] 返回 false 时停止迭代. 返回满足条件的元素数量 + */ +private inline fun E.countChildIterate(iterator: (E) -> E, mustBeTrue: (E) -> Boolean): Int { + var count = 0 + var value: E = this + if (!mustBeTrue(value)) return count + + while (true) { + count++ + val newValue = iterator(value) + if (mustBeTrue(newValue)) { + value = newValue + } else { + return count + } + } +} + +private class Head( + nextNode: Node +) : Node(nextNode) { + +} + +private open class Node( + nextNode: Node? +) { + internal val id: Int = nextId(); + + companion object { + private val idCount = atomic(0) + internal fun nextId() = idCount.getAndIncrement() + } + + override fun toString(): String = "Node#$id(${nodeValueRef.value})" + + + val nodeValueRef: AtomicRef = atomic(null) + + inline var nodeValue: E? + get() = nodeValueRef.value + set(value) { + nodeValueRef.value = value + } + + @Suppress("LeakingThis") + val nextNodeRef: AtomicRef> = atomic(nextNode ?: this) + inline var nextNode: Node + get() = nextNodeRef.value + set(value) { + nextNodeRef.value = value + } + + + inline fun iterateWhile(filter: (Node) -> Boolean): Node = this.childIterate>({ it.nextNode }, filter) + + inline fun iterateBeforeFirst(filter: (Node) -> Boolean): Node = + this.childIterate>({ it.nextNode }, { !filter(it) }) + + inline fun iterateStopOnFirst(filter: (Node) -> Boolean): Node = + iterateBeforeFirst(filter).nextNode + + @Suppress("NOTHING_TO_INLINE") + inline fun iterateBeforeNotnull(): Node = iterateBeforeFirst { it.nodeValue != null } + + @Suppress("NOTHING_TO_INLINE") + inline fun nextValidElement(): Node = this.iterateBeforeFirst { !it.isValidElementNode() } + + @Suppress("NOTHING_TO_INLINE") + inline fun nextNotnull(): Node = this.iterateBeforeFirst { it.nodeValueRef.value == null } + + inline fun allMatching(filter: (Node) -> Boolean): Boolean = this.iterateWhile(filter) !is Tail + + @Suppress("NOTHING_TO_INLINE") + inline fun iterateBeforeNodeValue(element: E): Node = this.iterateBeforeFirst { it.nodeValueRef.value == element } + + @Suppress("NOTHING_TO_INLINE") + inline fun iterateStopOnNodeValue(element: E): Node = this.iterateBeforeNodeValue(element).nextNode +} + +private open class Tail : Node(null) + +@Suppress("unused") +private fun AtomicRef>.getNodeValue(): E? = if (this.value is Tail) null else this.value.nodeValueRef.value + +@Suppress("NOTHING_TO_INLINE") +private inline fun Node<*>.isValidElementNode(): Boolean = !isHead() && !isTail() && !isRemoved() + +@Suppress("NOTHING_TO_INLINE") +private inline fun Node<*>.isHead(): Boolean = this is Head + +@Suppress("NOTHING_TO_INLINE") +private inline fun Node<*>.isTail(): Boolean = this is Tail + +@Suppress("NOTHING_TO_INLINE") +private inline fun Node<*>.isRemoved(): Boolean = this.nodeValue == null \ No newline at end of file diff --git a/mirai-core/src/jvmTest/kotlin/net/mamoe/mirai/utils/LockFreeLinkedListTest.kt b/mirai-core/src/jvmTest/kotlin/net/mamoe/mirai/utils/LockFreeLinkedListTest.kt new file mode 100644 index 000000000..b81c18938 --- /dev/null +++ b/mirai-core/src/jvmTest/kotlin/net/mamoe/mirai/utils/LockFreeLinkedListTest.kt @@ -0,0 +1,240 @@ +@file:Suppress("RemoveRedundantBackticks", "NonAsciiCharacters") + +package net.mamoe.mirai.utils + +import kotlinx.coroutines.* +import net.mamoe.mirai.test.shouldBeEqualTo +import net.mamoe.mirai.test.shouldBeFalse +import net.mamoe.mirai.test.shouldBeTrue +import org.junit.Test +import kotlin.system.exitProcess +import kotlin.test.* + +internal class LockFreeLinkedListTest { + init { + GlobalScope.launch { + delay(5000) + exitProcess(-100) + } + } + + @Test + fun addAndGetSingleThreaded() { + val list = LockFreeLinkedList() + list.add(1) + list.add(2) + list.add(3) + list.add(4) + + assertEquals(list[0], 1, "Failed on list[0]") + assertEquals(list[1], 2, "Failed on list[1]") + assertEquals(list[2], 3, "Failed on list[2]") + assertEquals(list[3], 4, "Failed on list[3]") + } + + @Test + fun addAndGetSingleConcurrent() { + val list = LockFreeLinkedList() + val add = GlobalScope.async { list.concurrentAdd(1000, 10, 1) } + val remove = GlobalScope.async { + add.join() + list.concurrentDo(100, 10) { + remove(1) + } + } + runBlocking { + joinAll(add, remove) + } + assertEquals(1000 * 10 - 100 * 10, list.size) + } + + @Test + fun remove() { + val list = LockFreeLinkedList() + + assertFalse { list.remove(1) } + assertEquals(0, list.size) + + list.add(1) + assertTrue { list.remove(1) } + assertEquals(0, list.size) + + list.add(2) + assertFalse { list.remove(1) } + assertEquals(1, list.size) + } + + @Test + fun getSize() { + val list = lockFreeLinkedListOf(1, 2, 3, 4, 5) + assertEquals(5, list.size) + + val list2 = lockFreeLinkedListOf() + assertEquals(0, list2.size) + } + + @Test + fun contains() { + val list = lockFreeLinkedListOf() + assertFalse { list.contains(0) } + + list.add(0) + assertTrue { list.contains(0) } + } + + @Test + fun containsAll() { + var list = lockFreeLinkedListOf(1, 2, 3) + assertTrue { list.containsAll(listOf(1, 2, 3)) } + assertTrue { list.containsAll(listOf()) } + + list = lockFreeLinkedListOf(1, 2) + assertFalse { list.containsAll(listOf(1, 2, 3)) } + + list = lockFreeLinkedListOf() + assertTrue { list.containsAll(listOf()) } + assertFalse { list.containsAll(listOf(1)) } + } + + @Test + fun indexOf() { + val list: LockFreeLinkedList = lockFreeLinkedListOf(1, 2, 3, 3) + assertEquals(0, list.indexOf(1)) + assertEquals(2, list.indexOf(3)) + + assertEquals(-1, list.indexOf(4)) + } + + @Test + fun isEmpty() { + val list: LockFreeLinkedList = lockFreeLinkedListOf() + list.isEmpty().shouldBeTrue() + + list.add(1) + list.isEmpty().shouldBeFalse() + } + + @Test + fun iterator() { + var list: LockFreeLinkedList = lockFreeLinkedListOf(2) + list.forEach { + it shouldBeEqualTo 2 + } + + list = lockFreeLinkedListOf(1, 2) + list.joinToString { it.toString() } shouldBeEqualTo "1, 2" + + + list = lockFreeLinkedListOf(1, 2) + val iterator = list.iterator() + iterator.remove() + var reached = false + for (i in iterator) { + i shouldBeEqualTo 2 + reached = true + } + reached shouldBeEqualTo true + + list.joinToString { it.toString() } shouldBeEqualTo "2" + iterator.remove() + assertFailsWith { iterator.remove() } + } + + @Test + fun `lastIndexOf of exact 1 match at first`() { + val list: LockFreeLinkedList = lockFreeLinkedListOf(2, 1) + list.lastIndexOf(2) shouldBeEqualTo 0 + } + + @Test + fun `lastIndexOf of exact 1 match`() { + val list: LockFreeLinkedList = lockFreeLinkedListOf(1, 2) + list.lastIndexOf(2) shouldBeEqualTo 1 + } + + @Test + fun `lastIndexOf of multiply matches`() { + val list: LockFreeLinkedList = lockFreeLinkedListOf(1, 2, 2) + list.lastIndexOf(2) shouldBeEqualTo 2 + } + + @Test + fun `lastIndexOf of no match`() { + val list: LockFreeLinkedList = lockFreeLinkedListOf(2) + list.lastIndexOf(3) shouldBeEqualTo -1 + } + + @Test + fun `lastIndexOf of many elements`() { + val list: LockFreeLinkedList = lockFreeLinkedListOf(1, 4, 2, 3, 4, 5) + list.lastIndexOf(4) shouldBeEqualTo 4 + } + + +/* + companion object{ + @JvmStatic + fun main(vararg args: String) { + LockFreeLinkedListTest().`lastIndexOf of many elements`() + } + }*/ + + @Test + fun listIterator() { + } + + @Test + fun testListIterator() { + } + + @Test + fun subList() { + } + + @Test + fun testAdd() { + } + + @Test + fun addAll() { + } + + @Test + fun testAddAll() { + } + + @Test + fun clear() { + } + + @Test + fun removeAll() { + } + + @Test + fun removeAt() { + } + + @Test + fun retainAll() { + } + + @Test + fun set() { + } +} + +internal fun withTimeoutBlocking(timeout: Long = 500L, block: suspend () -> Unit) = runBlocking { withTimeout(timeout) { block() } } + +internal suspend fun LockFreeLinkedList.concurrentAdd(numberOfCoroutines: Int, timesOfAdd: Int, element: E) = + concurrentDo(numberOfCoroutines, timesOfAdd) { add(element) } + +internal suspend fun > E.concurrentDo(numberOfCoroutines: Int, timesOfAdd: Int, todo: E.() -> Unit) = coroutineScope { + repeat(numberOfCoroutines) { + launch { + repeat(timesOfAdd) { + todo() + } + } + } +} \ No newline at end of file