diff --git a/Cargo.lock b/Cargo.lock index 82d3dee35e57..eebeaf59b631 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5107,6 +5107,7 @@ version = "0.19.0-dev.0" dependencies = [ "aim-downloader", "anyhow", + "serial_test 3.1.1", "tabby-common", "tokio-retry", "tracing", diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index b62942352311..9b9bcb18db55 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -10,7 +10,7 @@ use serde::Deserialize; use supervisor::LlamaCppSupervisor; use tabby_common::{ config::{HttpModelConfigBuilder, LocalModelConfig, ModelConfig}, - registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}, + registry::{parse_model_id, ModelRegistry}, }; use tabby_inference::{ChatCompletionStream, CompletionOptions, CompletionStream, Embedding}; @@ -277,15 +277,10 @@ pub async fn create_embedding(config: &ModelConfig) -> Arc { } async fn resolve_model_path(model_id: &str) -> String { - let path = PathBuf::from(model_id); - let path = if path.exists() { - path.join(GGML_MODEL_RELATIVE_PATH.as_str()) - } else { - let (registry, name) = parse_model_id(model_id); - let registry = ModelRegistry::new(registry).await; - registry.get_model_path(name) - }; - path.display().to_string() + let (registry, name) = parse_model_id(model_id); + let registry = ModelRegistry::new(registry).await; + let path = registry.get_model_entry_path(name); + path.unwrap().display().to_string() } #[derive(Deserialize)] diff --git a/crates/tabby-common/src/registry.rs b/crates/tabby-common/src/registry.rs index 862e2388c02c..f476eed68670 100644 --- a/crates/tabby-common/src/registry.rs +++ b/crates/tabby-common/src/registry.rs @@ -15,6 +15,23 @@ pub struct ModelInfo { pub chat_template: Option, #[serde(skip_serializing_if = "Option::is_none")] pub urls: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + pub sha256: Option, + // partition_urls is used for model download address + // if the model is partitioned, the addresses of each partition will be listed here, + // if there is only one partition, it will be the same as `urls`. + // + // will first try to the `urls`, if not found, will try this `partition_urls`. + // + // must make sure the first address is the entrypoint + #[serde(skip_serializing_if = "Option::is_none")] + pub partition_urls: Option>, +} + +#[derive(Serialize, Deserialize)] +pub struct PartitionModelUrl { + pub urls: Vec, pub sha256: String, } @@ -54,6 +71,18 @@ pub struct ModelRegistry { pub models: Vec, } +lazy_static! { + pub static ref LEGACY_GGML_MODEL_PATH: String = + format!("ggml{}model.gguf", std::path::MAIN_SEPARATOR_STR); + pub static ref GGML_MODEL_PARTITIONED_PREFIX: String = "model-00001-of-".into(); +} + +// model registry tree structure +// root: ~/.tabby/models/TabbyML +// +// fn get_model_root_dir(model_name) -> {root}/{model_name} +// +// fn get_model_dir(model_name) -> {root}/{model_name}/ggml impl ModelRegistry { pub async fn new(registry: &str) -> Self { Self { @@ -69,29 +98,61 @@ impl ModelRegistry { } } - fn get_model_dir(&self, name: &str) -> PathBuf { + // get_model_store_dir returns {root}/{name}/ggml, e.g.. ~/.tabby/models/TabbyML/StarCoder-1B/ggml + pub fn get_model_store_dir(&self, name: &str) -> PathBuf { + self.get_model_dir(name).join("ggml") + } + + // get_model_dir returns {root}/{name}, e.g. ~/.tabby/models/TabbyML/StarCoder-1B + pub fn get_model_dir(&self, name: &str) -> PathBuf { models_dir().join(&self.name).join(name) } - pub fn migrate_model_path(&self, name: &str) -> Result<(), std::io::Error> { - let model_path = self.get_model_path(name); + // get_model_path returns the entrypoint of the model, + // will look for the file with the prefix "00001-of-" + pub fn get_model_entry_path(&self, name: &str) -> Option { + for entry in fs::read_dir(self.get_model_store_dir(name)).ok()? { + let entry = entry.expect("Error reading directory entry"); + let file_name = entry.file_name(); + let file_name_str = file_name.to_string_lossy(); + + // Check if the file name starts with the specified prefix + if file_name_str.starts_with(GGML_MODEL_PARTITIONED_PREFIX.as_str()) { + return Some(entry.path()); // Return the full path as PathBuf + } + } + + None + } + + pub fn migrate_legacy_model_path(&self, name: &str) -> Result<(), std::io::Error> { let old_model_path = self .get_model_dir(name) - .join(LEGACY_GGML_MODEL_RELATIVE_PATH.as_str()); - - if !model_path.exists() && old_model_path.exists() { - std::fs::rename(&old_model_path, &model_path)?; - #[cfg(target_family = "unix")] - std::os::unix::fs::symlink(&model_path, &old_model_path)?; - #[cfg(target_family = "windows")] - std::os::windows::fs::symlink_file(&model_path, &old_model_path)?; + .join(LEGACY_GGML_MODEL_PATH.as_str()); + + if old_model_path.exists() { + return self.migrate_model_path(name, &old_model_path); } + Ok(()) } pub fn get_model_path(&self, name: &str) -> PathBuf { self.get_model_dir(name) - .join(GGML_MODEL_RELATIVE_PATH.as_str()) + .join(LEGACY_GGML_MODEL_PATH.as_str()) + } + + pub fn migrate_model_path( + &self, + name: &str, + old_model_path: &PathBuf, + ) -> Result<(), std::io::Error> { + // legacy model always has a single file + let model_path = self + .get_model_store_dir(name) + .join("model-00001-of-00001.gguf"); + std::fs::rename(old_model_path, &model_path)?; + Ok(()) } pub fn save_model_info(&self, name: &str) { @@ -120,13 +181,6 @@ pub fn parse_model_id(model_id: &str) -> (&str, &str) { } } -lazy_static! { - pub static ref LEGACY_GGML_MODEL_RELATIVE_PATH: String = - format!("ggml{}q8_0.v2.gguf", std::path::MAIN_SEPARATOR_STR); - pub static ref GGML_MODEL_RELATIVE_PATH: String = - format!("ggml{}model.gguf", std::path::MAIN_SEPARATOR_STR); -} - #[cfg(test)] mod tests { use temp_testdir::TempDir; @@ -142,7 +196,7 @@ mod tests { let registry = ModelRegistry::new("TabbyML").await; let dir = registry.get_model_dir("StarCoder-1B"); - let old_model_path = dir.join(LEGACY_GGML_MODEL_RELATIVE_PATH.as_str()); + let old_model_path = dir.join(LEGACY_GGML_MODEL_PATH.as_str()); tokio::fs::create_dir_all(old_model_path.parent().unwrap()) .await .unwrap(); @@ -153,8 +207,11 @@ mod tests { .await .unwrap(); - registry.migrate_model_path("StarCoder-1B").unwrap(); - assert!(registry.get_model_path("StarCoder-1B").exists()); - assert!(old_model_path.exists()); + registry.migrate_legacy_model_path("StarCoder-1B").unwrap(); + assert!(registry + .get_model_entry_path("StarCoder-1B") + .unwrap() + .exists()); + assert!(!old_model_path.exists()); } } diff --git a/crates/tabby-download/Cargo.toml b/crates/tabby-download/Cargo.toml index f1087d7555c3..355c58ba18d8 100644 --- a/crates/tabby-download/Cargo.toml +++ b/crates/tabby-download/Cargo.toml @@ -9,3 +9,6 @@ tabby-common = { path = "../tabby-common" } anyhow = { workspace = true } tracing = { workspace = true } tokio-retry = "0.3.0" + +[dev-dependencies] +serial_test = { workspace = true } diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index cb05301b1e5f..753d93496739 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -1,24 +1,71 @@ //! Responsible for downloading ML models for use with tabby. -use std::{ - fs::{self}, - path::Path, -}; +use std::fs; use aim_downloader::{bar::WrappedBar, error::DownloadError, hash::HashChecker, https}; -use anyhow::{anyhow, bail, Result}; -use tabby_common::registry::{parse_model_id, ModelRegistry}; +use anyhow::{bail, Result}; +use tabby_common::registry::{parse_model_id, ModelInfo, ModelRegistry}; use tokio_retry::{ strategy::{jitter, ExponentialBackoff}, Retry, }; use tracing::{info, warn}; -fn select_by_download_host(url: &String) -> bool { - if let Ok(host) = std::env::var("TABBY_DOWNLOAD_HOST") { - url.contains(&host) - } else { - true - } +pub fn get_download_host() -> String { + std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or_else(|_| "huggingface.co".to_string()) +} + +pub fn get_huggingface_mirror_host() -> Option { + std::env::var("TABBY_HUGGINGFACE_HOST_OVERRIDE").ok() +} + +pub fn filter_download_address(model_info: &ModelInfo) -> Vec<(String, String)> { + let download_host = get_download_host(); + if let Some(urls) = &model_info.urls { + if !urls.is_empty() { + let url = model_info + .urls + .iter() + .flatten() + .find(|f| f.contains(&download_host)); + if let Some(url) = url { + if let Some(mirror_host) = get_huggingface_mirror_host() { + return vec![( + url.replace("huggingface.co", &mirror_host), + model_info.sha256.clone().unwrap_or_default(), + )]; + } + return vec![( + url.to_owned(), + model_info.sha256.clone().unwrap_or_default(), + )]; + } + } + }; + + model_info + .partition_urls + .iter() + .flatten() + .map(|x| -> (String, String) { + let url = x.urls.iter().find(|f| f.contains(&download_host)); + if let Some(url) = url { + if let Some(mirror_host) = get_huggingface_mirror_host() { + return ( + url.replace("huggingface.co", &mirror_host), + x.sha256.clone(), + ); + } + return (url.to_owned(), x.sha256.clone()); + } + panic!("No download URLs available for <{}>", model_info.name); + }) + .collect() +} + +macro_rules! partitioned_file_name { + ($index:expr, $total:expr) => { + format!("model-{:05}-of-{:05}.gguf", $index + 1, $total) + }; } async fn download_model_impl( @@ -27,63 +74,79 @@ async fn download_model_impl( prefer_local_file: bool, ) -> Result<()> { let model_info = registry.get_model_info(name); - registry.save_model_info(name); + registry.migrate_legacy_model_path(name)?; - registry.migrate_model_path(name)?; - let model_path = registry.get_model_path(name); - if model_path.exists() { - if !prefer_local_file { - info!("Checking model integrity.."); - if HashChecker::check(&model_path.display().to_string(), &model_info.sha256).is_ok() { - return Ok(()); + let urls = filter_download_address(model_info); + if urls.is_empty() { + bail!( + "No download URLs available for <{}/{}>", + registry.name, + model_info.name + ); + } + + if !prefer_local_file { + info!("Checking model integrity.."); + + let mut sha256_matched = true; + for (index, url) in urls.iter().enumerate() { + if HashChecker::check( + partitioned_file_name!(index + 1, urls.len()).as_str(), + &url.1, + ) + .is_err() + { + sha256_matched = false; + break; } - warn!( - "Checksum doesn't match for <{}/{}>, re-downloading...", - registry.name, name - ); - fs::remove_file(&model_path)?; - } else { + } + + if sha256_matched { return Ok(()); } + + warn!( + "Checksum doesn't match for <{}/{}>, re-downloading...", + registry.name, name + ); + + fs::remove_dir_all(registry.get_model_dir(name))?; } - let Some(model_url) = model_info - .urls - .iter() - .flatten() - .find(|x| select_by_download_host(x)) - else { - return Err(anyhow!("No valid url for model <{}>", model_info.name)); - }; + // prepare for download + let dir = registry.get_model_store_dir(name); + fs::create_dir_all(dir)?; + registry.save_model_info(name); - // Replace the huggingface.co domain with the mirror host if it is set. - let model_url = if let Ok(host) = std::env::var("TABBY_HUGGINGFACE_HOST_OVERRIDE") { - model_url.replace("huggingface.co", &host) - } else { - model_url.to_owned() - }; + for (index, url) in urls.iter().enumerate() { + let dir = registry + .get_model_store_dir(name) + .to_string_lossy() + .into_owned(); + let filename: String = partitioned_file_name!(index + 1, urls.len()); + let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2); - let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2); - let download_job = Retry::spawn(strategy, || { - download_file(&model_url, model_path.as_path(), &model_info.sha256) - }); - download_job.await?; - Ok(()) -} + Retry::spawn(strategy, move || { + let dir = dir.clone(); + let filename = filename.clone(); -async fn download_file(url: &str, path: &Path, expected_sha256: &str) -> Result<()> { - let dir = path - .parent() - .ok_or_else(|| anyhow!("Must not be in root directory"))?; - fs::create_dir_all(dir)?; + download_file(&url.0, dir, filename, &url.1) + }) + .await?; + } - let filename = path - .to_str() - .ok_or_else(|| anyhow!("Could not convert filename to UTF-8"))?; - let intermediate_filename = filename.to_owned() + ".tmp"; + Ok(()) +} +async fn download_file( + url: &str, + dir: String, + filename: String, + expected_sha256: &str, +) -> Result<()> { + let fullpath = format! {"{}{}{}", dir, std::path::MAIN_SEPARATOR, filename}; + let intermediate_filename = fullpath.clone() + ".tmp"; let mut bar = WrappedBar::new(0, url, false); - if let Err(e) = https::HTTPSHandler::get(url, &intermediate_filename, &mut bar, expected_sha256).await { @@ -97,7 +160,7 @@ async fn download_file(url: &str, path: &Path, expected_sha256: &str) -> Result< } } - fs::rename(intermediate_filename, filename)?; + fs::rename(intermediate_filename, fullpath)?; Ok(()) } @@ -111,3 +174,209 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) { .await .unwrap_or_else(handler) } + +#[cfg(test)] +mod tests { + // filter_download_address tests should be serial because they rely on environment variables + use serial_test::serial; + use tabby_common::registry::{ModelInfo, PartitionModelUrl}; + + #[test] + #[serial(filter_download_address)] + fn test_filter_download_address() { + // multiple urls + let model_info = ModelInfo { + name: "test".to_string(), + urls: Some(vec![ + "https://huggingface.co/test".to_string(), + "https://huggingface.co/test2".to_string(), + "https://modelscope.co/test2".to_string(), + ]), + sha256: Some("test_sha256".to_string()), + prompt_template: None, + chat_template: None, + partition_urls: None, + }; + let urls = super::filter_download_address(&model_info); + assert_eq!(urls.len(), 1); + assert_eq!(urls[0].0, "https://huggingface.co/test"); + + // single url + let model_info = ModelInfo { + name: "test".to_string(), + urls: Some(vec![ + "https://huggingface.co/test".to_string(), + "https://modelscope.co/test2".to_string(), + ]), + sha256: Some("test_sha256".to_string()), + prompt_template: None, + chat_template: None, + partition_urls: None, + }; + let urls = super::filter_download_address(&model_info); + assert_eq!(urls.len(), 1); + assert_eq!(urls[0].0, "https://huggingface.co/test"); + } + + #[test] + #[serial(filter_download_address)] + fn test_filter_download_address_multiple_partitions() { + let model_info = ModelInfo { + name: "test".to_string(), + urls: None, + sha256: Some("test_sha256".to_string()), + prompt_template: None, + chat_template: None, + partition_urls: Some(vec![ + PartitionModelUrl { + urls: vec![ + "https://huggingface.co/part1".to_string(), + "https://modelscope.co/part1".to_string(), + ], + sha256: "test_sha256_1".to_string(), + }, + PartitionModelUrl { + urls: vec![ + "https://huggingface.co/part2".to_string(), + "https://modelscope.co/part2".to_string(), + ], + sha256: "test_sha256_2".to_string(), + }, + ]), + }; + let urls = super::filter_download_address(&model_info); + assert_eq!(urls.len(), 2); + assert_eq!(urls[0].0, "https://huggingface.co/part1"); + assert_eq!(urls[0].1, "test_sha256_1"); + assert_eq!(urls[1].0, "https://huggingface.co/part2"); + assert_eq!(urls[1].1, "test_sha256_2"); + } + + #[test] + #[serial(filter_download_address)] + fn test_filter_download_address_single_partition() { + let model_info = ModelInfo { + name: "test".to_string(), + urls: None, + sha256: Some("test_sha256".to_string()), + prompt_template: None, + chat_template: None, + partition_urls: Some(vec![PartitionModelUrl { + urls: vec!["https://huggingface.co/part1".to_string()], + sha256: "test_sha256_1".to_string(), + }]), + }; + let urls = super::filter_download_address(&model_info); + assert_eq!(urls.len(), 1); + assert_eq!(urls[0].0, "https://huggingface.co/part1"); + assert_eq!(urls[0].1, "test_sha256_1"); + } + + #[test] + #[serial(filter_download_address)] + fn test_filter_download_address_prefer_urls() { + let model_info = ModelInfo { + name: "test".to_string(), + urls: Some(vec!["https://huggingface.co/test".to_string()]), + sha256: Some("test_sha256".to_string()), + prompt_template: None, + chat_template: None, + partition_urls: Some(vec![PartitionModelUrl { + urls: vec!["https://modelscope.co/test".to_string()], + sha256: "test_sha256".to_string(), + }]), + }; + let urls = super::filter_download_address(&model_info); + assert_eq!(urls.len(), 1); + assert_eq!(urls[0].0, "https://huggingface.co/test"); + assert_eq!(urls[0].1, "test_sha256"); + } + + #[test] + #[serial(filter_download_address)] + fn test_filter_download_address_huggingface_override_urls() { + std::env::set_var("TABBY_HUGGINGFACE_HOST_OVERRIDE", "modelscope.co"); + let model_info = ModelInfo { + name: "test".to_string(), + urls: Some(vec!["https://huggingface.co/test".to_string()]), + sha256: Some("test_sha256".to_string()), + prompt_template: None, + chat_template: None, + partition_urls: None, + }; + let urls = super::filter_download_address(&model_info); + assert_eq!(urls.len(), 1); + assert_eq!(urls[0].0, "https://modelscope.co/test"); + assert_eq!(urls[0].1, "test_sha256"); + // must reset the env, or it will affect other tests + std::env::remove_var("TABBY_HUGGINGFACE_HOST_OVERRIDE"); + } + + #[test] + #[serial(filter_download_address)] + fn test_filter_download_address_huggingface_override_partitioned() { + std::env::set_var("TABBY_HUGGINGFACE_HOST_OVERRIDE", "modelscope.co"); + let model_info = ModelInfo { + name: "test".to_string(), + urls: None, + sha256: Some("test_sha256".to_string()), + prompt_template: None, + chat_template: None, + partition_urls: Some(vec![ + PartitionModelUrl { + urls: vec!["https://huggingface.co/part1".to_string()], + sha256: "test_sha256_1".to_string(), + }, + PartitionModelUrl { + urls: vec!["https://huggingface.co/part2".to_string()], + sha256: "test_sha256_2".to_string(), + }, + ]), + }; + let urls = super::filter_download_address(&model_info); + assert_eq!(urls.len(), 2); + assert_eq!(urls[0].0, "https://modelscope.co/part1"); + assert_eq!(urls[0].1, "test_sha256_1"); + assert_eq!(urls[1].0, "https://modelscope.co/part2"); + assert_eq!(urls[1].1, "test_sha256_2"); + // must reset the env, or it will affect other tests + std::env::remove_var("TABBY_HUGGINGFACE_HOST_OVERRIDE"); + } + + #[test] + #[serial(filter_download_address)] + fn test_filter_download_address_download_host() { + std::env::set_var("TABBY_DOWNLOAD_HOST", "modelscope.co"); + let model_info = ModelInfo { + name: "test".to_string(), + urls: None, + sha256: Some("test_sha256".to_string()), + prompt_template: None, + chat_template: None, + partition_urls: Some(vec![ + PartitionModelUrl { + urls: vec![ + "https://huggingface.co/part1".to_string(), + "https://modelscope.co/part1".to_string(), + ], + sha256: "test_sha256_1".to_string(), + }, + PartitionModelUrl { + urls: vec![ + "https://huggingface.co/part2".to_string(), + "https://modelscope.co/part2".to_string(), + ], + sha256: "test_sha256_2".to_string(), + }, + ]), + }; + let urls = super::filter_download_address(&model_info); + assert_eq!(urls.len(), 2); + assert_eq!(urls[0].0, "https://modelscope.co/part1"); + assert_eq!(urls[0].1, "test_sha256_1"); + assert_eq!(urls[1].0, "https://modelscope.co/part2"); + assert_eq!(urls[1].1, "test_sha256_2"); + // must reset the env, or it will affect other tests + std::env::remove_var("TABBY_DOWNLOAD_HOST"); + } +}