FakeExoMediaDrm.java

/*
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package androidx.media3.test.utils;

import android.media.DeniedByServerException;
import android.media.MediaCryptoException;
import android.media.MediaDrmException;
import android.media.NotProvisionedException;
import android.media.ResourceBusyException;
import android.os.Parcel;
import android.os.Parcelable;
import android.os.PersistableBundle;
import androidx.annotation.Nullable;
import androidx.annotation.RequiresApi;
import androidx.media3.common.C;
import androidx.media3.common.DrmInitData;
import androidx.media3.common.util.Assertions;
import androidx.media3.common.util.UnstableApi;
import androidx.media3.common.util.Util;
import androidx.media3.decoder.CryptoConfig;
import androidx.media3.exoplayer.drm.ExoMediaDrm;
import androidx.media3.exoplayer.drm.MediaDrmCallback;
import androidx.media3.exoplayer.drm.MediaDrmCallbackException;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.Bytes;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * A fake implementation of {@link ExoMediaDrm} for use in tests.
 *
 * <p>{@link LicenseServer} can be used to respond to interactions stemming from {@link
 * #getKeyRequest(byte[], List, int, HashMap)} and {@link #provideKeyResponse(byte[], byte[])}.
 *
 * <p>Currently only supports streaming key requests.
 */
// TODO: Consider replacing this with a Robolectric ShadowMediaDrm so we can use a real
//  FrameworkMediaDrm.
@RequiresApi(29)
@UnstableApi
public final class FakeExoMediaDrm implements ExoMediaDrm {

  /** Builder for {@link FakeExoMediaDrm} instances. */
  public static class Builder {
    private boolean enforceValidKeyResponses;
    private int provisionsRequired;
    private boolean throwNotProvisionedExceptionFromGetKeyRequest;
    private int maxConcurrentSessions;

    /** Constructs an instance. */
    public Builder() {
      enforceValidKeyResponses = true;
      provisionsRequired = 0;
      maxConcurrentSessions = Integer.MAX_VALUE;
    }

    /**
     * Sets whether key responses passed to {@link #provideKeyResponse(byte[], byte[])} should be
     * checked for validity (i.e. that they came from a {@link LicenseServer}).
     *
     * <p>Defaults to true.
     */
    public Builder setEnforceValidKeyResponses(boolean enforceValidKeyResponses) {
      this.enforceValidKeyResponses = enforceValidKeyResponses;
      return this;
    }

    /**
     * Sets how many successful provisioning round trips are needed for the {@link FakeExoMediaDrm}
     * to be provisioned.
     *
     * <p>An unprovisioned {@link FakeExoMediaDrm} will throw {@link NotProvisionedException} from
     * methods that declare it until enough valid provisioning responses are passed to {@link
     * FakeExoMediaDrm#provideProvisionResponse(byte[])}.
     *
     * <p>Defaults to 0 (i.e. device is already provisioned).
     */
    public Builder setProvisionsRequired(int provisionsRequired) {
      this.provisionsRequired = provisionsRequired;
      return this;
    }

    /**
     * Configures the {@link FakeExoMediaDrm} to throw any {@link NotProvisionedException} from
     * {@link #getKeyRequest(byte[], List, int, HashMap)} instead of the default behaviour of
     * throwing from {@link #openSession()}.
     */
    public Builder throwNotProvisionedExceptionFromGetKeyRequest() {
      this.throwNotProvisionedExceptionFromGetKeyRequest = true;
      return this;
    }

    /**
     * Sets the maximum number of concurrent sessions the {@link FakeExoMediaDrm} will support.
     *
     * <p>If this is exceeded then subsequent calls to {@link FakeExoMediaDrm#openSession()} will
     * throw {@link ResourceBusyException}.
     *
     * <p>Defaults to {@link Integer#MAX_VALUE}.
     */
    public Builder setMaxConcurrentSessions(int maxConcurrentSessions) {
      this.maxConcurrentSessions = maxConcurrentSessions;
      return this;
    }

    /**
     * Returns a {@link FakeExoMediaDrm} instance with an initial reference count of 1. The caller
     * is responsible for calling {@link FakeExoMediaDrm#release()} when they no longer need the
     * instance.
     */
    public FakeExoMediaDrm build() {
      return new FakeExoMediaDrm(
          enforceValidKeyResponses,
          provisionsRequired,
          throwNotProvisionedExceptionFromGetKeyRequest,
          maxConcurrentSessions);
    }
  }

  public static final ProvisionRequest FAKE_PROVISION_REQUEST =
      new ProvisionRequest(TestUtil.createByteArray(7, 8, 9), "bar.test");
  public static final ImmutableList<Byte> VALID_PROVISION_RESPONSE =
      TestUtil.createByteList(4, 5, 6);

