TextClassification.java

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package androidx.textclassifier;

import android.annotation.SuppressLint;
import android.app.PendingIntent;
import android.app.RemoteAction;
import android.content.Context;
import android.graphics.drawable.Drawable;
import android.net.Uri;
import android.os.Build;
import android.os.Bundle;
import android.text.SpannableString;
import android.text.SpannableStringBuilder;
import android.text.TextUtils;
import android.util.Log;
import android.view.View;

import androidx.annotation.FloatRange;
import androidx.annotation.IntRange;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.RequiresApi;
import androidx.annotation.RestrictTo;
import androidx.collection.ArrayMap;
import androidx.core.app.RemoteActionCompat;
import androidx.core.graphics.drawable.IconCompat;
import androidx.core.os.LocaleListCompat;
import androidx.core.util.Preconditions;
import androidx.textclassifier.TextClassifier.EntityType;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;

/**
 * Information for generating a widget to handle classified text.
 *
 * <p>A TextClassification object contains icons, labels, and intents that may be used to build a
 * widget that can be used to act on classified text.
 *
 * <p>e.g. building a menu that allows the user how to act on a piece of text:
 *
 * <pre>{@code
 *   // Called preferably outside the UiThread.
 *   TextClassification classification = textClassifier.classifyText(allText, 10, 25);
 *
 *   // Called on the UiThread.
 *   for (RemoteActionCompat action : classification.getActions()) {
 *       MenuItem item = menu.add(action.getTitle());
 *       item.setContentDescription(action.getContentDescription());
 *       item.setOnMenuItemClickListener(v -> action.getActionIntent().send());
 *       if (action.shouldShowIcon()) {
 *           item.setIcon(action.getIcon().loadDrawable(context));
 *       }
 *   }
 * }</pre>
 */
public final class TextClassification {

    private static final String LOG_TAG = "TextClassification";

    private static final String EXTRA_TEXT = "text";
    private static final String EXTRA_ACTIONS = "actions";
    private static final String EXTRA_ENTITY_CONFIDENCE = "entity_conf";
    private static final String EXTRA_ID = "id";
    private static final String EXTRA_EXTRAS = "extras";
    private static final IconCompat NO_ICON =
            IconCompat.createWithData(new byte[0], 0, 0);

    /**
     * @hide
     */
    @RestrictTo(RestrictTo.Scope.LIBRARY)
    static final TextClassification EMPTY = new TextClassification.Builder().build();

    @Nullable private final String mText;
    @NonNull private final List<RemoteActionCompat> mActions;
    @NonNull private final EntityConfidence mEntityConfidence;
    @Nullable private final String mId;
    @NonNull private final Bundle mExtras;

    TextClassification(
            @Nullable String text,
            @NonNull List<RemoteActionCompat> actions,
            @NonNull EntityConfidence entityConfidence,
            @Nullable String id,
            @NonNull Bundle extras) {
        mText = text;
        mActions = actions;
        mEntityConfidence = entityConfidence;
        mId = id;
        mExtras = extras;
    }

    /**
     * Gets the classified text.
     */
    @Nullable
    public CharSequence getText() {
        return mText;
    }

    /**
     * Returns the number of entity types found in the classified text.
     */
    @IntRange(from = 0)
    public int getEntityTypeCount() {
        return mEntityConfidence.getEntities().size();
    }

    /**
     * Returns the entity type at the specified index. Entities are ordered from high confidence
     * to low confidence.
     *
     * @throws IndexOutOfBoundsException if the specified index is out of range.
     * @see #getEntityTypeCount() for the number of entities available.
     */
    @NonNull
    public @EntityType String getEntityType(int index) {
        return mEntityConfidence.getEntities().get(index);
    }

    /**
     * Returns the confidence score for the specified entity. The value ranges from
     * 0 (low confidence) to 1 (high confidence). 0 indicates that the entity was not found for the
     * classified text.
     */
    @FloatRange(from = 0.0, to = 1.0)
    public float getConfidenceScore(@EntityType String entity) {
        return mEntityConfidence.getConfidenceScore(entity);
    }

    /**
     * Returns a list of actions that may be performed on the text. The list is ordered based on
     * the likelihood that a user will use the action, with the most likely action appearing first.
     */
    @NonNull
    public List<RemoteActionCompat> getActions() {
        return mActions;
    }

    /**
     * Returns the id for this object.
     */
    @Nullable
    public String getId() {
        return mId;
    }

