ChannelManager.kt

/*
 * Copyright 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package androidx.paging.multicast

import androidx.paging.multicast.ChannelManager.Message
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.flow.Flow
import java.util.ArrayDeque
import java.util.Collections

/**
 * Tracks active downstream channels and dispatches incoming upstream values to each of them in
 * parallel. The upstream is suspended after producing a value until at least one of the downstreams
 * acknowledges receiving it via [Message.Dispatch.Value.delivered].
 *
 * The [ChannelManager] will start the upstream from the given [upstream] [Flow] if there
 * is no active upstream and there's at least one downstream that has not received a value.
 *
 */
internal class ChannelManager<T>(
    /**
     * The scope in which ChannelManager actor runs
     */
    private val scope: CoroutineScope,
    /**
     * The buffer size that is used while the upstream is active
     */
    private val bufferSize: Int,
    /**
     * If true, downstream is never closed by the ChannelManager unless upstream throws an error.
     * Instead, it is kept open and if a new downstream shows up that causes us to restart the flow,
     * it will receive values as well.
     */
    private val piggybackingDownstream: Boolean = false,
    /**
     * Called when a value is dispatched
     */
    private val onEach: suspend (T) -> Unit,

    private val keepUpstreamAlive: Boolean = false,

    private val upstream: Flow<T>
) {

    suspend fun addDownstream(channel: SendChannel<Message.Dispatch.Value<T>>) =
        actor.send(Message.AddChannel(channel))

    suspend fun removeDownstream(channel: SendChannel<Message.Dispatch.Value<T>>) =
        actor.send(Message.RemoveChannel(channel))

    suspend fun close() = actor.close()

    private val actor = Actor()

    /**
     * Actor that does all the work. Any state and functionality should go here.
     */
    private inner class Actor : StoreRealActor<Message<T>>(scope) {

        private val buffer = Buffer<T>(bufferSize)
        /**
         * The current producer
         */
        private var producer: SharedFlowProducer<T>? = null

        /**
         * Tracks whether we've ever dispatched value or error from the current producer.
         * Reset when producer finishes.
         */
        private var dispatchedValue: Boolean = false

        /**
         * The ack for the very last message we've delivered.
         * When a new downstream comes and buffer is 0, we ack this message so that new downstream
         * can immediately start receiving values instead of waiting for values that it'll never
         * receive.
         */
        private var lastDeliveryAck: CompletableDeferred<Unit>? = null

        /**
         * List of downstream collectors.
         */
        private val channels = mutableListOf<ChannelEntry<T>>()

        override suspend fun handle(msg: Message<T>) {
            when (msg) {
                is Message.AddChannel -> doAdd(msg)
                is Message.RemoveChannel -> doRemove(msg.channel)
                is Message.Dispatch.Value -> doDispatchValue(msg)
                is Message.Dispatch.Error -> doDispatchError(msg)
                is Message.Dispatch.UpstreamFinished -> doHandleUpstreamClose(msg.producer)
            }
        }

        /**
         * Called when the channel manager is active (e.g. it has downstream collectors and needs a
         * producer)
         */
        private fun newProducer() =
            SharedFlowProducer(scope, upstream, ::send)

        /**
         * We are closing. Do a cleanup on existing channels where we'll close them and also decide
         * on the list of leftovers.
         */
        private fun doHandleUpstreamClose(producer: SharedFlowProducer<T>?) {
            if (this.producer !== producer) {
                return
            }
            val piggyBacked = mutableListOf<ChannelEntry<T>>()
            val leftovers = mutableListOf<ChannelEntry<T>>()
            channels.forEach {
                when {
                    it.receivedValue -> {
                        if (!piggybackingDownstream) {
                            it.close()
                        } else {
                            piggyBacked.add(it)
                        }
                    }
                    dispatchedValue ->
                        // we dispatched a value but this channel didn't receive so put it into
                        // leftovers
                        leftovers.add(it)
                    else -> { // upstream didn't dispatch
                        if (!piggybackingDownstream) {
                            it.close()
                        } else {
                            piggyBacked.add(it)
                        }
                    }
                }
            }
            channels.clear() // empty references
            channels.addAll(leftovers)
            channels.addAll(piggyBacked)
            this.producer = null
            // we only reactivate if leftovers is not empty
            if (leftovers.isNotEmpty()) {
                activateIfNecessary()
            }
        }

        override fun onClosed() {
            channels.forEach {
                it.close()
            }
            channels.clear()
            producer?.cancel()
        }

        /**
         * Dispatch value to all downstream collectors.
         */
        private suspend fun doDispatchValue(msg: Message.Dispatch.Value<T>) {
            onEach(msg.value)
            buffer.add(msg)
            dispatchedValue = true
            if (buffer.isEmpty()) {
                // if a new downstream arrives, we need to ack this so that it won't wait for
                // values that it'll never receive
                lastDeliveryAck = msg.delivered
            }
            channels.forEach {
                it.dispatchValue(msg)
            }
        }

        /**
         * Dispatch an upstream error to downstream collectors.
         */
        private fun doDispatchError(msg: Message.Dispatch.Error<T>) {
            // dispatching error is as good as dispatching value
            dispatchedValue = true
            channels.forEach {
                it.dispatchError(msg.error)
            }
        }

        /**
         * Remove a downstream collector.
         */
        private suspend fun doRemove(channel: SendChannel<Message.Dispatch.Value<T>>) {
            val index = channels.indexOfFirst {
                it.hasChannel(channel)
            }
            if (index >= 0) {
                channels.removeAt(index)
                if (channels.isEmpty() && !keepUpstreamAlive) {
                    producer?.cancelAndJoin()
                }
            }
        }

        /**
         * Add a new downstream collector
         */
        private suspend fun doAdd(msg: Message.AddChannel<T>) {
            addEntry(
                entry = ChannelEntry(
                    channel = msg.channel
                )
            )
            activateIfNecessary()
        }

        private fun activateIfNecessary() {
            if (producer == null) {
                producer = newProducer()
                dispatchedValue = false
                producer!!.start()
            }
        }

        /**
         * Internally add the new downstream collector to our list, send it anything buffered.
         */
        private suspend fun addEntry(entry: ChannelEntry<T>) {
            val new = channels.none {
                it.hasChannel(entry)
            }
            check(new) {
                "$entry is already in the list."
            }
            check(!entry.receivedValue) {
                "$entry already received a value"
            }
            channels.add(entry)
            if (buffer.items.isNotEmpty()) {
                // if there is anything in the buffer, send it
                buffer.items.forEach {
                    entry.dispatchValue(it)
                }
            } else {
                // unlock upstream since we now have a downstream that needs values
                lastDeliveryAck?.complete(Unit)
            }
        }
    }

    /**
     * Holder for each downstream collector
     */
    internal data class ChannelEntry<T>(
        /**
         * The channel used by the collector
         */
        private val channel: SendChannel<Message.Dispatch.Value<T>>,
        /**
         * Tracking whether we've ever dispatched a value or an error to downstream
         */
        private var _receivedValue: Boolean = false
    ) {
        val receivedValue
            get() = _receivedValue

        suspend fun dispatchValue(value: Message.Dispatch.Value<T>) {
            _receivedValue = true
            channel.send(value)
        }

        fun dispatchError(error: Throwable) {
            _receivedValue = true
            channel.close(error)
        }

        fun close() {
            channel.close()
        }

        fun hasChannel(channel: SendChannel<Message.Dispatch.Value<T>>) = this.channel === channel

        fun hasChannel(entry: ChannelEntry<T>) = this.channel === entry.channel
    }

    /**
     * Messages accepted by the [ChannelManager].
     */
    sealed class Message<T> {
        /**
         * Add a new channel, that means a new downstream subscriber
         */
        class AddChannel<T>(
            val channel: SendChannel<Dispatch.Value<T>>
        ) : Message<T>()

        /**
         * Remove a downstream subscriber, that means it completed
         */
        class RemoveChannel<T>(val channel: SendChannel<Dispatch.Value<T>>) : Message<T>()

        sealed class Dispatch<T> : Message<T>() {
            /**
             * Upstream dispatched a new value, send it to all downstream items
             */
            class Value<T>(
                /**
                 * The value dispatched by the upstream
                 */
                val value: T,
                /**
                 * Ack that is completed by all receiver. Upstream producer will await this before asking
                 * for a new value from upstream
                 */
                val delivered: CompletableDeferred<Unit>
            ) : Dispatch<T>()

            /**
             * Upstream dispatched an error, send it to all downstream items
             */
            class Error<T>(
                /**
                 * The error sent by the upstream
                 */
                val error: Throwable
            ) : Dispatch<T>()

            class UpstreamFinished<T>(
                /**
                 * SharedFlowProducer finished emitting
                 */
                val producer: SharedFlowProducer<T>
            ) : Dispatch<T>()
        }
    }
}