  /** Key for use with the Map returned from {@link FakeExoMediaDrm#queryKeyStatus(byte[])}. */
  public static final String KEY_STATUS_KEY = "KEY_STATUS";
  /** Value for use with the Map returned from {@link FakeExoMediaDrm#queryKeyStatus(byte[])}. */
  public static final String KEY_STATUS_AVAILABLE = "AVAILABLE";
  /** Value for use with the Map returned from {@link FakeExoMediaDrm#queryKeyStatus(byte[])}. */
  public static final String KEY_STATUS_UNAVAILABLE = "UNAVAILABLE";

  private static final ImmutableList<Byte> VALID_KEY_RESPONSE = TestUtil.createByteList(1, 2, 3);
  private static final ImmutableList<Byte> KEY_DENIED_RESPONSE = TestUtil.createByteList(9, 8, 7);
  private static final ImmutableList<Byte> PROVISIONING_REQUIRED_RESPONSE =
      TestUtil.createByteList(4, 5, 6);

  private final boolean enforceValidKeyResponses;
  private final int provisionsRequired;
  private final int maxConcurrentSessions;
  private final boolean throwNotProvisionedExceptionFromGetKeyRequest;
  private final Map<String, byte[]> byteProperties;
  private final Map<String, String> stringProperties;
  private final Set<List<Byte>> openSessionIds;
  private final Set<List<Byte>> sessionIdsWithValidKeys;
  private final AtomicInteger sessionIdGenerator;

  private int provisionsReceived;
  private int referenceCount;
  @Nullable private OnEventListener onEventListener;

  /**
   * @deprecated Use {@link Builder} instead.
   */
  @Deprecated
  public FakeExoMediaDrm() {
    this(/* maxConcurrentSessions= */ Integer.MAX_VALUE);
  }

  /**
   * @deprecated Use {@link Builder} instead.
   */
  @Deprecated
  public FakeExoMediaDrm(int maxConcurrentSessions) {
    this(
        /* enforceValidKeyResponses= */ true,
        /* provisionsRequired= */ 0,
        /* throwNotProvisionedExceptionFromGetKeyRequest= */ false,
        maxConcurrentSessions);
  }

  private FakeExoMediaDrm(
      boolean enforceValidKeyResponses,
      int provisionsRequired,
      boolean throwNotProvisionedExceptionFromGetKeyRequest,
      int maxConcurrentSessions) {
    this.enforceValidKeyResponses = enforceValidKeyResponses;
    this.provisionsRequired = provisionsRequired;
    this.maxConcurrentSessions = maxConcurrentSessions;
    this.throwNotProvisionedExceptionFromGetKeyRequest =
        throwNotProvisionedExceptionFromGetKeyRequest;
    byteProperties = new HashMap<>();
    stringProperties = new HashMap<>();
    openSessionIds = new HashSet<>();
    sessionIdsWithValidKeys = new HashSet<>();
    sessionIdGenerator = new AtomicInteger();

    referenceCount = 1;
  }

  // ExoMediaDrm implementation

  @Override
  public void setOnEventListener(@Nullable OnEventListener listener) {
    this.onEventListener = listener;
  }

  @Override
  public void setOnKeyStatusChangeListener(@Nullable OnKeyStatusChangeListener listener) {
    // Do nothing.
  }

  @Override
  public void setOnExpirationUpdateListener(@Nullable OnExpirationUpdateListener listener) {
    // Do nothing.
  }

  @Override
  public byte[] openSession() throws MediaDrmException {
    Assertions.checkState(referenceCount > 0);
    if (!throwNotProvisionedExceptionFromGetKeyRequest && provisionsReceived < provisionsRequired) {
      throw new NotProvisionedException("Not provisioned.");
    }
    if (openSessionIds.size() >= maxConcurrentSessions) {
      throw new ResourceBusyException("Too many sessions open. max=" + maxConcurrentSessions);
    }
    byte[] sessionId =
        TestUtil.buildTestData(/* length= */ 10, sessionIdGenerator.incrementAndGet());
    if (!openSessionIds.add(toByteList(sessionId))) {
      throw new MediaDrmException(
          Util.formatInvariant(
              "Generated sessionId[%s] clashes with already-open session",
              sessionIdGenerator.get()));
    }
    return sessionId;
  }

  @Override
  public void closeSession(byte[] sessionId) {
    Assertions.checkState(referenceCount > 0);
    // TODO: Store closed session IDs too?
    Assertions.checkState(openSessionIds.remove(toByteList(sessionId)));
  }

