RelationCollector.kt

/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package androidx.room.vo

import androidx.room.ext.CollectionTypeNames
import androidx.room.ext.CommonTypeNames
import androidx.room.ext.L
import androidx.room.ext.N
import androidx.room.ext.T
import androidx.room.ext.asDeclaredType
import androidx.room.ext.findTypeElement
import androidx.room.ext.requireTypeElement
import androidx.room.ext.requireTypeMirror
import androidx.room.ext.typeName
import androidx.room.parser.ParsedQuery
import androidx.room.parser.SQLTypeAffinity
import androidx.room.parser.SqlParser
import androidx.room.processor.Context
import androidx.room.processor.ProcessorErrors.cannotFindQueryResultAdapter
import androidx.room.processor.ProcessorErrors.relationAffinityMismatch
import androidx.room.processor.ProcessorErrors.relationJunctionChildAffinityMismatch
import androidx.room.processor.ProcessorErrors.relationJunctionParentAffinityMismatch
import androidx.room.solver.CodeGenScope
import androidx.room.solver.query.parameter.QueryParameterAdapter
import androidx.room.solver.query.result.RowAdapter
import androidx.room.solver.query.result.SingleColumnRowAdapter
import androidx.room.verifier.DatabaseVerificationErrors
import androidx.room.writer.QueryWriter
import androidx.room.writer.RelationCollectorMethodWriter
import com.squareup.javapoet.ClassName
import com.squareup.javapoet.CodeBlock
import com.squareup.javapoet.ParameterizedTypeName
import com.squareup.javapoet.TypeName
import stripNonJava
import java.nio.ByteBuffer
import java.util.ArrayList
import java.util.HashSet
import javax.lang.model.type.TypeMirror

/**
 * Internal class that is used to manage fetching 1/N to N relationships.
 */
