FieldReadWriteWriter.kt

/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package androidx.room.writer

import androidx.room.compiler.codegen.XCodeBlock
import androidx.room.compiler.codegen.XTypeName
import androidx.room.compiler.codegen.toJavaPoet
import androidx.room.ext.capitalize
import androidx.room.ext.defaultValue
import androidx.room.solver.CodeGenScope
import androidx.room.vo.CallType
import androidx.room.vo.Constructor
import androidx.room.vo.EmbeddedField
import androidx.room.vo.Field
import androidx.room.vo.FieldWithIndex
import androidx.room.vo.Pojo
import androidx.room.vo.RelationCollector
import java.util.Locale

/**
 * Handles writing a field into statement or reading it from statement.
 */
class FieldReadWriteWriter(fieldWithIndex: FieldWithIndex) {
    val field = fieldWithIndex.field
    val indexVar = fieldWithIndex.indexVar
    val alwaysExists = fieldWithIndex.alwaysExists

    companion object {
        /*
         * Get all parents including the ones which have grand children in this list but does not
         * have any direct children in the list.
         */
        fun getAllParents(fields: List<Field>): Set<EmbeddedField> {
            val allParents = mutableSetOf<EmbeddedField>()
            fun addAllParents(field: Field) {
                var parent = field.parent
                while (parent != null) {
                    if (allParents.add(parent)) {
                        parent = parent.parent
                    } else {
                        break
                    }
                }
            }
            fields.forEach(::addAllParents)
            return allParents
        }

        /**
         * Convert the fields with indices into a Node tree so that we can recursively process
         * them. This work is done here instead of parsing because the result may include arbitrary
         * fields.
         */
        private fun createNodeTree(
            rootVar: String,
            fieldsWithIndices: List<FieldWithIndex>,
            scope: CodeGenScope
        ): Node {
            val allParents = getAllParents(fieldsWithIndices.map { it.field })
            val rootNode = Node(rootVar, null)
            rootNode.directFields = fieldsWithIndices.filter { it.field.parent == null }
            val parentNodes = allParents.associateWith {
                Node(
                    varName = scope.getTmpVar("_tmp${it.field.name.capitalize(Locale.US)}"),
                    fieldParent = it
                )
            }
            parentNodes.values.forEach { node ->
                val fieldParent = node.fieldParent!!
                val grandParent = fieldParent.parent
                val grandParentNode = grandParent?.let {
                    parentNodes[it]
                } ?: rootNode
                node.directFields = fieldsWithIndices.filter { it.field.parent == fieldParent }
                node.parentNode = grandParentNode
                grandParentNode.subNodes.add(node)
            }
            return rootNode
        }

        fun bindToStatement(
            ownerVar: String,
            stmtParamVar: String,
            fieldsWithIndices: List<FieldWithIndex>,
            scope: CodeGenScope
        ) {
            fun visitNode(node: Node) {
                fun bindWithDescendants() {
                    node.directFields.forEach {
                        FieldReadWriteWriter(it).bindToStatement(
                            ownerVar = node.varName,
                            stmtParamVar = stmtParamVar,
                            scope = scope
                        )
                    }
                    node.subNodes.forEach(::visitNode)
                }

                val fieldParent = node.fieldParent
                if (fieldParent != null) {
                    fieldParent.getter.writeGet(
                        ownerVar = node.parentNode!!.varName,
                        outVar = node.varName,
                        builder = scope.builder
                    )
                    scope.builder.apply {
                        beginControlFlow("if (%L != null)", node.varName).apply {
                            bindWithDescendants()
                        }
                        nextControlFlow("else").apply {
                            node.allFields().forEach {
                                addStatement("%L.bindNull(%L)", stmtParamVar, it.indexVar)
                            }
                        }
                        endControlFlow()
                    }
                } else {
                    bindWithDescendants()
                }
            }
            visitNode(createNodeTree(ownerVar, fieldsWithIndices, scope))
        }

        /**
         * Just constructs the given item, does NOT DECLARE. Declaration happens outside the
         * reading statement since we may never read if the cursor does not have necessary
         * columns.
         */
        private fun construct(
            outVar: String,
            constructor: Constructor?,
            typeName: XTypeName,
            localVariableNames: Map<String, FieldWithIndex>,
            localEmbeddeds: List<Node>,
            localRelations: Map<String, Field>,
            scope: CodeGenScope
        ) {
            if (constructor == null) {
                // Instantiate with default constructor, best hope for code generation
                scope.builder.apply {
                    addStatement(
                        "%L = %L",
                        outVar,
                        XCodeBlock.ofNewInstance(scope.language, typeName)
                    )
                }
                return
            }
            val variableNames = constructor.params.map { param ->
                when (param) {
                    is Constructor.Param.FieldParam -> localVariableNames.entries.firstOrNull {
                        it.value.field === param.field
                    }?.key
                    is Constructor.Param.EmbeddedParam -> localEmbeddeds.firstOrNull {
                        it.fieldParent == param.embedded
                    }?.varName
                    is Constructor.Param.RelationParam -> localRelations.entries.firstOrNull {
                        it.value === param.relation.field
                    }?.key
                }
            }
            val args = variableNames.joinToString(",") { it ?: "null" }
            constructor.writeConstructor(outVar, args, scope.builder)
        }

        /**
         * Reads the row into the given variable. It does not declare it but constructs it.
         */
        fun readFromCursor(
            outVar: String,
            outPojo: Pojo,
            cursorVar: String,
            fieldsWithIndices: List<FieldWithIndex>,
            scope: CodeGenScope,
            relationCollectors: List<RelationCollector>
        ) {
            fun visitNode(node: Node) {
                val fieldParent = node.fieldParent
                fun readNode() {
                    // read constructor parameters into local fields
                    val constructorFields = node.directFields.filter {
                        it.field.setter.callType == CallType.CONSTRUCTOR
                    }.associateBy { fwi ->
                        FieldReadWriteWriter(fwi).readIntoTmpVar(
                            cursorVar,
                            fwi.field.setter.type.asTypeName(),
                            scope
                        )
                    }
                    // read decomposed fields (e.g. embedded)
                    node.subNodes.forEach(::visitNode)
                    // read relationship fields
                    val relationFields = relationCollectors.filter { (relation) ->
                        relation.field.parent === fieldParent
                    }.associate {
                        it.writeReadCollectionIntoTmpVar(
                            cursorVarName = cursorVar,
                            fieldsWithIndices = fieldsWithIndices,
                            scope = scope
                        )
                    }

                    // construct the object
                    if (fieldParent != null) {
                        construct(
                            outVar = node.varName,
                            constructor = fieldParent.pojo.constructor,
                            typeName = fieldParent.field.typeName,
                            localEmbeddeds = node.subNodes,
                            localRelations = relationFields,
                            localVariableNames = constructorFields,
                            scope = scope
                        )
                    } else {
                        construct(
                            outVar = node.varName,
                            constructor = outPojo.constructor,
                            typeName = outPojo.typeName,
                            localEmbeddeds = node.subNodes,
                            localRelations = relationFields,
                            localVariableNames = constructorFields,
                            scope = scope
                        )
                    }
                    // ready any field that was not part of the constructor
                    node.directFields.filterNot {
                        it.field.setter.callType == CallType.CONSTRUCTOR
                    }.forEach { fwi ->
                        FieldReadWriteWriter(fwi).readFromCursor(
                            ownerVar = node.varName,
                            cursorVar = cursorVar,
                            scope = scope
                        )
                    }
                    // assign sub nodes to fields if they were not part of the constructor.
                    node.subNodes.mapNotNull {
                        val setter = it.fieldParent?.setter
                        if (setter != null && setter.callType != CallType.CONSTRUCTOR) {
                            Pair(it.varName, setter)
                        } else {
                            null
                        }
                    }.forEach { (varName, setter) ->
                        setter.writeSet(
                            ownerVar = node.varName,
                            inVar = varName,
                            builder = scope.builder
                        )
                    }
                    // assign relation fields that were not part of the constructor
                    relationFields.filterNot { (_, field) ->
                        field.setter.callType == CallType.CONSTRUCTOR
                    }.forEach { (varName, field) ->
                        field.setter.writeSet(
                            ownerVar = node.varName,
                            inVar = varName,
                            builder = scope.builder
                        )
                    }
                }
                if (fieldParent == null) {
                    // root element
                    // always declared by the caller so we don't declare this
                    readNode()
                } else {
                    // always declare, we'll set below
                    scope.builder.addLocalVariable(
                        node.varName,
                        fieldParent.pojo.typeName
                    )
                    if (fieldParent.nonNull) {
                        readNode()
                    } else {
                        val myDescendants = node.allFields()
                        val allNullCheck = myDescendants.joinToString(" && ") {
                            if (it.alwaysExists) {
                                "$cursorVar.isNull(${it.indexVar})"
                            } else {
                                "(${it.indexVar} == -1 || $cursorVar.isNull(${it.indexVar}))"
                            }
                        }
                        scope.builder.apply {
                            beginControlFlow("if (!(%L))", allNullCheck).apply {
                                readNode()
                            }
                            nextControlFlow("else").apply {
                                addStatement("%L = null", node.varName)
                            }
                            endControlFlow()
                        }
                    }
                }
            }
            visitNode(createNodeTree(outVar, fieldsWithIndices, scope))
        }
    }

