Enhance LockFreeLinkedList

This commit is contained in:
Him188 2019-12-14 21:25:22 +08:00
parent a1d3cf0fd9
commit 8c58b83ec2
2 changed files with 262 additions and 88 deletions

View File

@ -4,7 +4,35 @@ package net.mamoe.mirai.utils
import kotlinx.atomicfu.AtomicRef import kotlinx.atomicfu.AtomicRef
import kotlinx.atomicfu.atomic import kotlinx.atomicfu.atomic
import net.mamoe.mirai.utils.Node.Companion.equals import kotlinx.atomicfu.loop
fun <E> LockFreeLinkedList<E>.joinToString(
separator: CharSequence = ", ",
prefix: CharSequence = "[",
postfix: CharSequence = "]",
transform: ((E) -> CharSequence)? = null
): String = prefix.toString() + buildString {
this@joinToString.forEach {
if (transform != null) {
append(transform(it))
} else append(it)
append(separator)
}
}.dropLast(2) + postfix
/**
* Returns a [List] containing all the elements in [this] in the same order
*/
fun <E> LockFreeLinkedList<E>.toList(): List<E> = toMutableList()
/**
* Returns a [MutableList] containing all the elements in [this] in the same order
*/
fun <E> LockFreeLinkedList<E>.toMutableList(): MutableList<E> {
val list = mutableListOf<E>()
this.forEach { list.add(it) }
return list
}
/** /**
* Implementation of lock-free LinkedList. * Implementation of lock-free LinkedList.
@ -12,12 +40,43 @@ import net.mamoe.mirai.utils.Node.Companion.equals
* Modifying can be performed concurrently. * Modifying can be performed concurrently.
* Iterating concurrency is guaranteed. * Iterating concurrency is guaranteed.
*/ */
class LockFreeLinkedList<E> { open class LockFreeLinkedList<E> {
private val tail: Tail<E> = Tail() @PublishedApi
internal val tail: Tail<E> = Tail()
private val head: Head<E> = Head(tail) @PublishedApi
internal val head: Head<E> = Head(tail)
fun add(element: E) { fun removeFirst(): E {
while (true) {
val currentFirst = head.nextNode
if (!currentFirst.isValidElementNode()) {
throw NoSuchElementException()
}
if (head.compareAndSetNextNodeRef(currentFirst, currentFirst.nextNode)) {
return currentFirst.nodeValue
}
}
}
fun peekFirst(): E = head.nextNode.letValueIfValid { return it } ?: throw NoSuchElementException()
fun peekLast(): E = head.iterateBeforeFirst { it === tail }.letValueIfValid { return it } ?: throw NoSuchElementException()
fun removeLast(): E {
while (true) {
val beforeLast = head.iterateBeforeFirst { it.nextNode === tail }
if (!beforeLast.isValidElementNode()) {
throw NoSuchElementException()
}
val last = beforeLast.nextNode
if (beforeLast.nextNodeRef.compareAndSet(last, last.nextNode)) {
return last.nodeValue
}
}
}
fun addLast(element: E) {
val node = element.asNode(tail) val node = element.asNode(tail)
while (true) { while (true) {
@ -28,21 +87,42 @@ class LockFreeLinkedList<E> {
} }
} }
override fun toString(): String = "[" + buildString { operator fun plusAssign(element: E) = this.addLast(element)
this@LockFreeLinkedList.forEach {
append(it) inline fun filteringGetOrAdd(filter: (E) -> Boolean, noinline supplier: () -> E): E {
append(", ") val node = LazyNode(tail, supplier)
while (true) {
var current: Node<E> = head
findLastNode@ while (true) {
if (current.isValidElementNode() && filter(current.nodeValue))
return current.nodeValue
if (current.nextNode === tail) {
if (current.compareAndSetNextNodeRef(tail, node)) { // ensure only one attempt can put the lazyNode in
return node.nodeValue
}
}
current = current.nextNode
}
} }
}.dropLast(2) + "]" }
@PublishedApi // limitation by atomicfu
internal fun <E> Node<E>.compareAndSetNextNodeRef(expect: Node<E>, update: Node<E>) = this.nextNodeRef.compareAndSet(expect, update)
override fun toString(): String = joinToString()
@Suppress("unused") @Suppress("unused")
internal fun getLinkStructure(): String = buildString { internal fun getLinkStructure(): String = buildString {
head.childIterateReturnsLastSatisfying<Node<*>>({ head.childIterateReturnsLastSatisfying<Node<*>>({
append(it.toString()) append(it.toString())
append("->") append(" <- ")
it.nextNode it.nextNode
}, { it !is Tail }) }, { it !is Tail })
}.dropLast(2) }.dropLast(4)
fun remove(element: E): Boolean { fun remove(element: E): Boolean {
while (true) { while (true) {
@ -51,40 +131,45 @@ class LockFreeLinkedList<E> {
if (toRemove === tail) { if (toRemove === tail) {
return false return false
} }
if (toRemove.nodeValue === null) { if (toRemove.isRemoved()) {
continue continue
} }
toRemove.nodeValue = null // logically remove: all the operations will recognize this node invalid toRemove.removed.value = true // logically remove: all the operations will recognize this node invalid
// physically remove: try to fix the link
var next: Node<E> = toRemove.nextNode var next: Node<E> = toRemove.nextNode
while (next !== tail && next.nodeValue === null) { while (next !== tail && next.isRemoved()) {
next = next.nextNode next = next.nextNode
} }
if (before.nextNodeRef.compareAndSet(toRemove, next)) {// physically remove: try to fix the link if (before.nextNodeRef.compareAndSet(toRemove, next)) {
return true return true
} }
} }
} }
val size: Int get() = head.countChildIterate<Node<E>>({ it.nextNodeRef.value }, { it !is Tail }) - 1 // empty head is always included val size: Int get() = head.countChildIterate<Node<E>>({ it.nextNode }, { it !is Tail }) - 1 // empty head is always included
operator fun contains(element: E): Boolean = head.iterateBeforeNodeValue(element) !== tail operator fun contains(element: E): Boolean {
forEach { if (it == element) return true }
return false
}
@Suppress("unused") @Suppress("unused")
fun containsAll(elements: Collection<E>): Boolean = elements.all { contains(it) } fun containsAll(elements: Collection<E>): Boolean = elements.all { contains(it) }
fun isEmpty(): Boolean = head.allMatching { it.isValidElementNode().not() } fun isEmpty(): Boolean = head.allMatching { it.isValidElementNode().not() }
fun forEach(block: (E) -> Unit) { inline fun forEach(block: (E) -> Unit) {
var node: Node<E> = head var node: Node<E> = head
while (true) { while (true) {
if (node === tail) return node.letValueIfValid(block)
node.letIfNotnull(block)
node = node.nextNode node = node.nextNode
if (node === tail) return
} }
} }
fun addAll(elements: Collection<E>) = elements.forEach { add(it) } fun addAll(elements: Collection<E>) = elements.forEach { addLast(it) }
@Suppress("unused") @Suppress("unused")
fun clear() { fun clear() {
@ -93,7 +178,7 @@ class LockFreeLinkedList<E> {
first.childIterateReturnFirstUnsatisfying({ first.childIterateReturnFirstUnsatisfying({
val n = it.nextNode val n = it.nextNode
it.nextNode = tail it.nextNode = tail
it.nodeValue = null it.removed.value = true
n n
}, { it !== tail }) // clear the link structure, help GC. }, { it !== tail }) // clear the link structure, help GC.
} }
@ -417,31 +502,18 @@ class LockFreeLinkedList<E> {
*/ */
} }
/**
* Returns a [List] containing all the elements in [this] in the same order
*/
fun <E> LockFreeLinkedList<E>.toList(): List<E> = toMutableList()
/**
* Returns a [MutableList] containing all the elements in [this] in the same order
*/
fun <E> LockFreeLinkedList<E>.toMutableList(): MutableList<E> {
val list = mutableListOf<E>()
this.forEach { list.add(it) }
return list
}
// region internal // region internal
@Suppress("NOTHING_TO_INLINE") @Suppress("NOTHING_TO_INLINE")
private inline fun <E> E.asNode(nextNode: Node<E>): Node<E> = Node(nextNode).apply { nodeValueRef.value = this@asNode } private inline fun <E> E.asNode(nextNode: Node<E>): Node<E> = Node(nextNode, this)
/** /**
* Self-iterate using the [iterator], until [mustBeTrue] returns `false`. * Self-iterate using the [iterator], until [mustBeTrue] returns `false`.
* Returns the element at the last time when the [mustBeTrue] returns `true` * Returns the element at the last time when the [mustBeTrue] returns `true`
*/ */
private inline fun <N : Node<*>> N.childIterateReturnsLastSatisfying(iterator: (N) -> N, mustBeTrue: (N) -> Boolean): N { @PublishedApi
internal inline fun <N : Node<*>> N.childIterateReturnsLastSatisfying(iterator: (N) -> N, mustBeTrue: (N) -> Boolean): N {
if (!mustBeTrue(this)) return this if (!mustBeTrue(this)) return this
var value: N = this var value: N = this
@ -497,38 +569,68 @@ private inline fun <E> E.countChildIterate(iterator: (E) -> E, mustBeTrue: (E) -
} }
} }
private class Head<E>( @PublishedApi
nextNode: Node<E> internal class LazyNode<E> @PublishedApi internal constructor(
) : Node<E>(nextNode) nextNode: Node<E>,
private val valueComputer: () -> E
) : Node<E>(nextNode, null) {
private val initialized = atomic(false)
private open class Node<E>( private val value: AtomicRef<E?> = atomic(null)
nextNode: Node<E>?
override val nodeValue: E
get() {
@Suppress("BooleanLiteralArgument") // false positive warning
if (initialized.compareAndSet(false, true)) { // ensure only one lucky attempt can go into the if
val value = valueComputer()
this.value.value = value
return value // fast path
}
value.loop {
if (it != null) {
return it
}
}
}
}
@PublishedApi
internal class Head<E>(nextNode: Node<E>) : Node<E>(nextNode, null) {
override fun toString(): String = "Head"
override val nodeValue: Nothing get() = error("Internal error: trying to get the value of a Head")
}
@PublishedApi
internal open class Tail<E> : Node<E>(null, null) {
override fun toString(): String = "Tail"
override val nodeValue: Nothing get() = error("Internal error: trying to get the value of a Tail")
}
@PublishedApi
internal open class Node<E>(
nextNode: Node<E>?,
private var initialNodeValue: E?
) { ) {
/*
internal val id: Int = nextId() internal val id: Int = nextId()
companion object { companion object {
private val idCount = atomic(0) private val idCount = atomic(0)
internal fun nextId() = idCount.getAndIncrement() internal fun nextId() = idCount.getAndIncrement()
} }*/
override fun toString(): String = "Node#$id(${nodeValueRef.value})" override fun toString(): String = "$nodeValue"
open val nodeValue: E get() = initialNodeValue ?: error("Internal error: nodeValue is not initialized")
val nodeValueRef: AtomicRef<E?> = atomic(null) val removed = atomic(false)
/**
* Short cut for accessing [nodeValueRef]
*/
inline var nodeValue: E?
get() = nodeValueRef.value
set(value) {
nodeValueRef.value = value
}
@Suppress("LeakingThis") @Suppress("LeakingThis")
val nextNodeRef: AtomicRef<Node<E>> = atomic(nextNode ?: this) val nextNodeRef: AtomicRef<Node<E>> = atomic(nextNode ?: this)
inline fun <R> letIfNotnull(block: (E) -> R): R? { inline fun <R> letValueIfValid(block: (E) -> R): R? {
if (!this.isValidElementNode()) {
return null
}
val value = this.nodeValue val value = this.nodeValue
return if (value !== null) block(value) else null return if (value !== null) block(value) else null
} }
@ -536,13 +638,12 @@ private open class Node<E>(
/** /**
* Short cut for accessing [nextNodeRef] * Short cut for accessing [nextNodeRef]
*/ */
inline var nextNode: Node<E> var nextNode: Node<E>
get() = nextNodeRef.value get() = nextNodeRef.value
set(value) { set(value) {
nextNodeRef.value = value nextNodeRef.value = value
} }
/** /**
* Returns the former node of the last node whence [filter] returns true * Returns the former node of the last node whence [filter] returns true
*/ */
@ -563,21 +664,23 @@ private open class Node<E>(
* E.g.: for `head <- 1 <- 2 <- 3 <- tail`, `iterateStopOnNodeValue(2)` returns the node whose value is 1 * E.g.: for `head <- 1 <- 2 <- 3 <- tail`, `iterateStopOnNodeValue(2)` returns the node whose value is 1
*/ */
@Suppress("NOTHING_TO_INLINE") @Suppress("NOTHING_TO_INLINE")
inline fun iterateBeforeNodeValue(element: E): Node<E> = this.iterateBeforeFirst { it.nodeValueRef.value == element } internal inline fun iterateBeforeNodeValue(element: E): Node<E> = this.iterateBeforeFirst { it.isValidElementNode() && it.nodeValue == element }
} }
private open class Tail<E> : Node<E>(null) @PublishedApi
internal fun <E> Node<E>.isRemoved() = this.removed.value
@PublishedApi
@Suppress("NOTHING_TO_INLINE") @Suppress("NOTHING_TO_INLINE")
private inline fun Node<*>.isValidElementNode(): Boolean = !isHead() && !isTail() && !isRemoved() internal inline fun Node<*>.isValidElementNode(): Boolean = !isHead() && !isTail() && !isRemoved()
@PublishedApi
@Suppress("NOTHING_TO_INLINE") @Suppress("NOTHING_TO_INLINE")
private inline fun Node<*>.isHead(): Boolean = this is Head internal inline fun Node<*>.isHead(): Boolean = this is Head
@PublishedApi
@Suppress("NOTHING_TO_INLINE") @Suppress("NOTHING_TO_INLINE")
private inline fun Node<*>.isTail(): Boolean = this is Tail internal inline fun Node<*>.isTail(): Boolean = this is Tail
@Suppress("NOTHING_TO_INLINE")
private inline fun Node<*>.isRemoved(): Boolean = this.nodeValue == null
// en dregion // en dregion

View File

@ -23,10 +23,10 @@ internal class LockFreeLinkedListTest {
@Test @Test
fun addAndGetSingleThreaded() { fun addAndGetSingleThreaded() {
val list = LockFreeLinkedList<Int>() val list = LockFreeLinkedList<Int>()
list.add(1) list.addLast(1)
list.add(2) list.addLast(2)
list.add(3) list.addLast(3)
list.add(4) list.addLast(4)
list.size shouldBeEqualTo 4 list.size shouldBeEqualTo 4
} }
@ -36,7 +36,7 @@ internal class LockFreeLinkedListTest {
//withContext(Dispatchers.Default){ //withContext(Dispatchers.Default){
val list = LockFreeLinkedList<Int>() val list = LockFreeLinkedList<Int>()
list.concurrentAdd(1000, 10, 1) list.concurrentDo(1000, 10) { addLast(1) }
list.size shouldBeEqualTo 1000 * 10 list.size shouldBeEqualTo 1000 * 10
list.concurrentDo(100, 10) { list.concurrentDo(100, 10) {
@ -51,18 +51,55 @@ internal class LockFreeLinkedListTest {
fun addAndGetMassConcurrentAccess() = runBlocking { fun addAndGetMassConcurrentAccess() = runBlocking {
val list = LockFreeLinkedList<Int>() val list = LockFreeLinkedList<Int>()
val addJob = async { list.concurrentAdd(5000, 10, 1) } val addJob = async { list.concurrentDo(2, 30000) { addLast(1) } }
delay(10) // let addJob fly //delay(1) // let addJob fly
if (addJob.isCompleted) { if (addJob.isCompleted) {
error("Number of elements are not enough") error("Number of elements are not enough")
} }
list.concurrentDo(1000, 10) { val foreachJob = async {
remove(1).shouldBeTrue() list.concurrentDo(1, 10000) {
forEach { it + it }
}
} }
addJob.join() val removeLastJob = async {
list.concurrentDo(1, 15000) {
removeLast() shouldBeEqualTo 1
}
}
val removeFirstJob = async {
list.concurrentDo(1, 10000) {
removeFirst() shouldBeEqualTo 1
}
}
val addJob2 = async {
list.concurrentDo(1, 5000) {
addLast(1)
}
}
val removeExactJob = launch {
list.concurrentDo(3, 1000) {
remove(1).shouldBeTrue()
}
}
val filteringGetOrAddJob = launch {
list.concurrentDo(1, 10000) {
filteringGetOrAdd({ it == 2 }, { 1 })
}
}
joinAll(addJob, addJob2, foreachJob, removeLastJob, removeFirstJob, removeExactJob, filteringGetOrAddJob)
list.size shouldBeEqualTo 5000 * 10 - 1000 * 10 list.size shouldBeEqualTo 2 * 30000 - 1 * 15000 - 1 * 10000 + 1 * 5000 - 3 * 1000 + 1 * 10000
}
@Test
fun removeWhileForeach() {
val list = LockFreeLinkedList<Int>()
repeat(10) { list.addLast(it) }
list.forEach {
list.remove(it + 1)
}
list.peekFirst() shouldBeEqualTo 0
} }
@Test @Test
@ -72,11 +109,11 @@ internal class LockFreeLinkedListTest {
assertFalse { list.remove(1) } assertFalse { list.remove(1) }
assertEquals(0, list.size) assertEquals(0, list.size)
list.add(1) list.addLast(1)
assertTrue { list.remove(1) } assertTrue { list.remove(1) }
assertEquals(0, list.size) assertEquals(0, list.size)
list.add(2) list.addLast(2)
assertFalse { list.remove(1) } assertFalse { list.remove(1) }
assertEquals(1, list.size) assertEquals(1, list.size)
} }
@ -107,6 +144,43 @@ internal class LockFreeLinkedListTest {
list.toString() shouldBeEqualTo "[1, 2, 3, 4, 5]" list.toString() shouldBeEqualTo "[1, 2, 3, 4, 5]"
} }
@Test
fun `filteringGetOrAdd when add`() {
val list = LockFreeLinkedList<Int>()
list.addAll(listOf(1, 2, 3, 4, 5))
val value = list.filteringGetOrAdd({ it == 6 }, { 6 })
println("Check value")
value shouldBeEqualTo 6
println("Check size")
println(list.getLinkStructure())
list.size shouldBeEqualTo 6
}
@Test
fun `filteringGetOrAdd when get`() {
val list = LockFreeLinkedList<Int>()
list.addAll(listOf(1, 2, 3, 4, 5))
val value = list.filteringGetOrAdd({ it == 2 }, { 2 })
println("Check value")
value shouldBeEqualTo 2
println("Check size")
println(list.getLinkStructure())
list.size shouldBeEqualTo 5
}
@Test
fun `filteringGetOrAdd when empty`() {
val list = LockFreeLinkedList<Int>()
val value = list.filteringGetOrAdd({ it == 2 }, { 2 })
println("Check value")
value shouldBeEqualTo 2
println("Check size")
println(list.getLinkStructure())
list.size shouldBeEqualTo 1
}
/* /*
@Test @Test
fun indexOf() { fun indexOf() {
@ -176,16 +250,13 @@ internal class LockFreeLinkedListTest {
*/ */
} }
@UseExperimental(ExperimentalCoroutinesApi::class)
@MiraiExperimentalAPI @MiraiExperimentalAPI
internal suspend inline fun <E> LockFreeLinkedList<E>.concurrentAdd(numberOfCoroutines: Int, timesOfAdd: Int, element: E) = internal suspend inline fun <E : LockFreeLinkedList<*>> E.concurrentDo(numberOfCoroutines: Int, times: Int, crossinline todo: E.() -> Unit) =
concurrentDo(numberOfCoroutines, timesOfAdd) { add(element) }
@MiraiExperimentalAPI
internal suspend inline fun <E : LockFreeLinkedList<*>> E.concurrentDo(numberOfCoroutines: Int, timesOfAdd: Int, crossinline todo: E.() -> Unit) =
coroutineScope { coroutineScope {
repeat(numberOfCoroutines) { repeat(numberOfCoroutines) {
launch { launch(start = CoroutineStart.UNDISPATCHED) {
repeat(timesOfAdd) { repeat(times) {
todo() todo()
} }
} }