From e76c553c8dc996141f604228396694b4673ed95d Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Mon, 25 Nov 2024 21:39:52 +0800 Subject: [PATCH] move app_settings to app.state --- src/olah/proxy/commits.py | 4 +- src/olah/proxy/files.py | 26 ++++---- src/olah/proxy/lfs.py | 4 +- src/olah/proxy/meta.py | 4 +- src/olah/proxy/pathsinfo.py | 4 +- src/olah/proxy/tree.py | 4 +- src/olah/server.py | 115 +++++++++++++++++++---------------- src/olah/utils/repo_utils.py | 16 ++--- src/olah/utils/rule_utils.py | 10 +-- 9 files changed, 99 insertions(+), 88 deletions(-) diff --git a/src/olah/proxy/commits.py b/src/olah/proxy/commits.py index 4e3323e..0f6a42e 100644 --- a/src/olah/proxy/commits.py +++ b/src/olah/proxy/commits.py @@ -81,7 +81,7 @@ async def commits_generator( headers["authorization"] = authorization # save - repos_path = app.app_settings.config.repos_path + repos_path = app.state.app_settings.config.repos_path save_dir = os.path.join( repos_path, f"api/{repo_type}/{org}/{repo}/commits/{commit}" ) @@ -92,7 +92,7 @@ async def commits_generator( org_repo = get_org_repo(org, repo) commits_url = urljoin( - app.app_settings.config.hf_url_base(), + app.state.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/commits/{commit}", ) # proxy diff --git a/src/olah/proxy/files.py b/src/olah/proxy/files.py index 21caab2..8125f81 100644 --- a/src/olah/proxy/files.py +++ b/src/olah/proxy/files.py @@ -267,7 +267,7 @@ async def _file_chunk_head( allow_cache: bool, file_size: int, ): - if not app.app_settings.config.offline: + if not app.state.app_settings.config.offline: async with client.stream( method=method, url=url, @@ -333,17 +333,17 @@ async def _file_realtime_stream( ) else: hf_url = urljoin( - app.app_settings.config.hf_lfs_url_base(), get_url_tail(clean_url) + app.state.app_settings.config.hf_lfs_url_base(), get_url_tail(clean_url) ) else: if urlparse(url).netloc in [ - app.app_settings.config.hf_netloc, - app.app_settings.config.hf_lfs_netloc, + app.state.app_settings.config.hf_netloc, + app.state.app_settings.config.hf_lfs_netloc, ]: hf_url = url else: hf_url = urljoin( - app.app_settings.config.hf_lfs_url_base(), get_url_tail(url) + app.state.app_settings.config.hf_lfs_url_base(), get_url_tail(url) ) request_headers = {k: v for k, v in request.headers.items()} @@ -409,7 +409,7 @@ async def _file_realtime_stream( etag = await _resource_etag( hf_url=hf_url, authorization=request.headers.get("authorization", None), - offline=app.app_settings.config.offline, + offline=app.state.app_settings.config.offline, ) response_headers["etag"] = etag @@ -466,7 +466,7 @@ async def file_get_generator( ): org_repo = get_org_repo(org, repo) # save - repos_path = app.app_settings.config.repos_path + repos_path = app.state.app_settings.config.repos_path head_path = os.path.join( repos_path, f"heads/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}" ) @@ -482,12 +482,12 @@ async def file_get_generator( # proxy if repo_type == "models": url = urljoin( - app.app_settings.config.hf_url_base(), + app.state.app_settings.config.hf_url_base(), f"/{org_repo}/resolve/{commit}/{file_path}", ) else: url = urljoin( - app.app_settings.config.hf_url_base(), + app.state.app_settings.config.hf_url_base(), f"/{repo_type}/{org_repo}/resolve/{commit}/{file_path}", ) return _file_realtime_stream( @@ -520,7 +520,7 @@ async def cdn_file_get_generator( org_repo = get_org_repo(org, repo) # save - repos_path = app.app_settings.config.repos_path + repos_path = app.state.app_settings.config.repos_path head_path = os.path.join( repos_path, f"heads/{repo_type}/{org}/{repo}/cdn/{file_hash}" ) @@ -535,10 +535,10 @@ async def cdn_file_get_generator( # proxy # request_url = urlparse(str(request.url)) - # if request_url.netloc == app.app_settings.config.hf_lfs_netloc: - # redirected_url = urljoin(app.app_settings.config.mirror_lfs_url_base(), get_url_tail(request_url)) + # if request_url.netloc == app.state.app_settings.config.hf_lfs_netloc: + # redirected_url = urljoin(app.state.app_settings.config.mirror_lfs_url_base(), get_url_tail(request_url)) # else: - # redirected_url = urljoin(app.app_settings.config.mirror_url_base(), get_url_tail(request_url)) + # redirected_url = urljoin(app.state.app_settings.config.mirror_url_base(), get_url_tail(request_url)) return _file_realtime_stream( app=app, diff --git a/src/olah/proxy/lfs.py b/src/olah/proxy/lfs.py index 87f2693..249f5dc 100644 --- a/src/olah/proxy/lfs.py +++ b/src/olah/proxy/lfs.py @@ -17,7 +17,7 @@ async def lfs_head_generator( app, dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request ): # save - repos_path = app.app_settings.config.repos_path + repos_path = app.state.app_settings.config.repos_path head_path = os.path.join( repos_path, f"lfs/heads/{dir1}/{dir2}/{hash_repo}/{hash_file}" ) @@ -47,7 +47,7 @@ async def lfs_get_generator( app, dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request ): # save - repos_path = app.app_settings.config.repos_path + repos_path = app.state.app_settings.config.repos_path head_path = os.path.join( repos_path, f"lfs/heads/{dir1}/{dir2}/{hash_repo}/{hash_file}" ) diff --git a/src/olah/proxy/meta.py b/src/olah/proxy/meta.py index bcb6a7d..e239648 100644 --- a/src/olah/proxy/meta.py +++ b/src/olah/proxy/meta.py @@ -77,7 +77,7 @@ async def meta_generator( headers["authorization"] = authorization # save - repos_path = app.app_settings.config.repos_path + repos_path = app.state.app_settings.config.repos_path save_dir = os.path.join( repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}" ) @@ -89,7 +89,7 @@ async def meta_generator( org_repo = get_org_repo(org, repo) meta_url = urljoin( - app.app_settings.config.hf_url_base(), + app.state.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/revision/{commit}", ) # proxy diff --git a/src/olah/proxy/pathsinfo.py b/src/olah/proxy/pathsinfo.py index c0c7e23..4cb7563 100644 --- a/src/olah/proxy/pathsinfo.py +++ b/src/olah/proxy/pathsinfo.py @@ -73,7 +73,7 @@ async def pathsinfo_generator( headers["authorization"] = authorization # save - repos_path = app.app_settings.config.repos_path + repos_path = app.state.app_settings.config.repos_path final_content = [] for path in paths: @@ -88,7 +88,7 @@ async def pathsinfo_generator( org_repo = get_org_repo(org, repo) pathsinfo_url = urljoin( - app.app_settings.config.hf_url_base(), + app.state.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/paths-info/{commit}", ) # proxy diff --git a/src/olah/proxy/tree.py b/src/olah/proxy/tree.py index 92eb338..1a29946 100644 --- a/src/olah/proxy/tree.py +++ b/src/olah/proxy/tree.py @@ -83,7 +83,7 @@ async def tree_generator( headers["authorization"] = authorization # save - repos_path = app.app_settings.config.repos_path + repos_path = app.state.app_settings.config.repos_path save_dir = os.path.join( repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}/{path}" ) @@ -94,7 +94,7 @@ async def tree_generator( org_repo = get_org_repo(org, repo) tree_url = urljoin( - app.app_settings.config.hf_url_base(), + app.state.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/tree/{commit}/{path}", ) # proxy diff --git a/src/olah/server.py b/src/olah/server.py index 3ee3dc9..9f498e4 100644 --- a/src/olah/server.py +++ b/src/olah/server.py @@ -6,12 +6,14 @@ # https://opensource.org/licenses/MIT. from contextlib import asynccontextmanager +import datetime import os import glob import argparse +import sys import time import traceback -from typing import Annotated, List, Optional, Union +from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union from urllib.parse import urljoin from fastapi import FastAPI, Header, Request, Form from fastapi.responses import ( @@ -32,7 +34,6 @@ from olah.proxy.tree import tree_generator from olah.utils.disk_utils import convert_bytes_to_human_readable, convert_to_bytes, get_folder_size, sort_files_by_access_time, sort_files_by_modify_time, sort_files_by_size from olah.utils.url_utils import clean_path -from olah.utils.zip_utils import decompress_data BASE_SETTINGS = False if not BASE_SETTINGS: @@ -88,53 +89,61 @@ async def check_connection(url: str) -> bool: except httpx.TimeoutException: return False +# from pympler import tracker, classtracker +# tr = tracker.SummaryTracker() +# cr = classtracker.ClassTracker() +# from olah.cache.bitset import Bitset +# from olah.cache.olah_cache import OlahCache, OlahCacheHeader +# cr.track_class(Bitset) +# cr.track_class(OlahCacheHeader) +# cr.track_class(OlahCache) -@repeat_every(seconds=60 * 5) +@repeat_every(seconds=5) async def check_hf_connection() -> None: - if app.app_settings.config.offline: + if app.state.app_settings.config.offline: return - scheme = app.app_settings.config.hf_scheme - netloc = app.app_settings.config.hf_netloc + scheme = app.state.app_settings.config.hf_scheme + netloc = app.state.app_settings.config.hf_netloc hf_online_status = await check_connection( f"{scheme}://{netloc}/datasets/Salesforce/wikitext/resolve/main/.gitattributes" ) if not hf_online_status: - logger.error("Failed to reach Huggingface Site.") - + print("Failed to reach Huggingface Site.", file=sys.stderr) @repeat_every(seconds=60 * 60) async def check_disk_usage() -> None: - if app.app_settings.config.offline: + if app.state.app_settings.config.offline: return - if app.app_settings.config.cache_size_limit is None: + if app.state.app_settings.config.cache_size_limit is None: return - limit_size = app.app_settings.config.cache_size_limit - current_size = get_folder_size(app.app_settings.config.repos_path) + limit_size = app.state.app_settings.config.cache_size_limit + current_size = get_folder_size(app.state.app_settings.config.repos_path) limit_size_h = convert_bytes_to_human_readable(limit_size) current_size_h = convert_bytes_to_human_readable(current_size) if current_size < limit_size: return - logger.warning( + print( f"Cache size exceeded! Limit: {limit_size_h}, Current: {current_size_h}." ) - logger.info("Cleaning...") - files_path = os.path.join(app.app_settings.config.repos_path, "files") - lfs_path = os.path.join(app.app_settings.config.repos_path, "lfs") + print("Cleaning...") + files_path = os.path.join(app.state.app_settings.config.repos_path, "files") + lfs_path = os.path.join(app.state.app_settings.config.repos_path, "lfs") - if app.app_settings.config.cache_clean_strategy == "LRU": + files: Sequence[Tuple[str, Union[int, datetime.datetime]]] = [] + if app.state.app_settings.config.cache_clean_strategy == "LRU": files = sort_files_by_access_time(files_path) + sort_files_by_access_time( lfs_path ) files = sorted(files, key=lambda x: x[1]) - elif app.app_settings.config.cache_clean_strategy == "FIFO": + elif app.state.app_settings.config.cache_clean_strategy == "FIFO": files = sort_files_by_modify_time(files_path) + sort_files_by_modify_time( lfs_path ) files = sorted(files, key=lambda x: x[1]) - elif app.app_settings.config.cache_clean_strategy == "LARGE_FIRST": + elif app.state.app_settings.config.cache_clean_strategy == "LARGE_FIRST": files = sort_files_by_size(files_path) + sort_files_by_size(lfs_path) files = sorted(files, key=lambda x: x[1], reverse=True) @@ -144,11 +153,11 @@ async def check_disk_usage() -> None: filesize = os.path.getsize(filepath) os.remove(filepath) current_size -= filesize - logger.info(f"Remove file: {filepath}. File Size: {convert_bytes_to_human_readable(filesize)}") + print(f"Remove file: {filepath}. File Size: {convert_bytes_to_human_readable(filesize)}") - current_size = get_folder_size(app.app_settings.config.repos_path) + current_size = get_folder_size(app.state.app_settings.config.repos_path) current_size_h = convert_bytes_to_human_readable(current_size) - logger.info(f"Cleaning finished. Limit: {limit_size_h}, Current: {current_size_h}.") + print(f"Cleaning finished. Limit: {limit_size_h}, Current: {current_size_h}.") @asynccontextmanager @@ -184,16 +193,16 @@ async def custom_404_handler(_, __): # File Meta Info API Hooks # See also: https://huggingface.co/docs/hub/api#repo-listing-api # ====================== -async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, method: str, authorization: Optional[str]) -> Response: +async def meta_proxy_common(repo_type: Literal["models", "datasets", "spaces"], org: str, repo: str, commit: str, method: str, authorization: Optional[str]) -> Response: # FIXME: do not show the private repos to other user besides owner, even though the repo was cached if repo_type not in REPO_TYPES_MAPPING.keys(): return error_page_not_found() if not await check_proxy_rules_hf(app, repo_type, org, repo): return error_repo_not_found() # Check Mirror Path - for mirror_path in app.app_settings.config.mirrors_path: + for mirror_path in app.state.app_settings.config.mirrors_path: + git_path = os.path.join(mirror_path, repo_type, org, repo) try: - git_path = os.path.join(mirror_path, repo_type, org, repo) if os.path.exists(git_path): local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) meta_data = local_repo.get_meta(commit) @@ -206,7 +215,7 @@ async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, me # Proxy the HF File Meta try: - if not app.app_settings.config.offline: + if not app.state.app_settings.config.offline: if not await check_commit_hf(app, repo_type, org, repo, commit=None, authorization=authorization, ): @@ -221,7 +230,7 @@ async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, me if commit_sha is None: return error_repo_not_found() # if branch name and online mode, refresh branch info - if not app.app_settings.config.offline and commit_sha != commit: + if not app.state.app_settings.config.offline and commit_sha != commit: generator = meta_generator( app=app, repo_type=repo_type, @@ -268,7 +277,7 @@ async def meta_proxy(repo_type: str, org_repo: str, request: Request): org, repo = parse_org_repo(org_repo) if org is None and repo is None: return error_repo_not_found() - if not app.app_settings.config.offline: + if not app.state.app_settings.config.offline: new_commit = await get_newest_commit_hf( app, repo_type, @@ -293,7 +302,7 @@ async def meta_proxy(repo_type: str, org_repo: str, request: Request): @app.head("/api/{repo_type}/{org}/{repo}") @app.get("/api/{repo_type}/{org}/{repo}") async def meta_proxy(repo_type: str, org: str, repo: str, request: Request): - if not app.app_settings.config.offline: + if not app.state.app_settings.config.offline: new_commit = await get_newest_commit_hf( app, repo_type, @@ -368,9 +377,9 @@ async def tree_proxy_common( if not await check_proxy_rules_hf(app, repo_type, org, repo): return error_repo_not_found() # Check Mirror Path - for mirror_path in app.app_settings.config.mirrors_path: + for mirror_path in app.state.app_settings.config.mirrors_path: + git_path = os.path.join(mirror_path, repo_type, org, repo) try: - git_path = os.path.join(mirror_path, repo_type, org, repo) if os.path.exists(git_path): local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) tree_data = local_repo.get_tree(commit, path, recursive=recursive, expand=expand) @@ -383,7 +392,7 @@ async def tree_proxy_common( # Proxy the HF File Meta try: - if not app.app_settings.config.offline: + if not app.state.app_settings.config.offline: if not await check_commit_hf(app, repo_type, org, repo, commit=None, authorization=authorization, ): @@ -398,7 +407,7 @@ async def tree_proxy_common( if commit_sha is None: return error_repo_not_found() # if branch name and online mode, refresh branch info - if not app.app_settings.config.offline and commit_sha != commit: + if not app.state.app_settings.config.offline and commit_sha != commit: generator = tree_generator( app=app, repo_type=repo_type, @@ -512,9 +521,9 @@ async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: st if not await check_proxy_rules_hf(app, repo_type, org, repo): return error_repo_not_found() # Check Mirror Path - for mirror_path in app.app_settings.config.mirrors_path: + for mirror_path in app.state.app_settings.config.mirrors_path: + git_path = os.path.join(mirror_path, repo_type, org, repo) try: - git_path = os.path.join(mirror_path, repo_type, org, repo) if os.path.exists(git_path): local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) pathsinfo_data = local_repo.get_pathinfos(commit, paths) @@ -527,7 +536,7 @@ async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: st # Proxy the HF File pathsinfo try: - if not app.app_settings.config.offline: + if not app.state.app_settings.config.offline: if not await check_commit_hf(app, repo_type, org, repo, commit=None, authorization=authorization, ): @@ -541,7 +550,7 @@ async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: st ) if commit_sha is None: return error_repo_not_found() - if not app.app_settings.config.offline and commit_sha != commit: + if not app.state.app_settings.config.offline and commit_sha != commit: generator = pathsinfo_generator( app=app, repo_type=repo_type, @@ -639,7 +648,7 @@ async def commits_proxy_common(repo_type: str, org: str, repo: str, commit: str, if not await check_proxy_rules_hf(app, repo_type, org, repo): return error_repo_not_found() # Check Mirror Path - for mirror_path in app.app_settings.config.mirrors_path: + for mirror_path in app.state.app_settings.config.mirrors_path: try: git_path = os.path.join(mirror_path, repo_type, org, repo) if os.path.exists(git_path): @@ -654,7 +663,7 @@ async def commits_proxy_common(repo_type: str, org: str, repo: str, commit: str, # Proxy the HF File Commits try: - if not app.app_settings.config.offline: + if not app.state.app_settings.config.offline: if not await check_commit_hf(app, repo_type, org, repo, commit=None, authorization=authorization, ): @@ -669,7 +678,7 @@ async def commits_proxy_common(repo_type: str, org: str, repo: str, commit: str, if commit_sha is None: return error_repo_not_found() # if branch name and online mode, refresh branch info - if not app.app_settings.config.offline and commit_sha != commit: + if not app.state.app_settings.config.offline and commit_sha != commit: generator = commits_generator( app=app, repo_type=repo_type, @@ -754,11 +763,11 @@ async def whoami_v2(request: Request): Sensitive Information!!! """ new_headers = {k.lower(): v for k, v in request.headers.items()} - new_headers["host"] = app.app_settings.config.hf_netloc + new_headers["host"] = app.state.app_settings.config.hf_netloc async with httpx.AsyncClient() as client: response = await client.request( method="GET", - url=urljoin(app.app_settings.config.hf_url_base(), "/api/whoami-v2"), + url=urljoin(app.state.app_settings.config.hf_url_base(), "/api/whoami-v2"), headers=new_headers, timeout=10, ) @@ -787,9 +796,9 @@ async def file_head_common( return error_repo_not_found() # Check Mirror Path - for mirror_path in app.app_settings.config.mirrors_path: + for mirror_path in app.state.app_settings.config.mirrors_path: + git_path = os.path.join(mirror_path, repo_type, org, repo) try: - git_path = os.path.join(mirror_path, repo_type, org, repo) if os.path.exists(git_path): local_repo = LocalMirrorRepo(git_path, repo_type, org, repo) head = local_repo.get_file_head(commit_hash=commit, path=file_path) @@ -802,7 +811,7 @@ async def file_head_common( # Proxy the HF File Head try: - if not app.app_settings.config.offline and not await check_commit_hf( + if not app.state.app_settings.config.offline and not await check_commit_hf( app, repo_type, org, @@ -922,7 +931,7 @@ async def file_get_common( if not await check_proxy_rules_hf(app, repo_type, org, repo): return error_repo_not_found() # Check Mirror Path - for mirror_path in app.app_settings.config.mirrors_path: + for mirror_path in app.state.app_settings.config.mirrors_path: try: git_path = os.path.join(mirror_path, repo_type, org, repo) if os.path.exists(git_path): @@ -935,7 +944,7 @@ async def file_get_common( logger.warning(f"Local repository {git_path} is not a valid git reposity.") continue try: - if not app.app_settings.config.offline and not await check_commit_hf( + if not app.state.app_settings.config.offline and not await check_commit_hf( app, repo_type, org, @@ -1081,16 +1090,16 @@ async def index(request: Request): "index.html", { "request": request, - "scheme": app.app_settings.config.mirror_scheme, - "netloc": app.app_settings.config.mirror_netloc, + "scheme": app.state.app_settings.config.mirror_scheme, + "netloc": app.state.app_settings.config.mirror_netloc, }, ) @app.get("/repos", response_class=HTMLResponse) async def repos(request: Request): - datasets_repos = glob.glob(os.path.join(app.app_settings.config.repos_path, "api/datasets/*/*")) - models_repos = glob.glob(os.path.join(app.app_settings.config.repos_path, "api/models/*/*")) - spaces_repos = glob.glob(os.path.join(app.app_settings.config.repos_path, "api/spaces/*/*")) + datasets_repos = glob.glob(os.path.join(app.state.app_settings.config.repos_path, "api/datasets/*/*")) + models_repos = glob.glob(os.path.join(app.state.app_settings.config.repos_path, "api/models/*/*")) + spaces_repos = glob.glob(os.path.join(app.state.app_settings.config.repos_path, "api/spaces/*/*")) datasets_repos = [get_org_repo(*repo.split("/")[-2:]) for repo in datasets_repos] models_repos = [get_org_repo(*repo.split("/")[-2:]) for repo in models_repos] spaces_repos = [get_org_repo(*repo.split("/")[-2:]) for repo in spaces_repos] @@ -1222,7 +1231,7 @@ def is_default_value(args, arg_name): time.sleep(0.2) # Init app settings - app.app_settings = AppSettings(config=config) + app.state.app_settings = AppSettings(config=config) return args def main(): diff --git a/src/olah/utils/repo_utils.py b/src/olah/utils/repo_utils.py index f25c386..6873803 100644 --- a/src/olah/utils/repo_utils.py +++ b/src/olah/utils/repo_utils.py @@ -145,7 +145,7 @@ async def get_newest_commit_hf_offline( The newest commit hash as a string. """ - repos_path = app.app_settings.config.repos_path + repos_path = app.state.app_settings.config.repos_path save_dir = get_meta_save_dir(repos_path, repo_type, org, repo) files = glob.glob(os.path.join(save_dir, "*", "meta_head.json")) @@ -184,9 +184,9 @@ async def get_newest_commit_hf( """ url = urljoin( - app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org}/{repo}" + app.state.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org}/{repo}" ) - if app.app_settings.config.offline: + if app.state.app_settings.config.offline: return await get_newest_commit_hf_offline(app, repo_type, org, repo) try: async with httpx.AsyncClient() as client: @@ -224,7 +224,7 @@ async def get_commit_hf_offline( Returns: The commit SHA as a string if available in the offline cache, or None if the information is not cached. """ - repos_path = app.app_settings.config.repos_path + repos_path = app.state.app_settings.config.repos_path save_path = get_meta_save_path(repos_path, repo_type, org, repo, commit) if os.path.exists(save_path): request_cache = await read_cache_request(save_path) @@ -261,10 +261,10 @@ async def get_commit_hf( """ org_repo = get_org_repo(org, repo) url = urljoin( - app.app_settings.config.hf_url_base(), + app.state.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/revision/{commit}", ) - if app.app_settings.config.offline: + if app.state.app_settings.config.offline: return await get_commit_hf_offline(app, repo_type, org, repo, commit) try: headers = {} @@ -309,11 +309,11 @@ async def check_commit_hf( org_repo = get_org_repo(org, repo) if commit is None: url = urljoin( - app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}" + app.state.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}" ) else: url = urljoin( - app.app_settings.config.hf_url_base(), + app.state.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/revision/{commit}", ) diff --git a/src/olah/utils/rule_utils.py b/src/olah/utils/rule_utils.py index 0ad95e7..55f4b2e 100644 --- a/src/olah/utils/rule_utils.py +++ b/src/olah/utils/rule_utils.py @@ -7,27 +7,29 @@ from typing import Dict, Literal, Optional, Tuple, Union + +from fastapi import FastAPI from olah.configs import OlahConfig from .repo_utils import get_org_repo async def check_proxy_rules_hf( - app, + app: FastAPI, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, ) -> bool: - config: OlahConfig = app.app_settings.config + config: OlahConfig = app.state.app_settings.config org_repo = get_org_repo(org, repo) return config.proxy.allow(org_repo) async def check_cache_rules_hf( - app, + app: FastAPI, repo_type: Optional[Literal["models", "datasets", "spaces"]], org: Optional[str], repo: str, ) -> bool: - config: OlahConfig = app.app_settings.config + config: OlahConfig = app.state.app_settings.config org_repo = get_org_repo(org, repo) return config.cache.allow(org_repo)