SeparableConvolutionShaderProgram.java

/*
 * Copyright 2023 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package androidx.media3.effect;

import android.content.Context;
import android.graphics.Bitmap;
import android.opengl.GLES20;
import android.opengl.GLUtils;
import androidx.annotation.CallSuper;
import androidx.annotation.RequiresApi;
import androidx.media3.common.GlObjectsProvider;
import androidx.media3.common.GlTextureInfo;
import androidx.media3.common.VideoFrameProcessingException;
import androidx.media3.common.util.Assertions;
import androidx.media3.common.util.GlProgram;
import androidx.media3.common.util.GlUtil;
import androidx.media3.common.util.Size;
import androidx.media3.common.util.UnstableApi;
import com.google.common.util.concurrent.MoreExecutors;
import java.io.IOException;
import java.nio.ShortBuffer;
import java.util.concurrent.Executor;
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;

/**
 * A {@link GlShaderProgram} for performing separable convolutions.
 *
 * <p>A single {@link ConvolutionFunction1D} is applied horizontally on a first pass and vertically
 * on a second pass.
 */
@RequiresApi(26) // Uses Bitmap.Config.RGBA_F16.
@UnstableApi
public class SeparableConvolutionShaderProgram implements GlShaderProgram {
  private static final String VERTEX_SHADER_PATH = "shaders/vertex_shader_transformation_es2.glsl";
  private static final String FRAGMENT_SHADER_PATH =
      "shaders/fragment_shader_separable_convolution_es2.glsl";

  // Constants specifically for fp16FromFloat().
  // TODO (b/282767994): Fix TAP hanging issue and update samples per texel.
  private static final int RASTER_SAMPLES_PER_TEXEL = 5;
  // Apply some padding in the function LUT to avoid any issues from GL sampling off the texture.
  private static final int FUNCTION_LUT_PADDING = RASTER_SAMPLES_PER_TEXEL;

  // BEGIN COPIED FP16 code.
  // Source: libcore/luni/src/main/java/libcore/util/FP16.java
  private static final int FP16_EXPONENT_BIAS = 15;
  private static final int FP16_SIGN_SHIFT = 15;
  private static final int FP16_EXPONENT_SHIFT = 10;
  private static final int FP32_SIGN_SHIFT = 31;
  private static final int FP32_EXPONENT_SHIFT = 23;
  private static final int FP32_SHIFTED_EXPONENT_MASK = 0xff;
  private static final int FP32_SIGNIFICAND_MASK = 0x7fffff;
  private static final int FP32_EXPONENT_BIAS = 127;
  // END FP16 copied code.

  private final GlProgram glProgram;
  private final boolean useHdr;
  private final SeparableConvolution convolution;
  private final float scaleWidth;
  private final float scaleHeight;

  private GlShaderProgram.InputListener inputListener;
  private GlShaderProgram.OutputListener outputListener;
  private GlShaderProgram.ErrorListener errorListener;
  private Executor errorListenerExecutor;
  private boolean outputTextureInUse;
  private GlTextureInfo outputTexture;
  private GlTextureInfo intermediateTexture;
  private GlTextureInfo functionLutTexture; // Values for the function LUT as a texture.
  private float functionLutTexelStep;
  private float functionLutCenterX;
  private float functionLutDomainStart;
  private float functionLutWidth;
  private Size outputSize;
  private Size lastInputSize;
  private Size intermediateSize;
  private @MonotonicNonNull ConvolutionFunction1D lastConvolutionFunction;

  /**
   * Creates an instance.
   *
   * @param context The {@link Context}.
   * @param useHdr Whether input textures come from an HDR source. If {@code true}, colors will be
   *     in linear RGB BT.2020. If {@code false}, colors will be in linear RGB BT.709.
   * @param convolution The {@link SeparableConvolution} to apply in each direction.
   * @param scaleWidth The scaling factor used to determine the width of the output relative to the
   *     input.
   * @param scaleHeight The scaling factor used to determine the height of the output relative to
   *     the input.
   * @throws VideoFrameProcessingException If a problem occurs while reading shader files.
   */
  public SeparableConvolutionShaderProgram(
      Context context,
      boolean useHdr,
      SeparableConvolution convolution,
      float scaleWidth,
      float scaleHeight)
      throws VideoFrameProcessingException {
    this.useHdr = useHdr;
    this.convolution = convolution;
    this.scaleWidth = scaleWidth;
    this.scaleHeight = scaleHeight;
    inputListener = new InputListener() {};
    outputListener = new OutputListener() {};
    errorListener = (frameProcessingException) -> {};
    errorListenerExecutor = MoreExecutors.directExecutor();
    functionLutTexture = GlTextureInfo.UNSET;
    intermediateTexture = GlTextureInfo.UNSET;
    outputTexture = GlTextureInfo.UNSET;
    lastInputSize = Size.ZERO;
    intermediateSize = Size.ZERO;
    outputSize = Size.ZERO;
    lastConvolutionFunction = null;

    try {
      glProgram = new GlProgram(context, VERTEX_SHADER_PATH, FRAGMENT_SHADER_PATH);
    } catch (IOException | GlUtil.GlException e) {
      throw new VideoFrameProcessingException(e);
    }
  }

