From 0c3935fa58c30c94ac8c52f86fbe112668f75e9b Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Mon, 8 Jul 2024 22:46:10 +0800 Subject: [PATCH] download wikitext bug fix --- .github/dependabot.yml | 11 ++ .github/workflows/dev.yml | 38 +++++ .github/workflows/release.yml | 54 +++++++ README.md | 26 +++- README_zh.md | 27 +++- olah/configs.py | 14 +- olah/constants.py | 3 +- olah/files.py | 257 ++++++++++++++++++++++++---------- olah/lfs.py | 2 +- olah/meta.py | 6 +- olah/server.py | 112 +++++++++++++-- olah/utls.py | 65 +++++---- static/index.html | 22 +++ tests/.gitignore | 0 14 files changed, 511 insertions(+), 126 deletions(-) create mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/dev.yml create mode 100644 .github/workflows/release.yml create mode 100644 tests/.gitignore diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..91abb11 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "pip" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml new file mode 100644 index 0000000..6475c93 --- /dev/null +++ b/.github/workflows/dev.yml @@ -0,0 +1,38 @@ +name: Olah GitHub Actions for Development +run-name: Olah GitHub Actions for Development +on: + push: + branches: [ "dev" ] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + + steps: + - name: Check out repository code + uses: actions/checkout@v4 + - name: Set up Apache Arrow + run: | + sudo apt update + sudo apt install -y -V ca-certificates lsb-release wget + wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + sudo apt update + sudo apt install -y -V libarrow-dev libarrow-glib-dev libparquet-dev libparquet-glib-dev + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install Olah + run: | + cd ${{ github.workspace }} + pip install --upgrade pip + pip install -e . + + - name: Test Olah + run: | + cd ${{ github.workspace }} + python -m unittest discover olah/tests diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..0fa9e6f --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,54 @@ +name: Olah GitHub Actions to release +run-name: Olah GitHub Actions to release +on: + push: + tags: + - "[0-9]+.[0-9]+.[0-9]+" + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] + + steps: + - name: Check out repository code + uses: actions/checkout@v4 + - name: Set up Apache Arrow + run: | + sudo apt update + sudo apt install -y -V ca-certificates lsb-release wget + wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb + sudo apt update + sudo apt install -y -V libarrow-dev libarrow-glib-dev libparquet-dev libparquet-glib-dev + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install Olah + run: | + cd ${{ github.workspace }} + pip install --upgrade pip + pip install -e . + + - name: Test Olah + run: | + cd ${{ github.workspace }} + python -m unittest discover olah/tests + + - name: Build Olah + run: | + cd ${{ github.workspace }} + pip install build + python -m build + + - name: Release + uses: "marvinpinto/action-automatic-releases@latest" + with: + repo_token: "${{ secrets.GITHUB_TOKEN }}" + prerelease: true + files: | + dist/*.tar.gz + dist/*.whl \ No newline at end of file diff --git a/README.md b/README.md index e57ca9c..eeb24e6 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ Olah is self-hosted lightweight huggingface mirror service. `Olah` means `hello` Other languages: [中文](README_zh.md) ## Features +* Huggingface Data Cache * Models mirror * Datasets mirror * Spaces mirror @@ -42,10 +43,17 @@ python -m olah.server ``` Then set the Environment Variable `HF_ENDPOINT` to the mirror site (Here is http://localhost:8090). + +Linux: ```bash export HF_ENDPOINT=http://localhost:8090 ``` +Windows Powershell: +```bash +$env:HF_ENDPOINT = "http://localhost:8090" +``` + Starting from now on, all download operations in the HuggingFace library will be proxied through this mirror site. ```python from huggingface_hub import snapshot_download @@ -53,10 +61,24 @@ from huggingface_hub import snapshot_download snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model', local_dir='./model_dir', resume_download=True, max_workers=8) +``` +Or you can download models and datasets by using huggingface cli. +```bash +pip install -U huggingface_hub ``` -You can check the path `./repos` which stores all cached datasets and models. +Download GPT2: +```bash +huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2 +``` + +Download WikiText: +```bash +huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext +``` + +You can check the path `./repos`, in which olah stores all cached datasets and models. ## Start the server Run the command in the console: @@ -75,6 +97,8 @@ The default mirror cache path is `./repos`, you can change it by `--repos-path` python -m olah.server --host localhost --port 8090 --repos-path ./hf_mirrors ``` +**Note that the cached data between different versions cannot be migrated. Please delete the cache folder before upgrading to the latest version of Olah.** + ## Future Work * Authentication diff --git a/README_zh.md b/README_zh.md index 92659ab..5794884 100644 --- a/README_zh.md +++ b/README_zh.md @@ -1,8 +1,10 @@ Olah是一种自托管的轻量级HuggingFace镜像服务。`Olah`在丘丘人语中意味着`你好`。 ## 特性 +* 数据缓存,减少下载流量 * 模型镜像 * 数据集镜像 +* 空间镜像 ## 安装 @@ -39,10 +41,16 @@ python -m olah.server ``` 然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090)。 +Linux: ```bash export HF_ENDPOINT=http://localhost:8090 ``` +Windows Powershell: +```bash +$env:HF_ENDPOINT = "http://localhost:8090" +``` + 从现在开始,HuggingFace库中的所有下载操作都将通过此镜像站点代理进行。 ```python from huggingface_hub import snapshot_download @@ -53,7 +61,22 @@ snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model', ``` -您可以检查存储所有缓存的数据集和模型的路径`./repos`。 +或者你也可以使用huggingface cli直接下载模型和数据集. +```bash +pip install -U huggingface_hub +``` + +下载GPT2: +```bash +huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2 +``` + +下载WikiText: +```bash +huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext +``` + +您可以查看路径`./repos`,其中存储了所有数据集和模型的缓存。 ## 启动服务器 在控制台运行以下命令: @@ -72,6 +95,8 @@ python -m olah.server --host localhost --port 8090 python -m olah.server --host localhost --port 8090 --repos-path ./hf_mirrors ``` +**注意,不同版本之间的缓存数据不能迁移,请删除缓存文件夹后再进行olah的升级** + ## 许可证 olah采用MIT许可证发布。 diff --git a/olah/configs.py b/olah/configs.py index cf6de4e..3427288 100644 --- a/olah/configs.py +++ b/olah/configs.py @@ -6,6 +6,11 @@ import fnmatch DEFAULT_PROXY_RULES = [ + { + "repo": "*", + "allow": True, + "use_re": False + }, { "repo": "*/*", "allow": True, @@ -14,6 +19,11 @@ ] DEFAULT_CACHE_RULES = [ + { + "repo": "*", + "allow": True, + "use_re": False + }, { "repo": "*/*", "allow": True, @@ -87,7 +97,7 @@ def __init__(self, path: Optional[str] = None) -> None: self.mirror_lfs_url = "http://localhost:8090" # accessibility - self.offline = True + self.offline = False self.proxy = OlahRuleList.from_list(DEFAULT_PROXY_RULES) self.cache = OlahRuleList.from_list(DEFAULT_CACHE_RULES) @@ -100,7 +110,7 @@ def empty_str(self, s: str) -> Optional[str]: else: return s - def read_toml(self, path: str): + def read_toml(self, path: str) -> None: config = toml.load(path) if "basic" in config: diff --git a/olah/constants.py b/olah/constants.py index cf4cd1e..5dd5686 100644 --- a/olah/constants.py +++ b/olah/constants.py @@ -1,5 +1,4 @@ - WORKER_API_TIMEOUT = 15 CHUNK_SIZE = 4096 -LFS_FILE_BLOCK = 64 * 1024 * 1024 \ No newline at end of file +LFS_FILE_BLOCK = 64 * 1024 * 1024 diff --git a/olah/files.py b/olah/files.py index 7e354b7..385b746 100644 --- a/olah/files.py +++ b/olah/files.py @@ -1,4 +1,3 @@ - import json import os import shutil @@ -7,109 +6,219 @@ from fastapi import Request import httpx +from starlette.datastructures import URL from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT -from olah.utls import check_cache_rules_hf +from olah.utls import check_cache_rules_hf, get_org_repo + + +async def _file_head_cache_stream(app, save_path: str, request: Request): + with open(save_path, "r", encoding="utf-8") as f: + response_headers = json.loads(f.read()) + if "location" in response_headers: + response_headers["location"] = response_headers["location"].replace( + app.app_settings.hf_url, app.app_settings.mirror_url + ) + yield response_headers -async def file_head_generator(app, repo_type: Literal["model", "dataset"], org: str, repo: str, commit: str, file_path: str, request: Request): + +async def _file_head_realtime_stream( + app, + save_path: str, + url: str, + headers, + request: Request, + method="HEAD", + allow_cache=True, +): + async with httpx.AsyncClient() as client: + async with client.stream( + method=method, + url=url, + headers=headers, + timeout=WORKER_API_TIMEOUT, + ) as response: + response_headers = response.headers + response_headers = {k: v for k, v in response_headers.items()} + if allow_cache: + with open(save_path, "w", encoding="utf-8") as f: + f.write(json.dumps(response_headers, ensure_ascii=False)) + if "location" in response_headers: + response_headers["location"] = response_headers["location"].replace( + app.app_settings.hf_url, app.app_settings.mirror_url + ) + yield response_headers + + async for raw_chunk in response.aiter_raw(): + if not raw_chunk: + continue + yield raw_chunk + + +async def file_head_generator( + app, + repo_type: Literal["models", "datasets"], + org: str, + repo: str, + commit: str, + file_path: str, + request: Request, +): headers = {k: v for k, v in request.headers.items()} headers.pop("host") # save repos_path = app.app_settings.repos_path - save_path = os.path.join(repos_path, f"heads/{repo_type}s/{org}/{repo}/resolve_head/{commit}/{file_path}") + save_path = os.path.join( + repos_path, f"heads/{repo_type}/{org}/{repo}/resolve_head/{commit}/{file_path}" + ) save_dir = os.path.dirname(save_path) if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) - + use_cache = os.path.exists(save_path) allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) # proxy if use_cache: - with open(save_path, "r", encoding="utf-8") as f: - response_headers = json.loads(f.read()) - if "location" in response_headers: - response_headers["location"] = response_headers["location"].replace(app.app_settings.hf_lfs_url, app.app_settings.mirror_lfs_url) - yield response_headers + return _file_head_cache_stream(app=app, save_path=save_path, request=request) else: - if repo_type == "model": + if repo_type == "models": url = f"{app.app_settings.hf_url}/{org}/{repo}/resolve/{commit}/{file_path}" else: - url = f"{app.app_settings.hf_url}/{repo_type}s/{org}/{repo}/resolve/{commit}/{file_path}" + url = f"{app.app_settings.hf_url}/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}" + return _file_head_realtime_stream( + app=app, + save_path=save_path, + url=url, + headers=headers, + request=request, + method="HEAD", + allow_cache=allow_cache, + ) + + +async def _file_cache_stream(save_path: str, request: Request): + yield request.headers + with open(save_path, "rb") as f: + while True: + chunk = f.read(CHUNK_SIZE) + if not chunk: + break + yield chunk + + +async def _file_realtime_stream( + save_path: str, url: str, headers, request: Request, method="GET", allow_cache=True +): + temp_file_path = None + try: async with httpx.AsyncClient() as client: - async with client.stream( - method="HEAD", url=url, - headers=headers, - timeout=WORKER_API_TIMEOUT, - ) as response: - response_headers = response.headers - response_headers = {k: v for k, v in response_headers.items()} - if allow_cache: - with open(save_path, "w", encoding="utf-8") as f: - f.write(json.dumps(response_headers, ensure_ascii=False)) - if "location" in response_headers: - response_headers["location"] = response_headers["location"].replace(app.app_settings.hf_lfs_url, app.app_settings.mirror_lfs_url) - yield response_headers - - async for raw_chunk in response.aiter_raw(): - if not raw_chunk: - continue - yield raw_chunk - - -async def file_get_generator(app, repo_type: Literal["model", "dataset"], org: str, repo: str, commit: str, file_path: str, request: Request): + with tempfile.NamedTemporaryFile(mode="wb", delete=False) as temp_file: + if not allow_cache: + temp_file = open(os.devnull, "wb") + async with client.stream( + method=method, + url=url, + headers=headers, + timeout=WORKER_API_TIMEOUT, + ) as response: + response_headers = response.headers + yield response_headers + + async for raw_chunk in response.aiter_raw(): + if not raw_chunk: + continue + temp_file.write(raw_chunk) + yield raw_chunk + if not allow_cache: + temp_file_path = None + else: + temp_file_path = temp_file.name + if temp_file_path is not None: + shutil.copyfile(temp_file_path, save_path) + finally: + if temp_file_path is not None: + os.remove(temp_file_path) + + +async def file_get_generator( + app, + repo_type: Literal["models", "datasets"], + org: str, + repo: str, + commit: str, + file_path: str, + request: Request, +): headers = {k: v for k, v in request.headers.items()} headers.pop("host") # save repos_path = app.app_settings.repos_path - save_path = os.path.join(repos_path, f"files/{repo_type}s/{org}/{repo}/resolve/{commit}/{file_path}") + save_path = os.path.join( + repos_path, f"files/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}" + ) save_dir = os.path.dirname(save_path) if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) - + use_cache = os.path.exists(save_path) allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) # proxy if use_cache: - yield request.headers - with open(save_path, "rb") as f: - while True: - chunk = f.read(CHUNK_SIZE) - if not chunk: - break - yield chunk + return _file_cache_stream(save_path=save_path, request=request) else: - try: - temp_file_path = None - if repo_type == "model": - url = f"{app.app_settings.hf_url}/{org}/{repo}/resolve/{commit}/{file_path}" - else: - url = f"{app.app_settings.hf_url}/{repo_type}s/{org}/{repo}/resolve/{commit}/{file_path}" - async with httpx.AsyncClient() as client: - with tempfile.NamedTemporaryFile(mode="wb", delete=False) as temp_file: - if not allow_cache: - temp_file = open(os.devnull, 'wb') - async with client.stream( - method="GET", url=url, - headers=headers, - timeout=WORKER_API_TIMEOUT, - ) as response: - response_headers = response.headers - yield response_headers - - async for raw_chunk in response.aiter_raw(): - if not raw_chunk: - continue - temp_file.write(raw_chunk) - yield raw_chunk - if not allow_cache: - temp_file_path = None - else: - temp_file_path = temp_file.name - if temp_file_path is not None: - shutil.copyfile(temp_file_path, save_path) - finally: - if temp_file_path is not None: - os.remove(temp_file_path) + if repo_type == "models": + url = f"{app.app_settings.hf_url}/{org}/{repo}/resolve/{commit}/{file_path}" + else: + url = f"{app.app_settings.hf_url}/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}" + return _file_realtime_stream( + save_path=save_path, + url=url, + headers=headers, + request=request, + method="GET", + allow_cache=allow_cache, + ) + + +async def cdn_file_get_generator( + app, + repo_type: Literal["models", "datasets"], + org: str, + repo: str, + file_hash: str, + request: Request, +): + headers = {k: v for k, v in request.headers.items()} + headers.pop("host") + + org_repo = get_org_repo(org, repo) + # save + repos_path = app.app_settings.repos_path + save_path = os.path.join( + repos_path, f"files/{repo_type}/cdn/{org}/{repo}/{file_hash}" + ) + save_dir = os.path.dirname(save_path) + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + + use_cache = os.path.exists(save_path) + allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) + + # proxy + if use_cache: + return _file_cache_stream(save_path=save_path, request=request) + else: + redirected_url = str(request.url) + redirected_url = redirected_url.replace(app.app_settings.hf_lfs_url, app.app_settings.mirror_lfs_url) + + return _file_realtime_stream( + save_path=save_path, + url=str(redirected_url), + headers=headers, + request=request, + method="GET", + allow_cache=allow_cache, + ) diff --git a/olah/lfs.py b/olah/lfs.py index 08f77f2..a0c3930 100644 --- a/olah/lfs.py +++ b/olah/lfs.py @@ -16,7 +16,7 @@ async def lfs_get_generator(app, repo_type: str, lfs_url: str, save_path: str, r # save repos_path = app.app_settings.repos_path - save_dir = os.path.join(repos_path, f"lfs/{repo_type}s/{save_path}") + save_dir = os.path.join(repos_path, f"lfs/{repo_type}/{save_path}") if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) diff --git a/olah/meta.py b/olah/meta.py index 89e9915..efe9e00 100644 --- a/olah/meta.py +++ b/olah/meta.py @@ -52,20 +52,20 @@ async def meta_proxy_generator(app: FastAPI, headers: Dict[str, str], meta_url: if temp_file_path is not None: os.remove(temp_file_path) -async def meta_generator(app: FastAPI, repo_type: Literal["model", "dataset"], org: str, repo: str, commit: str, request: Request): +async def meta_generator(app: FastAPI, repo_type: Literal["models", "datasets"], org: str, repo: str, commit: str, request: Request): headers = {k: v for k, v in request.headers.items()} headers.pop("host") # save repos_path = app.app_settings.repos_path - save_dir = os.path.join(repos_path, f"api/{repo_type}s/{org}/{repo}/revision/{commit}") + save_dir = os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}") save_path = os.path.join(save_dir, "meta.json") if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) use_cache = os.path.exists(save_path) allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) - meta_url = f"{app.app_settings.hf_url}/api/{repo_type}s/{org}/{repo}/revision/{commit}" + meta_url = f"{app.app_settings.hf_url}/api/{repo_type}/{org}/{repo}/revision/{commit}" # proxy if use_cache: async for item in meta_cache_generator(app, save_path): diff --git a/olah/server.py b/olah/server.py index 842de8c..32bdd5d 100644 --- a/olah/server.py +++ b/olah/server.py @@ -4,13 +4,13 @@ import argparse import tempfile import shutil -from typing import Annotated, Union +from typing import Annotated, Optional, Union from fastapi import FastAPI, Header, Request from fastapi.responses import HTMLResponse, StreamingResponse, Response import httpx from pydantic import BaseSettings from olah.configs import OlahConfig -from olah.files import file_get_generator, file_head_generator +from olah.files import cdn_file_get_generator, file_get_generator, file_head_generator from olah.lfs import lfs_get_generator from olah.meta import meta_generator from olah.utls import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf @@ -26,8 +26,15 @@ class AppSettings(BaseSettings): mirror_url: str = "http://localhost:8090" mirror_lfs_url: str = "http://localhost:8090" -@app.get("/api/{repo_type}s/{org}/{repo}") -async def meta_proxy(repo_type: str, org: str, repo: str, request: Request): +@app.get("/api/{repo_type}/{org_repo}") +async def meta_proxy(repo_type: str, org_repo: str, request: Request): + if "/" in org_repo and org_repo.count("/") != 1: + return Response(content="This repository is not accessible.", status_code=404) + if "/" in org_repo: + org, repo = org_repo.split("/") + else: + org = None + repo = org_repo if not await check_proxy_rules_hf(app, repo_type, org, repo): return Response(content="This repository is forbidden by the mirror.", status_code=403) if not await check_commit_hf(app, repo_type, org, repo, None): @@ -37,8 +44,25 @@ async def meta_proxy(repo_type: str, org: str, repo: str, request: Request): headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) -@app.get("/api/{repo_type}s/{org}/{repo}/revision/{commit}") -async def meta_proxy(repo_type: str, org: str, repo: str, commit: str, request: Request): +@app.get("/api/{repo_type}/{org}/{repo}/revision/{commit}") +async def meta_proxy_commit2(repo_type: str, org: str, repo: str, commit: str, request: Request): + if not await check_proxy_rules_hf(app, repo_type, org, repo): + return Response(content="This repository is forbidden by the mirror. ", status_code=403) + if not await check_commit_hf(app, repo_type, org, repo, commit): + return Response(content="This repository is not accessible. ", status_code=404) + generator = meta_generator(app, repo_type, org, repo, commit, request) + headers = await generator.__anext__() + return StreamingResponse(generator, headers=headers) + +@app.get("/api/{repo_type}/{org_repo}/revision/{commit}") +async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request): + if "/" in org_repo and org_repo.count("/") != 1: + return Response(content="This repository is not accessible.", status_code=404) + if "/" in org_repo: + org, repo = org_repo.split("/") + else: + org = None + repo = org_repo if not await check_proxy_rules_hf(app, repo_type, org, repo): return Response(content="This repository is forbidden by the mirror. ", status_code=403) if not await check_commit_hf(app, repo_type, org, repo, commit): @@ -47,33 +71,91 @@ async def meta_proxy(repo_type: str, org: str, repo: str, commit: str, request: headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) -@app.head("/{repo_type}s/{org}/{repo}/resolve/{commit}/{file_path:path}") +@app.head("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}") @app.head("/{org}/{repo}/resolve/{commit}/{file_path:path}") -async def file_head_proxy(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str = "model"): +async def file_head_proxy2(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str = "models"): + if not await check_proxy_rules_hf(app, repo_type, org, repo): + return Response(content="This repository is forbidden by the mirror. ", status_code=403) + if not await check_commit_hf(app, repo_type, org, repo, commit): + return Response(content="This repository is not accessible. ", status_code=404) + commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) + generator = await file_head_generator(app, repo_type, org, repo, commit_sha, file_path, request) + headers = await generator.__anext__() + return StreamingResponse(generator, headers=headers) + +@app.head("/{repo_type}/{org_repo}/resolve/{commit}/{file_path:path}") +@app.head("/{org_repo}/resolve/{commit}/{file_path:path}") +async def file_head_proxy(org_repo: str, commit: str, file_path: str, request: Request, repo_type: str = "models"): + if "/" in org_repo and org_repo.count("/") != 1: + return Response(content="This repository is not accessible.", status_code=404) + if "/" in org_repo: + org, repo = org_repo.split("/") + else: + org = None + repo = org_repo + if not await check_proxy_rules_hf(app, repo_type, org, repo): return Response(content="This repository is forbidden by the mirror. ", status_code=403) if not await check_commit_hf(app, repo_type, org, repo, commit): return Response(content="This repository is not accessible. ", status_code=404) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = file_head_generator(app, repo_type, org, repo, commit_sha, file_path, request) + generator = await file_head_generator(app, repo_type, org, repo, commit_sha, file_path, request) headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) -@app.get("/{repo_type}s/{org}/{repo}/resolve/{commit}/{file_path:path}") +@app.get("/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path:path}") @app.get("/{org}/{repo}/resolve/{commit}/{file_path:path}") -async def file_proxy(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str = "model"): +async def file_proxy2(org: str, repo: str, commit: str, file_path: str, request: Request, repo_type: str = "models"): + if not await check_proxy_rules_hf(app, repo_type, org, repo): + return Response(content="This repository is forbidden by the mirror. ", status_code=403) + if not await check_commit_hf(app, repo_type, org, repo, commit): + return Response(content="This repository is not accessible. ", status_code=404) + commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) + generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path, request) + headers = await generator.__anext__() + return StreamingResponse(generator, headers=headers) + +@app.get("/{repo_type}/{org_repo}/resolve/{commit}/{file_path:path}") +@app.get("/{org_repo}/resolve/{commit}/{file_path:path}") +async def file_proxy(org_repo: str, commit: str, file_path: str, request: Request, repo_type: str = "models"): + if "/" in org_repo and org_repo.count("/") != 1: + return Response(content="This repository is not accessible.", status_code=404) + if "/" in org_repo: + org, repo = org_repo.split("/") + else: + org = None + repo = org_repo + if not await check_proxy_rules_hf(app, repo_type, org, repo): return Response(content="This repository is forbidden by the mirror. ", status_code=403) if not await check_commit_hf(app, repo_type, org, repo, commit): return Response(content="This repository is not accessible. ", status_code=404) commit_sha = await get_commit_hf(app, repo_type, org, repo, commit) - generator = file_get_generator(app, repo_type, org, repo, commit_sha, file_path, request) + generator = await file_get_generator(app, repo_type, org, repo, commit_sha, file_path, request) headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) +@app.get("/{repo_type}/{org_repo}/{hash_file}") +async def cdn_file_proxy(org_repo: str, hash_file: str, request: Request, repo_type: str = "models"): + if "/" in org_repo and org_repo.count("/") != 1: + return Response(content="This repository is not accessible.", status_code=404) + if "/" in org_repo: + org, repo = org_repo.split("/") + else: + org = None + repo = org_repo + + if not await check_proxy_rules_hf(app, repo_type, org, repo): + return Response(content="This repository is forbidden by the mirror. ", status_code=403) + + generator = await cdn_file_get_generator(app, repo_type, org, repo, hash_file, request) + headers = await generator.__anext__() + return StreamingResponse(generator, headers=headers) + + @app.get("/repos/{dir1}/{dir2}/{hash_repo}/{hash_file}") async def lfs_proxy(dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request): - repo_type = "model" + repo_type = "models" lfs_url = f"{app.app_settings.hf_lfs_url}/repos/{dir1}/{dir2}/{hash_repo}/{hash_file}" save_path = f"{dir1}/{dir2}/{hash_repo}/{hash_file}" generator = lfs_get_generator(app, repo_type, lfs_url, save_path, request) @@ -82,7 +164,7 @@ async def lfs_proxy(dir1: str, dir2: str, hash_repo: str, hash_file: str, reques @app.get("/datasets/hendrycks_test/{hash_file}") async def lfs_proxy(hash_file: str, request: Request): - repo_type = "dataset" + repo_type = "datasets" lfs_url = f"{app.app_settings.hf_lfs_url}/datasets/hendrycks_test/{hash_file}" save_path = f"hendrycks_test/{hash_file}" generator = lfs_get_generator(app, repo_type, lfs_url, save_path, request) @@ -158,7 +240,7 @@ def is_default_value(args, arg_name): host=args.host, port=args.port, log_level="info", - reload=False, + reload=True, ssl_keyfile=args.ssl_key, ssl_certfile=args.ssl_cert ) diff --git a/olah/utls.py b/olah/utls.py index 462f404..9613618 100644 --- a/olah/utls.py +++ b/olah/utls.py @@ -8,16 +8,23 @@ from olah.configs import OlahConfig from olah.constants import WORKER_API_TIMEOUT -def get_meta_save_path(repos_path: str, repo_type: str, org: str, repo: str, commit: str) -> str: - return os.path.join(repos_path, f"api/{repo_type}s/{org}/{repo}/revision/{commit}") +def get_org_repo(org: Optional[str], repo: str) -> str: + if org is None: + org_repo = repo + else: + org_repo = f"{org}/{repo}" + return org_repo + +def get_meta_save_path(repos_path: str, repo_type: str, org: Optional[str], repo: str, commit: str) -> str: + return os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}") -def get_meta_save_dir(repos_path: str, repo_type: str, org: str, repo: str) -> str: - return os.path.join(repos_path, f"api/{repo_type}s/{org}/{repo}/revision") +def get_meta_save_dir(repos_path: str, repo_type: str, org: Optional[str], repo: str) -> str: + return os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/revision") -def get_file_save_path(repos_path: str, repo_type: str, org: str, repo: str, commit: str, file_path: str) -> str: - return os.path.join(repos_path, f"heads/{repo_type}s/{org}/{repo}/resolve_head/{commit}/{file_path}") +def get_file_save_path(repos_path: str, repo_type: str, org: Optional[str], repo: str, commit: str, file_path: str) -> str: + return os.path.join(repos_path, f"heads/{repo_type}/{org}/{repo}/resolve_head/{commit}/{file_path}") -async def get_newest_commit_hf_offline(app, repo_type: Literal["model", "dataset", "space"], org: str, repo: str) -> str: +async def get_newest_commit_hf_offline(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: str, repo: str) -> str: repos_path = app.app_settings.repos_path save_dir = get_meta_save_dir(repos_path, repo_type, org, repo) files = glob.glob(os.path.join(save_dir, "*", "meta.json")) @@ -32,9 +39,9 @@ async def get_newest_commit_hf_offline(app, repo_type: Literal["model", "dataset time_revisions = sorted(time_revisions) return time_revisions[-1][1] -async def get_newest_commit_hf(app, repo_type: Literal["model", "dataset", "space"], org: str, repo: str) -> str: - url = f"{app.app_settings.hf_url}/api/{repo_type}s/{org}/{repo}" - if app.app_settings.offline: +async def get_newest_commit_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str) -> str: + url = f"{app.app_settings.hf_url}/api/{repo_type}/{org}/{repo}" + if app.app_settings.config.offline: return get_newest_commit_hf_offline(app, repo_type, org, repo) try: async with httpx.AsyncClient() as client: @@ -46,7 +53,7 @@ async def get_newest_commit_hf(app, repo_type: Literal["model", "dataset", "spac except: return get_newest_commit_hf_offline(app, repo_type, org, repo) -async def get_commit_hf_offline(app, repo_type: Literal["model", "dataset", "space"], org: str, repo: str, commit: str) -> str: +async def get_commit_hf_offline(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, commit: str) -> str: repos_path = app.app_settings.repos_path save_path = get_meta_save_path(repos_path, repo_type, org, repo, commit) @@ -55,35 +62,39 @@ async def get_commit_hf_offline(app, repo_type: Literal["model", "dataset", "spa return obj["sha"] -async def get_commit_hf(app, repo_type: Literal["model", "dataset", "space"], org: str, repo: str, commit: str) -> str: - url = f"{app.app_settings.hf_url}/api/{repo_type}s/{org}/{repo}/revision/{commit}" - if app.app_settings.offline: - return get_commit_hf_offline(app, repo_type, org, repo, commit) +async def get_commit_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, commit: str) -> str: + org_repo = get_org_repo(org, repo) + url = f"{app.app_settings.hf_url}/api/{repo_type}/{org_repo}/revision/{commit}" + if app.app_settings.config.offline: + return await get_commit_hf_offline(app, repo_type, org, repo, commit) try: async with httpx.AsyncClient() as client: response = await client.get(url, timeout=WORKER_API_TIMEOUT) if response.status_code != 200: - return get_commit_hf_offline(app, repo_type, org, repo, commit) + return await get_commit_hf_offline(app, repo_type, org, repo, commit) obj = json.loads(response.text) return obj.get("sha", None) except: - return get_commit_hf_offline(app, repo_type, org, repo, commit) + return await get_commit_hf_offline(app, repo_type, org, repo, commit) -async def check_commit_hf(app, repo_type: Literal["model", "dataset", "space"], org: str, repo: str, commit: Optional[str]=None) -> bool: +async def check_commit_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, commit: Optional[str]=None) -> bool: + org_repo = get_org_repo(org, repo) if commit is None: - url = f"{app.app_settings.hf_url}/api/{repo_type}s/{org}/{repo}" + url = f"{app.app_settings.hf_url}/api/{repo_type}/{org_repo}" else: - url = f"{app.app_settings.hf_url}/api/{repo_type}s/{org}/{repo}/revision/{commit}" + url = f"{app.app_settings.hf_url}/api/{repo_type}/{org_repo}/revision/{commit}" + async with httpx.AsyncClient() as client: - response = await client.get(url, - timeout=WORKER_API_TIMEOUT) - return response.status_code == 200 + response = await client.get(url, timeout=WORKER_API_TIMEOUT) + return response.status_code in [200, 307] -async def check_proxy_rules_hf(app, repo_type: Literal["model", "dataset", "space"], org: str, repo: str) -> bool: +async def check_proxy_rules_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str) -> bool: config: OlahConfig = app.app_settings.config - return config.proxy.allow(f"{org}/{repo}") + org_repo = get_org_repo(org, repo) + return config.proxy.allow(f"{org_repo}") -async def check_cache_rules_hf(app, repo_type: Literal["model", "dataset", "space"], org: str, repo: str) -> bool: +async def check_cache_rules_hf(app, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str) -> bool: config: OlahConfig = app.app_settings.config - return config.cache.allow(f"{org}/{repo}") + org_repo = get_org_repo(org, repo) + return config.cache.allow(f"{org_repo}") diff --git a/static/index.html b/static/index.html index d1a2c03..c30797c 100644 --- a/static/index.html +++ b/static/index.html @@ -8,9 +8,16 @@