/*
* Copyright 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package androidx.room.solver.query.result
import androidx.room.compiler.codegen.CodeLanguage
import androidx.room.compiler.codegen.XCodeBlock
import androidx.room.compiler.codegen.XTypeName
import androidx.room.compiler.processing.XNullability
import androidx.room.compiler.processing.XType
import androidx.room.ext.CommonTypeNames
import androidx.room.ext.KotlinTypeNames
import androidx.room.solver.CodeGenScope
import androidx.room.solver.query.result.MultimapQueryResultAdapter.MapType.Companion.isSparseArray
import androidx.room.vo.ColumnIndexVar
/**
* This is an intermediary adapter class that enables nested multimap return types in DAOs.
*
* The [MapValueResultAdapter] sealed class is extended by 2 classes, [NestedMapValueResultAdapter]
* and [EndMapValueResultAdapter]. These adapters are wrappers for the adapters at different levels
* of nested maps. Each level of nesting of a map is represented by a [NestedMapValueResultAdapter],
* except the innermost level which is represented by an [EndMapValueResultAdapter].
*
* For example, if a DAO method returns a `Map<A, Map<B, Map<C, D>>>`, `Map<C, D>` is represented
* by an [EndMapValueResultAdapter], and the outer 2 levels are represented by a
* [NestedMapValueResultAdapter] each.
*
* A [NestedMapValueResultAdapter] can wrap either another [NestedMapValueResultAdapter] or an
* [EndMapValueResultAdapter], whereas an [EndMapValueResultAdapter] does not wrap another adapter
* and only contains row adapters for the innermost map.
*/
sealed class MapValueResultAdapter(
val rowAdapters: List<RowAdapter>
) {
/**
* True if this adapters requires key checking due to its values being passed by reference.
*/
abstract fun requiresContainsKeyCheck(): Boolean
/**
* Left-Hand-Side of a Map value type arg initialization.
*/
abstract fun getDeclarationTypeName(): XTypeName
/**
* Right-Hand-Side of a Map value type arg initialization.
*/
abstract fun getInstantiationTypeName(language: CodeLanguage): XTypeName
abstract fun isMigratedToDriver(): Boolean
abstract fun convert(
scope: CodeGenScope,
valuesVarName: String,
cursorVarName: String,
dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?,
genPutValueCode: (String, Boolean) -> Unit = { _, _ -> }
)
abstract fun generateContinueColumnCheck(
scope: CodeGenScope,
cursorVarName: String,
dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?
)
/**
* A [NestedMapValueResultAdapter] contains the key information and the value map information
* of any level of a nested map that is not the innermost "End" map.
*
* The [convert] function implementation for a [NestedMapValueResultAdapter] generates code that
* resolves the key of the map and delegates to the value map's [NestedMapValueResultAdapter] or
* [EndMapValueResultAdapter] (based on the level of nesting) to resolve the value map
* conversion.
*/
class NestedMapValueResultAdapter(
private val keyRowAdapter: RowAdapter,
private val keyTypeArg: XType,
private val mapType: MultimapQueryResultAdapter.MapType,
private val mapValueResultAdapter: MapValueResultAdapter
) : MapValueResultAdapter(
rowAdapters = listOf(keyRowAdapter) + mapValueResultAdapter.rowAdapters
) {
private val keyTypeName = keyTypeArg.asTypeName()
override fun requiresContainsKeyCheck(): Boolean = true
override fun getDeclarationTypeName() = when (val typeOfMap = this.mapType) {
MultimapQueryResultAdapter.MapType.DEFAULT,
MultimapQueryResultAdapter.MapType.ARRAY_MAP ->
typeOfMap.className.parametrizedBy(
keyTypeName,
mapValueResultAdapter.getDeclarationTypeName()
)
MultimapQueryResultAdapter.MapType.LONG_SPARSE,
MultimapQueryResultAdapter.MapType.INT_SPARSE ->
typeOfMap.className.parametrizedBy(
mapValueResultAdapter.getDeclarationTypeName()
)
}
override fun getInstantiationTypeName(
language: CodeLanguage
) = when (val typeOfMap = this.mapType) {
MultimapQueryResultAdapter.MapType.DEFAULT ->
// LinkedHashMap is used as impl to preserve key ordering for ordered
// query results.
when (language) {
CodeLanguage.JAVA -> CommonTypeNames.LINKED_HASH_MAP
CodeLanguage.KOTLIN -> KotlinTypeNames.LINKED_HASH_MAP
}.parametrizedBy(
keyTypeName,
mapValueResultAdapter.getDeclarationTypeName()
)
MultimapQueryResultAdapter.MapType.ARRAY_MAP ->
typeOfMap.className.parametrizedBy(
keyTypeName,
mapValueResultAdapter.getDeclarationTypeName()
)
MultimapQueryResultAdapter.MapType.LONG_SPARSE,
MultimapQueryResultAdapter.MapType.INT_SPARSE ->
typeOfMap.className.parametrizedBy(
mapValueResultAdapter.getDeclarationTypeName()
)
}
override fun isMigratedToDriver(): Boolean = mapValueResultAdapter.isMigratedToDriver()
override fun convert(
scope: CodeGenScope,
valuesVarName: String,
cursorVarName: String,
dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?,
genPutValueCode: (String, Boolean) -> Unit
) {
scope.builder.apply {
// Read map key
val tmpKeyVarName = scope.getTmpVar("_key")
addLocalVariable(tmpKeyVarName, keyTypeArg.asTypeName())
keyRowAdapter.convert(tmpKeyVarName, cursorVarName, scope)
// Generate map key check if the next value adapter is by reference
// (nested map case or collection end value)
@Suppress("NAME_SHADOWING") // On purpose to avoid miss using param
val valuesVarName = if (mapValueResultAdapter.requiresContainsKeyCheck()) {
scope.getTmpVar("_values").also { tmpValuesVarName ->
addLocalVariable(
tmpValuesVarName,
mapValueResultAdapter.getDeclarationTypeName()
)
if (mapType.isSparseArray()) {
beginControlFlow(
"if (%L.get(%L) != null)",
valuesVarName,
tmpKeyVarName
)
} else {
beginControlFlow(
"if (%L.containsKey(%L))",
valuesVarName,
tmpKeyVarName
)
}.apply {
val getFunction = when (language) {
CodeLanguage.JAVA ->
"get"
CodeLanguage.KOTLIN ->
if (mapType.isSparseArray()) "get" else "getValue"
}
addStatement(
"%L = %L.%L(%L)",
tmpValuesVarName,
valuesVarName,
getFunction,
tmpKeyVarName
)
}.nextControlFlow("else").apply {
addStatement(
"%L = %L",
tmpValuesVarName,
XCodeBlock.ofNewInstance(
language,
mapValueResultAdapter.getInstantiationTypeName(language)
)
)
addStatement(
"%L.put(%L, %L)",
valuesVarName,
tmpKeyVarName,
tmpValuesVarName
)
}.endControlFlow()
// Perform key columns null check, in a nested mapping we still add
// the key with an empty map as the value entry.
mapValueResultAdapter.generateContinueColumnCheck(
scope,
cursorVarName,
dupeColumnsIndexAdapter
)
}
} else {
valuesVarName
}
@Suppress("NAME_SHADOWING") // On purpose, to avoid using param
val genPutValueCode: (String, Boolean) -> Unit = { tmpValueVarName, doKeyCheck ->
if (doKeyCheck) {
// For consistency purposes, in the one-to-one object mapping case, if
// multiple values are encountered for the same key, we will only
// consider the first ever encountered mapping.
if (mapType.isSparseArray()) {
beginControlFlow(
"if (%L.get(%L) == null)",
valuesVarName, tmpKeyVarName
)
} else {
beginControlFlow(
"if (!%L.containsKey(%L))",
valuesVarName, tmpKeyVarName
)
}.apply {
addStatement(
"%L.put(%L, %L)",
valuesVarName, tmpKeyVarName, tmpValueVarName
)
}.endControlFlow()
} else {
addStatement(
"%L.put(%L, %L)",
valuesVarName, tmpKeyVarName, tmpValueVarName
)
}
}
mapValueResultAdapter.convert(
scope = scope,
valuesVarName = valuesVarName,
cursorVarName = cursorVarName,
dupeColumnsIndexAdapter = dupeColumnsIndexAdapter,
genPutValueCode = genPutValueCode
)
}
}
override fun generateContinueColumnCheck(
scope: CodeGenScope,
cursorVarName: String,
dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?
) {
scope.builder.add(
getContinueColumnNullCheck(
language = scope.language,
cursorVarName = cursorVarName,
rowAdapter = keyRowAdapter,
dupeColumnsIndexAdapter = dupeColumnsIndexAdapter
)
)
}
}
/**
* An [EndMapValueResultAdapter] contains only the value information regarding the innermost
* map of the returned nested map.
*
* The [convert] function implementation for an [EndMapValueResultAdapter] uses the value row
* adapter to innermost value map's value, regardless of whether it is a collection type or not.
*/
class EndMapValueResultAdapter(
private val valueRowAdapter: RowAdapter,
private val valueTypeArg: XType,
private val valueCollectionType: MultimapQueryResultAdapter.CollectionValueType?
) : MapValueResultAdapter(
rowAdapters = listOf(valueRowAdapter)
) {
override fun requiresContainsKeyCheck(): Boolean = valueCollectionType != null
// The type name of the concrete result map value
// For Map<Foo, Bar> it is Bar
// For Map<Foo, List<Bar> it is ArrayList<Bar>
override fun getDeclarationTypeName(): XTypeName {
return valueCollectionType?.className?.parametrizedBy(valueTypeArg.asTypeName())
?: valueTypeArg.asTypeName()
}
// The type name of the result map value
// For Map<Foo, Bar> it is Bar
// for Map<Foo, List<Bar> it is List<Bar>
override fun getInstantiationTypeName(language: CodeLanguage): XTypeName {
return when (valueCollectionType) {
MultimapQueryResultAdapter.CollectionValueType.LIST ->
CommonTypeNames.ARRAY_LIST.parametrizedBy(valueTypeArg.asTypeName())
MultimapQueryResultAdapter.CollectionValueType.SET ->
CommonTypeNames.HASH_SET.parametrizedBy(valueTypeArg.asTypeName())
else ->
valueTypeArg.asTypeName()
}
}
override fun isMigratedToDriver(): Boolean = valueRowAdapter.isMigratedToDriver()
override fun convert(
scope: CodeGenScope,
valuesVarName: String,
cursorVarName: String,
dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?,
genPutValueCode: (String, Boolean) -> Unit
) {
scope.builder.apply {
val tmpValueVarName = scope.getTmpVar("_value")
// If we have a collection type, then this means that we have a 1-to-many mapping
// as opposed to a 1-to-many mapping.
if (valueCollectionType != null) {
addLocalVariable(
tmpValueVarName,
valueTypeArg.asTypeName()
)
valueRowAdapter.convert(tmpValueVarName, cursorVarName, scope)
addStatement("%L.add(%L)", valuesVarName, tmpValueVarName)
} else {
check(valueRowAdapter is QueryMappedRowAdapter)
val valueIndexVars =
dupeColumnsIndexAdapter?.getIndexVarsForMapping(valueRowAdapter.mapping)
?: valueRowAdapter.getDefaultIndexAdapter().getIndexVars()
val columnNullCheckCodeBlock = getColumnNullCheckCode(
language = scope.language,
cursorVarName = cursorVarName,
indexVars = valueIndexVars
)
// Perform value columns null check, in a 1-to-1 mapping we still add the key
// with a null value entry if permitted.
beginControlFlow("if (%L)", columnNullCheckCodeBlock).apply {
if (
language == CodeLanguage.KOTLIN &&
valueTypeArg.nullability == XNullability.NONNULL
) {
addStatement(
"error(%S)",
"The column(s) of the map value object of type " +
"'$valueTypeArg' are NULL but the map's value type " +
"argument expect it to be NON-NULL"
)
} else {
genPutValueCode.invoke("null", false)
addStatement("continue")
}
}.endControlFlow()
addLocalVariable(tmpValueVarName, valueTypeArg.asTypeName())
valueRowAdapter.convert(tmpValueVarName, cursorVarName, scope)
genPutValueCode.invoke(tmpValueVarName, true)
}
}
}
override fun generateContinueColumnCheck(
scope: CodeGenScope,
cursorVarName: String,
dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?
) {
scope.builder.add(
getContinueColumnNullCheck(
language = scope.language,
cursorVarName = cursorVarName,
rowAdapter = valueRowAdapter,
dupeColumnsIndexAdapter = dupeColumnsIndexAdapter
)
)
}
}
/**
* Utility method that returns a code block containing the code expression that verifies if all
* matched fields are null.
*/
protected fun getContinueColumnNullCheck(
language: CodeLanguage,
rowAdapter: RowAdapter,
cursorVarName: String,
dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?
) = XCodeBlock.builder(language).apply {
check(rowAdapter is QueryMappedRowAdapter)
val valueIndexVars =
dupeColumnsIndexAdapter?.getIndexVarsForMapping(rowAdapter.mapping)
?: rowAdapter.getDefaultIndexAdapter().getIndexVars()
val columnNullCheckCodeBlock = getColumnNullCheckCode(
language = language,
cursorVarName = cursorVarName,
indexVars = valueIndexVars
)
beginControlFlow("if (%L)", columnNullCheckCodeBlock).apply {
addStatement("continue")
}.endControlFlow()
}.build()
/**
* Generates a code expression that verifies if all matched fields are null.
*/
protected fun getColumnNullCheckCode(
language: CodeLanguage,
cursorVarName: String,
indexVars: List<ColumnIndexVar>
) = XCodeBlock.builder(language).apply {
val space = when (language) {
CodeLanguage.JAVA -> "%W"
CodeLanguage.KOTLIN -> " "
}
val conditions = indexVars.map {
XCodeBlock.of(
language,
"%L.isNull(%L)",
cursorVarName,
it.indexVar
)
}
val placeholders = conditions.joinToString(separator = "$space&&$space") { "%L" }
add(placeholders, *conditions.toTypedArray())
}.build()
}