LazyStaggeredGridItemPlacementAnimator.kt

/*
 * Copyright 2023 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package androidx.compose.foundation.lazy.staggeredgrid

import androidx.compose.foundation.lazy.layout.LazyLayoutAnimateItemModifierNode
import androidx.compose.foundation.lazy.layout.LazyLayoutKeyIndexMap
import androidx.compose.ui.unit.IntOffset
import androidx.compose.ui.util.fastAny
import androidx.compose.ui.util.fastForEach

/**
 * Handles the item placement animations when it is set via
 * [LazyStaggeredGridItemScope.animateItemPlacement].
 *
 * This class is responsible for detecting when item position changed, figuring our start/end
 * offsets and starting the animations.
 */
internal class LazyStaggeredGridItemPlacementAnimator {
    // state containing relevant info for active items.
    private val keyToItemInfoMap = mutableMapOf<Any, ItemInfo>()

    // snapshot of the key to index map used for the last measuring.
    private var keyToIndexMap: LazyLayoutKeyIndexMap = LazyLayoutKeyIndexMap

    // keeps the index of the first visible item index.
    private var firstVisibleIndex = 0

    // stored to not allocate it every pass.
    private val movingAwayKeys = LinkedHashSet<Any>()
    private val movingInFromStartBound = mutableListOf<LazyStaggeredGridPositionedItem>()
    private val movingInFromEndBound = mutableListOf<LazyStaggeredGridPositionedItem>()
    private val movingAwayToStartBound = mutableListOf<LazyStaggeredGridMeasuredItem>()
    private val movingAwayToEndBound = mutableListOf<LazyStaggeredGridMeasuredItem>()

    /**
     * Should be called after the measuring so we can detect position changes and start animations.
     *
     * Note that this method can compose new item and add it into the [positionedItems] list.
     */
    fun onMeasured(
        consumedScroll: Int,
        layoutWidth: Int,
        layoutHeight: Int,
        positionedItems: MutableList<LazyStaggeredGridPositionedItem>,
        itemProvider: LazyStaggeredGridMeasureProvider,
        isVertical: Boolean,
        laneCount: Int
    ) {
        if (!positionedItems.fastAny { it.hasAnimations } && keyToItemInfoMap.isEmpty()) {
            // no animations specified - no work needed
            reset()
            return
        }

        val previousFirstVisibleIndex = firstVisibleIndex
        firstVisibleIndex = positionedItems.firstOrNull()?.index ?: 0
        val previousKeyToIndexMap = keyToIndexMap
        keyToIndexMap = itemProvider.keyToIndexMap

        val mainAxisLayoutSize = if (isVertical) layoutHeight else layoutWidth

        // the consumed scroll is considered as a delta we don't need to animate
        val scrollOffset = if (isVertical) {
            IntOffset(0, consumedScroll)
        } else {
            IntOffset(consumedScroll, 0)
        }

        // first add all items we had in the previous run
        movingAwayKeys.addAll(keyToItemInfoMap.keys)
        // iterate through the items which are visible (without animated offsets)
        positionedItems.fastForEach { item ->
            // remove items we have in the current one as they are still visible.
            movingAwayKeys.remove(item.key)
            if (item.hasAnimations) {
                val itemInfo = keyToItemInfoMap[item.key]
                // there is no state associated with this item yet
                if (itemInfo == null) {
                    keyToItemInfoMap[item.key] =
                        ItemInfo(item.lane, item.span, item.crossAxisOffset)
                    val previousIndex = previousKeyToIndexMap[item.key]
                    if (previousIndex != -1 && item.index != previousIndex) {
                        if (previousIndex < previousFirstVisibleIndex) {
                            // the larger index will be in the start of the list
                            movingInFromStartBound.add(item)
                        } else {
                            movingInFromEndBound.add(item)
                        }
                    } else {
                        initializeNode(
                            item,
                            item.offset.let { if (item.isVertical) it.y else it.x }
                        )
                    }
                } else {
                    item.forEachNode {
                        if (it.rawOffset != LazyLayoutAnimateItemModifierNode.NotInitialized) {
                            it.rawOffset += scrollOffset
                        }
                    }
                    itemInfo.lane = item.lane
                    itemInfo.span = item.span
                    itemInfo.crossAxisOffset = item.crossAxisOffset
                    startAnimationsIfNeeded(item)
                }
            } else {
                // no animation, clean up if needed
                keyToItemInfoMap.remove(item.key)
            }
        }

        val accumulatedOffsetPerLane = IntArray(laneCount) { 0 }
        if (movingInFromStartBound.isNotEmpty()) {
            movingInFromStartBound.sortByDescending { previousKeyToIndexMap[it.key] }
            movingInFromStartBound.fastForEach { item ->
                accumulatedOffsetPerLane[item.lane] += item.mainAxisSize
                val mainAxisOffset = 0 - accumulatedOffsetPerLane[item.lane]
                initializeNode(item, mainAxisOffset)
                startAnimationsIfNeeded(item)
            }
            accumulatedOffsetPerLane.fill(0)
        }
        if (movingInFromEndBound.isNotEmpty()) {
            movingInFromEndBound.sortBy { previousKeyToIndexMap[it.key] }
            movingInFromEndBound.fastForEach { item ->
                val mainAxisOffset = mainAxisLayoutSize + accumulatedOffsetPerLane[item.lane]
                accumulatedOffsetPerLane[item.lane] += item.mainAxisSize
                initializeNode(item, mainAxisOffset)
                startAnimationsIfNeeded(item)
            }
            accumulatedOffsetPerLane.fill(0)
        }

        movingAwayKeys.forEach { key ->
            // found an item which was in our map previously but is not a part of the
            // positionedItems now
            val itemInfo = keyToItemInfoMap.getValue(key)
            val newIndex = keyToIndexMap[key]

            if (newIndex == -1) {
                keyToItemInfoMap.remove(key)
            } else {
                val item = itemProvider.getAndMeasure(
                    newIndex,
                    SpanRange(itemInfo.lane, itemInfo.span)
                )
                // check if we have any active placement animation on the item
                var inProgress = false
                repeat(item.placeablesCount) {
                    if (item.getParentData(it).node?.isAnimationInProgress == true) {
                        inProgress = true
                        return@repeat
                    }
                }
                if ((!inProgress && newIndex == previousKeyToIndexMap[key])) {
                    keyToItemInfoMap.remove(key)
                } else {
                    if (newIndex < firstVisibleIndex) {
                        movingAwayToStartBound.add(item)
                    } else {
                        movingAwayToEndBound.add(item)
                    }
                }
            }
        }

        if (movingAwayToStartBound.isNotEmpty()) {
            movingAwayToStartBound.sortByDescending { keyToIndexMap[it.key] }
            movingAwayToStartBound.fastForEach { item ->
                accumulatedOffsetPerLane[item.lane] += item.mainAxisSize
                val mainAxisOffset = 0 - accumulatedOffsetPerLane[item.lane]

                val itemInfo = keyToItemInfoMap.getValue(item.key)
                val positionedItem =
                    item.position(mainAxisOffset, itemInfo.crossAxisOffset, mainAxisLayoutSize)
                positionedItems.add(positionedItem)
                startAnimationsIfNeeded(positionedItem)
            }
            accumulatedOffsetPerLane.fill(0)
        }
        if (movingAwayToEndBound.isNotEmpty()) {
            movingAwayToEndBound.sortBy { keyToIndexMap[it.key] }
            movingAwayToEndBound.fastForEach { item ->
                val mainAxisOffset = mainAxisLayoutSize + accumulatedOffsetPerLane[item.lane]
                accumulatedOffsetPerLane[item.lane] += item.mainAxisSize

                val itemInfo = keyToItemInfoMap.getValue(item.key)
                val positionedItem =
                    item.position(mainAxisOffset, itemInfo.crossAxisOffset, mainAxisLayoutSize)
                positionedItems.add(positionedItem)
                startAnimationsIfNeeded(positionedItem)
            }
        }

        movingInFromStartBound.clear()
        movingInFromEndBound.clear()
        movingAwayToStartBound.clear()
        movingAwayToEndBound.clear()
        movingAwayKeys.clear()
    }

