Skip to content

Commit

Permalink
format all files
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 14, 2024
1 parent b7e8875 commit e9d9ba8
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 65 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2023 Vtuber Plan
Copyright (c) 2024 Vtuber Plan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
52 changes: 23 additions & 29 deletions olah/configs.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,26 @@

# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from typing import List, Optional
import toml
import re
import fnmatch

DEFAULT_PROXY_RULES = [
{
"repo": "*",
"allow": True,
"use_re": False
},
{
"repo": "*/*",
"allow": True,
"use_re": False
}
{"repo": "*", "allow": True, "use_re": False},
{"repo": "*/*", "allow": True, "use_re": False},
]

DEFAULT_CACHE_RULES = [
{
"repo": "*",
"allow": True,
"use_re": False
},
{
"repo": "*/*",
"allow": True,
"use_re": False
}
{"repo": "*", "allow": True, "use_re": False},
{"repo": "*/*", "allow": True, "use_re": False},
]


class OlahRule(object):
def __init__(self) -> None:
self.repo = ""
Expand All @@ -48,40 +38,42 @@ def from_dict(data) -> "OlahRule":
if "use_re" in data:
out.use_re = data["use_re"]
return out

def match(self, repo_name: str) -> bool:
if self.use_re:
return self.match_re(repo_name)
else:
return self.match_fn(repo_name)

def match_fn(self, repo_name: str) -> bool:
return fnmatch.fnmatch(repo_name, self.repo)

def match_re(self, repo_name: str) -> bool:
return re.match(self.repo, repo_name) is not None


class OlahRuleList(object):
def __init__(self) -> None:
self.rules: List[OlahRule] = []

@staticmethod
def from_list(data) -> "OlahRuleList":
out = OlahRuleList()
for item in data:
out.rules.append(OlahRule.from_dict(item))
return out

def clear(self):
self.rules.clear()

def allow(self, repo_name: str) -> bool:
allow = False
for rule in self.rules:
if rule.match(repo_name):
allow = rule.allow
return allow


class OlahConfig(object):
def __init__(self, path: Optional[str] = None) -> None:

Expand All @@ -107,7 +99,7 @@ def __init__(self, path: Optional[str] = None) -> None:

if path is not None:
self.read_toml(path)

def hf_url_base(self) -> str:
return f"{self.hf_scheme}://{self.hf_netloc}"

Expand Down Expand Up @@ -143,7 +135,9 @@ def read_toml(self, path: str) -> None:

self.mirror_scheme = basic.get("mirror-scheme", self.mirror_scheme)
self.mirror_netloc = basic.get("mirror-netloc", self.mirror_netloc)
self.mirror_lfs_netloc = basic.get("mirror-lfs-netloc", self.mirror_lfs_netloc)
self.mirror_lfs_netloc = basic.get(
"mirror-lfs-netloc", self.mirror_lfs_netloc
)

if "accessibility" in config:
accessibility = config["accessibility"]
Expand Down
6 changes: 6 additions & 0 deletions olah/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

WORKER_API_TIMEOUT = 15
CHUNK_SIZE = 4096
Expand Down
13 changes: 9 additions & 4 deletions olah/files.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import json
import os
import shutil
import tempfile
from typing import Dict, Literal, Optional
from fastapi import Request

