/*
* Copyright 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package androidx.paging.testing
import androidx.paging.CombinedLoadStates
import androidx.paging.DifferCallback
import androidx.paging.ItemSnapshotList
import androidx.paging.LoadState
import androidx.paging.LoadStates
import androidx.paging.NullPaddedList
import androidx.paging.Pager
import androidx.paging.PagingData
import androidx.paging.PagingDataDiffer
import androidx.paging.testing.ErrorRecovery.THROW
import androidx.paging.testing.ErrorRecovery.RETRY
import androidx.paging.testing.ErrorRecovery.RETURN_CURRENT_SNAPSHOT
import androidx.paging.testing.LoaderCallback.CallbackType.ON_CHANGED
import androidx.paging.testing.LoaderCallback.CallbackType.ON_INSERTED
import androidx.paging.testing.LoaderCallback.CallbackType.ON_REMOVED
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.collectLatest
import kotlinx.coroutines.flow.debounce
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
/**
* Runs the [SnapshotLoader] load operations that are passed in and returns a List of data
* that would be presented to the UI after all load operations are complete.
*
* @param coroutineScope The [CoroutineScope] to collect from this Flow<PagingData> and contains
* the [CoroutineScope.coroutineContext] to load data from.
*
* @param onError The error recovery strategy when PagingSource returns LoadResult.Error. A lambda
* that returns an [ErrorRecovery] value. The default strategy is [ErrorRecovery.THROW].
*
* @param loadOperations The block containing [SnapshotLoader] load operations.
*/
public suspend fun <Value : Any> Flow<PagingData<Value>>.asSnapshot(
coroutineScope: CoroutineScope,
onError: LoadErrorHandler = LoadErrorHandler { THROW },
loadOperations: suspend SnapshotLoader<Value>.() -> @JvmSuppressWildcards Unit
): @JvmSuppressWildcards List<Value> {
lateinit var loader: SnapshotLoader<Value>
val callback = object : DifferCallback {
override fun onChanged(position: Int, count: Int) {
loader.onDataSetChanged(
loader.generations.value,
LoaderCallback(ON_CHANGED, position, count)
)
}
override fun onInserted(position: Int, count: Int) {
loader.onDataSetChanged(
loader.generations.value,
LoaderCallback(ON_INSERTED, position, count)
)
}
override fun onRemoved(position: Int, count: Int) {
loader.onDataSetChanged(
loader.generations.value,
LoaderCallback(ON_REMOVED, position, count)
)
}
}
// PagingDataDiffer automatically switches to Dispatchers.Main to call presentNewList
val differ = object : PagingDataDiffer<Value>(callback) {
override suspend fun presentNewList(
previousList: NullPaddedList<Value>,
newList: NullPaddedList<Value>,
lastAccessedIndex: Int,
onListPresentable: () -> Unit
): Int? {
onListPresentable()
/**
* On new generation, SnapshotLoader needs the latest [ItemSnapshotList]
* state so that it can initialize lastAccessedIndex to prepend/append from onwards.
*
* This initial lastAccessedIndex is necessary because initial load
* key may not be 0, for example when [Pager].initialKey != 0. It is calculated
* based on [ItemSnapshotList.placeholdersBefore] + [1/2 initial load size] to match
* the initial ViewportHint that [PagingDataDiffer.presentNewList] sends on
* first generation to auto-trigger prefetches on either direction.
*
* Any subsequent SnapshotLoader loads are based on the index tracked by
* [SnapshotLoader] internally.
*/
val lastLoadedIndex = snapshot().placeholdersBefore + (snapshot().items.size / 2)
loader.generations.value.lastAccessedIndex.set(lastLoadedIndex)
return null
}
}
loader = SnapshotLoader(differ, onError)
/**
* Launches collection on this [Pager.flow].
*
* The collection job is cancelled automatically after [loadOperations] completes.
*/
val collectPagingData = coroutineScope.launch {
this@asSnapshot.collectLatest {
incrementGeneration(loader)
differ.collectFrom(it)
}
}
/**
* Runs the input [loadOperations].
*
* Awaits for initial refresh to complete before invoking [loadOperations]. Automatically
* cancels the collection on this [Pager.flow] after [loadOperations] completes and Paging
* is idle.
*
* Returns a List of loaded data.
*/
return withContext(coroutineScope.coroutineContext) {
try {
differ.awaitNotLoading(onError)
loader.loadOperations()
differ.awaitNotLoading(onError)
} catch (stub: ReturnSnapshotStub) {
// we just want to stub and return snapshot early
} catch (throwable: Throwable) {
throw throwable
} finally {
collectPagingData.cancelAndJoin()
}
differ.snapshot().items
}
}
/**
* Awaits until both source and mediator states are NotLoading. We do not care about the state of
* endOfPaginationReached. Source and mediator states need to be checked individually because
* the aggregated LoadStates can reflect `NotLoading` when source states are `Loading`.
*
* We debounce(1ms) to prevent returning too early if this collected a `NotLoading` from the
* previous load. Without a way to determine whether the `NotLoading` it collected was from
* a previous operation or current operation, we debounce 1ms to allow collection on a potential
* incoming `Loading` state.
*/
@OptIn(kotlinx.coroutines.FlowPreview::class)
internal suspend fun <Value : Any> PagingDataDiffer<Value>.awaitNotLoading(
errorHandler: LoadErrorHandler
) {
val state = loadStateFlow.filterNotNull().debounce(1).filter {
it.isIdle() || it.hasError()
}.firstOrNull()
if (state != null && state.hasError()) {
handleLoadError(state, errorHandler)
}
}
internal fun <Value : Any> PagingDataDiffer<Value>.handleLoadError(
state: CombinedLoadStates,
errorHandler: LoadErrorHandler
) {
val recovery = errorHandler.onError(state)
when (recovery) {
THROW -> throw (state.getErrorState()).error
RETRY -> retry()
RETURN_CURRENT_SNAPSHOT -> throw ReturnSnapshotStub()
}
}
private class ReturnSnapshotStub : Exception()
private fun CombinedLoadStates?.isIdle(): Boolean {
if (this == null) return false
return source.isIdle() && mediator?.isIdle() ?: true
}
private fun LoadStates.isIdle(): Boolean {
return refresh is LoadState.NotLoading && append is LoadState.NotLoading &&
prepend is LoadState.NotLoading
}
private fun CombinedLoadStates?.hasError(): Boolean {
if (this == null) return false
return source.hasError() || mediator?.hasError() ?: false
}
private fun LoadStates.hasError(): Boolean {
return refresh is LoadState.Error || append is LoadState.Error ||
prepend is LoadState.Error
}
private fun CombinedLoadStates.getErrorState(): LoadState.Error {
return if (refresh is LoadState.Error) {
refresh as LoadState.Error
} else if (append is LoadState.Error) {
append as LoadState.Error
} else {
prepend as LoadState.Error
}
}
private fun <Value : Any> incrementGeneration(loader: SnapshotLoader<Value>) {
val currGen = loader.generations.value
if (currGen.id == loader.generations.value.id) {
loader.generations.value = Generation(
id = currGen.id + 1
)
}
}