    /**
     * Returns the extended, vendor specific data.
     *
     * <p><b>NOTE: </b>Each call to this method returns a new bundle copy so clients should
     * prefer to hold a reference to the returned bundle rather than frequently calling this
     * method. Avoid updating the content of this bundle. On pre-O devices, the values in the
     * Bundle are not deep copied.
     */
    @NonNull
    public Bundle getExtras() {
        return BundleUtils.deepCopy(mExtras);
    }

    @Override
    public String toString() {
        return String.format(Locale.US,
                "TextClassification {text=%s, entities=%s, actions=%s, id=%s}",
                mText, mEntityConfidence, mActions, mId);
    }

    /**
     * Adds this classification to a Bundle that can be read back with the same parameters
     * to {@link #createFromBundle(Bundle)}.
     */
    @NonNull
    public Bundle toBundle() {
        final Bundle bundle = new Bundle();
        bundle.putString(EXTRA_TEXT, mText);
        BundleUtils.putRemoteActionList(bundle, EXTRA_ACTIONS, mActions);
        BundleUtils.putMap(bundle, EXTRA_ENTITY_CONFIDENCE, mEntityConfidence.getConfidenceMap());
        bundle.putString(EXTRA_ID, mId);
        bundle.putBundle(EXTRA_EXTRAS, mExtras);
        return bundle;
    }

    /**
     * Extracts a classification from a bundle that was added using {@link #toBundle()}.
     * @throws IllegalArgumentException
     */
    @NonNull
    public static TextClassification createFromBundle(@NonNull Bundle bundle) {
        final Builder builder = new Builder()
                .setText(bundle.getString(EXTRA_TEXT))
                .setId(bundle.getString(EXTRA_ID))
                .setExtras(bundle.getBundle(EXTRA_EXTRAS));
        for (Map.Entry<String, Float> entityConfidence : BundleUtils.getFloatStringMapOrThrow(
                bundle, EXTRA_ENTITY_CONFIDENCE).entrySet()) {
            builder.setEntityType(entityConfidence.getKey(), entityConfidence.getValue());
        }
        for (RemoteActionCompat action : BundleUtils.getRemoteActionListOrThrow(
                bundle, EXTRA_ACTIONS)) {
            builder.addAction(action);
        }
        return builder.build();
    }


    /**
     * @hide
     */
    @RestrictTo(RestrictTo.Scope.LIBRARY)
    @RequiresApi(26)
    @SuppressWarnings("deprecation") // To support O
    @NonNull
    static TextClassification fromPlatform(
            @NonNull Context context,
            @NonNull android.view.textclassifier.TextClassification textClassification) {
        Preconditions.checkNotNull(textClassification);

        Builder builder = new TextClassification.Builder()
                .setText(textClassification.getText());

        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
            builder.setId(textClassification.getId());
        }

