Skip to content

Commit

Permalink
feat(download): allow fetching model files with multiple partitions (T…
Browse files Browse the repository at this point in the history
…abbyML#3258)

* finish main logic

* add ut

* [autofix.ci] apply automated fixes

* feat: use indexed model name

* chore: apply review from meng

* chore: revert unnecessary downloader change

* chore: fix ut

* chore: donwload one file each time

* [autofix.ci] apply automated fixes

* chore: fix ut

* chore: fix review from meng

* chore: fix ci

* chore: revert multibar

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* chore: filter download address tests should be serial

* chore: use the workspace dep for serial_test

---------

Co-authored-by: leili <lilei@deeproute.ai>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent 639c857 commit 779f785
Show file tree
Hide file tree
Showing 5 changed files with 416 additions and 91 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 5 additions & 10 deletions crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use serde::Deserialize;
use supervisor::LlamaCppSupervisor;
use tabby_common::{
config::{HttpModelConfigBuilder, LocalModelConfig, ModelConfig},
registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH},
registry::{parse_model_id, ModelRegistry},
};
use tabby_inference::{ChatCompletionStream, CompletionOptions, CompletionStream, Embedding};

Expand Down Expand Up @@ -277,15 +277,10 @@ pub async fn create_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
}

async fn resolve_model_path(model_id: &str) -> String {
let path = PathBuf::from(model_id);
let path = if path.exists() {
path.join(GGML_MODEL_RELATIVE_PATH.as_str())
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
registry.get_model_path(name)
};
path.display().to_string()
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let path = registry.get_model_entry_path(name);
path.unwrap().display().to_string()
}

#[derive(Deserialize)]
Expand Down
103 changes: 80 additions & 23 deletions crates/tabby-common/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ pub struct ModelInfo {
pub chat_template: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub urls: Option<Vec<String>>,

#[serde(skip_serializing_if = "Option::is_none")]
pub sha256: Option<String>,
// partition_urls is used for model download address
// if the model is partitioned, the addresses of each partition will be listed here,
// if there is only one partition, it will be the same as `urls`.
//
// will first try to the `urls`, if not found, will try this `partition_urls`.
//
// must make sure the first address is the entrypoint
#[serde(skip_serializing_if = "Option::is_none")]
pub partition_urls: Option<Vec<PartitionModelUrl>>,
}

#[derive(Serialize, Deserialize)]
pub struct PartitionModelUrl {
pub urls: Vec<String>,
pub sha256: String,
}

Expand Down Expand Up @@ -54,6 +71,18 @@ pub struct ModelRegistry {
pub models: Vec<ModelInfo>,
}

lazy_static! {
pub static ref LEGACY_GGML_MODEL_PATH: String =
format!("ggml{}model.gguf", std::path::MAIN_SEPARATOR_STR);
pub static ref GGML_MODEL_PARTITIONED_PREFIX: String = "model-00001-of-".into();
}

