Sonic.java

/*
 * Copyright (C) 2017 The Android Open Source Project
 * Copyright (C) 2010 Bill Cox, Sonic Library
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package androidx.media3.exoplayer.audio;

import static java.lang.Math.min;

import androidx.media3.common.util.Assertions;
import java.nio.ShortBuffer;
import java.util.Arrays;

/**
 * Sonic audio stream processor for time/pitch stretching.
 *
 * <p>Based on https://github.com/waywardgeek/sonic.
 */
/* package */ final class Sonic {

  private static final int MINIMUM_PITCH = 65;
  private static final int MAXIMUM_PITCH = 400;
  private static final int AMDF_FREQUENCY = 4000;
  private static final int BYTES_PER_SAMPLE = 2;

  private final int inputSampleRateHz;
  private final int channelCount;
  private final float speed;
  private final float pitch;
  private final float rate;
  private final int minPeriod;
  private final int maxPeriod;
  private final int maxRequiredFrameCount;
  private final short[] downSampleBuffer;

  private short[] inputBuffer;
  private int inputFrameCount;
  private short[] outputBuffer;
  private int outputFrameCount;
  private short[] pitchBuffer;
  private int pitchFrameCount;
  private int oldRatePosition;
  private int newRatePosition;
  private int remainingInputToCopyFrameCount;
  private int prevPeriod;
  private int prevMinDiff;
  private int minDiff;
  private int maxDiff;

  /**
   * Creates a new Sonic audio stream processor.
   *
   * @param inputSampleRateHz The sample rate of input audio, in hertz.
   * @param channelCount The number of channels in the input audio.
   * @param speed The speedup factor for output audio.
   * @param pitch The pitch factor for output audio.
   * @param outputSampleRateHz The sample rate for output audio, in hertz.
   */
  public Sonic(
      int inputSampleRateHz, int channelCount, float speed, float pitch, int outputSampleRateHz) {
    this.inputSampleRateHz = inputSampleRateHz;
    this.channelCount = channelCount;
    this.speed = speed;
    this.pitch = pitch;
    rate = (float) inputSampleRateHz / outputSampleRateHz;
    minPeriod = inputSampleRateHz / MAXIMUM_PITCH;
    maxPeriod = inputSampleRateHz / MINIMUM_PITCH;
    maxRequiredFrameCount = 2 * maxPeriod;
    downSampleBuffer = new short[maxRequiredFrameCount];
    inputBuffer = new short[maxRequiredFrameCount * channelCount];
    outputBuffer = new short[maxRequiredFrameCount * channelCount];
    pitchBuffer = new short[maxRequiredFrameCount * channelCount];
  }

  /**
   * Returns the number of bytes that have been input, but will not be processed until more input
   * data is provided.
   */
  public int getPendingInputBytes() {
    return inputFrameCount * channelCount * BYTES_PER_SAMPLE;
  }

  /**
   * Queues remaining data from {@code buffer}, and advances its position by the number of bytes
   * consumed.
   *
   * @param buffer A {@link ShortBuffer} containing input data between its position and limit.
   */
  public void queueInput(ShortBuffer buffer) {
    int framesToWrite = buffer.remaining() / channelCount;
    int bytesToWrite = framesToWrite * channelCount * 2;
    inputBuffer = ensureSpaceForAdditionalFrames(inputBuffer, inputFrameCount, framesToWrite);
    buffer.get(inputBuffer, inputFrameCount * channelCount, bytesToWrite / 2);
    inputFrameCount += framesToWrite;
    processStreamInput();
  }

  /**
   * Gets available output, outputting to the start of {@code buffer}. The buffer's position will be
   * advanced by the number of bytes written.
   *
   * @param buffer A {@link ShortBuffer} into which output will be written.
   */
  public void getOutput(ShortBuffer buffer) {
    int framesToRead = min(buffer.remaining() / channelCount, outputFrameCount);
    buffer.put(outputBuffer, 0, framesToRead * channelCount);
    outputFrameCount -= framesToRead;
    System.arraycopy(
        outputBuffer,
        framesToRead * channelCount,
        outputBuffer,
        0,
        outputFrameCount * channelCount);
  }

  /**
   * Forces generating output using whatever data has been queued already. No extra delay will be
   * added to the output, but flushing in the middle of words could introduce distortion.
   */
  public void queueEndOfStream() {
    int remainingFrameCount = inputFrameCount;
    float s = speed / pitch;
    float r = rate * pitch;
    int expectedOutputFrames =
        outputFrameCount + (int) ((remainingFrameCount / s + pitchFrameCount) / r + 0.5f);

    // Add enough silence to flush both input and pitch buffers.
    inputBuffer =
        ensureSpaceForAdditionalFrames(
            inputBuffer, inputFrameCount, remainingFrameCount + 2 * maxRequiredFrameCount);
    for (int xSample = 0; xSample < 2 * maxRequiredFrameCount * channelCount; xSample++) {
      inputBuffer[remainingFrameCount * channelCount + xSample] = 0;
    }
    inputFrameCount += 2 * maxRequiredFrameCount;
    processStreamInput();
    // Throw away any extra frames we generated due to the silence we added.
    if (outputFrameCount > expectedOutputFrames) {
      outputFrameCount = expectedOutputFrames;
    }
    // Empty input and pitch buffers.
    inputFrameCount = 0;
    remainingInputToCopyFrameCount = 0;
    pitchFrameCount = 0;
  }

  /** Clears state in preparation for receiving a new stream of input buffers. */
  public void flush() {
    inputFrameCount = 0;
    outputFrameCount = 0;
    pitchFrameCount = 0;
    oldRatePosition = 0;
    newRatePosition = 0;
    remainingInputToCopyFrameCount = 0;
    prevPeriod = 0;
    prevMinDiff = 0;
    minDiff = 0;
    maxDiff = 0;
  }

  /** Returns the size of output that can be read with {@link #getOutput(ShortBuffer)}, in bytes. */
  public int getOutputSize() {
    return outputFrameCount * channelCount * BYTES_PER_SAMPLE;
  }

  // Internal methods.

  /**
   * Returns {@code buffer} or a copy of it, such that there is enough space in the returned buffer
   * to store {@code newFrameCount} additional frames.
   *
   * @param buffer The buffer.
   * @param frameCount The number of frames already in the buffer.
   * @param additionalFrameCount The number of additional frames that need to be stored in the
   *     buffer.
   * @return A buffer with enough space for the additional frames.
   */
  private short[] ensureSpaceForAdditionalFrames(
      short[] buffer, int frameCount, int additionalFrameCount) {
    int currentCapacityFrames = buffer.length / channelCount;
    if (frameCount + additionalFrameCount <= currentCapacityFrames) {
      return buffer;
    } else {
      int newCapacityFrames = 3 * currentCapacityFrames / 2 + additionalFrameCount;
      return Arrays.copyOf(buffer, newCapacityFrames * channelCount);
    }
  }

  private void removeProcessedInputFrames(int positionFrames) {
    int remainingFrames = inputFrameCount - positionFrames;
    System.arraycopy(
        inputBuffer, positionFrames * channelCount, inputBuffer, 0, remainingFrames * channelCount);
    inputFrameCount = remainingFrames;
  }

  private void copyToOutput(short[] samples, int positionFrames, int frameCount) {
    outputBuffer = ensureSpaceForAdditionalFrames(outputBuffer, outputFrameCount, frameCount);
    System.arraycopy(
        samples,
        positionFrames * channelCount,
        outputBuffer,
        outputFrameCount * channelCount,
        frameCount * channelCount);
    outputFrameCount += frameCount;
  }

  private int copyInputToOutput(int positionFrames) {
    int frameCount = min(maxRequiredFrameCount, remainingInputToCopyFrameCount);
    copyToOutput(inputBuffer, positionFrames, frameCount);
    remainingInputToCopyFrameCount -= frameCount;
    return frameCount;
  }

  private void downSampleInput(short[] samples, int position, int skip) {
    // If skip is greater than one, average skip samples together and write them to the down-sample
    // buffer. If channelCount is greater than one, mix the channels together as we down sample.
    int frameCount = maxRequiredFrameCount / skip;
    int samplesPerValue = channelCount * skip;
    position *= channelCount;
    for (int i = 0; i < frameCount; i++) {
      int value = 0;
      for (int j = 0; j < samplesPerValue; j++) {
        value += samples[position + i * samplesPerValue + j];
      }
      value /= samplesPerValue;
      downSampleBuffer[i] = (short) value;
    }
  }

  private int findPitchPeriodInRange(short[] samples, int position, int minPeriod, int maxPeriod) {
    // Find the best frequency match in the range, and given a sample skip multiple. For now, just
    // find the pitch of the first channel.
    int bestPeriod = 0;
    int worstPeriod = 255;
    int minDiff = 1;
    int maxDiff = 0;
    position *= channelCount;
    for (int period = minPeriod; period <= maxPeriod; period++) {
      int diff = 0;
      for (int i = 0; i < period; i++) {
        short sVal = samples[position + i];
        short pVal = samples[position + period + i];
        diff += Math.abs(sVal - pVal);
      }
      // Note that the highest number of samples we add into diff will be less than 256, since we
      // skip samples. Thus, diff is a 24 bit number, and we can safely multiply by numSamples
      // without overflow.
      if (diff * bestPeriod < minDiff * period) {
        minDiff = diff;
        bestPeriod = period;
      }
      if (diff * worstPeriod > maxDiff * period) {
        maxDiff = diff;
        worstPeriod = period;
      }
    }
    this.minDiff = minDiff / bestPeriod;
    this.maxDiff = maxDiff / worstPeriod;
    return bestPeriod;
  }

  /**
   * Returns whether the previous pitch period estimate is a better approximation, which can occur
   * at the abrupt end of voiced words.
   */
  private boolean previousPeriodBetter(int minDiff, int maxDiff) {
    if (minDiff == 0 || prevPeriod == 0) {
      return false;
    }
    if (maxDiff > minDiff * 3) {
      // Got a reasonable match this period.
      return false;
    }
    if (minDiff * 2 <= prevMinDiff * 3) {
      // Mismatch is not that much greater this period.
      return false;
    }
    return true;
  }

  private int findPitchPeriod(short[] samples, int position) {
    // Find the pitch period. This is a critical step, and we may have to try multiple ways to get a
    // good answer. This version uses AMDF. To improve speed, we down sample by an integer factor
    // get in the 11 kHz range, and then do it again with a narrower frequency range without down
    // sampling.
    int period;
    int retPeriod;
    int skip = inputSampleRateHz > AMDF_FREQUENCY ? inputSampleRateHz / AMDF_FREQUENCY : 1;
    if (channelCount == 1 && skip == 1) {
      period = findPitchPeriodInRange(samples, position, minPeriod, maxPeriod);
    } else {
      downSampleInput(samples, position, skip);
      period = findPitchPeriodInRange(downSampleBuffer, 0, minPeriod / skip, maxPeriod / skip);
      if (skip != 1) {
        period *= skip;
        int minP = period - (skip * 4);
        int maxP = period + (skip * 4);
        if (minP < minPeriod) {
          minP = minPeriod;
        }
        if (maxP > maxPeriod) {
          maxP = maxPeriod;
        }
        if (channelCount == 1) {
          period = findPitchPeriodInRange(samples, position, minP, maxP);
        } else {
          downSampleInput(samples, position, 1);
          period = findPitchPeriodInRange(downSampleBuffer, 0, minP, maxP);
        }
      }
    }
    if (previousPeriodBetter(minDiff, maxDiff)) {
      retPeriod = prevPeriod;
    } else {
      retPeriod = period;
    }
    prevMinDiff = minDiff;
    prevPeriod = period;
    return retPeriod;
  }

  private void moveNewSamplesToPitchBuffer(int originalOutputFrameCount) {
    int frameCount = outputFrameCount - originalOutputFrameCount;
    pitchBuffer = ensureSpaceForAdditionalFrames(pitchBuffer, pitchFrameCount, frameCount);
    System.arraycopy(
        outputBuffer,
        originalOutputFrameCount * channelCount,
        pitchBuffer,
        pitchFrameCount * channelCount,
        frameCount * channelCount);
    outputFrameCount = originalOutputFrameCount;
    pitchFrameCount += frameCount;
  }

  private void removePitchFrames(int frameCount) {
    if (frameCount == 0) {
      return;
    }
    System.arraycopy(
        pitchBuffer,
        frameCount * channelCount,
        pitchBuffer,
        0,
        (pitchFrameCount - frameCount) * channelCount);
    pitchFrameCount -= frameCount;
  }

  private short interpolate(short[] in, int inPos, int oldSampleRate, int newSampleRate) {
    short left = in[inPos];
    short right = in[inPos + channelCount];
    int position = newRatePosition * oldSampleRate;
    int leftPosition = oldRatePosition * newSampleRate;
    int rightPosition = (oldRatePosition + 1) * newSampleRate;
    int ratio = rightPosition - position;
    int width = rightPosition - leftPosition;
    return (short) ((ratio * left + (width - ratio) * right) / width);
  }

  private void adjustRate(float rate, int originalOutputFrameCount) {
    if (outputFrameCount == originalOutputFrameCount) {
      return;
    }
    int newSampleRate = (int) (inputSampleRateHz / rate);
    int oldSampleRate = inputSampleRateHz;
    // Set these values to help with the integer math.
    while (newSampleRate > (1 << 14) || oldSampleRate > (1 << 14)) {
      newSampleRate /= 2;
      oldSampleRate /= 2;
    }
    moveNewSamplesToPitchBuffer(originalOutputFrameCount);
    // Leave at least one pitch sample in the buffer.
    for (int position = 0; position < pitchFrameCount - 1; position++) {
      while ((oldRatePosition + 1) * newSampleRate > newRatePosition * oldSampleRate) {
        outputBuffer =
            ensureSpaceForAdditionalFrames(
                outputBuffer, outputFrameCount, /* additionalFrameCount= */ 1);
        for (int i = 0; i < channelCount; i++) {
          outputBuffer[outputFrameCount * channelCount + i] =
              interpolate(pitchBuffer, position * channelCount + i, oldSampleRate, newSampleRate);
        }
        newRatePosition++;
        outputFrameCount++;
      }
      oldRatePosition++;
      if (oldRatePosition == oldSampleRate) {
        oldRatePosition = 0;
        Assertions.checkState(newRatePosition == newSampleRate);
        newRatePosition = 0;
      }
    }
    removePitchFrames(pitchFrameCount - 1);
  }

  private int skipPitchPeriod(short[] samples, int position, float speed, int period) {
    // Skip over a pitch period, and copy period/speed samples to the output.
    int newFrameCount;
    if (speed >= 2.0f) {
      newFrameCount = (int) (period / (speed - 1.0f));
    } else {
      newFrameCount = period;
      remainingInputToCopyFrameCount = (int) (period * (2.0f - speed) / (speed - 1.0f));
    }
    outputBuffer = ensureSpaceForAdditionalFrames(outputBuffer, outputFrameCount, newFrameCount);
    overlapAdd(
        newFrameCount,
        channelCount,
        outputBuffer,
        outputFrameCount,
        samples,
        position,
        samples,
        position + period);
    outputFrameCount += newFrameCount;
    return newFrameCount;
  }

  private int insertPitchPeriod(short[] samples, int position, float speed, int period) {
    // Insert a pitch period, and determine how much input to copy directly.
    int newFrameCount;
    if (speed < 0.5f) {
      newFrameCount = (int) (period * speed / (1.0f - speed));
    } else {
      newFrameCount = period;
      remainingInputToCopyFrameCount = (int) (period * (2.0f * speed - 1.0f) / (1.0f - speed));
    }
    outputBuffer =
        ensureSpaceForAdditionalFrames(outputBuffer, outputFrameCount, period + newFrameCount);
    System.arraycopy(
        samples,
        position * channelCount,
        outputBuffer,
        outputFrameCount * channelCount,
        period * channelCount);
    overlapAdd(
        newFrameCount,
        channelCount,
        outputBuffer,
        outputFrameCount + period,
        samples,
        position + period,
        samples,
        position);
    outputFrameCount += period + newFrameCount;
    return newFrameCount;
  }

  private void changeSpeed(float speed) {
    if (inputFrameCount < maxRequiredFrameCount) {
      return;
    }
    int frameCount = inputFrameCount;
    int positionFrames = 0;
    do {
      if (remainingInputToCopyFrameCount > 0) {
        positionFrames += copyInputToOutput(positionFrames);
      } else {
        int period = findPitchPeriod(inputBuffer, positionFrames);
        if (speed > 1.0) {
          positionFrames += period + skipPitchPeriod(inputBuffer, positionFrames, speed, period);
        } else {
          positionFrames += insertPitchPeriod(inputBuffer, positionFrames, speed, period);
        }
      }
    } while (positionFrames + maxRequiredFrameCount <= frameCount);
    removeProcessedInputFrames(positionFrames);
  }

  private void processStreamInput() {
    // Resample as many pitch periods as we have buffered on the input.
    int originalOutputFrameCount = outputFrameCount;
    float s = speed / pitch;
    float r = rate * pitch;
    if (s > 1.00001 || s < 0.99999) {
      changeSpeed(s);
    } else {
      copyToOutput(inputBuffer, 0, inputFrameCount);
      inputFrameCount = 0;
    }
    if (r != 1.0f) {
      adjustRate(r, originalOutputFrameCount);
    }
  }

  private static void overlapAdd(
      int frameCount,
      int channelCount,
      short[] out,
      int outPosition,
      short[] rampDown,
      int rampDownPosition,
      short[] rampUp,
      int rampUpPosition) {
    for (int i = 0; i < channelCount; i++) {
      int o = outPosition * channelCount + i;
      int u = rampUpPosition * channelCount + i;
      int d = rampDownPosition * channelCount + i;
      for (int t = 0; t < frameCount; t++) {
        out[o] = (short) ((rampDown[d] * (frameCount - t) + rampUp[u] * t) / frameCount);
        o += channelCount;
        d += channelCount;
        u += channelCount;
      }
    }
  }
}