data class RelationCollector(
    val relation: Relation,
    val affinity: SQLTypeAffinity,
    val mapTypeName: ParameterizedTypeName,
    val keyTypeName: TypeName,
    val relationTypeName: TypeName,
    val queryWriter: QueryWriter,
    val rowAdapter: RowAdapter,
    val loadAllQuery: ParsedQuery,
    val relationTypeIsCollection: Boolean
) {
    // variable name of map containing keys to relation collections, set when writing the code
    // generator in writeInitCode
    lateinit var varName: String

    fun writeInitCode(scope: CodeGenScope) {
        varName = scope.getTmpVar(
                "_collection${relation.field.getPath().stripNonJava().capitalize()}")
        scope.builder().apply {
            addStatement("final $T $L = new $T()", mapTypeName, varName, mapTypeName)
        }
    }

    // called to extract the key if it exists and adds it to the map of relations to fetch.
    fun writeReadParentKeyCode(
        cursorVarName: String,
        fieldsWithIndices: List<FieldWithIndex>,
        scope: CodeGenScope
    ) {
        val indexVar = fieldsWithIndices.firstOrNull {
            it.field === relation.parentField
        }?.indexVar
        scope.builder().apply {
            readKey(cursorVarName, indexVar, scope) { tmpVar ->
                if (relationTypeIsCollection) {
                    val tmpCollectionVar = scope.getTmpVar(
                        "_tmp${relation.field.name.stripNonJava().capitalize()}Collection")
                    addStatement("$T $L = $L.get($L)", relationTypeName, tmpCollectionVar,
                        varName, tmpVar)
                    beginControlFlow("if ($L == null)", tmpCollectionVar).apply {
                        addStatement("$L = new $T()", tmpCollectionVar, relationTypeName)
                        addStatement("$L.put($L, $L)", varName, tmpVar, tmpCollectionVar)
                    }
                    endControlFlow()
                } else {
                    addStatement("$L.put($L, null)", varName, tmpVar)
                }
            }
        }
    }

    // called to extract key and relation collection, defaulting to empty collection if not found
    fun writeReadCollectionIntoTmpVar(
        cursorVarName: String,
        fieldsWithIndices: List<FieldWithIndex>,
        scope: CodeGenScope
    ): Pair<String, Field> {
        val indexVar = fieldsWithIndices.firstOrNull {
            it.field === relation.parentField
        }?.indexVar
        val tmpvarNameSuffix = if (relationTypeIsCollection) "Collection" else ""
        val tmpRelationVar = scope.getTmpVar(
                "_tmp${relation.field.name.stripNonJava().capitalize()}$tmpvarNameSuffix")
        scope.builder().apply {
            addStatement("$T $L = null", relationTypeName, tmpRelationVar)
            readKey(cursorVarName, indexVar, scope) { tmpVar ->
                addStatement("$L = $L.get($L)", tmpRelationVar, varName, tmpVar)
            }
            if (relationTypeIsCollection) {
                beginControlFlow("if ($L == null)", tmpRelationVar).apply {
                    addStatement("$L = new $T()", tmpRelationVar, relationTypeName)
                }
                endControlFlow()
            }
        }
        return tmpRelationVar to relation.field
    }

    fun writeCollectionCode(scope: CodeGenScope) {
        val method = scope.writer
                .getOrCreateMethod(RelationCollectorMethodWriter(this))
        scope.builder().apply {
            addStatement("$N($L)", method, varName)
        }
    }

    fun readKey(
        cursorVarName: String,
        indexVar: String?,
        scope: CodeGenScope,
        postRead: CodeBlock.Builder.(String) -> Unit
    ) {
        val cursorGetter = when (affinity) {
            SQLTypeAffinity.INTEGER -> "getLong"
            SQLTypeAffinity.REAL -> "getDouble"
            SQLTypeAffinity.TEXT -> "getString"
            SQLTypeAffinity.BLOB -> "getBlob"
            else -> {
                "getString"
            }
        }
        scope.builder().apply {
            val keyType = if (mapTypeName.rawType == CollectionTypeNames.LONG_SPARSE_ARRAY) {
                keyTypeName.unbox()
            } else {
                keyTypeName
            }
            val tmpVar = scope.getTmpVar("_tmpKey")
            fun addKeyReadStatement() {
                if (keyTypeName == TypeName.get(ByteBuffer::class.java)) {
                    addStatement("final $T $L = $T.wrap($L.$L($L))",
                        keyType, tmpVar, keyTypeName, cursorVarName, cursorGetter, indexVar)
                } else {
                    addStatement("final $T $L = $L.$L($L)",
                        keyType, tmpVar, cursorVarName, cursorGetter, indexVar)
                }
                this.postRead(tmpVar)
            }
            if (relation.parentField.nonNull) {
                addKeyReadStatement()
            } else {
                beginControlFlow("if (!$L.isNull($L))", cursorVarName, indexVar).apply {
                    addKeyReadStatement()
                }
                endControlFlow()
            }
        }
    }

    /**
     * Adapter for binding a LongSparseArray keys into query arguments. This special adapter is only
     * used for binding the relationship query who's keys have INTEGER affinity.
     */
    private class LongSparseArrayKeyQueryParameterAdapter : QueryParameterAdapter(true) {
        override fun bindToStmt(
            inputVarName: String,
            stmtVarName: String,
            startIndexVarName: String,
            scope: CodeGenScope
        ) {
            scope.builder().apply {
                val itrIndexVar = "i"
                val itrItemVar = scope.getTmpVar("_item")
                beginControlFlow("for (int $L = 0; $L < $L.size(); i++)",
                        itrIndexVar, itrIndexVar, inputVarName).apply {
                    addStatement("long $L = $L.keyAt($L)", itrItemVar, inputVarName, itrIndexVar)
                    addStatement("$L.bindLong($L, $L)", stmtVarName, startIndexVarName, itrItemVar)
                    addStatement("$L ++", startIndexVarName)
                }
                endControlFlow()
            }
        }

        override fun getArgCount(
            inputVarName: String,
            outputVarName: String,
            scope: CodeGenScope
        ) {
            scope.builder().addStatement("final $T $L = $L.size()",
                    TypeName.INT, outputVarName, inputVarName)
        }
    }

    companion object {

        private val LONG_SPARSE_ARRAY_KEY_QUERY_PARAM_ADAPTER =
                LongSparseArrayKeyQueryParameterAdapter()

        fun createCollectors(
            baseContext: Context,
            relations: List<Relation>
        ): List<RelationCollector> {
            return relations.map { relation ->
                val context = baseContext.fork(
                    element = relation.field.element,
                    forceSuppressedWarnings = setOf(Warning.CURSOR_MISMATCH))
                val affinity = affinityFor(context, relation)
                val keyType = keyTypeFor(context, affinity)
                val (relationTypeName, isRelationCollection) = relationTypeFor(relation)
                val tmpMapType = temporaryMapTypeFor(context, affinity, keyType, relationTypeName)

                val loadAllQuery = relation.createLoadAllSql()
                val parsedQuery = SqlParser.parse(loadAllQuery)
                context.checker.check(parsedQuery.errors.isEmpty(), relation.field.element,
                        parsedQuery.errors.joinToString("\n"))
                if (parsedQuery.errors.isEmpty()) {
                    val resultInfo = context.databaseVerifier?.analyze(loadAllQuery)
                    parsedQuery.resultInfo = resultInfo
                    if (resultInfo?.error != null) {
                        context.logger.e(relation.field.element,
                                DatabaseVerificationErrors.cannotVerifyQuery(resultInfo.error))
                    }
                }
                val resultInfo = parsedQuery.resultInfo

                val usingLongSparseArray =
                    tmpMapType.rawType == CollectionTypeNames.LONG_SPARSE_ARRAY
                val queryParam = if (usingLongSparseArray) {
                    val longSparseArrayElement = context.processingEnv
                            .requireTypeElement(CollectionTypeNames.LONG_SPARSE_ARRAY)
                    QueryParameter(
                            name = RelationCollectorMethodWriter.PARAM_MAP_VARIABLE,
                            sqlName = RelationCollectorMethodWriter.PARAM_MAP_VARIABLE,
                            type = longSparseArrayElement.asDeclaredType(),
                            queryParamAdapter = LONG_SPARSE_ARRAY_KEY_QUERY_PARAM_ADAPTER
                    )
                } else {
                    val keyTypeMirror = keyTypeMirrorFor(context, affinity)
                    val set = context.processingEnv.requireTypeElement("java.util.Set")
                    val keySet = context.processingEnv.typeUtils.getDeclaredType(set, keyTypeMirror)
                    QueryParameter(
                            name = RelationCollectorMethodWriter.KEY_SET_VARIABLE,
                            sqlName = RelationCollectorMethodWriter.KEY_SET_VARIABLE,
                            type = keySet,
                            queryParamAdapter = context.typeAdapterStore.findQueryParameterAdapter(
                                keySet)
                    )
                }

                val queryWriter = QueryWriter(
                        parameters = listOf(queryParam),
                        sectionToParamMapping = listOf(Pair(parsedQuery.bindSections.first(),
                                queryParam)),
                        query = parsedQuery
                )

                // row adapter that matches full response
                fun getDefaultRowAdapter(): RowAdapter? {
                    return context.typeAdapterStore.findRowAdapter(relation.pojoType, parsedQuery)
                }
                val rowAdapter = if (relation.projection.size == 1 && resultInfo != null &&
                        (resultInfo.columns.size == 1 || resultInfo.columns.size == 2)) {
                    // check for a column adapter first
                    val cursorReader = context.typeAdapterStore.findCursorValueReader(
                            relation.pojoType, resultInfo.columns.first().type)
                    if (cursorReader == null) {
                        getDefaultRowAdapter()
                    } else {
                        SingleColumnRowAdapter(cursorReader)
                    }
                } else {
                    getDefaultRowAdapter()
                }

                if (rowAdapter == null) {
                    context.logger.e(relation.field.element,
                        cannotFindQueryResultAdapter(relation.pojoType.toString()))
                    null
                } else {
                    RelationCollector(
                            relation = relation,
                            affinity = affinity,
                            mapTypeName = tmpMapType,
                            keyTypeName = keyType,
                            relationTypeName = relationTypeName,
                            queryWriter = queryWriter,
                            rowAdapter = rowAdapter,
                            loadAllQuery = parsedQuery,
                            relationTypeIsCollection = isRelationCollection
                    )
                }
            }.filterNotNull()
        }

        // Gets and check the affinity of the relating columns.
        private fun affinityFor(context: Context, relation: Relation): SQLTypeAffinity {
            fun checkAffinity(
                first: SQLTypeAffinity?,
                second: SQLTypeAffinity?,
                onAffinityMismatch: () -> Unit
            ) = if (first != null && first == second) {
                first
            } else {
                onAffinityMismatch()
                SQLTypeAffinity.TEXT
            }

            val parentAffinity = relation.parentField.cursorValueReader?.affinity()
            val childAffinity = relation.entityField.cursorValueReader?.affinity()
            val junctionParentAffinity =
                relation.junction?.parentField?.cursorValueReader?.affinity()
            val junctionChildAffinity =
                relation.junction?.entityField?.cursorValueReader?.affinity()
            return if (relation.junction != null) {
                checkAffinity(childAffinity, junctionChildAffinity) {
                    context.logger.w(Warning.RELATION_TYPE_MISMATCH, relation.field.element,
                        relationJunctionChildAffinityMismatch(
                            childColumn = relation.entityField.columnName,
                            junctionChildColumn = relation.junction.entityField.columnName,
                            childAffinity = childAffinity,
                            junctionChildAffinity = junctionChildAffinity))
                }
                checkAffinity(parentAffinity, junctionParentAffinity) {
                    context.logger.w(Warning.RELATION_TYPE_MISMATCH, relation.field.element,
                        relationJunctionParentAffinityMismatch(
                            parentColumn = relation.parentField.columnName,
                            junctionParentColumn = relation.junction.parentField.columnName,
                            parentAffinity = parentAffinity,
                            junctionParentAffinity = junctionParentAffinity))
                }
            } else {
                checkAffinity(parentAffinity, childAffinity) {
                    context.logger.w(Warning.RELATION_TYPE_MISMATCH, relation.field.element,
                        relationAffinityMismatch(
                            parentColumn = relation.parentField.columnName,
                            childColumn = relation.entityField.columnName,
                            parentAffinity = parentAffinity,
                            childAffinity = childAffinity))
                }
            }
        }

        // Gets the resulting relation type name. (i.e. the Pojo's @Relation field type name.)
        private fun relationTypeFor(relation: Relation) =
            if (relation.field.typeName is ParameterizedTypeName) {
                val paramType = relation.field.typeName as ParameterizedTypeName
                val paramTypeName = if (paramType.rawType == CommonTypeNames.LIST) {
                    ParameterizedTypeName.get(ClassName.get(ArrayList::class.java),
                        relation.pojoTypeName)
                } else if (paramType.rawType == CommonTypeNames.SET) {
                    ParameterizedTypeName.get(ClassName.get(HashSet::class.java),
                        relation.pojoTypeName)
                } else {
                    ParameterizedTypeName.get(ClassName.get(ArrayList::class.java),
                        relation.pojoTypeName)
                }
                paramTypeName to true
            } else {
                relation.pojoTypeName to false
            }

        // Gets the type name of the temporary key map.
        private fun temporaryMapTypeFor(
            context: Context,
            affinity: SQLTypeAffinity,
            keyType: TypeName,
            relationTypeName: TypeName
        ): ParameterizedTypeName {
            val canUseLongSparseArray = context.processingEnv
                .findTypeElement(CollectionTypeNames.LONG_SPARSE_ARRAY) != null
            val canUseArrayMap = context.processingEnv
                .findTypeElement(CollectionTypeNames.ARRAY_MAP) != null
            return when {
                canUseLongSparseArray && affinity == SQLTypeAffinity.INTEGER -> {
                    ParameterizedTypeName.get(CollectionTypeNames.LONG_SPARSE_ARRAY,
                        relationTypeName)
                }
                canUseArrayMap -> {
                    ParameterizedTypeName.get(CollectionTypeNames.ARRAY_MAP,
                        keyType, relationTypeName)
                }
                else -> {
                    ParameterizedTypeName.get(ClassName.get(java.util.HashMap::class.java),
                        keyType, relationTypeName)
                }
            }
        }

        // Gets the type mirror of the relationship key.
        private fun keyTypeMirrorFor(context: Context, affinity: SQLTypeAffinity): TypeMirror {
            val processingEnv = context.processingEnv
            return when (affinity) {
                SQLTypeAffinity.INTEGER -> processingEnv.requireTypeMirror("java.lang.Long")
                SQLTypeAffinity.REAL -> processingEnv.requireTypeMirror("java.lang.Double")
                SQLTypeAffinity.TEXT -> context.COMMON_TYPES.STRING
                SQLTypeAffinity.BLOB -> processingEnv.requireTypeMirror("java.nio.ByteBuffer")
                else -> {
                    context.COMMON_TYPES.STRING
                }
            }
        }

        // Gets the type name of the relationship key.
        private fun keyTypeFor(context: Context, affinity: SQLTypeAffinity): TypeName {
            return when (affinity) {
                SQLTypeAffinity.INTEGER -> TypeName.LONG.box()
                SQLTypeAffinity.REAL -> TypeName.DOUBLE.box()
                SQLTypeAffinity.TEXT -> TypeName.get(String::class.java)
                SQLTypeAffinity.BLOB -> TypeName.get(ByteBuffer::class.java)
                else -> {
                    // no affinity select from type
                    context.COMMON_TYPES.STRING.typeName()
                }
            }
        }
    }
}