FieldReadWriteWriter.kt

/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package androidx.room.writer

import androidx.room.compiler.codegen.CodeLanguage
import androidx.room.compiler.codegen.XCodeBlock
import androidx.room.compiler.codegen.XTypeName
import androidx.room.compiler.processing.XNullability
import androidx.room.ext.capitalize
import androidx.room.ext.defaultValue
import androidx.room.solver.CodeGenScope
import androidx.room.vo.CallType
import androidx.room.vo.Constructor
import androidx.room.vo.EmbeddedField
import androidx.room.vo.Field
import androidx.room.vo.FieldWithIndex
import androidx.room.vo.Pojo
import androidx.room.vo.RelationCollector
import java.util.Locale

/**
 * Handles writing a field into statement or reading it from statement.
 */
class FieldReadWriteWriter(fieldWithIndex: FieldWithIndex) {
    val field = fieldWithIndex.field
    val indexVar = fieldWithIndex.indexVar
    val alwaysExists = fieldWithIndex.alwaysExists

    companion object {
        /*
         * Get all parents including the ones which have grand children in this list but does not
         * have any direct children in the list.
         */
        fun getAllParents(fields: List<Field>): Set<EmbeddedField> {
            val allParents = mutableSetOf<EmbeddedField>()
            fun addAllParents(field: Field) {
                var parent = field.parent
                while (parent != null) {
                    if (allParents.add(parent)) {
                        parent = parent.parent
                    } else {
                        break
                    }
                }
            }
            fields.forEach(::addAllParents)
            return allParents
        }

        /**
         * Convert the fields with indices into a Node tree so that we can recursively process
         * them. This work is done here instead of parsing because the result may include arbitrary
         * fields.
         */
        private fun createNodeTree(
            rootVar: String,
            fieldsWithIndices: List<FieldWithIndex>,
            scope: CodeGenScope
        ): Node {
            val allParents = getAllParents(fieldsWithIndices.map { it.field })
            val rootNode = Node(rootVar, null)
            rootNode.directFields = fieldsWithIndices.filter { it.field.parent == null }
            val parentNodes = allParents.associateWith {
                Node(
                    varName = scope.getTmpVar("_tmp${it.field.name.capitalize(Locale.US)}"),
                    fieldParent = it
                )
            }
            parentNodes.values.forEach { node ->
                val fieldParent = node.fieldParent!!
                val grandParent = fieldParent.parent
                val grandParentNode = grandParent?.let {
                    parentNodes[it]
                } ?: rootNode
                node.directFields = fieldsWithIndices.filter { it.field.parent == fieldParent }
                node.parentNode = grandParentNode
                grandParentNode.subNodes.add(node)
            }
            return rootNode
        }

        fun bindToStatement(
            ownerVar: String,
            stmtParamVar: String,
            fieldsWithIndices: List<FieldWithIndex>,
            scope: CodeGenScope
        ) {
            fun visitNode(node: Node) {
                fun bindWithDescendants() {
                    node.directFields.forEach {
                        FieldReadWriteWriter(it).bindToStatement(
                            ownerVar = node.varName,
                            stmtParamVar = stmtParamVar,
                            scope = scope
                        )
                    }
                    node.subNodes.forEach(::visitNode)
                }

                val fieldParent = node.fieldParent
                if (fieldParent != null) {
                    fieldParent.getter.writeGet(
                        ownerVar = node.parentNode!!.varName,
                        outVar = node.varName,
                        builder = scope.builder
                    )
                    scope.builder.apply {
                        if (fieldParent.nonNull) {
                            bindWithDescendants()
                        } else {
                            beginControlFlow("if (%L != null)", node.varName).apply {
                                bindWithDescendants()
                            }
                            nextControlFlow("else").apply {
                                node.allFields().forEach {
                                    addStatement("%L.bindNull(%L)", stmtParamVar, it.indexVar)
                                }
                            }
                            endControlFlow()
                        }
                    }
                } else {
                    bindWithDescendants()
                }
            }
            visitNode(createNodeTree(ownerVar, fieldsWithIndices, scope))
        }

        /**
         * Just constructs the given item, does NOT DECLARE. Declaration happens outside the
         * reading statement since we may never read if the cursor does not have necessary
         * columns.
         */
        private fun construct(
            outVar: String,
            constructor: Constructor?,
            typeName: XTypeName,
            localVariableNames: Map<String, FieldWithIndex>,
            localEmbeddeds: List<Node>,
            localRelations: Map<String, Field>,
            scope: CodeGenScope
        ) {
            if (constructor == null) {
                // Instantiate with default constructor, best hope for code generation
                scope.builder.apply {
                    addStatement(
                        "%L = %L",
                        outVar,
                        XCodeBlock.ofNewInstance(scope.language, typeName)
                    )
                }
                return
            }
            val variableNames = constructor.params.map { param ->
                when (param) {
                    is Constructor.Param.FieldParam -> localVariableNames.entries.firstOrNull {
                        it.value.field === param.field
                    }?.key
                    is Constructor.Param.EmbeddedParam -> localEmbeddeds.firstOrNull {
                        it.fieldParent == param.embedded
                    }?.varName
                    is Constructor.Param.RelationParam -> localRelations.entries.firstOrNull {
                        it.value === param.relation.field
                    }?.key
                }
            }
            val args = variableNames.joinToString(",") { it ?: "null" }
            constructor.writeConstructor(outVar, args, scope.builder)
        }

        /**
         * Reads the row into the given variable. It does not declare it but constructs it.
         */
        fun readFromCursor(
            outVar: String,
            outPojo: Pojo,
            cursorVar: String,
            fieldsWithIndices: List<FieldWithIndex>,
            scope: CodeGenScope,
            relationCollectors: List<RelationCollector>
        ) {
            fun visitNode(node: Node) {
                val fieldParent = node.fieldParent
                fun readNode() {
                    // read constructor parameters into local fields
                    val constructorFields = node.directFields.filter {
                        it.field.setter.callType == CallType.CONSTRUCTOR
                    }.associateBy { fwi ->
                        FieldReadWriteWriter(fwi).readIntoTmpVar(
                            cursorVar,
                            fwi.field.setter.type.asTypeName(),
                            scope
                        )
                    }
                    // read decomposed fields (e.g. embedded)
                    node.subNodes.forEach(::visitNode)
                    // read relationship fields
                    val relationFields = relationCollectors.filter { (relation) ->
                        relation.field.parent === fieldParent
                    }.associate {
                        it.writeReadCollectionIntoTmpVar(
                            cursorVarName = cursorVar,
                            fieldsWithIndices = fieldsWithIndices,
                            scope = scope
                        )
                    }

                    // construct the object
                    if (fieldParent != null) {
                        construct(
                            outVar = node.varName,
                            constructor = fieldParent.pojo.constructor,
                            typeName = fieldParent.field.typeName,
                            localEmbeddeds = node.subNodes,
                            localRelations = relationFields,
                            localVariableNames = constructorFields,
                            scope = scope
                        )
                    } else {
                        construct(
                            outVar = node.varName,
                            constructor = outPojo.constructor,
                            typeName = outPojo.typeName,
                            localEmbeddeds = node.subNodes,
                            localRelations = relationFields,
                            localVariableNames = constructorFields,
                            scope = scope
                        )
                    }
                    // ready any field that was not part of the constructor
                    node.directFields.filterNot {
                        it.field.setter.callType == CallType.CONSTRUCTOR
                    }.forEach { fwi ->
                        FieldReadWriteWriter(fwi).readFromCursor(
                            ownerVar = node.varName,
                            cursorVar = cursorVar,
                            scope = scope
                        )
                    }
                    // assign sub nodes to fields if they were not part of the constructor.
                    node.subNodes.mapNotNull {
                        val setter = it.fieldParent?.setter
                        if (setter != null && setter.callType != CallType.CONSTRUCTOR) {
                            Pair(it.varName, setter)
                        } else {
                            null
                        }
                    }.forEach { (varName, setter) ->
                        setter.writeSet(
                            ownerVar = node.varName,
                            inVar = varName,
                            builder = scope.builder
                        )
                    }
                    // assign relation fields that were not part of the constructor
                    relationFields.filterNot { (_, field) ->
                        field.setter.callType == CallType.CONSTRUCTOR
                    }.forEach { (varName, field) ->
                        field.setter.writeSet(
                            ownerVar = node.varName,
                            inVar = varName,
                            builder = scope.builder
                        )
                    }
                }
                if (fieldParent == null) {
                    // root element
                    // always declared by the caller so we don't declare this
                    readNode()
                } else {
                    // always declare, we'll set below
                    scope.builder.addLocalVariable(
                        node.varName,
                        fieldParent.field.typeName
                    )
                    if (fieldParent.nonNull) {
                        readNode()
                    } else {
                        val myDescendants = node.allFields()
                        val allNullCheck = myDescendants.joinToString(" && ") {
                            if (it.alwaysExists) {
                                "$cursorVar.isNull(${it.indexVar})"
                            } else {
                                "(${it.indexVar} == -1 || $cursorVar.isNull(${it.indexVar}))"
                            }
                        }
                        scope.builder.apply {
                            beginControlFlow("if (!(%L))", allNullCheck).apply {
                                readNode()
                            }
                            nextControlFlow("else").apply {
                                addStatement("%L = null", node.varName)
                            }
                            endControlFlow()
                        }
                    }
                }
            }
            visitNode(createNodeTree(outVar, fieldsWithIndices, scope))
        }
    }

