ProjectionDecoder.java
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package androidx.media3.exoplayer.video.spherical;
import androidx.annotation.Nullable;
import androidx.media3.common.C;
import androidx.media3.common.util.ParsableBitArray;
import androidx.media3.common.util.ParsableByteArray;
import androidx.media3.common.util.Util;
import androidx.media3.exoplayer.video.spherical.Projection.Mesh;
import androidx.media3.exoplayer.video.spherical.Projection.SubMesh;
import java.util.ArrayList;
import java.util.zip.Inflater;
/**
* A decoder for the projection mesh.
*
* <p>The mesh boxes parsed are described at <a
* href="https://github.com/google/spatial-media/blob/master/docs/spherical-video-v2-rfc.md">
* Spherical Video V2 RFC</a>.
*
* <p>The decoder does not perform CRC checks at the moment.
*/
/* package */ final class ProjectionDecoder {
private static final int TYPE_YTMP = 0x79746d70;
private static final int TYPE_MSHP = 0x6d736870;
private static final int TYPE_RAW = 0x72617720;
private static final int TYPE_DFL8 = 0x64666c38;
private static final int TYPE_MESH = 0x6d657368;
private static final int TYPE_PROJ = 0x70726f6a;
// Limits to prevent a bad file from creating an OOM situation. We don't expect a mesh to
// exceed these limits.
private static final int MAX_COORDINATE_COUNT = 10_000;
private static final int MAX_VERTEX_COUNT = 32 * 1000;
private static final int MAX_TRIANGLE_INDICES = 128 * 1000;
private ProjectionDecoder() {}
/*
* Decodes the projection data.
*
* @param projectionData The projection data.
* @param stereoMode A {@link C.StereoMode} value.
* @return The projection or null if the data can't be decoded.
*/
@Nullable
public static Projection decode(byte[] projectionData, @C.StereoMode int stereoMode) {
ParsableByteArray input = new ParsableByteArray(projectionData);
// MP4 containers include the proj box but webm containers do not.
// Both containers use mshp.
ArrayList<Mesh> meshes = null;
try {
meshes = isProj(input) ? parseProj(input) : parseMshp(input);
} catch (ArrayIndexOutOfBoundsException ignored) {
// Do nothing.
}
if (meshes == null) {
return null;
} else {
switch (meshes.size()) {
case 1:
return new Projection(meshes.get(0), stereoMode);
case 2:
return new Projection(meshes.get(0), meshes.get(1), stereoMode);
case 0:
default:
return null;
}
}
}
/** Returns true if the input contains a proj box. Indicates MP4 container. */
private static boolean isProj(ParsableByteArray input) {
input.skipBytes(4); // size
int type = input.readInt();
input.setPosition(0);
return type == TYPE_PROJ;
}
@Nullable
private static ArrayList<Mesh> parseProj(ParsableByteArray input) {
input.skipBytes(8); // size and type.
int position = input.getPosition();
int limit = input.limit();
while (position < limit) {
int childEnd = position + input.readInt();
if (childEnd <= position || childEnd > limit) {
return null;
}
int childAtomType = input.readInt();
// Some early files named the atom ytmp rather than mshp.
if (childAtomType == TYPE_YTMP || childAtomType == TYPE_MSHP) {
input.setLimit(childEnd);
return parseMshp(input);
}
position = childEnd;
input.setPosition(position);
}
return null;
}
@Nullable
private static ArrayList<Mesh> parseMshp(ParsableByteArray input) {
int version = input.readUnsignedByte();
if (version != 0) {
return null;
}
input.skipBytes(7); // flags + crc.
int encoding = input.readInt();
if (encoding == TYPE_DFL8) {
ParsableByteArray output = new ParsableByteArray();
Inflater inflater = new Inflater(true);
try {
if (!Util.inflate(input, output, inflater)) {
return null;
}
} finally {
inflater.end();
}
input = output;
} else if (encoding != TYPE_RAW) {
return null;
}
return parseRawMshpData(input);
}
/** Parses MSHP data after the encoding_four_cc field. */
@Nullable
private static ArrayList<Mesh> parseRawMshpData(ParsableByteArray input) {
ArrayList<Mesh> meshes = new ArrayList<>();
int position = input.getPosition();
int limit = input.limit();
while (position < limit) {
int childEnd = position + input.readInt();
if (childEnd <= position || childEnd > limit) {
return null;
}
int childAtomType = input.readInt();
if (childAtomType == TYPE_MESH) {
Mesh mesh = parseMesh(input);
if (mesh == null) {
return null;
}
meshes.add(mesh);
}
position = childEnd;
input.setPosition(position);
}
return meshes;
}
@Nullable
private static Mesh parseMesh(ParsableByteArray input) {
// Read the coordinates.
int coordinateCount = input.readInt();
if (coordinateCount > MAX_COORDINATE_COUNT) {
return null;
}
float[] coordinates = new float[coordinateCount];
for (int coordinate = 0; coordinate < coordinateCount; coordinate++) {
coordinates[coordinate] = input.readFloat();
}
// Read the vertices.
int vertexCount = input.readInt();
if (vertexCount > MAX_VERTEX_COUNT) {
return null;
}
final double log2 = Math.log(2.0);
int coordinateCountSizeBits = (int) Math.ceil(Math.log(2.0 * coordinateCount) / log2);
ParsableBitArray bitInput = new ParsableBitArray(input.getData());
bitInput.setPosition(input.getPosition() * 8);
float[] vertices = new float[vertexCount * 5];
int[] coordinateIndices = new int[5];
int vertexIndex = 0;
for (int vertex = 0; vertex < vertexCount; vertex++) {
for (int i = 0; i < 5; i++) {
int coordinateIndex =
coordinateIndices[i] + decodeZigZag(bitInput.readBits(coordinateCountSizeBits));
if (coordinateIndex >= coordinateCount || coordinateIndex < 0) {
return null;
}
vertices[vertexIndex++] = coordinates[coordinateIndex];
coordinateIndices[i] = coordinateIndex;
}
}
// Pad to next byte boundary
bitInput.setPosition(((bitInput.getPosition() + 7) & ~7));
int subMeshCount = bitInput.readBits(32);
SubMesh[] subMeshes = new SubMesh[subMeshCount];
for (int i = 0; i < subMeshCount; i++) {
int textureId = bitInput.readBits(8);
int drawMode = bitInput.readBits(8);
int triangleIndexCount = bitInput.readBits(32);
if (triangleIndexCount > MAX_TRIANGLE_INDICES) {
return null;
}
int vertexCountSizeBits = (int) Math.ceil(Math.log(2.0 * vertexCount) / log2);
int index = 0;
float[] triangleVertices = new float[triangleIndexCount * 3];
float[] textureCoords = new float[triangleIndexCount * 2];
for (int counter = 0; counter < triangleIndexCount; counter++) {
index += decodeZigZag(bitInput.readBits(vertexCountSizeBits));
if (index < 0 || index >= vertexCount) {
return null;
}
triangleVertices[counter * 3] = vertices[index * 5];
triangleVertices[counter * 3 + 1] = vertices[index * 5 + 1];
triangleVertices[counter * 3 + 2] = vertices[index * 5 + 2];
textureCoords[counter * 2] = vertices[index * 5 + 3];
textureCoords[counter * 2 + 1] = vertices[index * 5 + 4];
}
subMeshes[i] = new SubMesh(textureId, triangleVertices, textureCoords, drawMode);
}
return new Mesh(subMeshes);
}
/**
* Decodes Zigzag encoding as described in
* https://developers.google.com/protocol-buffers/docs/encoding#signed-integers
*/
private static int decodeZigZag(int n) {
return (n >> 1) ^ -(n & 1);
}
}