  @Override
  public final void setInputListener(InputListener inputListener) {
    this.inputListener = inputListener;
    if (!outputTextureInUse) {
      inputListener.onReadyToAcceptInputFrame();
    }
  }

  @Override
  public final void setOutputListener(OutputListener outputListener) {
    this.outputListener = outputListener;
  }

  @Override
  public final void setErrorListener(Executor errorListenerExecutor, ErrorListener errorListener) {
    this.errorListenerExecutor = errorListenerExecutor;
    this.errorListener = errorListener;
  }

  @Override
  public final void queueInputFrame(
      GlObjectsProvider glObjectsProvider, GlTextureInfo inputTexture, long presentationTimeUs) {
    Assertions.checkState(
        !outputTextureInUse,
        "The shader program does not currently accept input frames. Release prior output frames"
            + " first.");
    try {
      ensureTexturesAreConfigured(
          glObjectsProvider, new Size(inputTexture.width, inputTexture.height), presentationTimeUs);
      outputTextureInUse = true;
      renderHorizontal(inputTexture);
      renderVertical();

      onBlurRendered(inputTexture);

      // The four-vertex triangle strip forms a quad.
      GLES20.glDrawArrays(GLES20.GL_TRIANGLE_STRIP, /* i1= */ 0, /* i2= */ 4);
      GlUtil.checkGlError();
      inputListener.onInputFrameProcessed(inputTexture);
      outputListener.onOutputFrameAvailable(outputTexture, presentationTimeUs);
    } catch (GlUtil.GlException e) {
      errorListenerExecutor.execute(
          () -> errorListener.onError(VideoFrameProcessingException.from(e, presentationTimeUs)));
    }
  }

  @Override
  public final void releaseOutputFrame(GlTextureInfo outputTexture) {
    outputTextureInUse = false;
    inputListener.onReadyToAcceptInputFrame();
  }

  @Override
  public final void signalEndOfCurrentInputStream() {
    outputListener.onCurrentOutputStreamEnded();
  }

  @Override
  public final void flush() {
    outputTextureInUse = false;
    inputListener.onFlush();
    inputListener.onReadyToAcceptInputFrame();
  }

  @Override
  @CallSuper
  public void release() throws VideoFrameProcessingException {
    try {
      outputTexture.release();
      intermediateTexture.release();
      functionLutTexture.release();
      glProgram.delete();
    } catch (GlUtil.GlException e) {
      throw new VideoFrameProcessingException(e);
    }
  }

  /**
   * Called when the blur has been rendered onto the frame.
   *
   * <p>The default implementation is a no-op.
   *
   * @param inputTexture The input texture.
   * @throws GlUtil.GlException If an error occurs.
   */
  protected void onBlurRendered(GlTextureInfo inputTexture) throws GlUtil.GlException {
    // Do nothing.
  }

  private void renderOnePass(int inputTexId, boolean isHorizontal) throws GlUtil.GlException {
    int size = isHorizontal ? lastInputSize.getWidth() : intermediateSize.getHeight();
    glProgram.use();
    glProgram.setSamplerTexIdUniform("uTexSampler", inputTexId, /* texUnitIndex= */ 0);
    glProgram.setIntUniform("uIsHorizontal", isHorizontal ? 1 : 0);
    glProgram.setFloatUniform("uSourceTexelSize", 1.0f / size);
    glProgram.setFloatUniform("uSourceFullSize", (float) size);
    glProgram.setFloatUniform("uConvStartTexels", functionLutDomainStart);
    glProgram.setFloatUniform("uConvWidthTexels", functionLutWidth);
    glProgram.setFloatUniform("uFunctionLookupStepSize", functionLutTexelStep);
    glProgram.setFloatsUniform("uFunctionLookupCenter", new float[] {functionLutCenterX, 0.5f});
    glProgram.setSamplerTexIdUniform(
        "uFunctionLookupSampler", functionLutTexture.texId, /* texUnitIndex= */ 1);
    glProgram.bindAttributesAndUniforms();

    // The four-vertex triangle strip forms a quad.
    GLES20.glDrawArrays(GLES20.GL_TRIANGLE_STRIP, 0, 4);
    GlUtil.checkGlError();
  }

