From 8c58b83ec211f3432064edbf96819968769fdd67 Mon Sep 17 00:00:00 2001
From: Him188 <Him188@mamoe.net>
Date: Sat, 14 Dec 2019 21:25:22 +0800
Subject: [PATCH] Enhance LockFreeLinkedList

---
 .../utils/LockFreeLinkedList.kt               | 239 +++++++++++++-----
 .../mirai/utils/LockFreeLinkedListTest.kt     | 111 ++++++--
 2 files changed, 262 insertions(+), 88 deletions(-)

diff --git a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/LockFreeLinkedList.kt b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/LockFreeLinkedList.kt
index 05683a812..fb41bda83 100644
--- a/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/LockFreeLinkedList.kt
+++ b/mirai-core/src/commonMain/kotlin/net.mamoe.mirai/utils/LockFreeLinkedList.kt
@@ -4,7 +4,35 @@ package net.mamoe.mirai.utils
 
 import kotlinx.atomicfu.AtomicRef
 import kotlinx.atomicfu.atomic
-import net.mamoe.mirai.utils.Node.Companion.equals
+import kotlinx.atomicfu.loop
+
+fun <E> LockFreeLinkedList<E>.joinToString(
+    separator: CharSequence = ", ",
+    prefix: CharSequence = "[",
+    postfix: CharSequence = "]",
+    transform: ((E) -> CharSequence)? = null
+): String = prefix.toString() + buildString {
+    this@joinToString.forEach {
+        if (transform != null) {
+            append(transform(it))
+        } else append(it)
+        append(separator)
+    }
+}.dropLast(2) + postfix
+
+/**
+ * Returns a [List] containing all the elements in [this] in the same order
+ */
+fun <E> LockFreeLinkedList<E>.toList(): List<E> = toMutableList()
+
+/**
+ * Returns a [MutableList] containing all the elements in [this] in the same order
+ */
+fun <E> LockFreeLinkedList<E>.toMutableList(): MutableList<E> {
+    val list = mutableListOf<E>()
+    this.forEach { list.add(it) }
+    return list
+}
 
 /**
  * Implementation of lock-free LinkedList.
@@ -12,12 +40,43 @@ import net.mamoe.mirai.utils.Node.Companion.equals
  * Modifying can be performed concurrently.
  * Iterating concurrency is guaranteed.
  */