Expand Down Expand Up @@ -296,7 +301,7 @@ async def _file_realtime_stream(

async def file_get_generator(
app,
repo_type: Literal["models", "datasets"],
repo_type: Literal["models", "datasets", "spaces"],
org: str,
repo: str,
commit: str,
Expand Down Expand Up @@ -337,7 +342,7 @@ async def file_get_generator(

async def cdn_file_get_generator(
app,
repo_type: Literal["models", "datasets"],
repo_type: Literal["models", "datasets", "spaces"],
org: str,
repo: str,
file_hash: str,
Expand Down
18 changes: 11 additions & 7 deletions olah/lfs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import datetime
import json
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import os
from typing import Literal
from fastapi import FastAPI, Header, Request
Expand All @@ -10,8 +15,7 @@


async def lfs_head_generator(
app,
dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request
app, dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request
):
# save
repos_path = app.app_settings.repos_path
Expand Down Expand Up @@ -39,9 +43,9 @@ async def lfs_head_generator(
commit=None,
)


async def lfs_get_generator(
app,
dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request
app, dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request
):
# save
repos_path = app.app_settings.repos_path
Expand All @@ -67,4 +71,4 @@ async def lfs_get_generator(
method="GET",
allow_cache=allow_cache,
commit=None,
)
)
45 changes: 36 additions & 9 deletions olah/meta.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@


# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import os
import shutil
Expand All @@ -9,12 +13,12 @@
from fastapi import FastAPI, Request

import httpx
from olah.configs import OlahConfig
from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT

from olah.utils.url_utils import check_cache_rules_hf, get_org_repo
from olah.utils.file_utils import make_dirs


async def meta_cache_generator(app: FastAPI, save_path: str):
yield {}
with open(save_path, "rb") as f:
Expand All @@ -24,7 +28,14 @@ async def meta_cache_generator(app: FastAPI, save_path: str):
break
yield chunk

async def meta_proxy_generator(app: FastAPI, headers: Dict[str, str], meta_url: str, allow_cache: bool, save_path: str):

async def meta_proxy_generator(
app: FastAPI,
headers: Dict[str, str],
meta_url: str,
allow_cache: bool,
save_path: str,
):
try:
temp_file_path = None
async with httpx.AsyncClient(follow_redirects=True) as client:
Expand All @@ -35,7 +46,8 @@ async def meta_proxy_generator(app: FastAPI, headers: Dict[str, str], meta_url:
else:
write_temp_file = True
async with client.stream(
method="GET", url=meta_url,
method="GET",
url=meta_url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
Expand All @@ -55,25 +67,40 @@ async def meta_proxy_generator(app: FastAPI, headers: Dict[str, str], meta_url:
if temp_file_path is not None and os.path.exists(temp_file_path):
os.remove(temp_file_path)

async def meta_generator(app: FastAPI, repo_type: Literal["models", "datasets"], org: str, repo: str, commit: str, request: Request):

async def meta_generator(
app: FastAPI,
repo_type: Literal["models", "datasets", "spaces"],
org: str,
repo: str,
commit: str,
request: Request,
):
headers = {k: v for k, v in request.headers.items()}
headers.pop("host")

# save
repos_path = app.app_settings.repos_path
save_dir = os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}")
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}"
)
save_path = os.path.join(save_dir, "meta.json")
make_dirs(save_path)

use_cache = os.path.exists(save_path)
allow_cache = await check_cache_rules_hf(app, repo_type, org, repo)

org_repo = get_org_repo(org, repo)
meta_url = urljoin(app.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}/revision/{commit}")
meta_url = urljoin(
app.app_settings.config.hf_url_base(),
f"/api/{repo_type}/{org_repo}/revision/{commit}",
)
# proxy
if use_cache:
async for item in meta_cache_generator(app, save_path):
yield item
else:
async for item in meta_proxy_generator(app, headers, meta_url, allow_cache, save_path):
async for item in meta_proxy_generator(
app, headers, meta_url, allow_cache, save_path
):
yield item
19 changes: 11 additions & 8 deletions olah/server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import datetime
import json
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import os
import argparse
import tempfile
import shutil
from typing import Annotated, Optional, Union
from urllib.parse import urljoin
from fastapi import FastAPI, Header, Request
Expand Down Expand Up @@ -238,9 +241,9 @@ async def index():
parser.add_argument("--config", "-c", type=str, default="")
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8090)
parser.add_argument("--ssl-key", type=str, default=None)
parser.add_argument("--ssl-cert", type=str, default=None)
parser.add_argument("--repos-path", type=str, default="./repos")
parser.add_argument("--ssl-key", type=str, default=None, help="The SSL key file path, if HTTPS is used")
parser.add_argument("--ssl-cert", type=str, default=None, help="The SSL cert file path, if HTTPS is used")
parser.add_argument("--repos-path", type=str, default="./repos", help="The folder to save cached repositories")
parser.add_argument("--log-path", type=str, default="./logs", help="The folder to save logs")
args = parser.parse_args()
print(args)
Expand Down Expand Up @@ -282,7 +285,7 @@ def is_default_value(args, arg_name):
host=args.host,
port=args.port,
log_level="info",
reload=True,
reload=False,
ssl_keyfile=args.ssl_key,
ssl_certfile=args.ssl_cert
)
7 changes: 7 additions & 0 deletions olah/utils/bitset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

class Bitset:
def __init__(self, size):
self.size = size
Expand Down
9 changes: 7 additions & 2 deletions olah/utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@

# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import os

Expand All @@ -9,4 +14,4 @@ def make_dirs(path: str):
else:
save_dir = os.path.dirname(path)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
os.makedirs(save_dir, exist_ok=True)
14 changes: 10 additions & 4 deletions olah/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# coding=utf-8
# Copyright 2024 XiaHan
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from asyncio import AbstractEventLoop
import json
Expand Down Expand Up @@ -52,10 +58,10 @@ def build_logger(logger_name, logger_filename, logger_dir=DEFAULT_LOGGER_DIR) ->
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl

# stderr_logger = logging.getLogger("stderr")
# stderr_logger.setLevel(logging.ERROR)
# sl = StreamToLogger(stderr_logger, logging.ERROR)
# sys.stderr = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl

# Get logger
logger = logging.getLogger(logger_name)
Expand Down
Loading

0 comments on commit e9d9ba8

Please sign in to comment.