Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing headers for private repos. #5

Merged
merged 1 commit into from
Mar 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use rand::{thread_rng, Rng};
use reqwest::header::{CONTENT_RANGE, RANGE};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_RANGE, RANGE};
use std::collections::HashMap;
use std::io::SeekFrom;
use std::sync::Arc;

Expand All @@ -15,14 +16,15 @@ use tokio::time::sleep;
/// parallel_failures: Number of maximum failures of different chunks in parallel (cannot exceed max_files)
/// max_retries: Number of maximum attempts per chunk. (Retries are exponentially backed off + jitter)
#[pyfunction]
#[pyo3(signature = (url, filename, max_files, chunk_size, parallel_failures=0, max_retries=0))]
#[pyo3(signature = (url, filename, max_files, chunk_size, parallel_failures=0, max_retries=0, headers=None))]
fn download(
url: String,
filename: String,
max_files: usize,
chunk_size: usize,
parallel_failures: usize,
max_retries: usize,
headers: Option<HashMap<String, String>>,
) -> PyResult<()> {
if parallel_failures > max_files {
return Err(PyException::new_err(
Expand All @@ -46,6 +48,7 @@ fn download(
chunk_size,
parallel_failures,
max_retries,
headers,
)
.await
})
Expand Down Expand Up @@ -81,14 +84,31 @@ async fn download_async(
chunk_size: usize,
parallel_failures: usize,
max_retries: usize,
input_headers: Option<HashMap<String, String>>,
) -> PyResult<()> {
let client = reqwest::Client::new();

let mut headers = HeaderMap::new();
if let Some(input_headers) = input_headers {
for (k, v) in input_headers {
let k: HeaderName = k
.try_into()
.map_err(|err| PyException::new_err(format!("Invalid header: {err:?}")))?;
let v: HeaderValue = v
.try_into()
.map_err(|err| PyException::new_err(format!("Invalid header value: {err:?}")))?;
headers.insert(k, v);
}
};

let response = client
.get(&url)
.header(RANGE, "bytes=0-0".to_string())
.headers(headers.clone())
.header(RANGE, "bytes=0-0")
.send()
.await
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;

let content_range = response
.headers()
.get(CONTENT_RANGE)
Expand Down Expand Up @@ -116,6 +136,7 @@ async fn download_async(
let url = url.clone();
let filename = filename.clone();
let client = client.clone();
let headers = headers.clone();

let stop = std::cmp::min(start + chunk_size - 1, length);
let permit = semaphore
Expand All @@ -125,7 +146,7 @@ async fn download_async(
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;
let parallel_failures_semaphore = parallel_failures_semaphore.clone();
handles.push(tokio::spawn(async move {
let mut chunk = download_chunk(&client, &url, &filename, start, stop).await;
let mut chunk = download_chunk(&client, &url, &filename, start, stop, headers.clone()).await;
let mut i = 0;
if parallel_failures > 0{
while let Err(dlerr) = chunk {
Expand All @@ -138,7 +159,7 @@ async fn download_async(
let wait_time = exponential_backoff(300, i, 10_000);
sleep(tokio::time::Duration::from_millis(wait_time as u64)).await;

chunk = download_chunk(&client, &url, &filename, start, stop).await;
chunk = download_chunk(&client, &url, &filename, start, stop, headers.clone()).await;
i += 1;
if i > max_retries{
return Err(PyException::new_err(format!(
Expand Down Expand Up @@ -167,6 +188,7 @@ async fn download_chunk(
filename: &str,
start: usize,
stop: usize,
headers: HeaderMap,
) -> PyResult<()> {
// Process each socket concurrently.
let range = format!("bytes={start}-{stop}");
Expand All @@ -181,6 +203,7 @@ async fn download_chunk(
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;
let response = client
.get(url)
.headers(headers)
.header(RANGE, range)
.send()
.await
Expand Down