  private Size configure(Size inputSize) {
    // Draw the frame on the entire normalized device coordinate space, from -1 to 1, for x and y.
    glProgram.setBufferAttribute(
        "aFramePosition",
        GlUtil.getNormalizedCoordinateBounds(),
        GlUtil.HOMOGENEOUS_COORDINATE_VECTOR_SIZE);
    float[] identityMatrix = GlUtil.create4x4IdentityMatrix();
    glProgram.setFloatsUniform("uTransformationMatrix", identityMatrix);
    glProgram.setFloatsUniform("uTexTransformationMatrix", identityMatrix);

    return new Size(
        (int) (inputSize.getWidth() * scaleWidth), (int) (inputSize.getHeight() * scaleHeight));
  }

  private void renderHorizontal(GlTextureInfo inputTexture) throws GlUtil.GlException {
    // Render horizontal reads from the input texture and renders to the intermediate texture.
    GlUtil.focusFramebufferUsingCurrentContext(
        intermediateTexture.fboId, intermediateTexture.width, intermediateTexture.height);
    GlUtil.clearFocusedBuffers();
    renderOnePass(inputTexture.texId, /* isHorizontal= */ true);
  }

  private void renderVertical() throws GlUtil.GlException {
    // Render vertical reads from the intermediate and renders to the output texture.
    GlUtil.focusFramebufferUsingCurrentContext(
        outputTexture.fboId, outputTexture.width, outputTexture.height);
    GlUtil.clearFocusedBuffers();
    renderOnePass(intermediateTexture.texId, /* isHorizontal= */ false);
  }

  private void ensureTexturesAreConfigured(
      GlObjectsProvider glObjectsProvider, Size inputSize, long presentationTimeUs)
      throws GlUtil.GlException {
    ConvolutionFunction1D currentConvolutionFunction =
        convolution.getConvolution(presentationTimeUs);
    if (!currentConvolutionFunction.equals(lastConvolutionFunction)) {
      updateFunctionTexture(glObjectsProvider, currentConvolutionFunction);
      lastConvolutionFunction = currentConvolutionFunction;
    }

    // Only update intermediate and output textures if the size changes.
    if (inputSize.equals(lastInputSize)) {
      return;
    }

    outputSize = configure(inputSize);
    // If there is a size change with the filtering (for example, a scaling operation), the first
    // pass is applied horizontally.  As a result, width of the intermediate texture will match the
    // output size, while the height will be unchanged from the input
    intermediateSize = new Size(outputSize.getWidth(), inputSize.getHeight());
    intermediateTexture =
        configurePixelTexture(glObjectsProvider, intermediateTexture, intermediateSize);
    outputTexture = configurePixelTexture(glObjectsProvider, outputTexture, outputSize);

    this.lastInputSize = inputSize;
  }

  /**
   * Creates a function lookup table for the convolution, and stores it in a 16b floating point
   * texture for GPU access.
   */
  private void updateFunctionTexture(
      GlObjectsProvider glObjectsProvider, ConvolutionFunction1D convolutionFunction)
      throws GlUtil.GlException {

    int lutRasterSize =
        (int)
            Math.ceil(
                convolutionFunction.width() * RASTER_SAMPLES_PER_TEXEL + 2 * FUNCTION_LUT_PADDING);

    // The function LUT is mapped to [0, 1] texture coords. We need to calculate what change
    // in texture coordinated corresponds exactly with a size of 1 texel (or pixel) in the function.
    // This is basically 1 / function_width, but due to the ceil() call above, it needs to be
    // calculated based on the actual raster size.
    this.functionLutTexelStep = 1.0f / ((float) lutRasterSize / RASTER_SAMPLES_PER_TEXEL);

    // The function values are stored in an FP16 texture. Setting FP16 values in a Bitmap requires
    // multiple steps. For each step, calculate the function value as a Float, and then use the
    // Half class to convert to FP16 and then read the value as a Short int
    ShortBuffer functionShortBuffer = ShortBuffer.allocate(lutRasterSize * 4);
    float rasterSampleStep = 1.0f / RASTER_SAMPLES_PER_TEXEL;
    float functionDomainStart = convolutionFunction.domainStart();
    int index = 0;

    for (int i = 0; i < lutRasterSize; i++) {
      float sampleValue = 0.0f;
      int unpaddedI = i - FUNCTION_LUT_PADDING;
      float samplePosition = functionDomainStart + unpaddedI * rasterSampleStep;

      if (unpaddedI >= 0 && i <= lutRasterSize - FUNCTION_LUT_PADDING) {
        sampleValue = convolutionFunction.value(samplePosition);
      }

      // Convert float to half (fp16) and read out the bits as a short.
      // Texture for Bitmap is RGBA_F16, so we store the function value in RGB channels and 1.0
      // in A.
      short shortEncodedValue = fp16FromFloat(sampleValue);

      // Set RGB
      functionShortBuffer.put(index++, shortEncodedValue);
      functionShortBuffer.put(index++, shortEncodedValue);
      functionShortBuffer.put(index++, shortEncodedValue);

      // Set Alpha
      functionShortBuffer.put(index++, fp16FromFloat(1.0f));
    }

    // Calculate the center of the function in the raster.  The formula below is a slight
    // adjustment on (value - min) / (max - min), where value = 0 at center and
    // rasterSampleStep * lutRasterSize is equal to (max - min) over the range of the raster
    // samples, which may be slightly different than the difference between the function's max
    // and min domain values.
    // To find the value associated at position 0 in the texture, is the value corresponding with
    // the leading edge position of the first sample.  This needs to account for the padding and
    // the 1/2 texel offsets used in texture lookups (index 0 is centered at 0.5 / numTexels).
    float minValueWithPadding =
        functionDomainStart - rasterSampleStep * (FUNCTION_LUT_PADDING + 0.5f);
    this.functionLutCenterX = -minValueWithPadding / (rasterSampleStep * lutRasterSize);
    this.functionLutDomainStart = convolutionFunction.domainStart();
    this.functionLutWidth = convolutionFunction.width();

    // TODO(b/276982847): Use alternative to Bitmap to create function LUT texture.
    Bitmap functionLookupBitmap =
        Bitmap.createBitmap(lutRasterSize, /* height= */ 1, Bitmap.Config.RGBA_F16);
    functionLookupBitmap.copyPixelsFromBuffer(functionShortBuffer);

    // Create new GL texture if needed.
    if (functionLutTexture == GlTextureInfo.UNSET || functionLutTexture.width != lutRasterSize) {
      functionLutTexture.release();

      // Need to use high precision to force 16FP color.
      int functionLutTextureId =
          GlUtil.createTexture(
              lutRasterSize, /* height= */ 1, /* useHighPrecisionColorComponents= */ true);

      functionLutTexture =
          glObjectsProvider.createBuffersForTexture(
              functionLutTextureId, lutRasterSize, /* height= */ 1);
    }
    GLUtils.texImage2D(GLES20.GL_TEXTURE_2D, /* level= */ 0, functionLookupBitmap, /* border= */ 0);
    GlUtil.checkGlError();
  }

