From 26ed9a35089dbc233c54d48b74e9a12044879550 Mon Sep 17 00:00:00 2001 From: Scruel Date: Thu, 18 Jan 2024 03:25:54 +0800 Subject: [PATCH 1/2] fix: suppress `GatedRepoError` to use cache file (fix #28558). --- src/transformers/modeling_flax_utils.py | 1 + src/transformers/modeling_tf_utils.py | 1 + src/transformers/modeling_utils.py | 2 + src/transformers/models/auto/auto_factory.py | 1 + .../models/auto/tokenization_auto.py | 1 + src/transformers/pipelines/__init__.py | 1 + src/transformers/tokenization_utils_base.py | 2 + src/transformers/tools/base.py | 2 + src/transformers/utils/hub.py | 50 +++++++++++++++---- src/transformers/utils/peft_utils.py | 1 + 10 files changed, 51 insertions(+), 11 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index eb14216c5cd9..43b38b16c086 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -786,6 +786,7 @@ def from_pretrained( "user_agent": user_agent, "revision": revision, "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, "_raise_exceptions_for_missing_entries": False, "_commit_hash": commit_hash, } diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 4b1ceb2053ea..513046ad90c6 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2788,6 +2788,7 @@ def from_pretrained( "user_agent": user_agent, "revision": revision, "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, "_raise_exceptions_for_missing_entries": False, "_commit_hash": commit_hash, } diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3f19ec1884e7..ca82e94e5e91 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2926,6 +2926,7 @@ def from_pretrained( token=token, revision=revision, subfolder=subfolder, + _raise_exceptions_for_gated_repo=False, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, ) @@ -3369,6 +3370,7 @@ def from_pretrained( "user_agent": user_agent, "revision": revision, "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, "_raise_exceptions_for_missing_entries": False, "_commit_hash": commit_hash, } diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 92dbb006f6d5..c488c512095e 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -488,6 +488,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): resolved_config_file = cached_file( pretrained_model_name_or_path, CONFIG_NAME, + _raise_exceptions_for_gated_repo=False, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, **hub_kwargs, diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 27998e6c0f05..1145f5f8eef2 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -597,6 +597,7 @@ def get_tokenizer_config( revision=revision, local_files_only=local_files_only, subfolder=subfolder, + _raise_exceptions_for_gated_repo=False, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, _commit_hash=commit_hash, diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index b70617db20fe..ac8ad9d86ae2 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -747,6 +747,7 @@ def pipeline( resolved_config_file = cached_file( pretrained_model_name_or_path, CONFIG_NAME, + _raise_exceptions_for_gated_repo=False, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, **hub_kwargs, diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index b7377ea3143b..d389af676fd0 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1961,6 +1961,7 @@ def from_pretrained( local_files_only=local_files_only, subfolder=subfolder, user_agent=user_agent, + _raise_exceptions_for_gated_repo=False, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, _commit_hash=commit_hash, @@ -1997,6 +1998,7 @@ def from_pretrained( user_agent=user_agent, revision=revision, subfolder=subfolder, + _raise_exceptions_for_gated_repo=False, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, _commit_hash=commit_hash, diff --git a/src/transformers/tools/base.py b/src/transformers/tools/base.py index 4042b28ac64c..43ed1ed86389 100644 --- a/src/transformers/tools/base.py +++ b/src/transformers/tools/base.py @@ -228,6 +228,7 @@ def from_hub( TOOL_CONFIG_FILE, token=token, **hub_kwargs, + _raise_exceptions_for_gated_repo=False, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, ) @@ -238,6 +239,7 @@ def from_hub( CONFIG_NAME, token=token, **hub_kwargs, + _raise_exceptions_for_gated_repo=False, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, ) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 6b427ed4df0a..1e571f0d7efd 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -145,6 +145,24 @@ def is_offline_mode(): HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples" +_NO_RETURN = object() + + +def _return_cache_or_none_for_condition( + path_or_repo_id: str, + full_filename: str, + cache_dir: Union[str, Path, None] = None, + revision: Optional[str] = None, + condition: bool = False, +): + # We try to see if we have a cached version (not up to date): + resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision) + if resolved_file is not None and resolved_file != _CACHED_NO_EXIST: + return resolved_file + if condition: + return None + return _NO_RETURN + def is_remote_url(url_or_filename): parsed = urlparse(url_or_filename) @@ -266,6 +284,7 @@ def cached_file( subfolder: str = "", repo_type: Optional[str] = None, user_agent: Optional[Union[str, Dict[str, str]]] = None, + _raise_exceptions_for_gated_repo: bool = True, _raise_exceptions_for_missing_entries: bool = True, _raise_exceptions_for_connection_errors: bool = True, _commit_hash: Optional[str] = None, @@ -335,6 +354,8 @@ def cached_file( token = use_auth_token # Private arguments + # _raise_exceptions_for_gated_repo: if False, do not raise an exception for gated repo error but return + # None. # _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return # None. # _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return @@ -397,6 +418,11 @@ def cached_file( local_files_only=local_files_only, ) except GatedRepoError as e: + resolved_file = _return_cache_or_none_for_condition( + path_or_repo_id, full_filename, cache_dir, revision, not _raise_exceptions_for_gated_repo + ) + if _NO_RETURN != resolved_file: + return resolved_file raise EnvironmentError( "You are trying to access a gated repo.\nMake sure to request access at " f"https://huggingface.co/{path_or_repo_id} and pass a token having permission to this repo either " @@ -416,12 +442,15 @@ def cached_file( f"'https://huggingface.co/{path_or_repo_id}' for available revisions." ) from e except LocalEntryNotFoundError as e: - # We try to see if we have a cached version (not up to date): - resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision) - if resolved_file is not None and resolved_file != _CACHED_NO_EXIST: + resolved_file = _return_cache_or_none_for_condition( + path_or_repo_id, + full_filename, + cache_dir, + revision, + not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors, + ) + if _NO_RETURN != resolved_file: return resolved_file - if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors: - return None raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the" f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named" @@ -438,13 +467,11 @@ def cached_file( f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files." ) from e except HTTPError as err: - # First we try to see if we have a cached version (not up to date): - resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision) - if resolved_file is not None and resolved_file != _CACHED_NO_EXIST: + resolved_file = _return_cache_or_none_for_condition( + path_or_repo_id, full_filename, cache_dir, revision, not _raise_exceptions_for_connection_errors + ) + if _NO_RETURN != resolved_file: return resolved_file - if not _raise_exceptions_for_connection_errors: - return None - raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}") except HFValidationError as e: raise EnvironmentError( @@ -545,6 +572,7 @@ def get_file_from_repo( revision=revision, local_files_only=local_files_only, subfolder=subfolder, + _raise_exceptions_for_gated_repo=False, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, ) diff --git a/src/transformers/utils/peft_utils.py b/src/transformers/utils/peft_utils.py index 7830acd0b4d2..2078f1ae9609 100644 --- a/src/transformers/utils/peft_utils.py +++ b/src/transformers/utils/peft_utils.py @@ -96,6 +96,7 @@ def find_adapter_config_file( local_files_only=local_files_only, subfolder=subfolder, _commit_hash=_commit_hash, + _raise_exceptions_for_gated_repo=False, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, ) From e2908afb7a3f8e0e9ebf9a2d66e3598c903b5613 Mon Sep 17 00:00:00 2001 From: Scruel Date: Wed, 24 Jan 2024 21:49:00 +0800 Subject: [PATCH 2/2] move condition_to_return parameter back to outside. --- src/transformers/utils/hub.py | 40 ++++++++++++----------------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 1e571f0d7efd..edc6fb48fb29 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -145,23 +145,15 @@ def is_offline_mode(): HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples" -_NO_RETURN = object() - -def _return_cache_or_none_for_condition( - path_or_repo_id: str, - full_filename: str, - cache_dir: Union[str, Path, None] = None, - revision: Optional[str] = None, - condition: bool = False, +def _get_cache_file_to_return( + path_or_repo_id: str, full_filename: str, cache_dir: Union[str, Path, None] = None, revision: Optional[str] = None ): # We try to see if we have a cached version (not up to date): resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision) if resolved_file is not None and resolved_file != _CACHED_NO_EXIST: return resolved_file - if condition: - return None - return _NO_RETURN + return None def is_remote_url(url_or_filename): @@ -418,10 +410,8 @@ def cached_file( local_files_only=local_files_only, ) except GatedRepoError as e: - resolved_file = _return_cache_or_none_for_condition( - path_or_repo_id, full_filename, cache_dir, revision, not _raise_exceptions_for_gated_repo - ) - if _NO_RETURN != resolved_file: + resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision) + if resolved_file is not None or not _raise_exceptions_for_gated_repo: return resolved_file raise EnvironmentError( "You are trying to access a gated repo.\nMake sure to request access at " @@ -442,14 +432,12 @@ def cached_file( f"'https://huggingface.co/{path_or_repo_id}' for available revisions." ) from e except LocalEntryNotFoundError as e: - resolved_file = _return_cache_or_none_for_condition( - path_or_repo_id, - full_filename, - cache_dir, - revision, - not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors, - ) - if _NO_RETURN != resolved_file: + resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision) + if ( + resolved_file is not None + or not _raise_exceptions_for_missing_entries + or not _raise_exceptions_for_connection_errors + ): return resolved_file raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the" @@ -467,10 +455,8 @@ def cached_file( f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files." ) from e except HTTPError as err: - resolved_file = _return_cache_or_none_for_condition( - path_or_repo_id, full_filename, cache_dir, revision, not _raise_exceptions_for_connection_errors - ) - if _NO_RETURN != resolved_file: + resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision) + if resolved_file is not None or not _raise_exceptions_for_connection_errors: return resolved_file raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}") except HFValidationError as e: