diff --git a/src/lib.rs b/src/lib.rs index 6aad899..8735016 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ use pyo3::exceptions::PyException; use pyo3::prelude::*; use rand::{thread_rng, Rng}; -use reqwest::header::{CONTENT_RANGE, RANGE}; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_RANGE, RANGE}; +use std::collections::HashMap; use std::io::SeekFrom; use std::sync::Arc; @@ -15,7 +16,7 @@ use tokio::time::sleep; /// parallel_failures: Number of maximum failures of different chunks in parallel (cannot exceed max_files) /// max_retries: Number of maximum attempts per chunk. (Retries are exponentially backed off + jitter) #[pyfunction] -#[pyo3(signature = (url, filename, max_files, chunk_size, parallel_failures=0, max_retries=0))] +#[pyo3(signature = (url, filename, max_files, chunk_size, parallel_failures=0, max_retries=0, headers=None))] fn download( url: String, filename: String, @@ -23,6 +24,7 @@ fn download( chunk_size: usize, parallel_failures: usize, max_retries: usize, + headers: Option>, ) -> PyResult<()> { if parallel_failures > max_files { return Err(PyException::new_err( @@ -46,6 +48,7 @@ fn download( chunk_size, parallel_failures, max_retries, + headers, ) .await }) @@ -81,14 +84,31 @@ async fn download_async( chunk_size: usize, parallel_failures: usize, max_retries: usize, + input_headers: Option>, ) -> PyResult<()> { let client = reqwest::Client::new(); + + let mut headers = HeaderMap::new(); + if let Some(input_headers) = input_headers { + for (k, v) in input_headers { + let k: HeaderName = k + .try_into() + .map_err(|err| PyException::new_err(format!("Invalid header: {err:?}")))?; + let v: HeaderValue = v + .try_into() + .map_err(|err| PyException::new_err(format!("Invalid header value: {err:?}")))?; + headers.insert(k, v); + } + }; + let response = client .get(&url) - .header(RANGE, "bytes=0-0".to_string()) + .headers(headers.clone()) + .header(RANGE, "bytes=0-0") .send() .await .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; + let content_range = response .headers() .get(CONTENT_RANGE) @@ -116,6 +136,7 @@ async fn download_async( let url = url.clone(); let filename = filename.clone(); let client = client.clone(); + let headers = headers.clone(); let stop = std::cmp::min(start + chunk_size - 1, length); let permit = semaphore @@ -125,7 +146,7 @@ async fn download_async( .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; let parallel_failures_semaphore = parallel_failures_semaphore.clone(); handles.push(tokio::spawn(async move { - let mut chunk = download_chunk(&client, &url, &filename, start, stop).await; + let mut chunk = download_chunk(&client, &url, &filename, start, stop, headers.clone()).await; let mut i = 0; if parallel_failures > 0{ while let Err(dlerr) = chunk { @@ -138,7 +159,7 @@ async fn download_async( let wait_time = exponential_backoff(300, i, 10_000); sleep(tokio::time::Duration::from_millis(wait_time as u64)).await; - chunk = download_chunk(&client, &url, &filename, start, stop).await; + chunk = download_chunk(&client, &url, &filename, start, stop, headers.clone()).await; i += 1; if i > max_retries{ return Err(PyException::new_err(format!( @@ -167,6 +188,7 @@ async fn download_chunk( filename: &str, start: usize, stop: usize, + headers: HeaderMap, ) -> PyResult<()> { // Process each socket concurrently. let range = format!("bytes={start}-{stop}"); @@ -181,6 +203,7 @@ async fn download_chunk( .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; let response = client .get(url) + .headers(headers) .header(RANGE, range) .send() .await