  private GlTextureInfo configurePixelTexture(
      GlObjectsProvider glObjectsProvider, GlTextureInfo existingTexture, Size size)
      throws GlUtil.GlException {
    if (size.getWidth() == existingTexture.width && size.getHeight() == existingTexture.height) {
      return existingTexture;
    }

    existingTexture.release();
    int texId = GlUtil.createTexture(size.getWidth(), size.getHeight(), useHdr);

    return glObjectsProvider.createBuffersForTexture(texId, size.getWidth(), size.getHeight());
  }

  // BEGIN COPIED FP16 code.
  // Source: libcore/luni/src/main/java/libcore/util/FP16.java
  // Float to half float conversion, copied from FP16.  This code is introduced in API26, so the
  // one required method is copied here.
  private static short fp16FromFloat(float f) {
    int bits = Float.floatToRawIntBits(f);
    int s = bits >>> FP32_SIGN_SHIFT;
    int e = (bits >>> FP32_EXPONENT_SHIFT) & FP32_SHIFTED_EXPONENT_MASK;
    int m = bits & FP32_SIGNIFICAND_MASK;
    int outE = 0;
    int outM = 0;
    if (e == 0xff) { // Infinite or NaN
      outE = 0x1f;
      outM = (m != 0) ? 0x200 : 0;
    } else {
      e = e - FP32_EXPONENT_BIAS + FP16_EXPONENT_BIAS;
      if (e >= 0x1f) { // Overflow
        outE = 0x1f;
      } else if (e <= 0) { // Underflow
        if (e >= -10) {
          // The fp32 value is a normalized float less than MIN_NORMAL,
          // we convert to a denorm fp16
          m |= 0x800000;
          int shift = 14 - e;
          outM = m >>> shift;
          int lowm = m & ((1 << shift) - 1);
          int hway = 1 << (shift - 1);
          // if above halfway or exactly halfway and outM is odd
          if (lowm + (outM & 1) > hway) {
            // Round to nearest even
            // Can overflow into exponent bit, which surprisingly is OK.
            // This increment relies on the +outM in the return statement below
            outM++;
          }
        }
      } else {
        outE = e;
        outM = m >>> 13;
        // if above halfway or exactly halfway and outM is odd
        if ((m & 0x1fff) + (outM & 0x1) > 0x1000) {
          // Round to nearest even
          // Can overflow into exponent bit, which surprisingly is OK.
          // This increment relies on the +outM in the return statement below
          outM++;
        }
      }
    }
    // The outM is added here as the +1 increments for outM above can
    // cause an overflow in the exponent bit which is OK.
    return (short) ((s << FP16_SIGN_SHIFT) | ((outE << FP16_EXPONENT_SHIFT) + outM));
  }
  // END FP16 copied code.
}