Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: suppress GatedRepoError to use cache file (fix #28558). #28566

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,7 @@ def from_pretrained(
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
Expand Down
1 change: 1 addition & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2788,6 +2788,7 @@ def from_pretrained(
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2926,6 +2926,7 @@ def from_pretrained(
token=token,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
Expand Down Expand Up @@ -3369,6 +3370,7 @@ def from_pretrained(
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
**hub_kwargs,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def get_tokenizer_config(
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ def pipeline(
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
**hub_kwargs,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,6 +1961,7 @@ def from_pretrained(
local_files_only=local_files_only,
subfolder=subfolder,
user_agent=user_agent,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
Expand Down Expand Up @@ -1997,6 +1998,7 @@ def from_pretrained(
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def from_hub(
TOOL_CONFIG_FILE,
token=token,
**hub_kwargs,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
Expand All @@ -238,6 +239,7 @@ def from_hub(
CONFIG_NAME,
token=token,
**hub_kwargs,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
Expand Down
36 changes: 25 additions & 11 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ def is_offline_mode():
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"


def _get_cache_file_to_return(
path_or_repo_id: str, full_filename: str, cache_dir: Union[str, Path, None] = None, revision: Optional[str] = None
):
# We try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
return resolved_file
return None


def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
Expand Down Expand Up @@ -266,6 +276,7 @@ def cached_file(
subfolder: str = "",
repo_type: Optional[str] = None,
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_gated_repo: bool = True,
_raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None,
Expand Down Expand Up @@ -335,6 +346,8 @@ def cached_file(
token = use_auth_token

# Private arguments
# _raise_exceptions_for_gated_repo: if False, do not raise an exception for gated repo error but return
# None.
# _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
# None.
# _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
Expand Down Expand Up @@ -397,6 +410,9 @@ def cached_file(
local_files_only=local_files_only,
)
except GatedRepoError as e:
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
if resolved_file is not None or not _raise_exceptions_for_gated_repo:
return resolved_file
raise EnvironmentError(
"You are trying to access a gated repo.\nMake sure to request access at "
f"https://huggingface.co/{path_or_repo_id} and pass a token having permission to this repo either "
Expand All @@ -416,12 +432,13 @@ def cached_file(
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
) from e
except LocalEntryNotFoundError as e:
# We try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
if (
resolved_file is not None
or not _raise_exceptions_for_missing_entries
or not _raise_exceptions_for_connection_errors
):
return resolved_file
if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors:
return None
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
Expand All @@ -438,13 +455,9 @@ def cached_file(
f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
) from e
except HTTPError as err:
# First we try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
if resolved_file is not None or not _raise_exceptions_for_connection_errors:
return resolved_file
if not _raise_exceptions_for_connection_errors:
return None

raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
except HFValidationError as e:
raise EnvironmentError(
Expand Down Expand Up @@ -545,6 +558,7 @@ def get_file_from_repo(
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def find_adapter_config_file(
local_files_only=local_files_only,
subfolder=subfolder,
_commit_hash=_commit_hash,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
Expand Down
Loading