/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package androidx.room.writer
import androidx.room.compiler.codegen.CodeLanguage
import androidx.room.compiler.codegen.XCodeBlock
import androidx.room.compiler.codegen.XCodeBlock.Builder.Companion.addLocalVal
import androidx.room.compiler.codegen.XFunSpec
import androidx.room.compiler.codegen.XMemberName.Companion.packageMember
import androidx.room.compiler.codegen.XTypeName
import androidx.room.ext.AndroidTypeNames
import androidx.room.ext.CollectionTypeNames
import androidx.room.ext.CollectionsSizeExprCode
import androidx.room.ext.CommonTypeNames
import androidx.room.ext.InvokeWithLambdaParameter
import androidx.room.ext.KotlinTypeNames
import androidx.room.ext.LambdaSpec
import androidx.room.ext.MapKeySetExprCode
import androidx.room.ext.RoomMemberNames
import androidx.room.ext.RoomTypeNames
import androidx.room.ext.RoomTypeNames.RELATION_UTIL
import androidx.room.ext.SQLiteDriverTypeNames
import androidx.room.ext.stripNonJava
import androidx.room.solver.CodeGenScope
import androidx.room.solver.query.result.PojoRowAdapter
import androidx.room.vo.RelationCollector
/**
* Writes the function that fetches the relations of a POJO and assigns them into the given map.
*/
class RelationCollectorFunctionWriter(
private val collector: RelationCollector,
private val useDriverApi: Boolean
) : TypeWriter.SharedFunctionSpec(
baseName = if (useDriverApi) {
"fetchRelationship${collector.relation.entity.tableName.stripNonJava()}" +
"As${collector.relation.pojoTypeName.toString(CodeLanguage.JAVA).stripNonJava()}"
} else {
"fetchCompatRelationship${collector.relation.entity.tableName.stripNonJava()}" +
"As${collector.relation.pojoTypeName.toString(CodeLanguage.JAVA).stripNonJava()}"
},
) {
companion object {
const val PARAM_MAP_VARIABLE = "_map"
const val PARAM_CONNECTION_VARIABLE = "_connection"
const val KEY_SET_VARIABLE = "__mapKeySet"
}
private val usingLongSparseArray =
collector.mapTypeName.rawTypeName == CollectionTypeNames.LONG_SPARSE_ARRAY
private val usingArrayMap =
collector.mapTypeName.rawTypeName == CollectionTypeNames.ARRAY_MAP
override fun getUniqueKey(): String {
val relation = collector.relation
return "RelationCollectorMethodWriter" +
"-${collector.mapTypeName}" +
"-${relation.entity.typeName.toString(CodeLanguage.JAVA)}" +
"-${relation.entityField.columnName}" +
"-${relation.pojoTypeName}" +
"-${relation.createLoadAllSql()}" +
"-$useDriverApi"
}
override fun prepare(methodName: String, writer: TypeWriter, builder: XFunSpec.Builder) {
val scope = CodeGenScope(writer = writer, useDriverApi = useDriverApi)
scope.builder.apply {
// Check the input map key set for emptiness, returning early as no fetching is needed.
addIsInputEmptyCheck()
// Check if the input map key set exceeds MAX_BIND_PARAMETER_CNT, if so do a recursive
// fetch.
beginControlFlow(
"if (%L > %L)",
if (usingLongSparseArray) {
XCodeBlock.of(language, "%L.size()", PARAM_MAP_VARIABLE)
} else {
CollectionsSizeExprCode(language, PARAM_MAP_VARIABLE)
},
if (useDriverApi) {
"999"
} else {
XCodeBlock.of(
language,
"%T.MAX_BIND_PARAMETER_CNT",
RoomTypeNames.ROOM_DB
)
}
).apply {
addRecursiveFetchCall(scope, methodName)
addStatement("return")
}.endControlFlow()
createStmtAndReturn(scope)
}
builder.apply {
if (useDriverApi) {
addParameter(SQLiteDriverTypeNames.CONNECTION, PARAM_CONNECTION_VARIABLE)
}
addParameter(collector.mapTypeName, PARAM_MAP_VARIABLE)
addCode(scope.generate())
}
}
private fun XCodeBlock.Builder.createStmtAndReturn(
scope: CodeGenScope
) {
// Create SQL query, acquire statement and bind parameters.
val stmtVar = scope.getTmpVar("_stmt")
val cursorVar = "_cursor"
val sqlQueryVar = scope.getTmpVar("_sql")
if (useDriverApi) {
val connectionVar = scope.getTmpVar(PARAM_CONNECTION_VARIABLE)
val listSizeVars = collector.queryWriter.prepareQuery(sqlQueryVar, scope)
addLocalVal(
stmtVar,
SQLiteDriverTypeNames.STATEMENT,
"%L.prepare(%L)",
connectionVar,
sqlQueryVar
)
collector.queryWriter.bindArgs(stmtVar, listSizeVars, scope)
} else {
collector.queryWriter.prepareReadAndBind(sqlQueryVar, stmtVar, scope)
// Perform query and get a Cursor
val shouldCopyCursor = collector.rowAdapter.let {
it is PojoRowAdapter && it.relationCollectors.isNotEmpty()
}
addLocalVariable(
name = cursorVar,
typeName = AndroidTypeNames.CURSOR,
assignExpr = XCodeBlock.of(
language,
"%M(%N, %L, %L, %L)",
RoomMemberNames.DB_UTIL_QUERY,
DaoWriter.DB_PROPERTY_NAME,
stmtVar,
if (shouldCopyCursor) "true" else "false",
"null"
)
)
}
addRelationCollectorCode(scope, if (useDriverApi) stmtVar else cursorVar)
}
private fun XCodeBlock.Builder.addRelationCollectorCode(
scope: CodeGenScope,
cursorVar: String
) {
val relation = collector.relation
beginControlFlow("try").apply {
// Gets index of the column to be used as key
val itemKeyIndexVar = "_itemKeyIndex"
if (relation.junction != null) {
// When using a junction table the relationship map is keyed on the parent
// reference column of the junction table, the same column used in the WHERE IN
// clause, this column is the rightmost column in the generated SELECT
// clause.
val junctionParentColumnIndex = relation.projection.size
addStatement("// _junction.%L", relation.junction.parentField.columnName)
addLocalVal(
itemKeyIndexVar,
XTypeName.PRIMITIVE_INT,
"%L",
junctionParentColumnIndex
)
} else {
addLocalVal(
name = itemKeyIndexVar,
typeName = XTypeName.PRIMITIVE_INT,
assignExprFormat = "%M(%L, %S)",
if (useDriverApi) {
RoomTypeNames.STATEMENT_UTIL.packageMember("getColumnIndex")
} else { RoomMemberNames.CURSOR_UTIL_GET_COLUMN_INDEX },
cursorVar,
relation.entityField.columnName
)
}
// Check if index of column is not -1, indicating the column for the key is not in
// the result, can happen if the user specified a bad projection in @Relation.
beginControlFlow("if (%L == -1)", itemKeyIndexVar).apply {
addStatement("return")
}
endControlFlow()
// Prepare item column indices
collector.rowAdapter.onCursorReady(cursorVarName = cursorVar, scope = scope)
val tmpVarName = scope.getTmpVar("_item")
val stepName = if (scope.useDriverApi) "step" else "moveToNext"
beginControlFlow("while (%L.$stepName())", cursorVar).apply {
// Read key from the cursor, convert row to item and place it on map
collector.readKey(
cursorVarName = cursorVar,
indexVar = itemKeyIndexVar,
keyReader = collector.entityKeyColumnReader,
scope = scope
) { keyVar ->
if (collector.relationTypeIsCollection) {
val relationVar = scope.getTmpVar("_tmpRelation")
addLocalVal(
relationVar,
collector.relationTypeName.copy(nullable = true),
"%L.get(%L)",
PARAM_MAP_VARIABLE, keyVar
)
beginControlFlow("if (%L != null)", relationVar)
addLocalVariable(tmpVarName, relation.pojoTypeName)
collector.rowAdapter.convert(tmpVarName, cursorVar, scope)
addStatement("%L.add(%L)", relationVar, tmpVarName)
endControlFlow()
} else {
beginControlFlow("if (%N.containsKey(%L))", PARAM_MAP_VARIABLE, keyVar)
addLocalVariable(tmpVarName, relation.pojoTypeName)
collector.rowAdapter.convert(tmpVarName, cursorVar, scope)
addStatement("%N.put(%L, %L)", PARAM_MAP_VARIABLE, keyVar, tmpVarName)
endControlFlow()
}
}
}
endControlFlow()
}
nextControlFlow("finally").apply {
addStatement("%L.close()", cursorVar)
}
endControlFlow()
}
private fun XCodeBlock.Builder.addIsInputEmptyCheck() {
if (usingLongSparseArray) {
beginControlFlow("if (%L.isEmpty())", PARAM_MAP_VARIABLE)
} else {
val keySetType = CommonTypeNames.SET.parametrizedBy(collector.keyTypeName)
addLocalVariable(
name = KEY_SET_VARIABLE,
typeName = keySetType,
assignExpr = MapKeySetExprCode(language, PARAM_MAP_VARIABLE)
)
beginControlFlow("if (%L.isEmpty())", KEY_SET_VARIABLE)
}.apply {
addStatement("return")
}
endControlFlow()
}
private fun XCodeBlock.Builder.addRecursiveFetchCall(
scope: CodeGenScope,
methodName: String,
) {
val utilFunction =
RELATION_UTIL.let {
when {
usingLongSparseArray ->
it.packageMember("recursiveFetchLongSparseArray")
usingArrayMap ->
it.packageMember("recursiveFetchArrayMap")
else -> when (language) {
CodeLanguage.JAVA -> it.packageMember("recursiveFetchHashMap")
CodeLanguage.KOTLIN -> it.packageMember("recursiveFetchMap")
}
}
}
val paramName = scope.getTmpVar("_tmpMap")
val recursiveFetchBlock = InvokeWithLambdaParameter(
scope = scope,
functionName = utilFunction,
argFormat = listOf("%L", "%L"),
args = listOf(PARAM_MAP_VARIABLE, collector.relationTypeIsCollection),
lambdaSpec = object : LambdaSpec(
parameterTypeName = collector.mapTypeName,
parameterName = paramName,
returnTypeName = KotlinTypeNames.UNIT,
javaLambdaSyntaxAvailable = scope.javaLambdaSyntaxAvailable
) {
override fun XCodeBlock.Builder.body(scope: CodeGenScope) {
val recursiveCall = if (useDriverApi) {
XCodeBlock.of(
language,
"%L(%L, %L)",
methodName, PARAM_CONNECTION_VARIABLE, paramName
)
} else {
XCodeBlock.of(
language,
"%L(%L)",
methodName, paramName
)
}
addStatement("%L", recursiveCall)
if (language == CodeLanguage.JAVA) {
addStatement("return %T.INSTANCE", KotlinTypeNames.UNIT)
}
}
}
)
add("%L", recursiveFetchBlock)
}
}