/**
 * Buffer implementation for any late arrivals.
 */
private interface Buffer<T> {
    fun add(item: ChannelManager.Message.Dispatch.Value<T>)
    fun isEmpty() = items.isEmpty()
    val items: Collection<ChannelManager.Message.Dispatch.Value<T>>
}

/**
 * Default implementation of buffer which does not buffer anything.
 */
private class NoBuffer<T> : Buffer<T> {
    override val items: Collection<ChannelManager.Message.Dispatch.Value<T>>
        get() = Collections.emptyList()

    // ignore
    override fun add(item: ChannelManager.Message.Dispatch.Value<T>) = Unit
}

/**
 * Create a new buffer insteance based on the provided limit.
 */
@Suppress("FunctionName")
private fun <T> Buffer(limit: Int): Buffer<T> = if (limit > 0) {
    BufferImpl(limit)
} else {
    NoBuffer()
}

/**
 * A real buffer implementation that has a FIFO queue.
 */
private class BufferImpl<T>(private val limit: Int) :
    Buffer<T> {
    override val items =
        ArrayDeque<ChannelManager.Message.Dispatch.Value<T>>(limit.coerceAtMost(10))

    override fun add(item: ChannelManager.Message.Dispatch.Value<T>) {
        while (items.size >= limit) {
            items.pollFirst()
        }
        items.offerLast(item)
    }
}

internal fun <T> ChannelManager.Message.Dispatch.Value<T>.markDelivered() =
    delivered.complete(Unit)