SsimHelper.java
/*
* Copyright 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package androidx.media3.test.utils;
import static androidx.media3.common.util.Assertions.checkNotNull;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import android.content.Context;
import android.media.Image;
import android.media.MediaCodec;
import androidx.annotation.Nullable;
import androidx.annotation.RequiresApi;
import androidx.media3.common.util.UnstableApi;
import java.io.IOException;
import java.nio.ByteBuffer;
/**
* A helper for calculating SSIM score for transcoded videos.
*
* <p>SSIM (Structural Similarity) Index is a statistical measurement of the similarity between two
* images. The mean SSIM score (taken between multiple frames) of two videos is a metric to
* determine the similarity of the videos. SSIM does not measure the absolute difference of the two
* images like MSE (mean squared error), but rather outputs the human perceptual difference. A
* higher SSIM score signifies higher similarity, while a SSIM score of 1 means the two images are
* exactly the same.
*
* <p>SSIM is traditionally computed with the luminance channel (Y), this class uses the luma
* channel (Y') because the {@linkplain MediaCodec decoder} decodes to luma.
*/
@UnstableApi
@RequiresApi(21)
public final class SsimHelper {
/** The default comparison interval. */
public static final int DEFAULT_COMPARISON_INTERVAL = 11;
private static final int DECODED_IMAGE_CHANNEL_COUNT = 3;
private static final int MAX_IMAGE_READER_IMAGES_ALLOWED = 1;
/**
* Returns the mean SSIM score between the reference and the distorted video.
*
* <p>The method compares every {@link #DEFAULT_COMPARISON_INTERVAL n-th} frame from both videos.
*
* @param context The {@link Context}.
* @param referenceVideoPath The path to the reference video file, which must be in {@linkplain
* Context#getAssets() Assets}.
* @param distortedVideoPath The path to the distorted video file.
* @throws IOException When unable to open the provided video paths.
*/
public static double calculate(
Context context, String referenceVideoPath, String distortedVideoPath)
throws IOException, InterruptedException {
VideoDecodingWrapper referenceDecodingWrapper =
new VideoDecodingWrapper(
context,
referenceVideoPath,
DEFAULT_COMPARISON_INTERVAL,
MAX_IMAGE_READER_IMAGES_ALLOWED);
VideoDecodingWrapper distortedDecodingWrapper =
new VideoDecodingWrapper(
context,
distortedVideoPath,
DEFAULT_COMPARISON_INTERVAL,
MAX_IMAGE_READER_IMAGES_ALLOWED);
@Nullable byte[] referenceLumaBuffer = null;
@Nullable byte[] distortedLumaBuffer = null;
double accumulatedSsim = 0.0;
int comparedImagesCount = 0;
try {
while (true) {
@Nullable Image referenceImage = referenceDecodingWrapper.runUntilComparisonFrameOrEnded();
@Nullable Image distortedImage = distortedDecodingWrapper.runUntilComparisonFrameOrEnded();
if (referenceImage == null) {
assertThat(distortedImage).isNull();
break;
}
checkNotNull(distortedImage);
int width = referenceImage.getWidth();
int height = referenceImage.getHeight();
assertThat(distortedImage.getWidth()).isEqualTo(width);
assertThat(distortedImage.getHeight()).isEqualTo(height);
if (referenceLumaBuffer == null || referenceLumaBuffer.length != width * height) {
referenceLumaBuffer = new byte[width * height];
}
if (distortedLumaBuffer == null || distortedLumaBuffer.length != width * height) {
distortedLumaBuffer = new byte[width * height];
}
try {
accumulatedSsim +=
MssimCalculator.calculate(
extractLumaChannelBuffer(referenceImage, referenceLumaBuffer),
extractLumaChannelBuffer(distortedImage, distortedLumaBuffer),
width,
height);
} finally {
referenceImage.close();
distortedImage.close();
}
comparedImagesCount++;
}
} finally {
referenceDecodingWrapper.close();
distortedDecodingWrapper.close();
}
assertWithMessage("Input had no frames.").that(comparedImagesCount).isGreaterThan(0);
return accumulatedSsim / comparedImagesCount;
}
/**
* Extracts, sets and returns the buffer of the luma (Y') channel of the image.
*
* @param image The {@link Image} in YUV format.
* @param lumaChannelBuffer The buffer where the extracted luma values are stored.
* @return The {@code lumaChannelBuffer} for convenience.
*/
private static byte[] extractLumaChannelBuffer(Image image, byte[] lumaChannelBuffer) {
// This method is invoked on the main thread.
// `image` should contain YUV channels.
Image.Plane[] imagePlanes = image.getPlanes();
assertThat(imagePlanes).hasLength(DECODED_IMAGE_CHANNEL_COUNT);
Image.Plane lumaPlane = imagePlanes[0];
int rowStride = lumaPlane.getRowStride();
int pixelStride = lumaPlane.getPixelStride();
int width = image.getWidth();
int height = image.getHeight();
ByteBuffer lumaByteBuffer = lumaPlane.getBuffer();
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
lumaChannelBuffer[y * width + x] = lumaByteBuffer.get(y * rowStride + x * pixelStride);
}
}
return lumaChannelBuffer;
}
private SsimHelper() {
// Prevent instantiation.
}
}