Skip to content

Commit

Permalink
feat: add upload (#7)
Browse files Browse the repository at this point in the history
* feat: add upload

* fix: argument passing

* fix: PUT to s3 requires Content-Length header

* docs: add contributing docs

* refacto: use `FromPyObject` trait for conversion

* docs: add quality dependency to hugginface_hub cli

* refacto: remove lfs / huggingface_hub specific code

* fix: clippy warnings

* fix: PR comments

* feat: defer completion call to caller

* feat: remove reqwest `json` feature

* feat: change API to `Vec` instead of custom obj
  • Loading branch information
McPatate authored Mar 22, 2023
1 parent 25893e6 commit 088bff9
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 14 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ futures = "0.3.25"
openssl = { version = "0.10.44", features = ["vendored"] }
pyo3 = { version = "0.18.1", features = ["extension-module"] }
rand = "0.8.5"
reqwest = "0.11.13"
reqwest = { version = "0.11" , features = ["stream"] }
tokio = { version = "1.23.0", features = ["rt", "rt-multi-thread", "fs"] }
tokio-util = { version = "0.7", features = ["codec"]}

42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# HF Transfer

Speed up file transfers with the Hub.

Supports download and upload.

## Contributing

```sh
python3 -m venv ~/.venv/hf_transfer
source ~/.venv/hf_transfer/bin/activate
pip install maturin
maturin develop
```

### `huggingface_hub`

If you are working on changes with `huggingface_hub`

```sh
git clone git@github.com:huggingface/huggingface_hub.git
# git clone https://github.com/huggingface/huggingface_hub.git

cd huggingface_hub
python3 -m pip install -e ".[quality]"
```

You can use the following test script:

```py
import os

# os.environ["HF_ENDPOINT"] = "http://localhost:5564"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

from huggingface_hub import HfApi, logging

logging.set_verbosity_debug()
hf = HfApi()
hf.upload_file(path_or_fileobj="/path/to/my/repo/some_file", path_in_repo="some_file", repo_id="my/repo", repo_type="model")
```

203 changes: 190 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use rand::{thread_rng, Rng};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_RANGE, RANGE};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_LENGTH, CONTENT_RANGE, RANGE};
use std::collections::HashMap;
use std::io::SeekFrom;
use std::sync::Arc;

use std::fs::remove_file;
use std::io::SeekFrom;
use std::path::Path;
use tokio::io::AsyncSeekExt;
use std::sync::Arc;
use tokio::fs::OpenOptions;
use tokio::io::AsyncWriteExt;
use tokio::io::{AsyncReadExt, AsyncSeekExt};
use tokio::sync::Semaphore;
use tokio::time::sleep;
use tokio_util::codec::{BytesCodec, FramedRead};

const BASE_WAIT_TIME: usize = 300;
const MAX_WAIT_TIME: usize = 10_000;

/// max_files: Number of open file handles, which determines the maximum number of parallel downloads
/// parallel_failures: Number of maximum failures of different chunks in parallel (cannot exceed max_files)
/// max_retries: Number of maximum attempts per chunk. (Retries are exponentially backed off + jitter)
/// number of threads can be tuned by environment variable `TOKIO_WORKER_THREADS` as documented in https://docs.rs/tokio/latest/tokio/runtime/struct.Builder.html#method.worker_threads
///
/// The number of threads can be tuned by the environment variable `TOKIO_WORKER_THREADS` as documented in
/// https://docs.rs/tokio/latest/tokio/runtime/struct.Builder.html#method.worker_threads
#[pyfunction]
#[pyo3(signature = (url, filename, max_files, chunk_size, parallel_failures=0, max_retries=0, headers=None))]
fn download(
Expand Down Expand Up @@ -70,6 +77,55 @@ fn download(
})
}

/// parts_urls: Dictionary consisting of part numbers as keys and the associated url as values
/// completion_url: The url that should be called when the upload is finished
/// max_files: Number of open file handles, which determines the maximum number of parallel uploads
/// parallel_failures: Number of maximum failures of different chunks in parallel (cannot exceed max_files)
/// max_retries: Number of maximum attempts per chunk. (Retries are exponentially backed off + jitter)
///
/// The number of threads can be tuned by the environment variable `TOKIO_WORKER_THREADS` as documented in
/// https://docs.rs/tokio/latest/tokio/runtime/struct.Builder.html#method.worker_threads
///
/// See https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html for more information
/// on the multipart upload
#[pyfunction]
#[pyo3(signature = (file_path, parts_urls, chunk_size, max_files, parallel_failures=0, max_retries=0))]
fn multipart_upload(
file_path: String,
parts_urls: Vec<String>,
chunk_size: u64,
max_files: usize,
parallel_failures: usize,
max_retries: usize,
) -> PyResult<Vec<HashMap<String, String>>> {
if parallel_failures > max_files {
return Err(PyException::new_err(
"Error parallel_failures cannot be > max_files".to_string(),
));
}
if (parallel_failures == 0) != (max_retries == 0) {
return Err(PyException::new_err(
"For retry mechanism you need to set both `parallel_failures` and `max_retries`"
.to_string(),
));
}

tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(async {
upload_async(
file_path,
parts_urls,
chunk_size,
max_files,
parallel_failures,
max_retries,
)
.await
})
}

