diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 5980a786ff964e..40f2bcb008afb2 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -673,6 +673,40 @@ class LlamaData { return download(blob_url, bn, true, headers); } + int github_dl(const std::string & model, const std::string & bn) { + std::string repository = model; + std::string branch = "main"; + size_t at_pos = model.find('@'); + if (at_pos != std::string::npos) { + repository = model.substr(0, at_pos); + branch = model.substr(at_pos + 1); + } + + std::vector repo_parts; + size_t start = 0; + for (size_t end = 0; (end = repository.find('/', start)) != std::string::npos; start = end + 1) { + repo_parts.push_back(repository.substr(start, end - start)); + } + + repo_parts.push_back(repository.substr(start)); + if (repo_parts.size() < 3) { + printe("Invalid GitHub repository format\n"); + return 1; + } + + const std::string org = repo_parts[0]; + const std::string project = repo_parts[1]; + std::string project_path = repo_parts[2]; + for (size_t i = 3; i < repo_parts.size(); ++i) { + project_path += "/" + repo_parts[i]; + } + + const std::string url = + "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch + "/" + project_path; + + return download(url, bn, true); + } + std::string basename(const std::string & path) { const size_t pos = path.find_last_of("/\\"); if (pos == std::string::npos) { @@ -707,8 +741,12 @@ class LlamaData { } else if (string_starts_with(model_, "hf.co/")) { rm_until_substring(model_, "hf.co/"); ret = huggingface_dl(model_, bn); - } else if (string_starts_with(model_, "https://")) { + } else if (string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) { ret = download(model_, bn, true); + } else if (string_starts_with(model_, "github:") || string_starts_with(model_, "github://")) { + rm_until_substring(model_, "github://"); + rm_until_substring(model_, "github:"); + ret = github_dl(model_, bn); } else { // ollama:// or nothing rm_until_substring(model_, "://"); ret = ollama_dl(model_, bn);