LazyStaggeredGridMeasure.kt

/*
 * Copyright 2022 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package androidx.compose.foundation.lazy.staggeredgrid

import androidx.compose.foundation.ExperimentalFoundationApi
import androidx.compose.foundation.OverscrollEffect
import androidx.compose.foundation.checkScrollableContainerConstraints
import androidx.compose.foundation.gestures.Orientation
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.lazy.layout.LazyLayoutItemProvider
import androidx.compose.foundation.lazy.layout.LazyLayoutMeasureScope
import androidx.compose.runtime.Composable
import androidx.compose.runtime.remember
import androidx.compose.runtime.snapshots.Snapshot
import androidx.compose.ui.layout.Placeable
import androidx.compose.ui.unit.Constraints
import androidx.compose.ui.unit.Density
import androidx.compose.ui.unit.IntOffset
import androidx.compose.ui.unit.IntSize
import androidx.compose.ui.unit.constrainHeight
import androidx.compose.ui.unit.constrainWidth
import androidx.compose.ui.util.fastForEach
import kotlin.math.abs
import kotlin.math.roundToInt
import kotlin.math.sign

@Composable
@ExperimentalFoundationApi
internal fun rememberStaggeredGridMeasurePolicy(
    state: LazyStaggeredGridState,
    itemProvider: LazyLayoutItemProvider,
    contentPadding: PaddingValues,
    reverseLayout: Boolean,
    orientation: Orientation,
    verticalArrangement: Arrangement.Vertical,
    horizontalArrangement: Arrangement.Horizontal,
    slotSizesSums: Density.(Constraints) -> IntArray,
    overscrollEffect: OverscrollEffect
): LazyLayoutMeasureScope.(Constraints) -> LazyStaggeredGridMeasureResult = remember(
    state,
    itemProvider,
    contentPadding,
    reverseLayout,
    orientation,
    verticalArrangement,
    horizontalArrangement,
    slotSizesSums,
    overscrollEffect,
) {
    { constraints ->
        checkScrollableContainerConstraints(
            constraints,
            orientation
        )
        val isVertical = orientation == Orientation.Vertical

        val resolvedSlotSums = slotSizesSums(this, constraints)
        val itemCount = itemProvider.itemCount

        val mainAxisAvailableSize =
            if (isVertical) constraints.maxHeight else constraints.maxWidth

        val measuredItemProvider = LazyStaggeredGridMeasureProvider(
            isVertical,
            itemProvider,
            this,
            resolvedSlotSums
        ) { index, key, placeables ->
            LazyStaggeredGridMeasuredItem(
                index,
                key,
                placeables,
                isVertical
            )
        }

        val beforeContentPadding = 0
        val afterContentPadding = 0

        val initialItemIndices: IntArray
        val initialItemOffsets: IntArray

        Snapshot.withoutReadObservation {
            initialItemIndices =
                if (state.firstVisibleItems.size == resolvedSlotSums.size) {
                    state.firstVisibleItems
                } else {
                    IntArray(resolvedSlotSums.size) { -1 }
                }
            initialItemOffsets =
                if (state.firstVisibleItemScrollOffsets.size == resolvedSlotSums.size) {
                    state.firstVisibleItemScrollOffsets
                } else {
                    IntArray(resolvedSlotSums.size) { 0 }
                }
        }

        val spans = state.spans
        val firstItemIndices = initialItemIndices.copyOf()
        val firstItemOffsets = initialItemOffsets.copyOf()

        // Measure items

        if (itemCount <= 0) {
            LazyStaggeredGridMeasureResult(
                firstVisibleItemIndices = IntArray(0),
                firstVisibleItemScrollOffsets = IntArray(0),
                consumedScroll = 0f,
                measureResult = layout(constraints.minWidth, constraints.minHeight) {},
                canScrollForward = false,
                canScrollBackward = false,
                visibleItemsInfo = emptyArray()
            )
        } else {
            // todo(b/182882362): content padding

            // represents the real amount of scroll we applied as a result of this measure pass.
            var scrollDelta = state.scrollToBeConsumed.roundToInt()

            // applying the whole requested scroll offset. we will figure out if we can't consume
            // all of it later
            firstItemOffsets.offsetBy(-scrollDelta)

            // if the current scroll offset is less than minimally possible
            if (firstItemIndices[0] == 0 && firstItemOffsets[0] < 0) {
                scrollDelta += firstItemOffsets[0]
                firstItemOffsets.fill(0)
            }

            // this will contain all the MeasuredItems representing the visible items
            val measuredItems = Array(resolvedSlotSums.size) {
                mutableListOf<LazyStaggeredGridMeasuredItem>()
            }

            // include the start padding so we compose items in the padding area. before starting
            // scrolling forward we would remove it back
            firstItemOffsets.offsetBy(-beforeContentPadding)

            // define min and max offsets (min offset currently includes beforeContentPadding)
            val minOffset = -beforeContentPadding
            val maxOffset = mainAxisAvailableSize

            fun hasSpaceOnTop(): Boolean {
                for (column in firstItemIndices.indices) {
                    val itemIndex = firstItemIndices[column]
                    val itemOffset = firstItemOffsets[column]

                    if (itemOffset <= 0 && itemIndex > 0) {
                        return true
                    }
                }

                return false
            }

            // we had scrolled backward or we compose items in the start padding area, which means
            // items before current firstItemScrollOffset should be visible. compose them and update
            // firstItemScrollOffset
            while (hasSpaceOnTop()) {
                val columnIndex = firstItemOffsets.indexOfMinValue()
                val previousItemIndex = spans.findPreviousItemIndex(
                    item = firstItemIndices[columnIndex],
                    column = columnIndex
                )

                if (previousItemIndex < 0) {
                    break
                }

                if (spans.getSpan(previousItemIndex) == SpanLookup.SpanUnset) {
                    spans.setSpan(previousItemIndex, columnIndex)
                }

                val measuredItem = measuredItemProvider.getAndMeasure(
                    previousItemIndex,
                    columnIndex
                )
                measuredItems[columnIndex].add(0, measuredItem)

                firstItemIndices[columnIndex] = previousItemIndex
                firstItemOffsets[columnIndex] += measuredItem.sizeWithSpacings
            }

            // if we were scrolled backward, but there were not enough items before. this means
            // not the whole scroll was consumed
            if (firstItemOffsets[0] < minOffset) {
                scrollDelta += firstItemOffsets[0]
                firstItemOffsets.offsetBy(minOffset - firstItemOffsets[0])
            }

            val currentItemIndices = initialItemIndices.copyOf()
            val currentItemOffsets = IntArray(initialItemOffsets.size) {
                -(initialItemOffsets[it] - scrollDelta)
            }

            // neutralize previously added start padding as we stopped filling the before content padding
            firstItemOffsets.offsetBy(beforeContentPadding)

            val maxMainAxis = (maxOffset + afterContentPadding).coerceAtLeast(0)

            // compose first visible items we received from state
            currentItemIndices.forEachIndexed { columnIndex, itemIndex ->
                if (itemIndex == -1) return@forEachIndexed

                val measuredItem = measuredItemProvider.getAndMeasure(itemIndex, columnIndex)
                currentItemOffsets[columnIndex] += measuredItem.sizeWithSpacings

                if (
                    currentItemOffsets[columnIndex] <= minOffset &&
                        measuredItem.index != itemCount - 1
                ) {
                    // this item is offscreen and will not be placed. advance item index
                    firstItemIndices[columnIndex] = -1
                    firstItemOffsets[columnIndex] -= measuredItem.sizeWithSpacings
                } else {
                    measuredItems[columnIndex].add(measuredItem)
                }
            }

            // then composing visible items forward until we fill the whole viewport.
            // we want to have at least one item in visibleItems even if in fact all the items are
            // offscreen, this can happen if the content padding is larger than the available size.
            while (
                currentItemOffsets.any { it <= maxMainAxis } ||
                    measuredItems.all { it.isEmpty() }
            ) {
                val columnIndex = currentItemOffsets.indexOfMinValue()
                val nextItemIndex = spans.findNextItemIndex(
                    currentItemIndices[columnIndex],
                    columnIndex
                )

                if (nextItemIndex == itemCount) {
                    break
                }

                if (firstItemIndices[columnIndex] == -1) {
                    firstItemIndices[columnIndex] = nextItemIndex
                }
                spans.setSpan(nextItemIndex, columnIndex)

                val measuredItem = measuredItemProvider.getAndMeasure(nextItemIndex, columnIndex)
                currentItemOffsets[columnIndex] += measuredItem.sizeWithSpacings

                if (
                    currentItemOffsets[columnIndex] <= minOffset &&
                        measuredItem.index != itemCount - 1
                ) {
                    // this item is offscreen and will not be placed. advance item index
                    firstItemIndices[columnIndex] = -1
                    firstItemOffsets[columnIndex] -= measuredItem.sizeWithSpacings
                } else {
                    measuredItems[columnIndex].add(measuredItem)
                }

                currentItemIndices[columnIndex] = nextItemIndex
            }

            // we didn't fill the whole viewport with items starting from firstVisibleItemIndex.
            // lets try to scroll back if we have enough items before firstVisibleItemIndex.
            if (currentItemOffsets.all { it < maxOffset }) {
                val maxOffsetColumn = currentItemOffsets.indexOfMaxValue()
                val toScrollBack = maxOffset - currentItemOffsets[maxOffsetColumn]
                firstItemOffsets.offsetBy(-toScrollBack)
                currentItemOffsets.offsetBy(toScrollBack)
                while (
                    firstItemOffsets.any { it < beforeContentPadding } &&
                        firstItemIndices.all { it != 0 }
                ) {
                    val columnIndex = firstItemOffsets.indexOfMinValue()
                    val currentIndex =
                        if (firstItemIndices[columnIndex] == -1) {
                            itemCount
                        } else {
                            firstItemIndices[columnIndex]
                        }

                    val previousIndex =
                        spans.findPreviousItemIndex(currentIndex, columnIndex)

                    if (previousIndex < 0) {
                        break
                    }

                    val measuredItem = measuredItemProvider.getAndMeasure(
                        previousIndex,
                        columnIndex
                    )
                    measuredItems[columnIndex].add(0, measuredItem)
                    firstItemOffsets[columnIndex] += measuredItem.sizeWithSpacings
                    firstItemIndices[columnIndex] = previousIndex
                }
                scrollDelta += toScrollBack

                val minOffsetColumn = firstItemOffsets.indexOfMinValue()
                if (firstItemOffsets[minOffsetColumn] < 0) {
                    val offsetValue = firstItemOffsets[minOffsetColumn]
                    scrollDelta += offsetValue
                    currentItemOffsets.offsetBy(offsetValue)
                    firstItemOffsets.offsetBy(-offsetValue)
                }
            }

            // report the amount of pixels we consumed. scrollDelta can be smaller than
            // scrollToBeConsumed if there were not enough items to fill the offered space or it
            // can be larger if items were resized, or if, for example, we were previously
            // displaying the item 15, but now we have only 10 items in total in the data set.
            val consumedScroll = if (
                state.scrollToBeConsumed.roundToInt().sign == scrollDelta.sign &&
                    abs(state.scrollToBeConsumed.roundToInt()) >= abs(scrollDelta)
            ) {
                scrollDelta.toFloat()
            } else {
                state.scrollToBeConsumed
            }

            // todo(b/182882362):
            // even if we compose items to fill before content padding we should ignore items fully
            // located there for the state's scroll position calculation (first item + first offset)

            // end measure

            val layoutWidth = if (isVertical) {
                constraints.maxWidth
            } else {
                constraints.constrainWidth(currentItemOffsets.max())
            }
            val layoutHeight = if (isVertical) {
                constraints.constrainHeight(currentItemOffsets.max())
            } else {
                constraints.maxHeight
            }

            // Placement

            val itemScrollOffsets = firstItemOffsets.map { -it }
            val positionedItems = Array(measuredItems.size) {
                mutableListOf<LazyStaggeredGridPositionedItem>()
            }

            var currentCrossAxis = 0
            measuredItems.forEachIndexed { i, columnItems ->
                var currentMainAxis = itemScrollOffsets[i]

                // todo(b/182882362): arrangement/spacing support

                columnItems.fastForEach { item ->
                    positionedItems[i] += item.position(
                        currentMainAxis,
                        currentCrossAxis,
                    )
                    currentMainAxis += item.sizeWithSpacings
                }
                if (columnItems.isNotEmpty()) {
                    currentCrossAxis += columnItems[0].crossAxisSize
                }
            }

            // End placement

            // todo: reverse layout support
            // only scroll backward if the first item is not on screen or fully visible
            val canScrollBackward = !(firstItemIndices[0] == 0 && firstItemOffsets[0] <= 0)
            // only scroll forward if the last item is not on screen or fully visible
            val canScrollForward = currentItemIndices.indexOf(itemCount - 1).let { columnIndex ->
                if (columnIndex == -1) {
                    true
                } else {
                    (currentItemOffsets[columnIndex] -
                        measuredItems[columnIndex].last().sizeWithSpacings) < mainAxisAvailableSize
                }
            }

            @Suppress("UNCHECKED_CAST")
            LazyStaggeredGridMeasureResult(
                firstVisibleItemIndices = firstItemIndices,
                firstVisibleItemScrollOffsets = firstItemOffsets,
                consumedScroll = consumedScroll,
                measureResult = layout(layoutWidth, layoutHeight) {
                    positionedItems.forEach {
                        it.fastForEach { item ->
                            item.place(this)
                        }
                    }
                },
                canScrollForward = canScrollForward,
                canScrollBackward = canScrollBackward,
                visibleItemsInfo = positionedItems as Array<List<LazyStaggeredGridItemInfo>>
            ).also {
                state.applyMeasureResult(it)
                refreshOverscrollInfo(overscrollEffect, it)
            }
        }
    }
}

@OptIn(ExperimentalFoundationApi::class)
private fun refreshOverscrollInfo(
    overscrollEffect: OverscrollEffect,
    result: LazyStaggeredGridMeasureResult
) {
    overscrollEffect.isEnabled = result.canScrollForward || result.canScrollBackward
}

private fun IntArray.offsetBy(delta: Int) {
    for (i in indices) {
        this[i] = this[i] + delta
    }
}

private fun IntArray.indexOfMinValue(): Int {
    var result = -1
    var min = Int.MAX_VALUE
    for (i in indices) {
        if (min > this[i]) {
            min = this[i]
            result = i
        }
    }

    return result
}

private fun IntArray.indexOfMaxValue(): Int {
    var result = -1
    var max = Int.MIN_VALUE
    for (i in indices) {
        if (max < this[i]) {
            max = this[i]
            result = i
        }
    }

    return result
}

private fun SpanLookup.findPreviousItemIndex(item: Int, column: Int): Int {
    for (i in (item - 1) downTo 0) {
        val span = getSpan(i)
        if (span == column || span == SpanLookup.SpanUnset) {
            return i
        }
    }
    return -1
}

private fun SpanLookup.findNextItemIndex(item: Int, column: Int): Int {
    for (i in (item + 1) until capacity()) {
        val span = getSpan(i)
        if (span == column || span == SpanLookup.SpanUnset) {
            return i
        }
    }
    return capacity()
}

@OptIn(ExperimentalFoundationApi::class)
private class LazyStaggeredGridMeasureProvider(
    private val isVertical: Boolean,
    private val itemProvider: LazyLayoutItemProvider,
    private val measureScope: LazyLayoutMeasureScope,
    private val resolvedSlotSums: IntArray,
    private val measuredItemFactory: MeasuredItemFactory
) {
    fun childConstraints(slot: Int): Constraints {
        val previousSum = if (slot == 0) 0 else resolvedSlotSums[slot - 1]
        val crossAxisSize = resolvedSlotSums[slot] - previousSum
        return if (isVertical) {
            Constraints.fixedWidth(crossAxisSize)
        } else {
            Constraints.fixedHeight(crossAxisSize)
        }
    }

    fun getAndMeasure(index: Int, slot: Int): LazyStaggeredGridMeasuredItem {
        val key = itemProvider.getKey(index)
        val placeables = measureScope.measure(index, childConstraints(slot))
        return measuredItemFactory.createItem(index, key, placeables)
    }
}

private class LazyStaggeredGridMeasuredItem(
    val index: Int,
    val key: Any,
    val placeables: Array<Placeable>,
    val isVertical: Boolean
) {
    val sizeWithSpacings: Int = placeables.fold(0) { size, placeable ->
        size + if (isVertical) placeable.height else placeable.width
    }

    val crossAxisSize: Int = placeables.maxOf {
        if (isVertical) it.width else it.height
    }

    fun position(
        mainAxis: Int,
        crossAxis: Int,
    ): LazyStaggeredGridPositionedItem =
        LazyStaggeredGridPositionedItem(
            offset = if (isVertical) {
                IntOffset(crossAxis, mainAxis)
            } else {
                IntOffset(mainAxis, crossAxis)
            },
            index = index,
            key = key,
            size = IntSize(sizeWithSpacings, crossAxisSize),
            placeables = placeables
        )
}

private class LazyStaggeredGridPositionedItem(
    override val offset: IntOffset,
    override val index: Int,
    override val key: Any,
    override val size: IntSize,
    val placeables: Array<Placeable>
) : LazyStaggeredGridItemInfo {
    fun place(scope: Placeable.PlacementScope) = with(scope) {
        placeables.forEach { placeable ->
            placeable.placeWithLayer(offset)
        }
    }
}

// This interface allows to avoid autoboxing on index param
private fun interface MeasuredItemFactory {
    fun createItem(
        index: Int,
        key: Any,
        placeables: Array<Placeable>
    ): LazyStaggeredGridMeasuredItem
}