        final int entityCount = textClassification.getEntityCount();
        for (int i = 0; i < entityCount; i++) {
            String entity = textClassification.getEntity(i);
            builder.setEntityType(entity, textClassification.getConfidenceScore(entity));
        }

        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
            List<RemoteAction> actions = textClassification.getActions();
            for (RemoteAction action : actions) {
                builder.addAction(RemoteActionCompat.createFromRemoteAction(action));
            }
        } else {
            if (textClassification.getIntent() != null
                    && !TextUtils.isEmpty(textClassification.getLabel())) {
                builder.addAction(createRemoteActionCompat(context, textClassification));
            }
        }
        return builder.build();
    }

    /**
     * Converts a given {@link TextClassification} object to a {@link RemoteActionCompat} object.
     * It is assumed that the intent and the label in the textclassification object are not null.
     */
    @RequiresApi(26)
    @SuppressWarnings("deprecation") //To support O
    @NonNull
    private static RemoteActionCompat createRemoteActionCompat(
            @NonNull Context context,
            @NonNull android.view.textclassifier.TextClassification textClassification) {
        PendingIntent pendingIntent =
                PendingIntent.getActivity(
                        context,
                        textClassification.getText().hashCode(),
                        textClassification.getIntent(),
                        PendingIntent.FLAG_UPDATE_CURRENT);

        Drawable drawable = textClassification.getIcon();
        CharSequence label = textClassification.getLabel();
        IconCompat icon;
        if (drawable == null) {
            // Placeholder, should never be shown.
            icon = NO_ICON;
        } else {
            icon = ConvertUtils.createIconFromDrawable(textClassification.getIcon());
        }
        RemoteActionCompat remoteAction = new RemoteActionCompat(icon, label, label, pendingIntent);
        remoteAction.setShouldShowIcon(drawable != null);
        return remoteAction;
    }

    /**
     * @hide
     */
    // Lint does not know @EntityType in platform and here are same.
    @SuppressWarnings("deprecation") // To support O
    @RestrictTo(RestrictTo.Scope.LIBRARY)
    @SuppressLint("WrongConstant")
    @RequiresApi(26)
    @NonNull
    Object toPlatform(@NonNull Context context) {
        Preconditions.checkNotNull(context);

        android.view.textclassifier.TextClassification.Builder builder =
                new android.view.textclassifier.TextClassification.Builder()
                        .setText(getText() == null ? null : getText().toString());

        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
            builder.setId(getId());
        }

        final int entityCount = getEntityTypeCount();
        for (int i = 0; i < entityCount; i++) {
            String entity = getEntityType(i);
            builder.setEntityType(entity, getConfidenceScore(entity));
        }

        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
            List<RemoteActionCompat> actions = getActions();
            for (RemoteActionCompat action : actions) {
                builder.addAction(action.toRemoteAction());
            }
        }

        if (!getActions().isEmpty()) {
            final RemoteActionCompat firstAction = getActions().get(0);
            builder.setLabel(firstAction.getTitle().toString())
                    .setIcon(firstAction.getIcon().loadDrawable(context))
                    .setOnClickListener(new View.OnClickListener() {
                        @Override
                        public void onClick(View view) {
                            try {
                                firstAction.getActionIntent().send();
                            } catch (PendingIntent.CanceledException e) {
                                Log.e(TextClassifier.DEFAULT_LOG_TAG, "Failed to start action ", e);
                            }
                        }
                    });
        }

        return builder.build();
    }

    /**
     * Builder for building {@link TextClassification} objects.
     *
     * <p>e.g.
     *
     * <pre>{@code
     *   TextClassification classification = new TextClassification.Builder()
     *          .setText(classifiedText)
     *          .setEntityType(TextClassifier.TYPE_EMAIL, 0.9)
     *          .setEntityType(TextClassifier.TYPE_OTHER, 0.1)
     *          .addAction(remoteAction1)
     *          .addAction(remoteAction2)
     *          .build();
     * }</pre>
     */
    public static final class Builder {

        @Nullable private String mText;
        @NonNull private List<RemoteActionCompat> mActions = new ArrayList<>();
        @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>();
        @Nullable private String mId;
        @Nullable private Bundle mExtras;

        /**
         * Sets the classified text.
         */
        public Builder setText(@Nullable CharSequence text) {
            mText = text == null ? null : text.toString();
            return this;
        }

        /**
         * Sets an entity type for the classification result and assigns a confidence score.
         * If a confidence score had already been set for the specified entity type, this will
         * override that score.
         *
         * @param confidenceScore a value from 0 (low confidence) to 1 (high confidence).
         *      0 implies the entity does not exist for the classified text.
         *      Values greater than 1 are clamped to 1.
         */
        public Builder setEntityType(
                @NonNull @EntityType String type,
                @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
            mEntityConfidence.put(type, confidenceScore);
            return this;
        }

        /**
         * Adds an action that may be performed on the classified text. Actions should be added in
         * order of likelihood that the user will use them, with the most likely action being added
         * first.
         */
        @NonNull
        public Builder addAction(@NonNull RemoteActionCompat action) {
            Preconditions.checkArgument(action != null);
            mActions.add(action);
            return this;
        }

        /**
         * Sets an id for the TextClassification object.
         */
        @NonNull
        public Builder setId(@Nullable String id) {
            mId = id;
            return this;
        }

        /**
         * Sets the extended, vendor specific data.
         */
        @NonNull
        public Builder setExtras(@Nullable Bundle extras) {
            mExtras = extras;
            return this;
        }

        /**
         * Builds and returns a {@link TextClassification} object.
         */
        @NonNull
        public TextClassification build() {
            return new TextClassification(
                    mText, mActions, new EntityConfidence(mEntityConfidence), mId,
                    mExtras == null ? Bundle.EMPTY : BundleUtils.deepCopy(mExtras));
        }
    }

    /**
     * A request object for generating TextClassification.
     */
    public static final class Request {
        private static final String EXTRA_TEXT = "text";
        private static final String EXTRA_START_INDEX = "start";
        private static final String EXTRA_END_INDEX = "end";
        private static final String EXTRA_DEFAULT_LOCALES = "locales";
        private static final String EXTRA_REFERENCE_TIME = "reftime";

        private final CharSequence mText;
        private final int mStartIndex;
        private final int mEndIndex;
        @Nullable private final LocaleListCompat mDefaultLocales;
        @Nullable private final Long mReferenceTime;
        @NonNull private final Bundle mExtras;

        Request(
                CharSequence text,
                int startIndex,
                int endIndex,
                LocaleListCompat defaultLocales,
                Long referenceTime,
                Bundle extras) {
            mText = text;
            mStartIndex = startIndex;
            mEndIndex = endIndex;
            mDefaultLocales = defaultLocales;
            mReferenceTime = referenceTime;
            mExtras = extras;
        }

        /**
         * Returns the text providing context for the text to classify (which is specified
         *      by the sub sequence starting at startIndex and ending at endIndex)
         */
        @NonNull
        public CharSequence getText() {
            return mText;
        }

        /**
         * Returns start index of the text to classify.
         */
        @IntRange(from = 0)
        public int getStartIndex() {
            return mStartIndex;
        }

        /**
         * Returns end index of the text to classify.
         */
        @IntRange(from = 0)
        public int getEndIndex() {
            return mEndIndex;
        }

        /**
         * @return ordered list of locale preferences that can be used to disambiguate
         *      the provided text.
         */
        @Nullable
        public LocaleListCompat getDefaultLocales() {
            return mDefaultLocales;
        }

        /**
         * @return reference time based on which relative dates (e.g. "tomorrow") should be
         *      interpreted. This should be milliseconds from the epoch of
         *      1970-01-01T00:00:00Z(UTC timezone).
         */
        @Nullable
        public Long getReferenceTime() {
            return mReferenceTime;
        }

        /**
         * Returns the extended, vendor specific data.
         *
         * <p><b>NOTE: </b>Each call to this method returns a new bundle copy so clients should
         * prefer to hold a reference to the returned bundle rather than frequently calling this
         * method. Avoid updating the content of this bundle. On pre-O devices, the values in the
         * Bundle are not deep copied.
         */
        @NonNull
        public Bundle getExtras() {
            return BundleUtils.deepCopy(mExtras);
        }

        /**
         * @hide
         */
        @RestrictTo(RestrictTo.Scope.LIBRARY)
        @RequiresApi(28)
        @NonNull
        static TextClassification.Request fromPlatform(
                @NonNull android.view.textclassifier.TextClassification.Request request) {
            return new TextClassification.Request.Builder(
                    request.getText(), request.getStartIndex(), request.getEndIndex())
                    .setReferenceTime(ConvertUtils.zonedDateTimeToUtcMs(request.getReferenceTime()))
                    .setDefaultLocales(ConvertUtils.wrapLocalList(request.getDefaultLocales()))
                    .build();
        }

        /**
         * @hide
         */
        @RestrictTo(RestrictTo.Scope.LIBRARY)
        @RequiresApi(28)
        @NonNull
        Object toPlatform() {
            return new android.view.textclassifier.TextClassification.Request.Builder(
                    mText, mStartIndex, mEndIndex)
                    .setDefaultLocales(ConvertUtils.unwrapLocalListCompat(getDefaultLocales()))
                    .setReferenceTime(ConvertUtils.createZonedDateTimeFromUtc(mReferenceTime))
                    .build();
        }

        /**
         * A builder for building TextClassification requests.
         */
        public static final class Builder {

            private final CharSequence mText;
            private final int mStartIndex;
            private final int mEndIndex;
            private Bundle mExtras;

            @Nullable private LocaleListCompat mDefaultLocales;
            @Nullable private Long mReferenceTime = null;

            /**
             * @param text text providing context for the text to classify (which is specified
             *      by the sub sequence starting at startIndex and ending at endIndex)
             * @param startIndex start index of the text to classify
             * @param endIndex end index of the text to classify
             */
            public Builder(
                    @NonNull CharSequence text,
                    @IntRange(from = 0) int startIndex,
                    @IntRange(from = 0) int endIndex) {
                Preconditions.checkArgument(text != null);
                Preconditions.checkArgument(startIndex >= 0);
                Preconditions.checkArgument(endIndex <= text.length());
                Preconditions.checkArgument(endIndex > startIndex);
                mText = text;
                mStartIndex = startIndex;
                mEndIndex = endIndex;
            }

            /**
             * @param defaultLocales ordered list of locale preferences that may be used to
             *      disambiguate the provided text. If no locale preferences exist, set this to null
             *      or an empty locale list.
             *
             * @return this builder
             */
            @NonNull
            public Builder setDefaultLocales(@Nullable LocaleListCompat defaultLocales) {
                mDefaultLocales = defaultLocales;
                return this;
            }

            /**
             * @param referenceTime reference time based on which relative dates (e.g. "tomorrow")
             *      should be interpreted. This should usually be the time when the text was
             *      originally composed and should be milliseconds from the epoch of
             *      1970-01-01T00:00:00Z(UTC timezone). For example, if there is a message saying
             *      "see you 10 days later", and the message was composed yesterday, text classifier
             *      will then realize it is indeed means 9 days later from now and classify the text
             *      accordingly. If no reference time is set, now is used.
             *
             * @return this builder
             */
            @NonNull
            public Builder setReferenceTime(@Nullable Long referenceTime) {
                mReferenceTime = referenceTime;
                return this;
            }

            /**
             * Sets the extended, vendor specific data.
             *
             * @return this builder
             */
            @NonNull
            public Builder setExtras(@Nullable Bundle extras) {
                mExtras = extras;
                return this;
            }

            /**
             * Builds and returns the request object.
             */
            @NonNull
            public Request build() {
                return new Request(
                        normalizeIfUri(mText, mStartIndex, mEndIndex),
                        mStartIndex, mEndIndex, mDefaultLocales, mReferenceTime,
                        mExtras == null ? Bundle.EMPTY : BundleUtils.deepCopy(mExtras));
            }

            // Ensures the package manager can recognize a url scheme that is not all lowercase.
            // b/123640937
            @Nullable
            private static CharSequence normalizeIfUri(
                    CharSequence text, int startIndex, int endIndex) {
                try {
                    // TODO: Skip if running Android Q.
                    final Uri uri = Uri.parse(text.subSequence(startIndex, endIndex).toString());
                    final String scheme = uri.getScheme();
                    final String lower = scheme == null ? null : scheme.toLowerCase(Locale.ROOT);
                    if (lower != null && !scheme.equals(lower)) {
                        final String normalized = uri.buildUpon().scheme(lower).build().toString();
                        if (normalized.length() == (endIndex - startIndex)) {
                            return new SpannableString(
                                    new SpannableStringBuilder(text)
                                            .replace(startIndex, endIndex, normalized));
                        }
                    }
                } catch (Exception e) {
                    // Catching to ensure no crashes from this method.
                    Log.e(LOG_TAG, "Error fixing uri scheme", e);
                }
                return text;
            }
        }

        /**
         * Adds this Request to a Bundle that can be read back with the same parameters
         * to {@link #createFromBundle(Bundle)}.
         */
        @NonNull
        public Bundle toBundle() {
            final Bundle bundle = new Bundle();
            bundle.putCharSequence(EXTRA_TEXT, mText);
            bundle.putInt(EXTRA_START_INDEX, mStartIndex);
            bundle.putInt(EXTRA_END_INDEX, mEndIndex);
            BundleUtils.putLocaleList(bundle, EXTRA_DEFAULT_LOCALES, mDefaultLocales);
            BundleUtils.putLong(bundle, EXTRA_REFERENCE_TIME, mReferenceTime);
            bundle.putBundle(EXTRA_EXTRAS, mExtras);
            return bundle;
        }

        /**
         * Extracts a Request from a bundle that was added using {@link #toBundle()}.
         */
        public static Request createFromBundle(@NonNull Bundle bundle) {
            final Builder builder = new Builder(
                    bundle.getCharSequence(EXTRA_TEXT),
                    bundle.getInt(EXTRA_START_INDEX),
                    bundle.getInt(EXTRA_END_INDEX))
                    .setDefaultLocales(BundleUtils.getLocaleList(bundle, EXTRA_DEFAULT_LOCALES))
                    .setReferenceTime(BundleUtils.getLong(bundle, EXTRA_REFERENCE_TIME))
                    .setExtras(bundle.getBundle(EXTRA_EXTRAS));
            return builder.build();
        }
    }
}