-class LockFreeLinkedList<E> {
-    private val tail: Tail<E> = Tail()
+open class LockFreeLinkedList<E> {
+    @PublishedApi
+    internal val tail: Tail<E> = Tail()
 
-    private val head: Head<E> = Head(tail)
+    @PublishedApi
+    internal val head: Head<E> = Head(tail)
 
-    fun add(element: E) {
+    fun removeFirst(): E {
+        while (true) {
+            val currentFirst = head.nextNode
+            if (!currentFirst.isValidElementNode()) {
+                throw NoSuchElementException()
+            }
+            if (head.compareAndSetNextNodeRef(currentFirst, currentFirst.nextNode)) {
+                return currentFirst.nodeValue
+            }
+        }
+    }
+
+    fun peekFirst(): E = head.nextNode.letValueIfValid { return it } ?: throw NoSuchElementException()
+
+    fun peekLast(): E = head.iterateBeforeFirst { it === tail }.letValueIfValid { return it } ?: throw NoSuchElementException()
+
+    fun removeLast(): E {
+        while (true) {
+            val beforeLast = head.iterateBeforeFirst { it.nextNode === tail }
+            if (!beforeLast.isValidElementNode()) {
+                throw NoSuchElementException()
+            }
+            val last = beforeLast.nextNode
+            if (beforeLast.nextNodeRef.compareAndSet(last, last.nextNode)) {
+                return last.nodeValue
+            }
+        }
+    }
+
+    fun addLast(element: E) {
         val node = element.asNode(tail)
 
         while (true) {
@@ -28,21 +87,42 @@ class LockFreeLinkedList<E> {
         }
     }
 
-    override fun toString(): String = "[" + buildString {
-        this@LockFreeLinkedList.forEach {
-            append(it)
-            append(", ")
+    operator fun plusAssign(element: E) = this.addLast(element)
+
+    inline fun filteringGetOrAdd(filter: (E) -> Boolean, noinline supplier: () -> E): E {
+        val node = LazyNode(tail, supplier)
+
+        while (true) {
+            var current: Node<E> = head
+
+            findLastNode@ while (true) {
+                if (current.isValidElementNode() && filter(current.nodeValue))
+                    return current.nodeValue
+
+                if (current.nextNode === tail) {
+                    if (current.compareAndSetNextNodeRef(tail, node)) { // ensure only one attempt can put the lazyNode in
+                        return node.nodeValue
+                    }
+                }
+
+                current = current.nextNode
+            }
         }
-    }.dropLast(2) + "]"
+    }
+
+    @PublishedApi // limitation by atomicfu
+    internal fun <E> Node<E>.compareAndSetNextNodeRef(expect: Node<E>, update: Node<E>) = this.nextNodeRef.compareAndSet(expect, update)
+
+    override fun toString(): String = joinToString()
 
     @Suppress("unused")
     internal fun getLinkStructure(): String = buildString {
         head.childIterateReturnsLastSatisfying<Node<*>>({
             append(it.toString())
-            append("->")
+            append(" <- ")
             it.nextNode
         }, { it !is Tail })
-    }.dropLast(2)
+    }.dropLast(4)
 
     fun remove(element: E): Boolean {
         while (true) {
@@ -51,40 +131,45 @@ class LockFreeLinkedList<E> {
             if (toRemove === tail) {
                 return false
             }
-            if (toRemove.nodeValue === null) {
+            if (toRemove.isRemoved()) {
                 continue
             }
-            toRemove.nodeValue = null // logically remove: all the operations will recognize this node invalid
+            toRemove.removed.value = true // logically remove: all the operations will recognize this node invalid
 
+
+            // physically remove: try to fix the link
             var next: Node<E> = toRemove.nextNode
-            while (next !== tail && next.nodeValue === null) {
+            while (next !== tail && next.isRemoved()) {
                 next = next.nextNode
             }
-            if (before.nextNodeRef.compareAndSet(toRemove, next)) {// physically remove: try to fix the link
+            if (before.nextNodeRef.compareAndSet(toRemove, next)) {
                 return true
             }
         }
     }
 
-    val size: Int get() = head.countChildIterate<Node<E>>({ it.nextNodeRef.value }, { it !is Tail }) - 1 // empty head is always included
+    val size: Int get() = head.countChildIterate<Node<E>>({ it.nextNode }, { it !is Tail }) - 1 // empty head is always included
 
-    operator fun contains(element: E): Boolean = head.iterateBeforeNodeValue(element) !== tail
+    operator fun contains(element: E): Boolean {
+        forEach { if (it == element) return true }
+        return false
+    }
 
     @Suppress("unused")
     fun containsAll(elements: Collection<E>): Boolean = elements.all { contains(it) }
 
     fun isEmpty(): Boolean = head.allMatching { it.isValidElementNode().not() }
 
-    fun forEach(block: (E) -> Unit) {
+    inline fun forEach(block: (E) -> Unit) {
         var node: Node<E> = head
         while (true) {
-            if (node === tail) return
-            node.letIfNotnull(block)
+            node.letValueIfValid(block)
             node = node.nextNode
+            if (node === tail) return
         }
     }
 
-    fun addAll(elements: Collection<E>) = elements.forEach { add(it) }
+    fun addAll(elements: Collection<E>) = elements.forEach { addLast(it) }
 
     @Suppress("unused")
     fun clear() {
@@ -93,7 +178,7 @@ class LockFreeLinkedList<E> {
         first.childIterateReturnFirstUnsatisfying({
             val n = it.nextNode
             it.nextNode = tail
-            it.nodeValue = null
+            it.removed.value = true
             n
         }, { it !== tail }) // clear the link structure, help GC.
     }
@@ -417,31 +502,18 @@ class LockFreeLinkedList<E> {
     */
 }
 
-/**
- * Returns a [List] containing all the elements in [this] in the same order
- */
-fun <E> LockFreeLinkedList<E>.toList(): List<E> = toMutableList()
-
-/**
- * Returns a [MutableList] containing all the elements in [this] in the same order
- */
-fun <E> LockFreeLinkedList<E>.toMutableList(): MutableList<E> {
-    val list = mutableListOf<E>()
-    this.forEach { list.add(it) }
-    return list
-}
-
 
 // region internal
 
 @Suppress("NOTHING_TO_INLINE")
-private inline fun <E> E.asNode(nextNode: Node<E>): Node<E> = Node(nextNode).apply { nodeValueRef.value = this@asNode }
+private inline fun <E> E.asNode(nextNode: Node<E>): Node<E> = Node(nextNode, this)
 
 /**
  * Self-iterate using the [iterator], until [mustBeTrue] returns `false`.
  * Returns the element at the last time when the [mustBeTrue] returns `true`
  */
-private inline fun <N : Node<*>> N.childIterateReturnsLastSatisfying(iterator: (N) -> N, mustBeTrue: (N) -> Boolean): N {
+@PublishedApi
+internal inline fun <N : Node<*>> N.childIterateReturnsLastSatisfying(iterator: (N) -> N, mustBeTrue: (N) -> Boolean): N {
     if (!mustBeTrue(this)) return this
     var value: N = this
 
@@ -497,38 +569,68 @@ private inline fun <E> E.countChildIterate(iterator: (E) -> E, mustBeTrue: (E) -
     }
 }
 
-private class Head<E>(
-    nextNode: Node<E>
-) : Node<E>(nextNode)
+@PublishedApi
+internal class LazyNode<E> @PublishedApi internal constructor(
+    nextNode: Node<E>,
+    private val valueComputer: () -> E
+) : Node<E>(nextNode, null) {
+    private val initialized = atomic(false)
 
-private open class Node<E>(
-    nextNode: Node<E>?
+    private val value: AtomicRef<E?> = atomic(null)
+
+    override val nodeValue: E
+        get() {
+            @Suppress("BooleanLiteralArgument") // false positive warning
+            if (initialized.compareAndSet(false, true)) { // ensure only one lucky attempt can go into the if
+                val value = valueComputer()
+                this.value.value = value
+                return value // fast path
+            }
+            value.loop {
+                if (it != null) {
+                    return it
+                }
+            }
+        }
+}
+
+@PublishedApi
+internal class Head<E>(nextNode: Node<E>) : Node<E>(nextNode, null) {
+    override fun toString(): String = "Head"
+    override val nodeValue: Nothing get() = error("Internal error: trying to get the value of a Head")
+}
+
+@PublishedApi
+internal open class Tail<E> : Node<E>(null, null) {
+    override fun toString(): String = "Tail"
+    override val nodeValue: Nothing get() = error("Internal error: trying to get the value of a Tail")
+}
+
+@PublishedApi
+internal open class Node<E>(
+    nextNode: Node<E>?,
+    private var initialNodeValue: E?
 ) {
+    /*
     internal val id: Int = nextId()
-
     companion object {
         private val idCount = atomic(0)
         internal fun nextId() = idCount.getAndIncrement()
-    }
+    }*/
 
-    override fun toString(): String = "Node#$id(${nodeValueRef.value})"
+    override fun toString(): String = "$nodeValue"
 
+    open val nodeValue: E get() = initialNodeValue ?: error("Internal error: nodeValue is not initialized")
 
-    val nodeValueRef: AtomicRef<E?> = atomic(null)
-
-    /**
-     * Short cut for accessing [nodeValueRef]
-     */
-    inline var nodeValue: E?
-        get() = nodeValueRef.value
-        set(value) {
-            nodeValueRef.value = value
-        }
+    val removed = atomic(false)
 
     @Suppress("LeakingThis")
     val nextNodeRef: AtomicRef<Node<E>> = atomic(nextNode ?: this)
 
-    inline fun <R> letIfNotnull(block: (E) -> R): R? {
+    inline fun <R> letValueIfValid(block: (E) -> R): R? {
+        if (!this.isValidElementNode()) {
+            return null
+        }
         val value = this.nodeValue
         return if (value !== null) block(value) else null
     }
@@ -536,13 +638,12 @@ private open class Node<E>(
     /**
      * Short cut for accessing [nextNodeRef]
      */
-    inline var nextNode: Node<E>
+    var nextNode: Node<E>
         get() = nextNodeRef.value
         set(value) {
             nextNodeRef.value = value
         }
 
-
     /**
      * Returns the former node of the last node whence [filter] returns true
      */
@@ -563,21 +664,23 @@ private open class Node<E>(
      * E.g.: for `head <- 1 <- 2 <- 3 <- tail`, `iterateStopOnNodeValue(2)` returns the node whose value is 1
      */
     @Suppress("NOTHING_TO_INLINE")
-    inline fun iterateBeforeNodeValue(element: E): Node<E> = this.iterateBeforeFirst { it.nodeValueRef.value == element }
+    internal inline fun iterateBeforeNodeValue(element: E): Node<E> = this.iterateBeforeFirst { it.isValidElementNode() && it.nodeValue == element }
+
 }
 
-private open class Tail<E> : Node<E>(null)
+@PublishedApi
+internal fun <E> Node<E>.isRemoved() = this.removed.value
 
+@PublishedApi
 @Suppress("NOTHING_TO_INLINE")
-private inline fun Node<*>.isValidElementNode(): Boolean = !isHead() && !isTail() && !isRemoved()
+internal inline fun Node<*>.isValidElementNode(): Boolean = !isHead() && !isTail() && !isRemoved()
 
+@PublishedApi
 @Suppress("NOTHING_TO_INLINE")
-private inline fun Node<*>.isHead(): Boolean = this is Head
+internal inline fun Node<*>.isHead(): Boolean = this is Head
 
+@PublishedApi
 @Suppress("NOTHING_TO_INLINE")
-private inline fun Node<*>.isTail(): Boolean = this is Tail
-
-@Suppress("NOTHING_TO_INLINE")
-private inline fun Node<*>.isRemoved(): Boolean = this.nodeValue == null
+internal inline fun Node<*>.isTail(): Boolean = this is Tail
 
 // en dregion
\ No newline at end of file
diff --git a/mirai-core/src/jvmTest/kotlin/net/mamoe/mirai/utils/LockFreeLinkedListTest.kt b/mirai-core/src/jvmTest/kotlin/net/mamoe/mirai/utils/LockFreeLinkedListTest.kt
index 976b65e32..eaa02fbea 100644
--- a/mirai-core/src/jvmTest/kotlin/net/mamoe/mirai/utils/LockFreeLinkedListTest.kt
+++ b/mirai-core/src/jvmTest/kotlin/net/mamoe/mirai/utils/LockFreeLinkedListTest.kt
@@ -23,10 +23,10 @@ internal class LockFreeLinkedListTest {
     @Test
     fun addAndGetSingleThreaded() {
         val list = LockFreeLinkedList<Int>()
-        list.add(1)
-        list.add(2)
-        list.add(3)
-        list.add(4)
+        list.addLast(1)
+        list.addLast(2)
+        list.addLast(3)
+        list.addLast(4)
 
         list.size shouldBeEqualTo 4
     }
@@ -36,7 +36,7 @@ internal class LockFreeLinkedListTest {
         //withContext(Dispatchers.Default){
         val list = LockFreeLinkedList<Int>()
 
-        list.concurrentAdd(1000, 10, 1)
+        list.concurrentDo(1000, 10) { addLast(1) }
         list.size shouldBeEqualTo 1000 * 10
 
         list.concurrentDo(100, 10) {
@@ -51,18 +51,55 @@ internal class LockFreeLinkedListTest {
     fun addAndGetMassConcurrentAccess() = runBlocking {
         val list = LockFreeLinkedList<Int>()
 
-        val addJob = async { list.concurrentAdd(5000, 10, 1) }
+        val addJob = async { list.concurrentDo(2, 30000) { addLast(1) } }
 
-        delay(10) // let addJob fly
+        //delay(1) // let addJob fly
         if (addJob.isCompleted) {
             error("Number of elements are not enough")
         }
-        list.concurrentDo(1000, 10) {
-            remove(1).shouldBeTrue()
+        val foreachJob = async {
+            list.concurrentDo(1, 10000) {
+                forEach { it + it }
+            }
         }
-        addJob.join()
+        val removeLastJob = async {
+            list.concurrentDo(1, 15000) {
+                removeLast() shouldBeEqualTo 1
+            }
+        }
+        val removeFirstJob = async {
+            list.concurrentDo(1, 10000) {
+                removeFirst() shouldBeEqualTo 1
+            }
+        }
+        val addJob2 = async {
+            list.concurrentDo(1, 5000) {
+                addLast(1)
+            }
+        }
+        val removeExactJob = launch {
+            list.concurrentDo(3, 1000) {
+                remove(1).shouldBeTrue()
+            }
+        }
+        val filteringGetOrAddJob = launch {
+            list.concurrentDo(1, 10000) {
+                filteringGetOrAdd({ it == 2 }, { 1 })
+            }
+        }
+        joinAll(addJob, addJob2, foreachJob, removeLastJob, removeFirstJob, removeExactJob, filteringGetOrAddJob)
 
-        list.size shouldBeEqualTo 5000 * 10 - 1000 * 10
+        list.size shouldBeEqualTo 2 * 30000 - 1 * 15000 - 1 * 10000 + 1 * 5000 - 3 * 1000 + 1 * 10000
+    }
+
+    @Test
+    fun removeWhileForeach() {
+        val list = LockFreeLinkedList<Int>()
+        repeat(10) { list.addLast(it) }
+        list.forEach {
+            list.remove(it + 1)
+        }
+        list.peekFirst() shouldBeEqualTo 0
     }
 
     @Test
@@ -72,11 +109,11 @@ internal class LockFreeLinkedListTest {
         assertFalse { list.remove(1) }
         assertEquals(0, list.size)
 
-        list.add(1)
+        list.addLast(1)
         assertTrue { list.remove(1) }
         assertEquals(0, list.size)
 
-        list.add(2)
+        list.addLast(2)
         assertFalse { list.remove(1) }
         assertEquals(1, list.size)
     }
@@ -107,6 +144,43 @@ internal class LockFreeLinkedListTest {
         list.toString() shouldBeEqualTo "[1, 2, 3, 4, 5]"
     }
 
+    @Test
+    fun `filteringGetOrAdd when add`() {
+        val list = LockFreeLinkedList<Int>()
+        list.addAll(listOf(1, 2, 3, 4, 5))
+        val value = list.filteringGetOrAdd({ it == 6 }, { 6 })
+
+        println("Check value")
+        value shouldBeEqualTo 6
+        println("Check size")
+        println(list.getLinkStructure())
+        list.size shouldBeEqualTo 6
+    }
+
+    @Test
+    fun `filteringGetOrAdd when get`() {
+        val list = LockFreeLinkedList<Int>()
+        list.addAll(listOf(1, 2, 3, 4, 5))
+        val value = list.filteringGetOrAdd({ it == 2 }, { 2 })
+
+        println("Check value")
+        value shouldBeEqualTo 2
+        println("Check size")
+        println(list.getLinkStructure())
+        list.size shouldBeEqualTo 5
+    }
+
+    @Test
+    fun `filteringGetOrAdd when empty`() {
+        val list = LockFreeLinkedList<Int>()
+        val value = list.filteringGetOrAdd({ it == 2 }, { 2 })
+
+        println("Check value")
+        value shouldBeEqualTo 2
+        println("Check size")
+        println(list.getLinkStructure())
+        list.size shouldBeEqualTo 1
+    }
     /*
     @Test
     fun indexOf() {
@@ -176,16 +250,13 @@ internal class LockFreeLinkedListTest {
      */
 }
 
+@UseExperimental(ExperimentalCoroutinesApi::class)
 @MiraiExperimentalAPI
-internal suspend inline fun <E> LockFreeLinkedList<E>.concurrentAdd(numberOfCoroutines: Int, timesOfAdd: Int, element: E) =
-    concurrentDo(numberOfCoroutines, timesOfAdd) { add(element) }
-
-@MiraiExperimentalAPI
-internal suspend inline fun <E : LockFreeLinkedList<*>> E.concurrentDo(numberOfCoroutines: Int, timesOfAdd: Int, crossinline todo: E.() -> Unit) =
+internal suspend inline fun <E : LockFreeLinkedList<*>> E.concurrentDo(numberOfCoroutines: Int, times: Int, crossinline todo: E.() -> Unit) =
     coroutineScope {
         repeat(numberOfCoroutines) {
-            launch {
-                repeat(timesOfAdd) {
+            launch(start = CoroutineStart.UNDISPATCHED) {
+                repeat(times) {
                     todo()
                 }
             }