Skip to content

Commit

Permalink
Merge pull request #3 from huggingface/max_retry_seq
Browse files Browse the repository at this point in the history
Other retry implem.
  • Loading branch information
Narsil authored Mar 13, 2023
2 parents 2de163c + f11c1bc commit eea07e3
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 21 deletions.
72 changes: 60 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ crate-type = ["cdylib"]
[dependencies]
futures = "0.3.25"
openssl = { version = "0.10.44", features = ["vendored"] }
pyo3 = { version = "0.17.3", features = ["extension-module"] }
pyo3 = { version = "0.18.1", features = ["extension-module"] }
rand = "0.8.5"
reqwest = "0.11.13"
tokio = { version = "1.23.0", features = ["rt", "rt-multi-thread", "fs"] }
85 changes: 77 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,54 @@
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use rand::{thread_rng, Rng};
use reqwest::header::{CONTENT_RANGE, RANGE};
use std::io::SeekFrom;
use std::sync::Arc;

use std::fs::remove_file;
use std::path::Path;
use tokio::io::AsyncSeekExt;
use tokio::io::AsyncWriteExt;
use tokio::sync::Semaphore;
use std::path::Path;
use std::fs::remove_file;
use tokio::time::sleep;

/// parallel_failures: Number of maximum failures of different chunks in parallel (cannot exceed max_files)
/// max_retries: Number of maximum attempts per chunk. (Retries are exponentially backed off + jitter)
#[pyfunction]
fn download(url: String, filename: String, max_files: usize, chunk_size: usize) -> PyResult<()> {
#[pyo3(signature = (url, filename, max_files, chunk_size, parallel_failures=0, max_retries=0))]
fn download(
url: String,
filename: String,
max_files: usize,
chunk_size: usize,
parallel_failures: usize,
max_retries: usize,
) -> PyResult<()> {
if parallel_failures > max_files {
return Err(PyException::new_err(
"Error parallel_failures cannot be > max_files".to_string(),
));
}
if (parallel_failures == 0) != (max_retries == 0) {
return Err(PyException::new_err(
"For retry mechanism you need to set both `parallel_failures` and `max_retries`"
.to_string(),
));
}
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(async { download_async(url, filename.clone(), max_files, chunk_size).await })
.block_on(async {
download_async(
url,
filename.clone(),
max_files,
chunk_size,
parallel_failures,
max_retries,
)
.await
})
.map_err(|err| {
let path = Path::new(&filename);
if path.exists() {
Expand All @@ -33,11 +66,21 @@ fn download(url: String, filename: String, max_files: usize, chunk_size: usize)
})
}

fn jitter() -> usize {
thread_rng().gen_range(0..=500)
}

pub fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize {
(base_wait_time + n.pow(2) + jitter()).min(max)
}

async fn download_async(
url: String,
filename: String,
max_files: usize,
chunk_size: usize,
parallel_failures: usize,
max_retries: usize,
) -> PyResult<()> {
let client = reqwest::Client::new();
let response = client
Expand Down Expand Up @@ -66,6 +109,7 @@ async fn download_async(

let mut handles = vec![];
let semaphore = Arc::new(Semaphore::new(max_files));
let parallel_failures_semaphore = Arc::new(Semaphore::new(parallel_failures));

let chunk_size = chunk_size;
for start in (0..length).step_by(chunk_size) {
Expand All @@ -79,8 +123,31 @@ async fn download_async(
.acquire_owned()
.await
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;
let parallel_failures_semaphore = parallel_failures_semaphore.clone();
handles.push(tokio::spawn(async move {
let chunk = download_chunk(client, url, filename, start, stop).await;
let mut chunk = download_chunk(&client, &url, &filename, start, stop).await;
let mut i = 0;
if parallel_failures > 0{
while let Err(dlerr) = chunk {
let parallel_failure_permit = parallel_failures_semaphore.clone().try_acquire_owned().map_err(|err| {
PyException::new_err(format!(
"Failed too many failures in parallel ({parallel_failures:?}): {dlerr:?} ({err:?})"
))
})?;

let wait_time = exponential_backoff(300, i, 10_000);
sleep(tokio::time::Duration::from_millis(wait_time as u64)).await;

chunk = download_chunk(&client, &url, &filename, start, stop).await;
i += 1;
if i > max_retries{
return Err(PyException::new_err(format!(
"Failed after too many retries ({max_retries:?}): {dlerr:?}"
)));
}
drop(parallel_failure_permit);
}
}
drop(permit);
chunk
}));
Expand All @@ -95,9 +162,9 @@ async fn download_async(
}

async fn download_chunk(
client: reqwest::Client,
url: String,
filename: String,
client: &reqwest::Client,
url: &str,
filename: &str,
start: usize,
stop: usize,
) -> PyResult<()> {
Expand All @@ -117,6 +184,8 @@ async fn download_chunk(
.header(RANGE, range)
.send()
.await
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?
.error_for_status()
.map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?;
let content = response
.bytes()
Expand Down

0 comments on commit eea07e3

Please sign in to comment.