fn jitter() -> usize {
thread_rng().gen_range(0..=500)
}
Expand Down Expand Up @@ -149,24 +205,24 @@ async fn download_async(
handles.push(tokio::spawn(async move {
let mut chunk = download_chunk(&client, &url, &filename, start, stop, headers.clone()).await;
let mut i = 0;
if parallel_failures > 0{
if parallel_failures > 0 {
while let Err(dlerr) = chunk {
if i >= max_retries {
return Err(PyException::new_err(format!(
"Failed after too many retries ({max_retries:?}): {dlerr:?}"
)));
}
let parallel_failure_permit = parallel_failures_semaphore.clone().try_acquire_owned().map_err(|err| {
PyException::new_err(format!(
"Failed too many failures in parallel ({parallel_failures:?}): {dlerr:?} ({err:?})"
))
})?;

let wait_time = exponential_backoff(300, i, 10_000);
let wait_time = exponential_backoff(BASE_WAIT_TIME, i, MAX_WAIT_TIME);
sleep(tokio::time::Duration::from_millis(wait_time as u64)).await;

chunk = download_chunk(&client, &url, &filename, start, stop, headers.clone()).await;
i += 1;
if i > max_retries{
return Err(PyException::new_err(format!(
"Failed after too many retries ({max_retries:?}): {dlerr:?}"
)));
}
drop(parallel_failure_permit);
}
}
Expand Down Expand Up @@ -221,9 +277,130 @@ async fn download_chunk(
Ok(())
}

async fn upload_async(
file_path: String,
parts_urls: Vec<String>,
chunk_size: u64,
max_files: usize,
parallel_failures: usize,
max_retries: usize,
) -> PyResult<Vec<HashMap<String, String>>> {
let client = reqwest::Client::new();

let mut handles = vec![];
let semaphore = Arc::new(Semaphore::new(max_files));
let parallel_failures_semaphore = Arc::new(Semaphore::new(parallel_failures));

for (part_number, part_url) in parts_urls.iter().enumerate() {
let url = part_url.to_string();
let path = file_path.to_owned();
let client = client.clone();

let start = (part_number as u64) * chunk_size;
let permit = semaphore
.clone()
.acquire_owned()
.await
.map_err(|err| PyException::new_err(format!("Error acquiring semaphore: {err}")))?;
let parallel_failures_semaphore = parallel_failures_semaphore.clone();
handles.push(tokio::spawn(async move {
let mut chunk = upload_chunk(&client, &url, &path, start, chunk_size).await;
let mut i = 0;
if parallel_failures > 0 {
while let Err(ul_err) = chunk {
if i >= max_retries {
return Err(PyException::new_err(format!(
"Failed after too many retries ({max_retries:?}): {ul_err:?}"
)));
}

let parallel_failure_permit = parallel_failures_semaphore.clone().try_acquire_owned().map_err(|err| {
PyException::new_err(format!(
"Failed too many failures in parallel ({parallel_failures:?}): {ul_err:?} ({err:?})"
))
})?;

let wait_time = exponential_backoff(BASE_WAIT_TIME, i, MAX_WAIT_TIME);
sleep(tokio::time::Duration::from_millis(wait_time as u64)).await;

chunk = upload_chunk(&client, &url, &path, start, chunk_size).await;
i += 1;
drop(parallel_failure_permit);
}
}
drop(permit);
chunk
}));
}

let results: Vec<Result<PyResult<HashMap<String, String>>, tokio::task::JoinError>> =
futures::future::join_all(handles).await;

let results: PyResult<Vec<HashMap<String, String>>> = results
.into_iter()
.flat_map(|res| match res {
Ok(Ok(response)) => Some(Ok(response)),
Ok(Err(err)) => Some(Err(err)),
Err(err) => Some(Err(PyException::new_err(format!(
"Error occurred while uploading: {err}"
)))),
})
.collect();

results
}

async fn upload_chunk(
client: &reqwest::Client,
url: &str,
path: &str,
start: u64,
chunk_size: u64,
) -> PyResult<HashMap<String, String>> {
let mut options = OpenOptions::new();
let mut file = options.read(true).open(path).await?;
let file_size = file.metadata().await?.len();
let bytes_transfered = std::cmp::min(file_size - start, chunk_size);

file.seek(SeekFrom::Start(start)).await?;
let chunk = file.take(chunk_size);

let response = client
.put(url)
.header(CONTENT_LENGTH, bytes_transfered)
.body(reqwest::Body::wrap_stream(FramedRead::new(
chunk,
BytesCodec::new(),
)))
.send()
.await
.map_err(|err| PyException::new_err(format!("Error sending chunk: {err}")))?
.error_for_status()
.map_err(|err| {
PyException::new_err(format!(
"Server responded with error status code while upload chunk: {err}"
))
})?;

let mut headers = HashMap::new();
for (name, value) in response.headers().into_iter() {
headers.insert(
name.to_string(),
value
.to_str()
.map_err(|err| {
PyException::new_err(format!("Response header contains non ASCII chars: {err}"))
})?
.to_owned(),
);
}
Ok(headers)
}

/// A Python module implemented in Rust.
#[pymodule]
fn hf_transfer(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(download, m)?)?;
m.add_function(wrap_pyfunction!(multipart_upload, m)?)?;
Ok(())
}

0 comments on commit 088bff9

Please sign in to comment.