FloatAudioMixingAlgorithm.java

/*
 * Copyright 2022 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package androidx.media3.transformer;

import static androidx.media3.common.util.Assertions.checkArgument;

import android.annotation.SuppressLint;
import androidx.media3.common.C;
import androidx.media3.common.Format;
import androidx.media3.common.audio.AudioProcessor.AudioFormat;
import androidx.media3.common.audio.ChannelMixingMatrix;
import java.nio.ByteBuffer;

/** An {@link AudioMixingAlgorithm} which mixes into float samples. */
/* package */ class FloatAudioMixingAlgorithm implements AudioMixingAlgorithm {

  // Short.MIN_VALUE != -Short.MAX_VALUE so use different scaling factors for positive and
  // negative samples.
  private static final float SCALE_S16_FOR_NEGATIVE_INPUT = -1f / Short.MIN_VALUE;
  private static final float SCALE_S16_FOR_POSITIVE_INPUT = 1f / Short.MAX_VALUE;

  private final AudioFormat mixingAudioFormat;

  public FloatAudioMixingAlgorithm(AudioFormat mixingAudioFormat) {
    checkArgument(mixingAudioFormat.encoding == C.ENCODING_PCM_FLOAT);
    checkArgument(mixingAudioFormat.channelCount != Format.NO_VALUE);
    this.mixingAudioFormat = mixingAudioFormat;
  }

  @Override
  @SuppressLint("SwitchIntDef")
  public boolean supportsSourceAudioFormat(AudioFormat sourceAudioFormat) {
    if (sourceAudioFormat.sampleRate != mixingAudioFormat.sampleRate) {
      return false;
    }
    switch (sourceAudioFormat.encoding) {
      case C.ENCODING_PCM_16BIT:
      case C.ENCODING_PCM_FLOAT:
        return true;
      default:
        return false;
    }
  }

  @Override
  @SuppressLint("SwitchIntDef")
  public ByteBuffer mix(
      ByteBuffer sourceBuffer,
      AudioFormat sourceAudioFormat,
      ChannelMixingMatrix channelMixingMatrix,
      int frameCount,
      ByteBuffer mixingBuffer) {
    checkArgument(
        supportsSourceAudioFormat(sourceAudioFormat), "Source audio format is not supported.");
    checkArgument(
        channelMixingMatrix.getInputChannelCount() == sourceAudioFormat.channelCount,
        "Input channel count does not match source format.");
    checkArgument(
        channelMixingMatrix.getOutputChannelCount() == mixingAudioFormat.channelCount,
        "Output channel count does not match mixing format.");
    checkArgument(
        sourceBuffer.remaining() >= frameCount * sourceAudioFormat.bytesPerFrame,
        "Source buffer is too small.");
    checkArgument(
        mixingBuffer.remaining() >= frameCount * mixingAudioFormat.bytesPerFrame,
        "Mixing buffer is too small.");

    switch (sourceAudioFormat.encoding) {
      case C.ENCODING_PCM_FLOAT:
        return mixFloatIntoFloat(sourceBuffer, channelMixingMatrix, frameCount, mixingBuffer);
      case C.ENCODING_PCM_16BIT:
        return mixS16IntoFloat(sourceBuffer, channelMixingMatrix, frameCount, mixingBuffer);
      default:
        throw new IllegalArgumentException("Source encoding is not supported.");
    }
  }

  private static ByteBuffer mixFloatIntoFloat(
      ByteBuffer sourceBuffer,
      ChannelMixingMatrix channelMixingMatrix,
      int frameCount,
      ByteBuffer mixingBuffer) {
    if (channelMixingMatrix.isDiagonal()) {
      return mixFloatIntoFloatDiagonal(sourceBuffer, channelMixingMatrix, frameCount, mixingBuffer);
    }
    int sourceChannelCount = channelMixingMatrix.getInputChannelCount();
    float[] sourceFrame = new float[sourceChannelCount];
    for (int i = 0; i < frameCount; i++) {
      for (int sourceChannel = 0; sourceChannel < sourceChannelCount; sourceChannel++) {
        sourceFrame[sourceChannel] = sourceBuffer.getFloat();
      }
      mixFloatFrameIntoFloat(sourceFrame, channelMixingMatrix, mixingBuffer);
    }
    return mixingBuffer;
  }

  private static void mixFloatFrameIntoFloat(
      float[] sourceFrame, ChannelMixingMatrix channelMixingMatrix, ByteBuffer mixingBuffer) {
    int mixingChannelCount = channelMixingMatrix.getOutputChannelCount();
    for (int mixingChannel = 0; mixingChannel < mixingChannelCount; mixingChannel++) {
      float mixedSample = mixingBuffer.getFloat(mixingBuffer.position());
      for (int sourceChannel = 0; sourceChannel < sourceFrame.length; sourceChannel++) {
        mixedSample +=
            channelMixingMatrix.getMixingCoefficient(sourceChannel, mixingChannel)
                * sourceFrame[sourceChannel];
      }
      mixingBuffer.putFloat(mixedSample);
    }
  }

  private static ByteBuffer mixFloatIntoFloatDiagonal(
      ByteBuffer sourceBuffer,
      ChannelMixingMatrix channelMixingMatrix,
      int frameCount,
      ByteBuffer mixingBuffer) {
    int channelCount = channelMixingMatrix.getInputChannelCount();
    for (int i = 0; i < frameCount; i++) {
      for (int c = 0; c < channelCount; c++) {
        float sourceSample = sourceBuffer.getFloat();
        float mixedSample =
            mixingBuffer.getFloat(mixingBuffer.position())
                + channelMixingMatrix.getMixingCoefficient(c, c) * sourceSample;
        mixingBuffer.putFloat(mixedSample);
      }
    }
    return mixingBuffer;
  }

  private static ByteBuffer mixS16IntoFloat(
      ByteBuffer sourceBuffer,
      ChannelMixingMatrix channelMixingMatrix,
      int frameCount,
      ByteBuffer mixingBuffer) {
    if (channelMixingMatrix.isDiagonal()) {
      return mixS16IntoFloatDiagonal(sourceBuffer, channelMixingMatrix, frameCount, mixingBuffer);
    }
    int sourceChannelCount = channelMixingMatrix.getInputChannelCount();
    float[] sourceFrame = new float[sourceChannelCount];
    for (int i = 0; i < frameCount; i++) {
      for (int sourceChannel = 0; sourceChannel < sourceChannelCount; sourceChannel++) {
        sourceFrame[sourceChannel] = s16ToFloat(sourceBuffer.getShort());
      }
      mixFloatFrameIntoFloat(sourceFrame, channelMixingMatrix, mixingBuffer);
    }
    return mixingBuffer;
  }

  private static ByteBuffer mixS16IntoFloatDiagonal(
      ByteBuffer sourceBuffer,
      ChannelMixingMatrix channelMixingMatrix,
      int frameCount,
      ByteBuffer mixingBuffer) {
    int channelCount = channelMixingMatrix.getInputChannelCount();
    for (int i = 0; i < frameCount; i++) {
      for (int c = 0; c < channelCount; c++) {
        float sourceSample = s16ToFloat(sourceBuffer.getShort());
        float mixedSample =
            mixingBuffer.getFloat(mixingBuffer.position())
                + channelMixingMatrix.getMixingCoefficient(c, c) * sourceSample;
        mixingBuffer.putFloat(mixedSample);
      }
    }
    return mixingBuffer;
  }

  private static float s16ToFloat(short shortValue) {
    return shortValue
        * (shortValue < 0 ? SCALE_S16_FOR_NEGATIVE_INPUT : SCALE_S16_FOR_POSITIVE_INPUT);
  }
}