/*
* Copyright 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package androidx.remotecallback.compiler;
import com.squareup.javapoet.AnnotationSpec;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.CodeBlock;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import javax.annotation.processing.Messager;
import javax.annotation.processing.ProcessingEnvironment;
import javax.lang.model.element.AnnotationMirror;
import javax.lang.model.element.AnnotationValue;
import javax.lang.model.element.Element;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.Modifier;
import javax.lang.model.element.VariableElement;
import javax.lang.model.type.ExecutableType;
import javax.lang.model.type.TypeMirror;
import javax.tools.Diagnostic;
/**
*
*/
public class CallableMethod {
private static final String EXTERNAL_INPUT = "androidx.remotecallback.ExternalInput";
private static final String BYTE = "byte";
private static final String CHAR = "char";
private static final String SHORT = "short";
private static final String INT = "int";
private static final String LONG = "long";
private static final String FLOAT = "float";
private static final String DOUBLE = "double";
private static final String BOOLEAN = "boolean";
private static final String BYTE_ARRAY = "byte[]";
private static final String CHAR_ARRAY = "char[]";
private static final String SHORT_ARRAY = "short[]";
private static final String INT_ARRAY = "int[]";
private static final String LONG_ARRAY = "long[]";
private static final String FLOAT_ARRAY = "float[]";
private static final String DOUBLE_ARRAY = "double[]";
private static final String BOOLEAN_ARRAY = "boolean[]";
private static final String STRING_ARRAY = "java.lang.String[]";
private static final String CONTEXT = "android.content.Context";
private static final String STRING = "java.lang.String";
private static final String URI = "android.net.Uri";
private static final String OBJ_BYTE = "java.lang.Byte";
private static final String CHARACTER = "java.lang.Character";
private static final String OBJ_SHORT = "java.lang.Short";
private static final String INTEGER = "java.lang.Integer";
private static final String OBJ_LONG = "java.lang.Long";
private static final String OBJ_FLOAT = "java.lang.Float";
private static final String OBJ_DOUBLE = "java.lang.Double";
private static final String OBJ_BOOLEAN = "java.lang.Boolean";
private final Element mElement;
private final ArrayList<TypeMirror> mTypes = new ArrayList<>();
private final ArrayList<String> mNames = new ArrayList<>();
private final ArrayList<String> mExtInputKeys = new ArrayList<>();
private final String mClsName;
private final ProcessingEnvironment mEnv;
private TypeMirror mReturnType;
public CallableMethod(String name, Element element, ProcessingEnvironment env) {
mClsName = name;
mElement = element;
mEnv = env;
init();
}
/**
* Get the name of the method this class is representing/tracking.
*/
public String getName() {
return mElement.getSimpleName().toString();
}
private void init() {
ExecutableType type = (ExecutableType) mElement.asType();
ExecutableElement element = (ExecutableElement) mElement;
List<? extends TypeMirror> types = type.getParameterTypes();
List<? extends VariableElement> vars = element.getParameters();
mReturnType = element.getReturnType();
for (int i = 0; i < types.size(); i++) {
mTypes.add(types.get(i));
AnnotationMirror mirror = findAnnotation(vars.get(i), EXTERNAL_INPUT);
mExtInputKeys.add(mirror != null ? getValue(mirror, "value", null) : null);
mNames.add(vars.get(i).getSimpleName().toString());
}
}
private AnnotationMirror findAnnotation(VariableElement element, String cls) {
for (AnnotationMirror mirror: element.getAnnotationMirrors()) {
if (mirror.getAnnotationType().toString().equals(cls)) {
return mirror;
}
}
return null;
}
private String getValue(AnnotationMirror annotation, String name, String defValue) {
Map<? extends ExecutableElement, ? extends AnnotationValue> elementValues =
annotation.getElementValues();
for (ExecutableElement av: elementValues.keySet()) {
if (Objects.equals(av.getSimpleName().toString(), name)) {
AnnotationValue v = elementValues.get(av);
return v != null ? v.toString() : av.getDefaultValue().getValue().toString();
}
}
if (defValue != null) {
return defValue;
}
mEnv.getMessager().printMessage(Diagnostic.Kind.ERROR, "Can't find annotation value");
return null;
}
/**
* Generate code and add the methods to the specified class/method.
*/
public void addMethods(TypeSpec.Builder genClass, MethodSpec.Builder runBuilder,
ProcessingEnvironment env, Messager messager) {
// Validate types
for (int i = 0; i < mTypes.size(); i++) {
if (checkType(mTypes.get(i).toString(), messager)) {
messager.printMessage(Diagnostic.Kind.ERROR,
"Invalid type " + mTypes.get(i));
return;
}
}
if (!"androidx.remotecallback.RemoteCallback".equals(mReturnType.toString())) {
messager.printMessage(Diagnostic.Kind.ERROR,
"RemoteCallable methods must return RemoteCallback.LOCAL.");
return;
}
ClassName callbackHandlerRegistry = ClassName.get("androidx.remotecallback",
"CallbackHandlerRegistry");
ClassName callbackHandler = ClassName.get("androidx.remotecallback",
"CallbackHandlerRegistry.CallbackHandler");
ClassName remoteInputHolder = ClassName.get("androidx.remotecallback",
"RemoteInputHolder");
ClassName bundle = ClassName.get("android.os", "Bundle");
ClassName context = ClassName.get("android.content", "Context");
CodeBlock.Builder code = CodeBlock.builder();
String methodName = mElement.getSimpleName().toString();
code.add("$L.registerCallbackHandler($L.class, $S, ", callbackHandlerRegistry, mClsName,
methodName);
code.beginControlFlow("new $L<$L>()", callbackHandler, mClsName);
// Begin executeCallback implementation ------------------------------------------------
code.beginControlFlow(" public void executeCallback($L context, $L receiver, $L args)",
context, mClsName, bundle);
StringBuilder methodCall = new StringBuilder();
methodCall.append("receiver.");
methodCall.append(mElement.getSimpleName());
methodCall.append("(");
for (int i = 0; i < mNames.size(); i++) {
// Pass the parameter to the method call.
if (i != 0) {
methodCall.append(", ");
}
methodCall.append("p" + i);
if (mTypes.get(i).toString().equals(context.toString())) {
code.addStatement("$L p" + i + " = context", mTypes.get(i));
continue;
}
code.addStatement("$L p" + i, mTypes.get(i));
String key = mExtInputKeys.get(i) != null ? mExtInputKeys.get(i) : getBundleKey(i);
// Generate code to extract the value.
code.addStatement("p$L = $L", i, getBundleParam(mTypes.get(i).toString(), key));
}
methodCall.append(")");
// Add the method call as the last thing.
code.addStatement(methodCall.toString());
code.endControlFlow();
// End executeCallback implementation --------------------------------------------------
code.endControlFlow();
code.add(");\n");
runBuilder.addCode(code.build());
// Start assembleArguments implementation ----------------------------------------------
code = CodeBlock.builder();
ClassName remoteCallback = ClassName.get("androidx.remotecallback", "RemoteCallback");
MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName)
.returns(remoteCallback)
.addAnnotation(AnnotationSpec.builder(ClassName.bestGuess("Override")).build())
.addModifiers(Modifier.PUBLIC);
code.addStatement("$L b = new $L()", bundle, bundle);
for (int i = 0; i < mNames.size(); i++) {
builder.addParameter(TypeName.get(mTypes.get(i)), "p" + i);
if (mTypes.get(i).toString().equals(context.toString())) {
continue;
}
boolean isNative = isNative(mTypes.get(i).toString());
// Only fill in value if the argument has a value.
if (!isNative) code.beginControlFlow("if (p$L != null)", i);
// Otherwise just need to place the arg value.
code.addStatement("b.put$L($L, ($L) p$L)",
getTypeMethod(mTypes.get(i).toString()),
getBundleKey(i), mTypes.get(i), i);
// No value present, need an explicit null for security.
if (!isNative) code.nextControlFlow("else");
if (!isNative) code.addStatement("b.putString($L, null)", getBundleKey(i));
if (!isNative) code.endControlFlow();
}
code.addStatement(
"return androidx.remotecallback.CallbackHandlerRegistry.stubToRemoteCallback("
+ "this, $L.class, b, $S)",
mClsName, mElement.getSimpleName());
builder.addCode(code.build());
genClass.addMethod(builder.build());
}
private int countArgs(ClassName context) {
int ct = 0;
for (int i = 0; i < mTypes.size(); i++) {
if (mTypes.get(i).toString().equals(context.toString())) {
continue;
}
ct++;
}
return ct;
}
private String getBundleParam(String type, int index) {
String key = getBundleKey(index);
return getBundleParam(type, key);
}
private boolean isNative(String type) {
switch (type) {
case BYTE:
case CHAR:
case SHORT:
case INT:
case LONG:
case FLOAT:
case DOUBLE:
case BOOLEAN:
return true;
}
return false;
}
private String getBundleParam(String type, String key) {
switch (type) {
case BYTE:
return "args.getByte(" + key + ", (byte) 0)";
case CHAR:
return "args.getChar(" + key + ", (char) 0)";
case SHORT:
return "args.getShort(" + key + ", (short) 0)";
case INT:
return "args.getInt(" + key + ", 0)";
case LONG:
return "args.getLong(" + key + ", 0)";
case FLOAT:
return "args.getFloat(" + key + ", 0f)";
case DOUBLE:
return "args.getDouble(" + key + ", 0.0)";
case BOOLEAN:
return "args.getBoolean(" + key + ", false)";
}
return "(" + type + ") args.get(" + key + ")";
}
private String getTypeMethod(String type) {
switch (type) {
case BYTE:
return "Byte";
case CHAR:
return "Char";
case SHORT:
return "Short";
case INT:
return "Int";
case LONG:
return "Long";
case FLOAT:
return "Float";
case DOUBLE:
return "Double";
case BOOLEAN:
return "Boolean";
case STRING:
return "String";
case URI:
return "Parcelable";
case BYTE_ARRAY:
return "ByteArray";
case CHAR_ARRAY:
return "CharArray";
case SHORT_ARRAY:
return "ShortArray";
case INT_ARRAY:
return "IntArray";
case LONG_ARRAY:
return "LongArray";
case FLOAT_ARRAY:
return "FloatArray";
case DOUBLE_ARRAY:
return "DoubleArray";
case BOOLEAN_ARRAY:
return "BooleanArray";
case STRING_ARRAY:
return "StringArray";
case OBJ_BYTE:
return "Byte";
case CHARACTER:
return "Char";
case OBJ_SHORT:
return "Short";
case INTEGER:
return "Int";
case OBJ_LONG:
return "Long";
case OBJ_FLOAT:
return "Float";
case OBJ_DOUBLE:
return "Double";
case OBJ_BOOLEAN:
return "Boolean";
}
throw new RuntimeException("Invalid type " + type);
}
private String getBundleKey(int index) {
return "\"p" + index + "\"";
}
private boolean checkType(String type, Messager messager) {
switch (type) {
case BYTE:
case CHAR:
case SHORT:
case INT:
case LONG:
case FLOAT:
case DOUBLE:
case BOOLEAN:
case STRING:
case CONTEXT:
case BYTE_ARRAY:
case CHAR_ARRAY:
case SHORT_ARRAY:
case INT_ARRAY:
case LONG_ARRAY:
case FLOAT_ARRAY:
case DOUBLE_ARRAY:
case BOOLEAN_ARRAY:
case STRING_ARRAY:
case URI:
case OBJ_BYTE:
case CHARACTER:
case OBJ_SHORT:
case INTEGER:
case OBJ_LONG:
case OBJ_FLOAT:
case OBJ_DOUBLE:
case OBJ_BOOLEAN:
return false;
default:
return true;
}
}
}