    /**
     * Should be called when the animations are not needed for the next positions change,
     * for example when we snap to a new position.
     */
    fun reset() {
        keyToItemInfoMap.clear()
        keyToIndexMap = LazyLayoutKeyIndexMap
        firstVisibleIndex = -1
    }

    private fun initializeNode(
        item: LazyStaggeredGridPositionedItem,
        mainAxisOffset: Int
    ) {
        val firstPlaceableOffset = item.offset

        val targetFirstPlaceableOffset = if (item.isVertical) {
            firstPlaceableOffset.copy(y = mainAxisOffset)
        } else {
            firstPlaceableOffset.copy(x = mainAxisOffset)
        }

        // initialize offsets
        item.forEachNode { node ->
            val diffToFirstPlaceableOffset =
                item.offset - firstPlaceableOffset
            node.rawOffset = targetFirstPlaceableOffset + diffToFirstPlaceableOffset
        }
    }

    private fun startAnimationsIfNeeded(item: LazyStaggeredGridPositionedItem) {
        item.forEachNode { node ->
            val newTarget = item.offset
            val currentTarget = node.rawOffset
            if (currentTarget != LazyLayoutAnimateItemModifierNode.NotInitialized &&
                currentTarget != newTarget
            ) {
                node.animatePlacementDelta(newTarget - currentTarget)
            }
            node.rawOffset = newTarget
        }
    }

    private val Any?.node get() = this as? LazyLayoutAnimateItemModifierNode

    private val LazyStaggeredGridPositionedItem.hasAnimations: Boolean
        get() {
            forEachNode { return true }
            return false
        }

    private inline fun LazyStaggeredGridPositionedItem.forEachNode(
        block: (LazyLayoutAnimateItemModifierNode) -> Unit
    ) {
        repeat(placeablesCount) { index ->
            getParentData(index).node?.let(block)
        }
    }
}

private class ItemInfo(
    var lane: Int,
    var span: Int,
    var crossAxisOffset: Int
)