/*
* Copyright (C) 2016 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package androidx.media3.datasource.okhttp;
import static androidx.media3.common.util.Util.castNonNull;
import static androidx.media3.datasource.HttpUtil.buildRangeRequestHeader;
import static java.lang.Math.min;
import android.net.Uri;
import androidx.annotation.Nullable;
import androidx.media3.common.C;
import androidx.media3.common.MediaLibraryInfo;
import androidx.media3.common.PlaybackException;
import androidx.media3.common.util.Assertions;
import androidx.media3.common.util.UnstableApi;
import androidx.media3.common.util.Util;
import androidx.media3.datasource.BaseDataSource;
import androidx.media3.datasource.DataSource;
import androidx.media3.datasource.DataSourceException;
import androidx.media3.datasource.DataSpec;
import androidx.media3.datasource.HttpDataSource;
import androidx.media3.datasource.HttpUtil;
import androidx.media3.datasource.TransferListener;
import com.google.common.base.Predicate;
import com.google.common.net.HttpHeaders;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import okhttp3.CacheControl;
import okhttp3.Call;
import okhttp3.HttpUrl;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.ResponseBody;
/**
* An {@link HttpDataSource} that delegates to Square's {@link Call.Factory}.
*
* <p>Note: HTTP request headers will be set using all parameters passed via (in order of decreasing
* priority) the {@code dataSpec}, {@link #setRequestProperty} and the default parameters used to
* construct the instance.
*/
@UnstableApi
public class OkHttpDataSource extends BaseDataSource implements HttpDataSource {
static {
MediaLibraryInfo.registerModule("media3.datasource.okhttp");
}
/** {@link DataSource.Factory} for {@link OkHttpDataSource} instances. */
public static final class Factory implements HttpDataSource.Factory {
private final RequestProperties defaultRequestProperties;
private final Call.Factory callFactory;
@Nullable private String userAgent;
@Nullable private TransferListener transferListener;
@Nullable private CacheControl cacheControl;
@Nullable private Predicate<String> contentTypePredicate;
/**
* Creates an instance.
*
* @param callFactory A {@link Call.Factory} (typically an {@link OkHttpClient}) for use by the
* sources created by the factory.
*/
public Factory(Call.Factory callFactory) {
this.callFactory = callFactory;
defaultRequestProperties = new RequestProperties();
}
@Override
public final Factory setDefaultRequestProperties(Map<String, String> defaultRequestProperties) {
this.defaultRequestProperties.clearAndSet(defaultRequestProperties);
return this;
}
/**
* Sets the user agent that will be used.
*
* <p>The default is {@code null}, which causes the default user agent of the underlying {@link
* OkHttpClient} to be used.
*
* @param userAgent The user agent that will be used, or {@code null} to use the default user
* agent of the underlying {@link OkHttpClient}.
* @return This factory.
*/
public Factory setUserAgent(@Nullable String userAgent) {
this.userAgent = userAgent;
return this;
}
/**
* Sets the {@link CacheControl} that will be used.
*
* <p>The default is {@code null}.
*
* @param cacheControl The cache control that will be used.
* @return This factory.
*/
public Factory setCacheControl(@Nullable CacheControl cacheControl) {
this.cacheControl = cacheControl;
return this;
}
/**
* Sets a content type {@link Predicate}. If a content type is rejected by the predicate then a
* {@link HttpDataSource.InvalidContentTypeException} is thrown from {@link
* OkHttpDataSource#open(DataSpec)}.
*
* <p>The default is {@code null}.
*
* @param contentTypePredicate The content type {@link Predicate}, or {@code null} to clear a
* predicate that was previously set.
* @return This factory.
*/
public Factory setContentTypePredicate(@Nullable Predicate<String> contentTypePredicate) {
this.contentTypePredicate = contentTypePredicate;
return this;
}
/**
* Sets the {@link TransferListener} that will be used.
*
* <p>The default is {@code null}.
*
* <p>See {@link DataSource#addTransferListener(TransferListener)}.
*
* @param transferListener The listener that will be used.
* @return This factory.
*/
public Factory setTransferListener(@Nullable TransferListener transferListener) {
this.transferListener = transferListener;
return this;
}
@Override
public OkHttpDataSource createDataSource() {
OkHttpDataSource dataSource =
new OkHttpDataSource(
callFactory, userAgent, cacheControl, defaultRequestProperties, contentTypePredicate);
if (transferListener != null) {
dataSource.addTransferListener(transferListener);
}
return dataSource;
}
}
private final Call.Factory callFactory;
private final RequestProperties requestProperties;
@Nullable private final String userAgent;
@Nullable private final CacheControl cacheControl;
@Nullable private final RequestProperties defaultRequestProperties;
@Nullable private Predicate<String> contentTypePredicate;
@Nullable private DataSpec dataSpec;
@Nullable private Response response;
@Nullable private InputStream responseByteStream;
private boolean opened;
private long bytesToRead;
private long bytesRead;
/** @deprecated Use {@link OkHttpDataSource.Factory} instead. */
@SuppressWarnings("deprecation")
@Deprecated
public OkHttpDataSource(Call.Factory callFactory) {
this(callFactory, /* userAgent= */ null);
}
/** @deprecated Use {@link OkHttpDataSource.Factory} instead. */
@SuppressWarnings("deprecation")
@Deprecated
public OkHttpDataSource(Call.Factory callFactory, @Nullable String userAgent) {
this(callFactory, userAgent, /* cacheControl= */ null, /* defaultRequestProperties= */ null);
}
/** @deprecated Use {@link OkHttpDataSource.Factory} instead. */
@Deprecated
public OkHttpDataSource(
Call.Factory callFactory,
@Nullable String userAgent,
@Nullable CacheControl cacheControl,
@Nullable RequestProperties defaultRequestProperties) {
this(
callFactory,
userAgent,
cacheControl,
defaultRequestProperties,
/* contentTypePredicate= */ null);
}
private OkHttpDataSource(
Call.Factory callFactory,
@Nullable String userAgent,
@Nullable CacheControl cacheControl,
@Nullable RequestProperties defaultRequestProperties,
@Nullable Predicate<String> contentTypePredicate) {
super(/* isNetwork= */ true);
this.callFactory = Assertions.checkNotNull(callFactory);
this.userAgent = userAgent;
this.cacheControl = cacheControl;
this.defaultRequestProperties = defaultRequestProperties;
this.contentTypePredicate = contentTypePredicate;
this.requestProperties = new RequestProperties();
}
/**
* @deprecated Use {@link OkHttpDataSource.Factory#setContentTypePredicate(Predicate)} instead.
*/
@Deprecated
public void setContentTypePredicate(@Nullable Predicate<String> contentTypePredicate) {
this.contentTypePredicate = contentTypePredicate;
}
@Override
@Nullable
public Uri getUri() {
return response == null ? null : Uri.parse(response.request().url().toString());
}
@Override
public int getResponseCode() {
return response == null ? -1 : response.code();
}
@Override
public Map<String, List<String>> getResponseHeaders() {
return response == null ? Collections.emptyMap() : response.headers().toMultimap();
}
@Override
public void setRequestProperty(String name, String value) {
Assertions.checkNotNull(name);
Assertions.checkNotNull(value);
requestProperties.set(name, value);
}
@Override
public void clearRequestProperty(String name) {
Assertions.checkNotNull(name);
requestProperties.remove(name);
}
@Override
public void clearAllRequestProperties() {
requestProperties.clear();
}
@Override
public long open(DataSpec dataSpec) throws HttpDataSourceException {
this.dataSpec = dataSpec;
bytesRead = 0;
bytesToRead = 0;
transferInitializing(dataSpec);
Request request = makeRequest(dataSpec);
Response response;
ResponseBody responseBody;
try {
this.response = callFactory.newCall(request).execute();
response = this.response;
responseBody = Assertions.checkNotNull(response.body());
responseByteStream = responseBody.byteStream();
} catch (IOException e) {
throw HttpDataSourceException.createForIOException(
e, dataSpec, HttpDataSourceException.TYPE_OPEN);
}
int responseCode = response.code();
// Check for a valid response code.
if (!response.isSuccessful()) {
if (responseCode == 416) {
long documentSize =
HttpUtil.getDocumentSize(response.headers().get(HttpHeaders.CONTENT_RANGE));
if (dataSpec.position == documentSize) {
opened = true;
transferStarted(dataSpec);
return dataSpec.length != C.LENGTH_UNSET ? dataSpec.length : 0;
}
}
byte[] errorResponseBody;
try {
errorResponseBody = Util.toByteArray(Assertions.checkNotNull(responseByteStream));
} catch (IOException e) {
errorResponseBody = Util.EMPTY_BYTE_ARRAY;
}
Map<String, List<String>> headers = response.headers().toMultimap();
closeConnectionQuietly();
@Nullable
IOException cause =
responseCode == 416
? new DataSourceException(PlaybackException.ERROR_CODE_IO_READ_POSITION_OUT_OF_RANGE)
: null;
throw new InvalidResponseCodeException(
responseCode, response.message(), cause, headers, dataSpec, errorResponseBody);
}
// Check for a valid content type.
@Nullable MediaType mediaType = responseBody.contentType();
String contentType = mediaType != null ? mediaType.toString() : "";
if (contentTypePredicate != null && !contentTypePredicate.apply(contentType)) {
closeConnectionQuietly();
throw new InvalidContentTypeException(contentType, dataSpec);
}
// If we requested a range starting from a non-zero position and received a 200 rather than a
// 206, then the server does not support partial requests. We'll need to manually skip to the
// requested position.
long bytesToSkip = responseCode == 200 && dataSpec.position != 0 ? dataSpec.position : 0;
// Determine the length of the data to be read, after skipping.
if (dataSpec.length != C.LENGTH_UNSET) {
bytesToRead = dataSpec.length;
} else {
long contentLength = responseBody.contentLength();
bytesToRead = contentLength != -1 ? (contentLength - bytesToSkip) : C.LENGTH_UNSET;
}
opened = true;
transferStarted(dataSpec);
try {
skipFully(bytesToSkip, dataSpec);
} catch (HttpDataSourceException e) {
closeConnectionQuietly();
throw e;
}
return bytesToRead;
}
@Override
public int read(byte[] buffer, int offset, int length) throws HttpDataSourceException {
try {
return readInternal(buffer, offset, length);
} catch (IOException e) {
throw HttpDataSourceException.createForIOException(
e, castNonNull(dataSpec), HttpDataSourceException.TYPE_READ);
}
}
@Override
public void close() {
if (opened) {
opened = false;
transferEnded();
closeConnectionQuietly();
}
}
/** Establishes a connection. */
private Request makeRequest(DataSpec dataSpec) throws HttpDataSourceException {
long position = dataSpec.position;
long length = dataSpec.length;
@Nullable HttpUrl url = HttpUrl.parse(dataSpec.uri.toString());
if (url == null) {
throw new HttpDataSourceException(
"Malformed URL",
dataSpec,
PlaybackException.ERROR_CODE_FAILED_RUNTIME_CHECK,
HttpDataSourceException.TYPE_OPEN);
}
Request.Builder builder = new Request.Builder().url(url);
if (cacheControl != null) {
builder.cacheControl(cacheControl);
}
Map<String, String> headers = new HashMap<>();
if (defaultRequestProperties != null) {
headers.putAll(defaultRequestProperties.getSnapshot());
}
headers.putAll(requestProperties.getSnapshot());
headers.putAll(dataSpec.httpRequestHeaders);
for (Map.Entry<String, String> header : headers.entrySet()) {
builder.header(header.getKey(), header.getValue());
}
@Nullable String rangeHeader = buildRangeRequestHeader(position, length);
if (rangeHeader != null) {
builder.addHeader(HttpHeaders.RANGE, rangeHeader);
}
if (userAgent != null) {
builder.addHeader(HttpHeaders.USER_AGENT, userAgent);
}
if (!dataSpec.isFlagSet(DataSpec.FLAG_ALLOW_GZIP)) {
builder.addHeader(HttpHeaders.ACCEPT_ENCODING, "identity");
}
@Nullable RequestBody requestBody = null;
if (dataSpec.httpBody != null) {
requestBody = RequestBody.create(null, dataSpec.httpBody);
} else if (dataSpec.httpMethod == DataSpec.HTTP_METHOD_POST) {
// OkHttp requires a non-null body for POST requests.
requestBody = RequestBody.create(null, Util.EMPTY_BYTE_ARRAY);
}
builder.method(dataSpec.getHttpMethodString(), requestBody);
return builder.build();
}
/**
* Attempts to skip the specified number of bytes in full.
*
* @param bytesToSkip The number of bytes to skip.
* @param dataSpec The {@link DataSpec}.
* @throws HttpDataSourceException If the thread is interrupted during the operation, or an error
* occurs while reading from the source, or if the data ended before skipping the specified
* number of bytes.
*/
private void skipFully(long bytesToSkip, DataSpec dataSpec) throws HttpDataSourceException {
if (bytesToSkip == 0) {
return;
}
byte[] skipBuffer = new byte[4096];
try {
while (bytesToSkip > 0) {
int readLength = (int) min(bytesToSkip, skipBuffer.length);
int read = castNonNull(responseByteStream).read(skipBuffer, 0, readLength);
if (Thread.currentThread().isInterrupted()) {
throw new InterruptedIOException();
}
if (read == -1) {
throw new HttpDataSourceException(
dataSpec,
PlaybackException.ERROR_CODE_IO_READ_POSITION_OUT_OF_RANGE,
HttpDataSourceException.TYPE_OPEN);
}
bytesToSkip -= read;
bytesTransferred(read);
}
return;
} catch (IOException e) {
if (e instanceof HttpDataSourceException) {
throw (HttpDataSourceException) e;
} else {
throw new HttpDataSourceException(
dataSpec,
PlaybackException.ERROR_CODE_IO_UNSPECIFIED,
HttpDataSourceException.TYPE_OPEN);
}
}
}
/**
* Reads up to {@code length} bytes of data and stores them into {@code buffer}, starting at index
* {@code offset}.
*
* <p>This method blocks until at least one byte of data can be read, the end of the opened range
* is detected, or an exception is thrown.
*
* @param buffer The buffer into which the read data should be stored.
* @param offset The start offset into {@code buffer} at which data should be written.
* @param readLength The maximum number of bytes to read.
* @return The number of bytes read, or {@link C#RESULT_END_OF_INPUT} if the end of the opened
* range is reached.
* @throws IOException If an error occurs reading from the source.
*/
private int readInternal(byte[] buffer, int offset, int readLength) throws IOException {
if (readLength == 0) {
return 0;
}
if (bytesToRead != C.LENGTH_UNSET) {
long bytesRemaining = bytesToRead - bytesRead;
if (bytesRemaining == 0) {
return C.RESULT_END_OF_INPUT;
}
readLength = (int) min(readLength, bytesRemaining);
}
int read = castNonNull(responseByteStream).read(buffer, offset, readLength);
if (read == -1) {
return C.RESULT_END_OF_INPUT;
}
bytesRead += read;
bytesTransferred(read);
return read;
}
/** Closes the current connection quietly, if there is one. */
private void closeConnectionQuietly() {
if (response != null) {
Assertions.checkNotNull(response.body()).close();
response = null;
}
responseByteStream = null;
}
}