diff --git a/Cargo.lock b/Cargo.lock index 2f902de..627fc5b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -197,6 +197,17 @@ dependencies = [ "slab", ] +[[package]] +name = "getrandom" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "h2" version = "0.3.15" @@ -238,6 +249,7 @@ dependencies = [ "futures", "openssl", "pyo3", + "rand", "reqwest", "tokio", ] @@ -408,9 +420,9 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "memoffset" -version = "0.6.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" +checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" dependencies = [ "autocfg", ] @@ -569,6 +581,12 @@ version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro2" version = "1.0.47" @@ -580,9 +598,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.17.3" +version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "268be0c73583c183f2b14052337465768c07726936a260f480f0857cb95ba543" +checksum = "06a3d8e8a46ab2738109347433cb7b96dffda2e4a218b03ef27090238886b147" dependencies = [ "cfg-if", "indoc", @@ -597,9 +615,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.17.3" +version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28fcd1e73f06ec85bf3280c48c67e731d8290ad3d730f8be9dc07946923005c8" +checksum = "75439f995d07ddfad42b192dfcf3bc66a7ecfd8b4a1f5f6f046aa5c2c5d7677d" dependencies = [ "once_cell", "target-lexicon", @@ -607,9 +625,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.17.3" +version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f6cb136e222e49115b3c51c32792886defbfb0adead26a688142b346a0b9ffc" +checksum = "839526a5c07a17ff44823679b68add4a58004de00512a95b6c1c98a6dcac0ee5" dependencies = [ "libc", "pyo3-build-config", @@ -617,9 +635,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.17.3" +version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94144a1266e236b1c932682136dc35a9dee8d3589728f68130c7c3861ef96b28" +checksum = "bd44cf207476c6a9760c4653559be4f206efafb924d3e4cbf2721475fc0d6cc5" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -629,9 +647,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.17.3" +version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8df9be978a2d2f0cdebabb03206ed73b11314701a5bfe71b0d753b81997777f" +checksum = "dc1f43d8e30460f36350d18631ccf85ded64c059829208fe680904c65bcd0a4c" dependencies = [ "proc-macro2", "quote", @@ -647,6 +665,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "redox_syscall" version = "0.2.16" diff --git a/Cargo.toml b/Cargo.toml index ac6a9f2..dcec38d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ crate-type = ["cdylib"] [dependencies] futures = "0.3.25" openssl = { version = "0.10.44", features = ["vendored"] } -pyo3 = { version = "0.17.3", features = ["extension-module"] } +pyo3 = { version = "0.18.1", features = ["extension-module"] } +rand = "0.8.5" reqwest = "0.11.13" tokio = { version = "1.23.0", features = ["rt", "rt-multi-thread", "fs"] } diff --git a/src/lib.rs b/src/lib.rs index cbda66f..6aad899 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,21 +1,54 @@ use pyo3::exceptions::PyException; use pyo3::prelude::*; +use rand::{thread_rng, Rng}; use reqwest::header::{CONTENT_RANGE, RANGE}; use std::io::SeekFrom; use std::sync::Arc; +use std::fs::remove_file; +use std::path::Path; use tokio::io::AsyncSeekExt; use tokio::io::AsyncWriteExt; use tokio::sync::Semaphore; -use std::path::Path; -use std::fs::remove_file; +use tokio::time::sleep; +/// parallel_failures: Number of maximum failures of different chunks in parallel (cannot exceed max_files) +/// max_retries: Number of maximum attempts per chunk. (Retries are exponentially backed off + jitter) #[pyfunction] -fn download(url: String, filename: String, max_files: usize, chunk_size: usize) -> PyResult<()> { +#[pyo3(signature = (url, filename, max_files, chunk_size, parallel_failures=0, max_retries=0))] +fn download( + url: String, + filename: String, + max_files: usize, + chunk_size: usize, + parallel_failures: usize, + max_retries: usize, +) -> PyResult<()> { + if parallel_failures > max_files { + return Err(PyException::new_err( + "Error parallel_failures cannot be > max_files".to_string(), + )); + } + if (parallel_failures == 0) != (max_retries == 0) { + return Err(PyException::new_err( + "For retry mechanism you need to set both `parallel_failures` and `max_retries`" + .to_string(), + )); + } tokio::runtime::Builder::new_multi_thread() .enable_all() .build()? - .block_on(async { download_async(url, filename.clone(), max_files, chunk_size).await }) + .block_on(async { + download_async( + url, + filename.clone(), + max_files, + chunk_size, + parallel_failures, + max_retries, + ) + .await + }) .map_err(|err| { let path = Path::new(&filename); if path.exists() { @@ -33,11 +66,21 @@ fn download(url: String, filename: String, max_files: usize, chunk_size: usize) }) } +fn jitter() -> usize { + thread_rng().gen_range(0..=500) +} + +pub fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { + (base_wait_time + n.pow(2) + jitter()).min(max) +} + async fn download_async( url: String, filename: String, max_files: usize, chunk_size: usize, + parallel_failures: usize, + max_retries: usize, ) -> PyResult<()> { let client = reqwest::Client::new(); let response = client @@ -66,6 +109,7 @@ async fn download_async( let mut handles = vec![]; let semaphore = Arc::new(Semaphore::new(max_files)); + let parallel_failures_semaphore = Arc::new(Semaphore::new(parallel_failures)); let chunk_size = chunk_size; for start in (0..length).step_by(chunk_size) { @@ -79,8 +123,31 @@ async fn download_async( .acquire_owned() .await .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; + let parallel_failures_semaphore = parallel_failures_semaphore.clone(); handles.push(tokio::spawn(async move { - let chunk = download_chunk(client, url, filename, start, stop).await; + let mut chunk = download_chunk(&client, &url, &filename, start, stop).await; + let mut i = 0; + if parallel_failures > 0{ + while let Err(dlerr) = chunk { + let parallel_failure_permit = parallel_failures_semaphore.clone().try_acquire_owned().map_err(|err| { + PyException::new_err(format!( + "Failed too many failures in parallel ({parallel_failures:?}): {dlerr:?} ({err:?})" + )) + })?; + + let wait_time = exponential_backoff(300, i, 10_000); + sleep(tokio::time::Duration::from_millis(wait_time as u64)).await; + + chunk = download_chunk(&client, &url, &filename, start, stop).await; + i += 1; + if i > max_retries{ + return Err(PyException::new_err(format!( + "Failed after too many retries ({max_retries:?}): {dlerr:?}" + ))); + } + drop(parallel_failure_permit); + } + } drop(permit); chunk })); @@ -95,9 +162,9 @@ async fn download_async( } async fn download_chunk( - client: reqwest::Client, - url: String, - filename: String, + client: &reqwest::Client, + url: &str, + filename: &str, start: usize, stop: usize, ) -> PyResult<()> { @@ -117,6 +184,8 @@ async fn download_chunk( .header(RANGE, range) .send() .await + .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))? + .error_for_status() .map_err(|err| PyException::new_err(format!("Error while downloading: {err:?}")))?; let content = response .bytes()