    /**
     * @param ownerVar The entity / pojo variable that owns this field.
     * It must own this field! (not the container pojo)
     * @param stmtParamVar The statement variable
     * @param scope The code generation scope
     */
    private fun bindToStatement(ownerVar: String, stmtParamVar: String, scope: CodeGenScope) {
        val binder = field.statementBinder ?: return
        field.getter.writeGetToStatement(
            ownerVar, stmtParamVar, indexVar, binder, scope
        )
    }

    /**
     * @param ownerVar The entity / pojo variable that owns this field.
     * It must own this field (not the container pojo)
     * @param cursorVar The cursor variable
     * @param scope The code generation scope
     */
    private fun readFromCursor(ownerVar: String, cursorVar: String, scope: CodeGenScope) {
        fun doRead() {
            val reader = field.cursorValueReader ?: return
            field.setter.writeSetFromCursor(
                ownerVar, cursorVar, indexVar, reader, scope
            )
        }
        if (alwaysExists) {
            doRead()
        } else {
            scope.builder.apply {
                beginControlFlow("if (%L != -1)", indexVar).apply {
                    doRead()
                }
                endControlFlow()
            }
        }
    }

    /**
     * Reads the value into a temporary local variable.
     */
    fun readIntoTmpVar(
        cursorVar: String,
        typeName: XTypeName,
        scope: CodeGenScope
    ): String {
        val tmpField = scope.getTmpVar("_tmp${field.name.capitalize(Locale.US)}")
        scope.builder.apply {
            addLocalVariable(tmpField, typeName)
            if (alwaysExists) {
                field.cursorValueReader?.readFromCursor(tmpField, cursorVar, indexVar, scope)
            } else {
                beginControlFlow("if (%L == -1)", indexVar).apply {
                    val defaultValue = typeName.defaultValue()
                    if (
                        language == CodeLanguage.KOTLIN &&
                        typeName.nullability == XNullability.NONNULL &&
                        defaultValue == "null"
                    ) {
                        // TODO(b/249984504): Generate / output a better message.
                        addStatement(
                            "error(%S)",
                            "Missing column '${field.columnName}' for a non null value."
                        )
                    } else {
                        addStatement("%L = %L", tmpField, defaultValue)
                    }
                }
                nextControlFlow("else").apply {
                    field.cursorValueReader?.readFromCursor(tmpField, cursorVar, indexVar, scope)
                }
                endControlFlow()
            }
        }
        return tmpField
    }

    /**
     * On demand node which is created based on the fields that were passed into this class.
     */
    private class Node(
        // root for me
        val varName: String,
        // set if I'm a FieldParent
        val fieldParent: EmbeddedField?
    ) {
        // whom do i belong
        var parentNode: Node? = null
        // these fields are my direct fields
        lateinit var directFields: List<FieldWithIndex>
        // these nodes are under me
        val subNodes = mutableListOf<Node>()

        fun allFields(): List<FieldWithIndex> {
            return directFields + subNodes.flatMap { it.allFields() }
        }
    }
}