SnapshotWeakSet.kt

/*
 * Copyright 2022 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package androidx.compose.runtime.snapshots

import androidx.compose.runtime.TestOnly
import androidx.compose.runtime.WeakReference
import androidx.compose.runtime.identityHashCode

private const val INITIAL_CAPACITY = 16

/**
 * A set of values references where the values are held weakly.
 *
 * This doesn't implement the entire Set<T> API and only implements those methods that are needed
 * for use in [Snapshot].
 *
 * [add], [find] and [findExactIndex] are copied from IdentityArraySet and refined to use weak
 * references. Any bugs found in these methods are likely to also be in IdentityArraySet and vis
 * versa.
 */
internal class SnapshotWeakSet<T : Any> {
    /**
     * The size of the set. The set has at most [size] entries but could have fewer if any of the
     * entries have been collected.
     */
    internal var size: Int = 0

    /**
     * Hashes are kept separately as the original object might not be available but its hash is
     * required to be available as the entries are stored in hash order and found via a binary
     * search.
     */
    internal var hashes = IntArray(INITIAL_CAPACITY)
    internal var values: Array<WeakReference<T>?> = arrayOfNulls(INITIAL_CAPACITY)

    /**
     * Add [value] to the set and return `true` if it was added or `false` if it already existed.
     */
    fun add(value: T): Boolean {
        val index: Int
        val size = size
        val hash = identityHashCode(value)
        if (size > 0) {
            index = find(value, hash)

            if (index >= 0) {
                return false
            }
        } else {
            index = -1
        }

        val insertIndex = -(index + 1)
        val capacity = values.size
        if (size == capacity) {
            val newCapacity = capacity * 2
            val newValues = arrayOfNulls<WeakReference<T>?>(newCapacity)
            val newHashes = IntArray(newCapacity)
            values.copyInto(
                destination = newValues,
                destinationOffset = insertIndex + 1,
                startIndex = insertIndex,
                endIndex = size
            )
            values.copyInto(
                destination = newValues,
                endIndex = insertIndex
            )
            hashes.copyInto(
                destination = newHashes,
                destinationOffset = insertIndex + 1,
                startIndex = insertIndex,
                endIndex = size
            )
            hashes.copyInto(
                destination = newHashes,
                endIndex = insertIndex
            )
            values = newValues
            hashes = newHashes
        } else {
            values.copyInto(
                destination = values,
                destinationOffset = insertIndex + 1,
                startIndex = insertIndex,
                endIndex = size
            )
            hashes.copyInto(
                destination = hashes,
                destinationOffset = insertIndex + 1,
                startIndex = insertIndex,
                endIndex = size
            )
        }

        // A hole for the new items has been opened with the arrays, add the element there.
        values[insertIndex] = WeakReference(value)
        hashes[insertIndex] = hash
        this.size++
        return true
    }

    /**
     * Remove an entry from the set if [block] returns true.
     *
     * This also will discard any weak references that are no longer referring to their objects.
     *
     * This call is inline to avoid allocations while enumerating the set.
     */
     inline fun removeIf(block: (T) -> Boolean) {
        val size = size
        var currentUsed = 0
        // Call `block` on all entries that still have a valid reference
        // removing entries that are not valid or return `true` from block.
        for (i in 0 until size) {
            val entry = values[i]
            val value = entry?.get()
            if (value != null && !block(value)) {
                // We are keeping this entry
                if (currentUsed != i) {
                    values[currentUsed] = entry
                    hashes[currentUsed] = hashes[i]
                }
                currentUsed++
            }
        }

        // Clear the remaining entries
        for (i in currentUsed until size) {
            values[i] = null
            hashes[i] = 0
        }

        // Adjust the size to match number of slots left.
        if (currentUsed != size) {
            this.size = currentUsed
        }
    }

    /**
     * Returns the index of [value] in the set or the negative index - 1 of the location where
     * it would have been if it had been in the set.
     */
    private fun find(value: T, hash: Int): Int {
        var low = 0
        var high = size - 1

        while (low <= high) {
            val mid = (low + high).ushr(1)
            val midHash = hashes[mid]
            when {
                midHash < hash -> low = mid + 1
                midHash > hash -> high = mid - 1
                else -> {
                    val midVal = values[mid]?.get()
                    if (value === midVal) return mid
                    return findExactIndex(mid, value, hash)
                }
            }
        }
        return -(low + 1)
    }

    /**
     * When multiple items share the same [identityHashCode], then we must find the specific
     * index of the target item. This method assumes that [midIndex] has already been checked
     * for an exact match for [value], but will look at nearby values to find the exact item index.
     * If no match is found, the negative index - 1 of the position in which it would be will
     * be returned, which is always after the last item with the same [identityHashCode].
     */
    private fun findExactIndex(midIndex: Int, value: T, valueHash: Int): Int {
        // hunt down first
        for (i in midIndex - 1 downTo 0) {
            if (hashes[i] != valueHash) {
                break // we've gone too far
            }
            val v = values[i]?.get()
            if (v === value) {
                return i
            }
        }

        for (i in midIndex + 1 until size) {
            if (hashes[i] != valueHash) {
                // We've gone too far. We should insert here.
                return -(i + 1)
            }
            val v = values[i]?.get()
            if (v === value) {
                return i
            }
        }

        // We should insert at the end
        return -(size + 1)
    }

    @TestOnly
    internal fun isValid(): Boolean {
        val size = size
        val values = values
        val hashes = hashes
        val capacity = values.size

        // Validate that the size is less than or equal to the capacity
        if (size > capacity) return false

        // Validate that the hashes are in order and they match identity hash of the value or
        // the value has been collected.
        var previous = Int.MIN_VALUE
        for (i in 0 until size) {
            val hash = hashes[i]
            if (hash < previous) return false
            val entry = values[i] ?: return false
            val value = entry.get()
            if (value != null && hash != identityHashCode(value)) return false
            previous = hash
        }

        // Validate that all hashes and entries size and above are empty
        for (i in size until capacity) {
            if (hashes[i] != 0) return false
            if (values[i] != null) return false
        }

        return true
    }
}