    /**
     * @param ownerVar The entity / pojo that owns this field. It must own this field! (not the
     * container pojo)
     * @param stmtParamVar The statement variable
     * @param scope The code generation scope
     */
    private fun bindToStatement(ownerVar: String, stmtParamVar: String, scope: CodeGenScope) {
        field.statementBinder?.let { binder ->
            val varName = if (field.getter.callType == CallType.FIELD) {
                "$ownerVar.${field.name}"
            } else {
                "$ownerVar.${field.getter.jvmName}()"
            }
            binder.bindToStmt(stmtParamVar, indexVar, varName, scope)
        }
    }

    /**
     * @param ownerVar The entity / pojo that owns this field. It must own this field (not the
     * container pojo)
     * @param cursorVar The cursor variable
     * @param scope The code generation scope
     */
    private fun readFromCursor(ownerVar: String, cursorVar: String, scope: CodeGenScope) {
        fun doRead() {
            field.cursorValueReader?.let { reader ->
                scope.builder.apply {
                    when (field.setter.callType) {
                        CallType.FIELD -> {
                            val outFieldName = "$ownerVar.${field.setter.jvmName}"
                            reader.readFromCursor(outFieldName, cursorVar, indexVar, scope)
                        }
                        CallType.METHOD -> {
                            val tmpField = scope.getTmpVar(
                                "_tmp${field.name.capitalize(Locale.US)}"
                            )
                            addLocalVariable(tmpField, field.setter.type.asTypeName())
                            reader.readFromCursor(tmpField, cursorVar, indexVar, scope)
                            addStatement("%L.%L(%L)", ownerVar, field.setter.jvmName, tmpField)
                        }
                        CallType.CONSTRUCTOR -> {
                            // no-op
                        }
                    }
                }
            }
        }
        if (alwaysExists) {
            doRead()
        } else {
            scope.builder.apply {
                beginControlFlow("if (%L != -1)", indexVar).apply {
                    doRead()
                }
                endControlFlow()
            }
        }
    }

