OnBackPressedDispatcher.kt

/*
 * Copyright 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package androidx.activity

import android.os.Build
import android.window.OnBackInvokedCallback
import android.window.OnBackInvokedDispatcher
import androidx.annotation.DoNotInline
import androidx.annotation.MainThread
import androidx.annotation.RequiresApi
import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleEventObserver
import androidx.lifecycle.LifecycleOwner

/**
 * Dispatcher that can be used to register [OnBackPressedCallback] instances for handling
 * the [ComponentActivity.onBackPressed] callback via composition.
 *
 * ```
 * class FormEntryFragment : Fragment() {
 *   override fun onAttach(context: Context) {
 *     super.onAttach(context)
 *     val callback = object : OnBackPressedCallback(
 *       true // default to enabled
 *     ) {
 *       override fun handleOnBackPressed() {
 *         showAreYouSureDialog()
 *       }
 *     }
 *     requireActivity().onBackPressedDispatcher.addCallback(
 *       this, // LifecycleOwner
 *       callback
 *     )
 *   }
 * }
 * ```
 *
 * When constructing an instance of this class, the [fallbackOnBackPressed] can be set to
 * receive a callback if [onBackPressed] is called when [hasEnabledCallbacks] returns `false`.
 */
class OnBackPressedDispatcher @JvmOverloads constructor(
    private val fallbackOnBackPressed: Runnable? = null
) {
    private val onBackPressedCallbacks = ArrayDeque<OnBackPressedCallback>()
    private var enabledChangedCallback: (() -> Unit)? = null
    private var onBackInvokedCallback: OnBackInvokedCallback? = null
    private var invokedDispatcher: OnBackInvokedDispatcher? = null
    private var backInvokedCallbackRegistered = false

    /**
     * Sets the [OnBackInvokedDispatcher] for handling system back for Android SDK T+.
     *
     * @param invoker the OnBackInvokedDispatcher to be set on this dispatcher
     */
    @RequiresApi(Build.VERSION_CODES.TIRAMISU)
    fun setOnBackInvokedDispatcher(invoker: OnBackInvokedDispatcher) {
        invokedDispatcher = invoker
        updateBackInvokedCallbackState()
    }

    @RequiresApi(Build.VERSION_CODES.TIRAMISU)
    internal fun updateBackInvokedCallbackState() {
        val shouldBeRegistered = hasEnabledCallbacks()
        val dispatcher = invokedDispatcher
        val onBackInvokedCallback = onBackInvokedCallback
        if (dispatcher != null && onBackInvokedCallback != null) {
            if (shouldBeRegistered && !backInvokedCallbackRegistered) {
                Api33Impl.registerOnBackInvokedCallback(
                    dispatcher,
                    OnBackInvokedDispatcher.PRIORITY_DEFAULT,
                    onBackInvokedCallback
                )
                backInvokedCallbackRegistered = true
            } else if (!shouldBeRegistered && backInvokedCallbackRegistered) {
                Api33Impl.unregisterOnBackInvokedCallback(
                    dispatcher,
                    onBackInvokedCallback
                )
                backInvokedCallbackRegistered = false
            }
        }
    }

    init {
        if (Build.VERSION.SDK_INT >= 33) {
            enabledChangedCallback = {
                updateBackInvokedCallbackState()
            }
            onBackInvokedCallback = Api33Impl.createOnBackInvokedCallback { onBackPressed() }
        }
    }

    /**
     * Add a new [OnBackPressedCallback]. Callbacks are invoked in the reverse order in which
     * they are added, so this newly added [OnBackPressedCallback] will be the first
     * callback to receive a callback if [onBackPressed] is called.
     *
     * This method is **not** [Lifecycle] aware - if you'd like to ensure that
     * you only get callbacks when at least [started][Lifecycle.State.STARTED], use
     * [addCallback]. It is expected that you
     * call [OnBackPressedCallback.remove] to manually remove your callback.
     *
     * @param onBackPressedCallback The callback to add
     *
     * @see onBackPressed
     */
    @MainThread
    fun addCallback(onBackPressedCallback: OnBackPressedCallback) {
        addCancellableCallback(onBackPressedCallback)
    }

    /**
     * Internal implementation of [addCallback] that gives
     * access to the [Cancellable] that specifically removes this callback from
     * the dispatcher without relying on [OnBackPressedCallback.remove] which
     * is what external developers should be using.
     *
     * @param onBackPressedCallback The callback to add
     * @return a [Cancellable] which can be used to [cancel][Cancellable.cancel]
     * the callback and remove it from the set of OnBackPressedCallbacks.
     */
    @MainThread
    internal fun addCancellableCallback(onBackPressedCallback: OnBackPressedCallback): Cancellable {
        onBackPressedCallbacks.add(onBackPressedCallback)
        val cancellable = OnBackPressedCancellable(onBackPressedCallback)
        onBackPressedCallback.addCancellable(cancellable)
        if (Build.VERSION.SDK_INT >= 33) {
            updateBackInvokedCallbackState()
            onBackPressedCallback.enabledChangedCallback = enabledChangedCallback
        }
        return cancellable
    }

    /**
     * Receive callbacks to a new [OnBackPressedCallback] when the given
     * [LifecycleOwner] is at least [started][Lifecycle.State.STARTED].
     *
     * This will automatically call [addCallback] and remove the callback as the lifecycle
     * state changes. As a corollary, if your lifecycle is already at least
     * [started][Lifecycle.State.STARTED], calling this method will result in an immediate
     * call to [addCallback].
     *
     * When the [LifecycleOwner] is [destroyed][Lifecycle.State.DESTROYED], it will
     * automatically be removed from the list of callbacks. The only time you would need to
     * manually call [OnBackPressedCallback.remove] is if
     * you'd like to remove the callback prior to destruction of the associated lifecycle.
     *
     * If the Lifecycle is already [destroyed][Lifecycle.State.DESTROYED]
     * when this method is called, the callback will not be added.
     *
     * @param owner The LifecycleOwner which controls when the callback should be invoked
     * @param onBackPressedCallback The callback to add
     *
     * @see onBackPressed
     */
    @MainThread
    fun addCallback(
        owner: LifecycleOwner,
        onBackPressedCallback: OnBackPressedCallback
    ) {
        val lifecycle = owner.lifecycle
        if (lifecycle.currentState === Lifecycle.State.DESTROYED) {
            return
        }
        onBackPressedCallback.addCancellable(
            LifecycleOnBackPressedCancellable(lifecycle, onBackPressedCallback)
        )
        if (Build.VERSION.SDK_INT >= 33) {
            updateBackInvokedCallbackState()
            onBackPressedCallback.enabledChangedCallback = enabledChangedCallback
        }
    }

    /**
     * Checks if there is at least one [enabled][OnBackPressedCallback.isEnabled]
     * callback registered with this dispatcher.
     *
     * @return True if there is at least one enabled callback.
     */
    @MainThread
    fun hasEnabledCallbacks(): Boolean = onBackPressedCallbacks.any {
        it.isEnabled
    }

    /**
     * Trigger a call to the currently added [callbacks][OnBackPressedCallback] in reverse
     * order in which they were added. Only if the most recently added callback is not
     * [enabled][OnBackPressedCallback.isEnabled]
     * will any previously added callback be called.
     *
     * If [hasEnabledCallbacks] is `false` when this method is called, the
     * [fallbackOnBackPressed] set by the constructor will be triggered.
     */
    @MainThread
    fun onBackPressed() {
        val callback = onBackPressedCallbacks.lastOrNull {
            it.isEnabled
        }
        if (callback != null) {
            callback.handleOnBackPressed()
            return
        }
        fallbackOnBackPressed?.run()
    }

    private inner class OnBackPressedCancellable(
        private val onBackPressedCallback: OnBackPressedCallback
    ) : Cancellable {
        override fun cancel() {
            onBackPressedCallbacks.remove(onBackPressedCallback)
            onBackPressedCallback.removeCancellable(this)
            if (Build.VERSION.SDK_INT >= 33) {
                onBackPressedCallback.enabledChangedCallback = null
                updateBackInvokedCallbackState()
            }
        }
    }

    private inner class LifecycleOnBackPressedCancellable(
        private val lifecycle: Lifecycle,
        private val onBackPressedCallback: OnBackPressedCallback
    ) : LifecycleEventObserver, Cancellable {
        private var currentCancellable: Cancellable? = null

        init {
            lifecycle.addObserver(this)
        }

        override fun onStateChanged(
            source: LifecycleOwner,
            event: Lifecycle.Event
        ) {
            if (event === Lifecycle.Event.ON_START) {
                currentCancellable = addCancellableCallback(onBackPressedCallback)
            } else if (event === Lifecycle.Event.ON_STOP) {
                // Should always be non-null
                currentCancellable?.cancel()
            } else if (event === Lifecycle.Event.ON_DESTROY) {
                cancel()
            }
        }

        override fun cancel() {
            lifecycle.removeObserver(this)
            onBackPressedCallback.removeCancellable(this)
            currentCancellable?.cancel()
            currentCancellable = null
        }
    }

    @RequiresApi(Build.VERSION_CODES.TIRAMISU)
    internal object Api33Impl {
        @DoNotInline
        fun registerOnBackInvokedCallback(
            dispatcher: Any,
            priority: Int,
            callback: Any
        ) {
            val onBackInvokedDispatcher = dispatcher as OnBackInvokedDispatcher
            val onBackInvokedCallback = callback as OnBackInvokedCallback
            onBackInvokedDispatcher.registerOnBackInvokedCallback(priority, onBackInvokedCallback)
        }

        @DoNotInline
        fun unregisterOnBackInvokedCallback(dispatcher: Any, callback: Any) {
            val onBackInvokedDispatcher = dispatcher as OnBackInvokedDispatcher
            val onBackInvokedCallback = callback as OnBackInvokedCallback
            onBackInvokedDispatcher.unregisterOnBackInvokedCallback(onBackInvokedCallback)
        }

        @DoNotInline
        fun createOnBackInvokedCallback(onBackInvoked: () -> Unit): OnBackInvokedCallback {
            return OnBackInvokedCallback { onBackInvoked() }
        }
    }
}

/**
 * Create and add a new [OnBackPressedCallback] that calls [onBackPressed] in
 * [OnBackPressedCallback.handleOnBackPressed].
 *
 * If an [owner] is specified, the callback will only be added when the Lifecycle is
 * [androidx.lifecycle.Lifecycle.State.STARTED].
 *
 * A default [enabled] state can be supplied.
 */
@Suppress("RegistrationName")
fun OnBackPressedDispatcher.addCallback(
    owner: LifecycleOwner? = null,
    enabled: Boolean = true,
    onBackPressed: OnBackPressedCallback.() -> Unit
): OnBackPressedCallback {
    val callback = object : OnBackPressedCallback(enabled) {
        override fun handleOnBackPressed() {
            onBackPressed()
        }
    }
    if (owner != null) {
        addCallback(owner, callback)
    } else {
        addCallback(callback)
    }
    return callback
}