  @Override
  public KeyRequest getKeyRequest(
      byte[] scope,
      @Nullable List<DrmInitData.SchemeData> schemeDatas,
      int keyType,
      @Nullable HashMap<String, String> optionalParameters)
      throws NotProvisionedException {
    Assertions.checkState(referenceCount > 0);
    if (keyType == KEY_TYPE_OFFLINE || keyType == KEY_TYPE_RELEASE) {
      throw new UnsupportedOperationException("Offline key requests are not supported.");
    }
    Assertions.checkArgument(keyType == KEY_TYPE_STREAMING, "Unrecognised keyType: " + keyType);
    if (throwNotProvisionedExceptionFromGetKeyRequest && provisionsReceived < provisionsRequired) {
      throw new NotProvisionedException("Not provisioned.");
    }
    Assertions.checkState(openSessionIds.contains(toByteList(scope)));
    Assertions.checkNotNull(schemeDatas);
    KeyRequestData requestData =
        new KeyRequestData(
            schemeDatas,
            keyType,
            optionalParameters != null ? optionalParameters : ImmutableMap.of());
    @KeyRequest.RequestType
    int requestType =
        sessionIdsWithValidKeys.contains(toByteList(scope))
            ? KeyRequest.REQUEST_TYPE_RENEWAL
            : KeyRequest.REQUEST_TYPE_INITIAL;
    return new KeyRequest(requestData.toByteArray(), /* licenseServerUrl= */ "", requestType);
  }

  @Override
  public byte[] provideKeyResponse(byte[] scope, byte[] response)
      throws NotProvisionedException, DeniedByServerException {
    Assertions.checkState(referenceCount > 0);
    List<Byte> responseAsList = Bytes.asList(response);
    if (responseAsList.equals(KEY_DENIED_RESPONSE)) {
      throw new DeniedByServerException("Key request denied");
    }
    if (responseAsList.equals(PROVISIONING_REQUIRED_RESPONSE)) {
      throw new NotProvisionedException("Provisioning required");
    }
    if (enforceValidKeyResponses && !responseAsList.equals(VALID_KEY_RESPONSE)) {
      throw new IllegalArgumentException(
          "Unrecognised response. scope="
              + Util.toHexString(scope)
              + ", response="
              + Util.toHexString(response));
    }
    sessionIdsWithValidKeys.add(Bytes.asList(scope));
    return Util.EMPTY_BYTE_ARRAY;
  }

  @Override
  public ProvisionRequest getProvisionRequest() {
    Assertions.checkState(referenceCount > 0);
    return FAKE_PROVISION_REQUEST;
  }

  @Override
  public void provideProvisionResponse(byte[] response) throws DeniedByServerException {
    Assertions.checkState(referenceCount > 0);
    if (Bytes.asList(response).equals(VALID_PROVISION_RESPONSE)) {
      provisionsReceived++;
    }
  }

  @Override
  public Map<String, String> queryKeyStatus(byte[] sessionId) {
    Assertions.checkState(referenceCount > 0);
    Assertions.checkState(openSessionIds.contains(toByteList(sessionId)));
    return ImmutableMap.of(
        KEY_STATUS_KEY,
        sessionIdsWithValidKeys.contains(toByteList(sessionId))
            ? KEY_STATUS_AVAILABLE
            : KEY_STATUS_UNAVAILABLE);
  }

  @Override
  public boolean requiresSecureDecoder(byte[] sessionId, String mimeType) {
    return false;
  }

  @Override
  public void acquire() {
    Assertions.checkState(referenceCount > 0);
    referenceCount++;
  }

  @Override
  public void release() {
    referenceCount--;
  }

  @Override
  public void restoreKeys(byte[] sessionId, byte[] keySetId) {
    throw new UnsupportedOperationException();
  }

  @Nullable
  @Override
  public PersistableBundle getMetrics() {
    Assertions.checkState(referenceCount > 0);

    return null;
  }

  @Override
  public String getPropertyString(String propertyName) {
    Assertions.checkState(referenceCount > 0);
    @Nullable String value = stringProperties.get(propertyName);
    if (value == null) {
      throw new IllegalArgumentException("Unrecognized propertyName: " + propertyName);
    }
    return value;
  }

  @Override
  public byte[] getPropertyByteArray(String propertyName) {
    Assertions.checkState(referenceCount > 0);
    @Nullable byte[] value = byteProperties.get(propertyName);
    if (value == null) {
      throw new IllegalArgumentException("Unrecognized propertyName: " + propertyName);
    }
    return value;
  }