    /**
     * Reads the value into a temporary local variable.
     */
    fun readIntoTmpVar(
        cursorVar: String,
        typeName: XTypeName,
        scope: CodeGenScope
    ): String {
        val tmpField = scope.getTmpVar("_tmp${field.name.capitalize(Locale.US)}")
        scope.builder.apply {
            addLocalVariable(tmpField, typeName)
            if (alwaysExists) {
                field.cursorValueReader?.readFromCursor(tmpField, cursorVar, indexVar, scope)
            } else {
                beginControlFlow("if (%L == -1)", indexVar).apply {
                    addStatement("%L = %L", tmpField, typeName.toJavaPoet().defaultValue())
                }
                nextControlFlow("else").apply {
                    field.cursorValueReader?.readFromCursor(tmpField, cursorVar, indexVar, scope)
                }
                endControlFlow()
            }
        }
        return tmpField
    }

    /**
     * On demand node which is created based on the fields that were passed into this class.
     */
    private class Node(
        // root for me
        val varName: String,
        // set if I'm a FieldParent
        val fieldParent: EmbeddedField?
    ) {
        // whom do i belong
        var parentNode: Node? = null
        // these fields are my direct fields
        lateinit var directFields: List<FieldWithIndex>
        // these nodes are under me
        val subNodes = mutableListOf<Node>()

        fun allFields(): List<FieldWithIndex> {
            return directFields + subNodes.flatMap { it.allFields() }
        }
    }
}