Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 13, 2024
1 parent 1c98ce2 commit e34cbce
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 46 deletions.
66 changes: 21 additions & 45 deletions olah/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,25 +132,8 @@ async def _file_header(
if os.path.exists(head_path):
with open(head_path, "r", encoding="utf-8") as f:
response_headers = json.loads(f.read())
response_headers = {k.lower():v for k, v in response_headers.items()}
new_headers = {k.lower():v for k, v in FILE_HEADER_TEMPLATE.items()}
new_headers["content-type"] = response_headers["content-type"]
# new_headers["content-length"] = response_headers["content-length"]
if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers:
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = response_headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers:
new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers:
new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = response_headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "")
new_headers["etag"] = response_headers["etag"]

if commit is not None:
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit
return new_headers
response_headers_dict = {k.lower():v for k, v in response_headers.items()}
else:
# Redirect Header
if "range" in headers:
headers.pop("range")
async with client.stream(
method=method,
url=url,
Expand All @@ -161,21 +144,18 @@ async def _file_header(
if allow_cache and method.lower() == "head":
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(response_headers_dict, ensure_ascii=False))
if "location" in response_headers_dict:
location_url = urlparse(response_headers_dict["location"])
if location_url.netloc == app.app_settings.config.hf_lfs_netloc:
response_headers_dict["location"] = urljoin(
app.app_settings.config.mirror_lfs_url_base(),
get_url_tail(location_url),
)
else:
response_headers_dict["location"] = urljoin(
app.app_settings.config.mirror_url_base(),
get_url_tail(location_url),
)
if commit is not None:
response_headers_dict[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit
return response_headers_dict

new_headers = {}
new_headers["content-type"] = response_headers_dict["content-type"]
new_headers["content-length"] = response_headers_dict["content-length"]
if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "")
if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "")
new_headers["etag"] = response_headers_dict["etag"]
return new_headers

async def _get_file_block_from_cache(cache_file: OlahCache, block_index: int):
return cache_file.read_block(block_index)
Expand Down Expand Up @@ -314,17 +294,13 @@ async def _file_realtime_stream(
allow_cache=allow_cache,
commit=commit,
)
response_headers = await _file_header(
app=app,
save_path=save_path,
head_path=head_path,
client=client,
method=method,
url=redirect_loc,
headers=request_headers,
allow_cache=allow_cache,
commit=commit,
)
file_size = int(head_info["content-length"])
response_headers = {k: v for k,v in head_info.items()}
if "range" in request_headers:
start_pos, end_pos = parse_range_params(request_headers.get("range", f"bytes={0}-{file_size}"), file_size)
response_headers["content-length"] = end_pos - start_pos
if commit is not None:
response_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit
yield response_headers
if method.lower() == "get":
async for each_chunk in _file_chunk_get(
Expand All @@ -336,7 +312,7 @@ async def _file_realtime_stream(
url=redirect_loc,
headers=request_headers,
allow_cache=allow_cache,
file_size=int(head_info["content-length"]),
file_size=file_size,
commit=commit,
):
yield each_chunk
Expand Down
2 changes: 1 addition & 1 deletion olah/utils/olah_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

CURRENT_OLAH_CACHE_VERSION = 8
DEFAULT_BLOCK_MASK_MAX = 1024 * 1024
DEFAULT_BLOCK_SIZE = 64 * 1024 * 1024
DEFAULT_BLOCK_SIZE = 8 * 1024 * 1024


class OlahCacheHeader(object):
Expand Down

0 comments on commit e34cbce

Please sign in to comment.