  @Override
  public void setPropertyString(String propertyName, String value) {
    Assertions.checkState(referenceCount > 0);
    stringProperties.put(propertyName, value);
  }

  @Override
  public void setPropertyByteArray(String propertyName, byte[] value) {
    Assertions.checkState(referenceCount > 0);
    byteProperties.put(propertyName, value);
  }

  @Override
  public CryptoConfig createCryptoConfig(byte[] sessionId) throws MediaCryptoException {
    Assertions.checkState(referenceCount > 0);
    Assertions.checkState(openSessionIds.contains(toByteList(sessionId)));
    return new FakeCryptoConfig();
  }

  @Override
  public @C.CryptoType int getCryptoType() {
    return FakeCryptoConfig.TYPE;
  }

  // Methods to facilitate testing

  public int getReferenceCount() {
    return referenceCount;
  }

  /**
   * Calls {@link OnEventListener#onEvent(ExoMediaDrm, byte[], int, int, byte[])} on the attached
   * listener (if present) once for each open session ID which passes {@code sessionIdPredicate},
   * passing the provided values for {@code event}, {@code extra} and {@code data}.
   */
  public void triggerEvent(
      Predicate<byte[]> sessionIdPredicate, int event, int extra, @Nullable byte[] data) {
    @Nullable OnEventListener onEventListener = this.onEventListener;
    if (onEventListener == null) {
      return;
    }
    for (List<Byte> sessionId : openSessionIds) {
      byte[] sessionIdArray = Bytes.toArray(sessionId);
      if (sessionIdPredicate.apply(sessionIdArray)) {
        onEventListener.onEvent(this, sessionIdArray, event, extra, data);
      }
    }
  }

  /**
   * Resets the provisioning state of this instance, so it requires {@link
   * Builder#setProvisionsRequired(int) provisionsRequired} (possibly zero) provision operations
   * before it's operational again.
   */
  public void resetProvisioning() {
    provisionsReceived = 0;
  }

  private static ImmutableList<Byte> toByteList(byte[] byteArray) {
    return ImmutableList.copyOf(Bytes.asList(byteArray));
  }

  /** An license server implementation to interact with {@link FakeExoMediaDrm}. */
  public static class LicenseServer implements MediaDrmCallback {

    private final ImmutableSet<ImmutableList<DrmInitData.SchemeData>> allowedSchemeDatas;

    private final List<ImmutableList<Byte>> receivedProvisionRequests;
    private final List<ImmutableList<DrmInitData.SchemeData>> receivedSchemeDatas;

    private boolean nextResponseIndicatesProvisioningRequired;

    @SafeVarargs
    public static LicenseServer allowingSchemeDatas(List<DrmInitData.SchemeData>... schemeDatas) {
      ImmutableSet.Builder<ImmutableList<DrmInitData.SchemeData>> schemeDatasBuilder =
          ImmutableSet.builder();
      for (List<DrmInitData.SchemeData> schemeData : schemeDatas) {
        schemeDatasBuilder.add(ImmutableList.copyOf(schemeData));
      }
      return new LicenseServer(schemeDatasBuilder.build());
    }

    @SafeVarargs
    public static LicenseServer requiringProvisioningThenAllowingSchemeDatas(
        List<DrmInitData.SchemeData>... schemeDatas) {
      ImmutableSet.Builder<ImmutableList<DrmInitData.SchemeData>> schemeDatasBuilder =
          ImmutableSet.builder();
      for (List<DrmInitData.SchemeData> schemeData : schemeDatas) {
        schemeDatasBuilder.add(ImmutableList.copyOf(schemeData));
      }
      LicenseServer licenseServer = new LicenseServer(schemeDatasBuilder.build());
      licenseServer.nextResponseIndicatesProvisioningRequired = true;
      return licenseServer;
    }

    private LicenseServer(ImmutableSet<ImmutableList<DrmInitData.SchemeData>> allowedSchemeDatas) {
      this.allowedSchemeDatas = allowedSchemeDatas;

      receivedProvisionRequests = new ArrayList<>();
      receivedSchemeDatas = new ArrayList<>();
    }

    public ImmutableList<ImmutableList<Byte>> getReceivedProvisionRequests() {
      return ImmutableList.copyOf(receivedProvisionRequests);
    }

    public ImmutableList<ImmutableList<DrmInitData.SchemeData>> getReceivedSchemeDatas() {
      return ImmutableList.copyOf(receivedSchemeDatas);
    }

