FakeTrackOutput.java

/*
 * Copyright (C) 2016 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package androidx.media3.test.utils;

import static com.google.common.truth.Truth.assertThat;

import androidx.annotation.Nullable;
import androidx.media3.common.C;
import androidx.media3.common.DataReader;
import androidx.media3.common.Format;
import androidx.media3.common.util.Assertions;
import androidx.media3.common.util.ParsableByteArray;
import androidx.media3.common.util.UnstableApi;
import androidx.media3.common.util.Util;
import androidx.media3.extractor.TrackOutput;
import androidx.media3.test.utils.Dumper.Dumpable;
import com.google.common.primitives.Bytes;
import java.io.EOFException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/** A fake {@link TrackOutput}. */
@UnstableApi
public final class FakeTrackOutput implements TrackOutput, Dumper.Dumpable {

  public static final Factory DEFAULT_FACTORY =
      (id, type) -> new FakeTrackOutput(/* deduplicateConsecutiveFormats= */ false);

  /** Factory for {@link FakeTrackOutput} instances. */
  public interface Factory {
    FakeTrackOutput create(int id, int type);
  }

  private final boolean deduplicateConsecutiveFormats;
  private final ArrayList<DumpableSampleInfo> sampleInfos;
  private final ArrayList<Dumpable> dumpables;

  private byte[] sampleData;
  private int formatCount;
  private boolean receivedSampleInFormat;

  @Nullable public Format lastFormat;

  public FakeTrackOutput(boolean deduplicateConsecutiveFormats) {
    this.deduplicateConsecutiveFormats = deduplicateConsecutiveFormats;
    sampleInfos = new ArrayList<>();
    dumpables = new ArrayList<>();
    sampleData = Util.EMPTY_BYTE_ARRAY;
    formatCount = 0;
    receivedSampleInFormat = true;
  }

  public void clear() {
    sampleInfos.clear();
    dumpables.clear();
    sampleData = Util.EMPTY_BYTE_ARRAY;
    formatCount = 0;
    receivedSampleInFormat = true;
  }

  @Override
  public void format(Format format) {
    if (!deduplicateConsecutiveFormats) {
      Assertions.checkState(
          receivedSampleInFormat,
          "deduplicateConsecutiveFormats=false so TrackOutput must receive at least one"
              + " sampleMetadata() call between format() calls.");
    } else if (!receivedSampleInFormat) {
      Dumpable dumpable = dumpables.remove(dumpables.size() - 1);
      formatCount--;
      Assertions.checkState(
          dumpable instanceof DumpableFormat,
          "receivedSampleInFormat=false so expected last dumpable to be a DumpableFormat. Found: "
              + dumpable.getClass().getCanonicalName());
    }
    receivedSampleInFormat = false;
    addFormat(format);
  }

  @Override
  public int sampleData(
      DataReader input, int length, boolean allowEndOfInput, @SampleDataPart int sampleDataPart)
      throws IOException {
    byte[] newData = new byte[length];
    int bytesAppended = input.read(newData, 0, length);
    if (bytesAppended == C.RESULT_END_OF_INPUT) {
      if (allowEndOfInput) {
        return C.RESULT_END_OF_INPUT;
      }
      throw new EOFException();
    }
    newData = Arrays.copyOf(newData, bytesAppended);
    sampleData = Bytes.concat(sampleData, newData);
    return bytesAppended;
  }

  @Override
  public void sampleData(ParsableByteArray data, int length, @SampleDataPart int sampleDataPart) {
    byte[] newData = new byte[length];
    data.readBytes(newData, 0, length);
    sampleData = Bytes.concat(sampleData, newData);
  }

  @Override
  public void sampleMetadata(
      long timeUs,
      @C.BufferFlags int flags,
      int size,
      int offset,
      @Nullable CryptoData cryptoData) {
    receivedSampleInFormat = true;
    if (lastFormat == null) {
      throw new IllegalStateException("TrackOutput must receive format before sampleMetadata");
    }
    if (lastFormat.maxInputSize != Format.NO_VALUE && size > lastFormat.maxInputSize) {
      throw new IllegalStateException("Sample size exceeds Format.maxInputSize");
    }
    if (dumpables.isEmpty()) {
      addFormat(lastFormat);
    }
    addSampleInfo(
        timeUs, flags, sampleData.length - offset - size, sampleData.length - offset, cryptoData);
  }

  public void assertSampleCount(int count) {
    assertThat(sampleInfos).hasSize(count);
  }

