From a4417ddda98fd0558fb4d802253e68a933704b59 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Mon, 27 Jan 2025 19:36:10 +0100 Subject: [PATCH] Add new hf protocol for ollama (#11449) https://huggingface.co/docs/hub/en/ollama Signed-off-by: Eric Curtin --- examples/run/run.cpp | 109 +++++++++++++++++++++++++++++-------------- 1 file changed, 74 insertions(+), 35 deletions(-) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 92a49eb744fda..8a0db74b62d05 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -319,6 +319,10 @@ class HttpClient { public: int init(const std::string & url, const std::vector & headers, const std::string & output_file, const bool progress, std::string * response_str = nullptr) { + if (std::filesystem::exists(output_file)) { + return 0; + } + std::string output_file_partial; curl = curl_easy_init(); if (!curl) { @@ -558,13 +562,14 @@ class LlamaData { } sampler = initialize_sampler(opt); + return 0; } private: #ifdef LLAMA_USE_CURL - int download(const std::string & url, const std::vector & headers, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { + int download(const std::string & url, const std::string & output_file, const bool progress, + const std::vector & headers = {}, std::string * response_str = nullptr) { HttpClient http; if (http.init(url, headers, output_file, progress, response_str)) { return 1; @@ -573,48 +578,85 @@ class LlamaData { return 0; } #else - int download(const std::string &, const std::vector &, const std::string &, const bool, + int download(const std::string &, const std::string &, const bool, const std::vector & = {}, std::string * = nullptr) { printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); + return 1; } #endif - int huggingface_dl(const std::string & model, const std::vector headers, const std::string & bn) { + // Helper function to handle model tag extraction and URL construction + std::pair extract_model_and_tag(std::string & model, const std::string & base_url) { + std::string model_tag = "latest"; + const size_t colon_pos = model.find(':'); + if (colon_pos != std::string::npos) { + model_tag = model.substr(colon_pos + 1); + model = model.substr(0, colon_pos); + } + + std::string url = base_url + model + "/manifests/" + model_tag; + + return { model, url }; + } + + // Helper function to download and parse the manifest + int download_and_parse_manifest(const std::string & url, const std::vector & headers, + nlohmann::json & manifest) { + std::string manifest_str; + int ret = download(url, "", false, headers, &manifest_str); + if (ret) { + return ret; + } + + manifest = nlohmann::json::parse(manifest_str); + + return 0; + } + + int huggingface_dl(std::string & model, const std::string & bn) { // Find the second occurrence of '/' after protocol string size_t pos = model.find('/'); pos = model.find('/', pos + 1); + std::string hfr, hff; + std::vector headers = { "User-Agent: llama-cpp", "Accept: application/json" }; + std::string url; + if (pos == std::string::npos) { - return 1; + auto [model_name, manifest_url] = extract_model_and_tag(model, "https://huggingface.co/v2/"); + hfr = model_name; + + nlohmann::json manifest; + int ret = download_and_parse_manifest(manifest_url, headers, manifest); + if (ret) { + return ret; + } + + hff = manifest["ggufFile"]["rfilename"]; + } else { + hfr = model.substr(0, pos); + hff = model.substr(pos + 1); } - const std::string hfr = model.substr(0, pos); - const std::string hff = model.substr(pos + 1); - const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff; - return download(url, headers, bn, true); + url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff; + + return download(url, bn, true, headers); } - int ollama_dl(std::string & model, const std::vector headers, const std::string & bn) { + int ollama_dl(std::string & model, const std::string & bn) { + const std::vector headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" }; if (model.find('/') == std::string::npos) { model = "library/" + model; } - std::string model_tag = "latest"; - size_t colon_pos = model.find(':'); - if (colon_pos != std::string::npos) { - model_tag = model.substr(colon_pos + 1); - model = model.substr(0, colon_pos); - } - - std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag; - std::string manifest_str; - const int ret = download(manifest_url, headers, "", false, &manifest_str); + auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/"); + nlohmann::json manifest; + int ret = download_and_parse_manifest(manifest_url, {}, manifest); if (ret) { return ret; } - nlohmann::json manifest = nlohmann::json::parse(manifest_str); - std::string layer; + std::string layer; for (const auto & l : manifest["layers"]) { if (l["mediaType"] == "application/vnd.ollama.image.model") { layer = l["digest"]; @@ -622,8 +664,9 @@ class LlamaData { } } - std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer; - return download(blob_url, headers, bn, true); + std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer; + + return download(blob_url, bn, true, headers); } std::string basename(const std::string & path) { @@ -653,22 +696,18 @@ class LlamaData { return ret; } - const std::string bn = basename(model_); - const std::vector headers = { "--header", - "Accept: application/vnd.docker.distribution.manifest.v2+json" }; + const std::string bn = basename(model_); if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) { rm_until_substring(model_, "://"); - ret = huggingface_dl(model_, headers, bn); + ret = huggingface_dl(model_, bn); } else if (string_starts_with(model_, "hf.co/")) { rm_until_substring(model_, "hf.co/"); - ret = huggingface_dl(model_, headers, bn); - } else if (string_starts_with(model_, "ollama://")) { - rm_until_substring(model_, "://"); - ret = ollama_dl(model_, headers, bn); + ret = huggingface_dl(model_, bn); } else if (string_starts_with(model_, "https://")) { - ret = download(model_, headers, bn, true); - } else { - ret = ollama_dl(model_, headers, bn); + ret = download(model_, bn, true); + } else { // ollama:// or nothing + rm_until_substring(model_, "://"); + ret = ollama_dl(model_, bn); } model_ = bn;