Skip to content

Commit

Permalink
chore(download): properly implement partitioned model download (Tabby…
Browse files Browse the repository at this point in the history
…ML#3294)

* bug: fix ignored local model

* feat: should redownload if part of models not found

* test: fix sha256 check returned error

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
zwpaper and autofix-ci[bot] authored Oct 21, 2024
1 parent 3c0bc99 commit ff3410d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 30 deletions.
20 changes: 13 additions & 7 deletions crates/aim-downloader/src/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ impl HashChecker {
pub fn check(filename: &str, expected_hash: &str) -> Result<(), ValidateError> {
let mut result = Ok(());
if filename != "stdout" && (!expected_hash.is_empty()) {
let actual_hash = HashChecker::sha256sum(filename);
let actual_hash = HashChecker::sha256sum(filename)?;
if actual_hash != expected_hash {
result = Err(ValidateError::Sha256Mismatch);
}
Expand All @@ -22,15 +22,21 @@ impl HashChecker {
result
}

fn sha256sum(filename: &str) -> String {
fn sha256sum(filename: &str) -> Result<String, ValidateError> {
let mut hasher = Sha256::new();
let mut file = fs::File::open(filename).unwrap();

io::copy(&mut file, &mut hasher).unwrap();
let mut file = fs::File::open(filename).map_err(|e| {
println!("Can not open {filename}:\n {e}");
ValidateError::Sha256Mismatch
})?;

io::copy(&mut file, &mut hasher).map_err(|e| {
println!("Can not read {filename}:\n {e}");
ValidateError::Sha256Mismatch
})?;
let computed_hash = hasher.finalize();
drop(file);

format!("{computed_hash:x}")
Ok(format!("{computed_hash:x}"))
}
}

Expand Down Expand Up @@ -61,7 +67,7 @@ mod tests {
fn test_sha256sum_api() {
let expected = "21d7847124bfb9d9a9d44af6f00d8003006c44b9ef9ba458b5d4d3fc5f81bde5";

let actual = HashChecker::sha256sum("LICENCE.md");
let actual = HashChecker::sha256sum("LICENCE.md").unwrap();

assert_eq!(actual, expected);
}
Expand Down
57 changes: 34 additions & 23 deletions crates/tabby-download/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,34 +85,45 @@ async fn download_model_impl(
);
}

if !prefer_local_file {
info!("Checking model integrity..");

let mut sha256_matched = true;
for (index, url) in urls.iter().enumerate() {
if HashChecker::check(
partitioned_file_name!(index + 1, urls.len()).as_str(),
&url.1,
)
.is_err()
{
sha256_matched = false;
break;
}
let mut model_existed = true;
for (index, _) in urls.iter().enumerate() {
if fs::metadata(
registry
.get_model_store_dir(name)
.join(partitioned_file_name!(index, urls.len())),
)
.is_err()
{
model_existed = false;
break;
}
}

if sha256_matched {
return Ok(());
}
if model_existed && prefer_local_file {
return Ok(());
}

warn!(
"Checksum doesn't match for <{}/{}>, re-downloading...",
registry.name, name
);
info!("Checking model integrity..");

fs::remove_dir_all(registry.get_model_dir(name))?;
let mut sha256_matched = true;
for (index, url) in urls.iter().enumerate() {
if HashChecker::check(partitioned_file_name!(index, urls.len()).as_str(), &url.1).is_err() {
sha256_matched = false;
break;
}
}

if sha256_matched {
return Ok(());
}

warn!(
"Checksum doesn't match for <{}/{}>, re-downloading...",
registry.name, name
);

fs::remove_dir_all(registry.get_model_dir(name))?;

// prepare for download
let dir = registry.get_model_store_dir(name);
fs::create_dir_all(dir)?;
Expand All @@ -123,7 +134,7 @@ async fn download_model_impl(
.get_model_store_dir(name)
.to_string_lossy()
.into_owned();
let filename: String = partitioned_file_name!(index + 1, urls.len());
let filename: String = partitioned_file_name!(index, urls.len());
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2);

Retry::spawn(strategy, move || {
Expand Down

0 comments on commit ff3410d

Please sign in to comment.