  public void assertSample(
      int index, byte[] data, long timeUs, int flags, @Nullable CryptoData cryptoData) {
    byte[] actualData = getSampleData(index);
    assertThat(actualData).isEqualTo(data);
    assertThat(getSampleTimeUs(index)).isEqualTo(timeUs);
    assertThat(getSampleFlags(index)).isEqualTo(flags);
    assertThat(getSampleCryptoData(index)).isEqualTo(cryptoData);
  }

  public byte[] getSampleData(int index) {
    return Arrays.copyOfRange(sampleData, getSampleStartOffset(index), getSampleEndOffset(index));
  }

  private byte[] getSampleData(int fromIndex, int toIndex) {
    return Arrays.copyOfRange(sampleData, fromIndex, toIndex);
  }

  public long getSampleTimeUs(int index) {
    return sampleInfos.get(index).timeUs;
  }

  public int getSampleFlags(int index) {
    return sampleInfos.get(index).flags;
  }

  @Nullable
  public CryptoData getSampleCryptoData(int index) {
    return sampleInfos.get(index).cryptoData;
  }

  public int getSampleCount() {
    return sampleInfos.size();
  }

  public List<Long> getSampleTimesUs() {
    List<Long> sampleTimesUs = new ArrayList<>();
    for (DumpableSampleInfo sampleInfo : sampleInfos) {
      sampleTimesUs.add(sampleInfo.timeUs);
    }
    return Collections.unmodifiableList(sampleTimesUs);
  }

  @Override
  public void dump(Dumper dumper) {
    dumper.add("total output bytes", sampleData.length);
    dumper.add("sample count", sampleInfos.size());
    if (dumpables.isEmpty() && lastFormat != null) {
      new DumpableFormat(lastFormat, 0).dump(dumper);
    }
    for (int i = 0; i < dumpables.size(); i++) {
      dumpables.get(i).dump(dumper);
    }
  }

  private int getSampleStartOffset(int index) {
    return sampleInfos.get(index).startOffset;
  }

  private int getSampleEndOffset(int index) {
    return sampleInfos.get(index).endOffset;
  }

  private void addFormat(Format format) {
    lastFormat = format;
    dumpables.add(new DumpableFormat(format, formatCount));
    formatCount++;
  }

  private void addSampleInfo(
      long timeUs, int flags, int startOffset, int endOffset, @Nullable CryptoData cryptoData) {
    DumpableSampleInfo sampleInfo =
        new DumpableSampleInfo(timeUs, flags, startOffset, endOffset, cryptoData, getSampleCount());
    sampleInfos.add(sampleInfo);
    dumpables.add(sampleInfo);
  }

  private final class DumpableSampleInfo implements Dumper.Dumpable {
    public final long timeUs;
    public final int flags;
    public final int startOffset;
    public final int endOffset;
    @Nullable public final CryptoData cryptoData;
    public final int index;

    public DumpableSampleInfo(
        long timeUs,
        int flags,
        int startOffset,
        int endOffset,
        @Nullable CryptoData cryptoData,
        int index) {
      this.timeUs = timeUs;
      this.flags = flags;
      this.startOffset = startOffset;
      this.endOffset = endOffset;
      this.cryptoData = cryptoData;
      this.index = index;
    }

    @Override
    public void dump(Dumper dumper) {
      dumper
          .startBlock("sample " + index)
          .add("time", timeUs)
          .add("flags", flags)
          .add("data", getSampleData(startOffset, endOffset));
      if (cryptoData != null) {
        dumper.add("crypto mode", cryptoData.cryptoMode);
        dumper.add("encryption key", cryptoData.encryptionKey);
      }
      dumper.endBlock();
    }

    @Override
    public boolean equals(@Nullable Object o) {
      if (this == o) {
        return true;
      }
      if (o == null || getClass() != o.getClass()) {
        return false;
      }
      DumpableSampleInfo that = (DumpableSampleInfo) o;
      return timeUs == that.timeUs
          && flags == that.flags
          && startOffset == that.startOffset
          && endOffset == that.endOffset
          && index == that.index
          && Util.areEqual(cryptoData, that.cryptoData);
    }

    @Override
    public int hashCode() {
      int result = (int) timeUs;
      result = 31 * result + flags;
      result = 31 * result + startOffset;
      result = 31 * result + endOffset;
      result = 31 * result + (cryptoData == null ? 0 : cryptoData.hashCode());
      result = 31 * result + index;
      return result;
    }
  }
}