// model registry tree structure
// root: ~/.tabby/models/TabbyML
//
// fn get_model_root_dir(model_name) -> {root}/{model_name}
//
// fn get_model_dir(model_name) -> {root}/{model_name}/ggml
impl ModelRegistry {
pub async fn new(registry: &str) -> Self {
Self {
Expand All @@ -69,29 +98,61 @@ impl ModelRegistry {
}
}

fn get_model_dir(&self, name: &str) -> PathBuf {
// get_model_store_dir returns {root}/{name}/ggml, e.g.. ~/.tabby/models/TabbyML/StarCoder-1B/ggml
pub fn get_model_store_dir(&self, name: &str) -> PathBuf {
self.get_model_dir(name).join("ggml")
}

// get_model_dir returns {root}/{name}, e.g. ~/.tabby/models/TabbyML/StarCoder-1B
pub fn get_model_dir(&self, name: &str) -> PathBuf {
models_dir().join(&self.name).join(name)
}

pub fn migrate_model_path(&self, name: &str) -> Result<(), std::io::Error> {
let model_path = self.get_model_path(name);
// get_model_path returns the entrypoint of the model,
// will look for the file with the prefix "00001-of-"
pub fn get_model_entry_path(&self, name: &str) -> Option<PathBuf> {
for entry in fs::read_dir(self.get_model_store_dir(name)).ok()? {
let entry = entry.expect("Error reading directory entry");
let file_name = entry.file_name();
let file_name_str = file_name.to_string_lossy();

// Check if the file name starts with the specified prefix
if file_name_str.starts_with(GGML_MODEL_PARTITIONED_PREFIX.as_str()) {
return Some(entry.path()); // Return the full path as PathBuf
}
}

None
}

pub fn migrate_legacy_model_path(&self, name: &str) -> Result<(), std::io::Error> {
let old_model_path = self
.get_model_dir(name)
.join(LEGACY_GGML_MODEL_RELATIVE_PATH.as_str());

if !model_path.exists() && old_model_path.exists() {
std::fs::rename(&old_model_path, &model_path)?;
#[cfg(target_family = "unix")]
std::os::unix::fs::symlink(&model_path, &old_model_path)?;
#[cfg(target_family = "windows")]
std::os::windows::fs::symlink_file(&model_path, &old_model_path)?;
.join(LEGACY_GGML_MODEL_PATH.as_str());

if old_model_path.exists() {
return self.migrate_model_path(name, &old_model_path);
}

Ok(())
}

pub fn get_model_path(&self, name: &str) -> PathBuf {
self.get_model_dir(name)
.join(GGML_MODEL_RELATIVE_PATH.as_str())
.join(LEGACY_GGML_MODEL_PATH.as_str())
}

pub fn migrate_model_path(
&self,
name: &str,
old_model_path: &PathBuf,
) -> Result<(), std::io::Error> {
// legacy model always has a single file
let model_path = self
.get_model_store_dir(name)
.join("model-00001-of-00001.gguf");
std::fs::rename(old_model_path, &model_path)?;
Ok(())
}

pub fn save_model_info(&self, name: &str) {
Expand Down Expand Up @@ -120,13 +181,6 @@ pub fn parse_model_id(model_id: &str) -> (&str, &str) {
}
}

lazy_static! {
pub static ref LEGACY_GGML_MODEL_RELATIVE_PATH: String =
format!("ggml{}q8_0.v2.gguf", std::path::MAIN_SEPARATOR_STR);
pub static ref GGML_MODEL_RELATIVE_PATH: String =
format!("ggml{}model.gguf", std::path::MAIN_SEPARATOR_STR);
}

#[cfg(test)]
mod tests {
use temp_testdir::TempDir;
Expand All @@ -142,7 +196,7 @@ mod tests {
let registry = ModelRegistry::new("TabbyML").await;
let dir = registry.get_model_dir("StarCoder-1B");

let old_model_path = dir.join(LEGACY_GGML_MODEL_RELATIVE_PATH.as_str());
let old_model_path = dir.join(LEGACY_GGML_MODEL_PATH.as_str());
tokio::fs::create_dir_all(old_model_path.parent().unwrap())
.await
.unwrap();
Expand All @@ -153,8 +207,11 @@ mod tests {
.await
.unwrap();

registry.migrate_model_path("StarCoder-1B").unwrap();
assert!(registry.get_model_path("StarCoder-1B").exists());
assert!(old_model_path.exists());
registry.migrate_legacy_model_path("StarCoder-1B").unwrap();
assert!(registry
.get_model_entry_path("StarCoder-1B")
.unwrap()
.exists());
assert!(!old_model_path.exists());
}
}
3 changes: 3 additions & 0 deletions crates/tabby-download/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ tabby-common = { path = "../tabby-common" }
anyhow = { workspace = true }
tracing = { workspace = true }
tokio-retry = "0.3.0"

[dev-dependencies]
serial_test = { workspace = true }
Loading

0 comments on commit 779f785

Please sign in to comment.