    @Override
    public byte[] executeProvisionRequest(UUID uuid, ProvisionRequest request)
        throws MediaDrmCallbackException {
      receivedProvisionRequests.add(ImmutableList.copyOf(Bytes.asList(request.getData())));
      if (Arrays.equals(request.getData(), FAKE_PROVISION_REQUEST.getData())) {
        return Bytes.toArray(VALID_PROVISION_RESPONSE);
      } else {
        return Util.EMPTY_BYTE_ARRAY;
      }
    }

    @Override
    public byte[] executeKeyRequest(UUID uuid, KeyRequest request)
        throws MediaDrmCallbackException {
      ImmutableList<DrmInitData.SchemeData> schemeDatas =
          KeyRequestData.fromByteArray(request.getData()).schemeDatas;
      receivedSchemeDatas.add(schemeDatas);

      ImmutableList<Byte> response;
      if (nextResponseIndicatesProvisioningRequired) {
        nextResponseIndicatesProvisioningRequired = false;
        response = PROVISIONING_REQUIRED_RESPONSE;
      } else if (allowedSchemeDatas.contains(schemeDatas)) {
        response = VALID_KEY_RESPONSE;
      } else {
        response = KEY_DENIED_RESPONSE;
      }
      return Bytes.toArray(response);
    }
  }

  /**
   * A structured set of key request fields that can be serialized into bytes by {@link
   * #getKeyRequest(byte[], List, int, HashMap)} and then deserialized by {@link
   * LicenseServer#executeKeyRequest(UUID, KeyRequest)}.
   */
  private static class KeyRequestData implements Parcelable {
    public final ImmutableList<DrmInitData.SchemeData> schemeDatas;
    public final int type;
    public final ImmutableMap<String, String> optionalParameters;

    public KeyRequestData(
        List<DrmInitData.SchemeData> schemeDatas,
        int type,
        Map<String, String> optionalParameters) {
      this.schemeDatas = ImmutableList.copyOf(schemeDatas);
      this.type = type;
      this.optionalParameters = ImmutableMap.copyOf(optionalParameters);
    }

    public KeyRequestData(Parcel in) {
      this.schemeDatas =
          ImmutableList.copyOf(
              in.readParcelableList(
                  new ArrayList<>(), DrmInitData.SchemeData.class.getClassLoader()));
      this.type = in.readInt();

      ImmutableMap.Builder<String, String> optionalParameters = new ImmutableMap.Builder<>();
      List<String> optionalParameterKeys = Assertions.checkNotNull(in.createStringArrayList());
      List<String> optionalParameterValues = Assertions.checkNotNull(in.createStringArrayList());
      Assertions.checkArgument(optionalParameterKeys.size() == optionalParameterValues.size());
      for (int i = 0; i < optionalParameterKeys.size(); i++) {
        optionalParameters.put(optionalParameterKeys.get(i), optionalParameterValues.get(i));
      }

      this.optionalParameters = optionalParameters.buildOrThrow();
    }

    public byte[] toByteArray() {
      Parcel parcel = Parcel.obtain();
      try {
        writeToParcel(parcel, /* flags= */ 0);
        return parcel.marshall();
      } finally {
        parcel.recycle();
      }
    }

    public static KeyRequestData fromByteArray(byte[] bytes) {
      Parcel parcel = Parcel.obtain();
      try {
        parcel.unmarshall(bytes, 0, bytes.length);
        parcel.setDataPosition(0);
        return CREATOR.createFromParcel(parcel);
      } finally {
        parcel.recycle();
      }
    }

    @Override
    public boolean equals(@Nullable Object obj) {
      if (!(obj instanceof KeyRequestData)) {
        return false;
      }

      KeyRequestData that = (KeyRequestData) obj;
      return Objects.equals(this.schemeDatas, that.schemeDatas)
          && this.type == that.type
          && Objects.equals(this.optionalParameters, that.optionalParameters);
    }

    @Override
    public int hashCode() {
      return Objects.hash(schemeDatas, type, optionalParameters);
    }

    // Parcelable implementation.

    @Override
    public int describeContents() {
      return 0;
    }

    @Override
    public void writeToParcel(Parcel dest, int flags) {
      dest.writeParcelableList(schemeDatas, flags);
      dest.writeInt(type);
      dest.writeStringList(optionalParameters.keySet().asList());
      dest.writeStringList(optionalParameters.values().asList());
    }

    public static final Parcelable.Creator<KeyRequestData> CREATOR =
        new Parcelable.Creator<KeyRequestData>() {

          @Override
          public KeyRequestData createFromParcel(Parcel in) {
            return new KeyRequestData(in);
          }

          @Override
          public KeyRequestData[] newArray(int size) {
            return new KeyRequestData[size];
          }
        };
  }
}