From 4aa2053707b2fa3391be2c126d98b47ad0fbf5d9 Mon Sep 17 00:00:00 2001 From: Lucain Date: Mon, 26 Aug 2024 14:41:36 +0200 Subject: [PATCH] Refacto error parsing (HfHubHttpError) (#2474) * Refacto error parsing (HfHubHttpError) * wait... what? --- src/huggingface_hub/errors.py | 80 +-------- src/huggingface_hub/utils/__init__.py | 4 +- src/huggingface_hub/utils/_errors.py | 161 ------------------ src/huggingface_hub/utils/_http.py | 227 +++++++++++++++++++++++++- tests/test_utils_errors.py | 166 ++++++++----------- 5 files changed, 306 insertions(+), 332 deletions(-) delete mode 100644 src/huggingface_hub/utils/_errors.py diff --git a/src/huggingface_hub/errors.py b/src/huggingface_hub/errors.py index 0a95eae78a..feebe5edb1 100644 --- a/src/huggingface_hub/errors.py +++ b/src/huggingface_hub/errors.py @@ -67,54 +67,14 @@ class HfHubHTTPError(HTTPError): ``` """ - request_id: Optional[str] = None - server_message: Optional[str] = None - - def __init__(self, message: str, response: Optional[Response] = None): - # Parse server information if any. - if response is not None: - # Import here to avoid circular import - from .utils._fixes import JSONDecodeError - - self.request_id = response.headers.get("X-Request-Id") - try: - server_data = response.json() - except JSONDecodeError: - server_data = {} - - # Retrieve server error message from multiple sources - server_message_from_headers = response.headers.get("X-Error-Message") - server_message_from_body = server_data.get("error") - server_multiple_messages_from_body = "\n".join( - error["message"] for error in server_data.get("errors", []) if "message" in error - ) - - # Concatenate error messages - _server_message = "" - if server_message_from_headers is not None: # from headers - _server_message += server_message_from_headers + "\n" - if server_message_from_body is not None: # from body "error" - if isinstance(server_message_from_body, list): - server_message_from_body = "\n".join(server_message_from_body) - if server_message_from_body not in _server_message: - _server_message += server_message_from_body + "\n" - if server_multiple_messages_from_body is not None: # from body "errors" - if server_multiple_messages_from_body not in _server_message: - _server_message += server_multiple_messages_from_body + "\n" - _server_message = _server_message.strip() - - # Set message to `HfHubHTTPError` (if any) - if _server_message != "": - self.server_message = _server_message + def __init__(self, message: str, response: Optional[Response] = None, *, server_message: Optional[str] = None): + self.request_id = response.headers.get("x-request-id") if response is not None else None + self.server_message = server_message super().__init__( - _format_error_message( - message, - request_id=self.request_id, - server_message=self.server_message, - ), - response=response, # type: ignore - request=response.request if response is not None else None, # type: ignore + message, + response=response, # type: ignore [arg-type] + request=response.request if response is not None else None, # type: ignore [arg-type] ) def append_to_message(self, additional_message: str) -> None: @@ -211,7 +171,7 @@ class FileMetadataError(OSError): """ -# REPOSIORY ERRORS +# REPOSITORY ERRORS class RepositoryNotFoundError(HfHubHTTPError): @@ -351,29 +311,3 @@ class BadRequestError(HfHubHTTPError, ValueError): huggingface_hub.utils._errors.BadRequestError: Bad request for check endpoint: {details} (Request ID: XXX) ``` """ - - -def _format_error_message(message: str, request_id: Optional[str], server_message: Optional[str]) -> str: - """ - Format the `HfHubHTTPError` error message based on initial message and information - returned by the server. - - Used when initializing `HfHubHTTPError`. - """ - # Add message from response body - if server_message is not None and len(server_message) > 0 and server_message.lower() not in message.lower(): - if "\n\n" in message: - message += "\n" + server_message - else: - message += "\n\n" + server_message - - # Add Request ID - if request_id is not None and str(request_id).lower() not in message.lower(): - request_id_message = f" (Request ID: {request_id})" - if "\n" in message: - newline_index = message.index("\n") - message = message[:newline_index] + request_id_message + message[newline_index:] - else: - message += request_id_message - - return message diff --git a/src/huggingface_hub/utils/__init__.py b/src/huggingface_hub/utils/__init__.py index cf75acf35c..5d41aa6ff7 100644 --- a/src/huggingface_hub/utils/__init__.py +++ b/src/huggingface_hub/utils/__init__.py @@ -46,9 +46,6 @@ ) from ._chunk_utils import chunk_iterable from ._datetime import parse_datetime -from ._errors import ( - hf_raise_for_status, -) from ._experimental import experimental from ._fixes import SoftTemporaryDirectory, WeakFileLock, yaml_dump from ._git_credential import list_credential_helpers, set_git_credential, unset_git_credential @@ -58,6 +55,7 @@ configure_http_backend, fix_hf_endpoint_in_url, get_session, + hf_raise_for_status, http_backoff, reset_sessions, ) diff --git a/src/huggingface_hub/utils/_errors.py b/src/huggingface_hub/utils/_errors.py deleted file mode 100644 index 20abb6facd..0000000000 --- a/src/huggingface_hub/utils/_errors.py +++ /dev/null @@ -1,161 +0,0 @@ -import re -from typing import Optional - -from requests import HTTPError, Response - -from ..errors import ( - BadRequestError, - DisabledRepoError, - EntryNotFoundError, - GatedRepoError, - HfHubHTTPError, - RepositoryNotFoundError, - RevisionNotFoundError, -) - - -REPO_API_REGEX = re.compile( - r""" - # staging or production endpoint - ^https://[^/]+ - ( - # on /api/repo_type/repo_id - /api/(models|datasets|spaces)/(.+) - | - # or /repo_id/resolve/revision/... - /(.+)/resolve/(.+) - ) - """, - flags=re.VERBOSE, -) - - -def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) -> None: - """ - Internal version of `response.raise_for_status()` that will refine a - potential HTTPError. Raised exception will be an instance of `HfHubHTTPError`. - - This helper is meant to be the unique method to raise_for_status when making a call - to the Hugging Face Hub. - - - Example: - ```py - import requests - from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError - - response = get_session().post(...) - try: - hf_raise_for_status(response) - except HfHubHTTPError as e: - print(str(e)) # formatted message - e.request_id, e.server_message # details returned by server - - # Complete the error message with additional information once it's raised - e.append_to_message("\n`create_commit` expects the repository to exist.") - raise - ``` - - Args: - response (`Response`): - Response from the server. - endpoint_name (`str`, *optional*): - Name of the endpoint that has been called. If provided, the error message - will be more complete. - - - - Raises when the request has failed: - - - [`~utils.RepositoryNotFoundError`] - If the repository to download from cannot be found. This may be because it - doesn't exist, because `repo_type` is not set correctly, or because the repo - is `private` and you do not have access. - - [`~utils.GatedRepoError`] - If the repository exists but is gated and the user is not on the authorized - list. - - [`~utils.RevisionNotFoundError`] - If the repository exists but the revision couldn't be find. - - [`~utils.EntryNotFoundError`] - If the repository exists but the entry (e.g. the requested file) couldn't be - find. - - [`~utils.BadRequestError`] - If request failed with a HTTP 400 BadRequest error. - - [`~utils.HfHubHTTPError`] - If request failed for a reason not listed above. - - - """ - try: - response.raise_for_status() - except HTTPError as e: - error_code = response.headers.get("X-Error-Code") - error_message = response.headers.get("X-Error-Message") - - if error_code == "RevisionNotFound": - message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}." - raise RevisionNotFoundError(message, response) from e - - elif error_code == "EntryNotFound": - message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}." - raise EntryNotFoundError(message, response) from e - - elif error_code == "GatedRepo": - message = ( - f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}." - ) - raise GatedRepoError(message, response) from e - - elif error_message == "Access to this resource is disabled.": - message = ( - f"{response.status_code} Client Error." - + "\n\n" - + f"Cannot access repository for url {response.url}." - + "\n" - + "Access to this resource is disabled." - ) - raise DisabledRepoError(message, response) from e - - elif error_code == "RepoNotFound" or ( - response.status_code == 401 - and response.request is not None - and response.request.url is not None - and REPO_API_REGEX.search(response.request.url) is not None - ): - # 401 is misleading as it is returned for: - # - private and gated repos if user is not authenticated - # - missing repos - # => for now, we process them as `RepoNotFound` anyway. - # See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9 - message = ( - f"{response.status_code} Client Error." - + "\n\n" - + f"Repository Not Found for url: {response.url}." - + "\nPlease make sure you specified the correct `repo_id` and" - " `repo_type`.\nIf you are trying to access a private or gated repo," - " make sure you are authenticated." - ) - raise RepositoryNotFoundError(message, response) from e - - elif response.status_code == 400: - message = ( - f"\n\nBad request for {endpoint_name} endpoint:" if endpoint_name is not None else "\n\nBad request:" - ) - raise BadRequestError(message, response=response) from e - - elif response.status_code == 403: - message = ( - f"\n\n{response.status_code} Forbidden: {error_message}." - + f"\nCannot access content at: {response.url}." - + "\nMake sure your token has the correct permissions." - ) - raise HfHubHTTPError(message, response=response) from e - - elif response.status_code == 416: - range_header = response.request.headers.get("Range") - message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}." - raise HfHubHTTPError(message, response=response) from e - - # Convert `HTTPError` into a `HfHubHTTPError` to display request information - # as well (request id and/or server error message) - raise HfHubHTTPError(str(e), response=response) from e diff --git a/src/huggingface_hub/utils/_http.py b/src/huggingface_hub/utils/_http.py index 495152ca15..84f6302b32 100644 --- a/src/huggingface_hub/utils/_http.py +++ b/src/huggingface_hub/utils/_http.py @@ -16,6 +16,7 @@ import io import os +import re import threading import time import uuid @@ -24,14 +25,24 @@ from typing import Callable, Optional, Tuple, Type, Union import requests -from requests import Response +from requests import HTTPError, Response from requests.adapters import HTTPAdapter from requests.models import PreparedRequest from huggingface_hub.errors import OfflineModeIsEnabled from .. import constants +from ..errors import ( + BadRequestError, + DisabledRepoError, + EntryNotFoundError, + GatedRepoError, + HfHubHTTPError, + RepositoryNotFoundError, + RevisionNotFoundError, +) from . import logging +from ._fixes import JSONDecodeError from ._typing import HTTP_METHOD_T @@ -43,6 +54,21 @@ X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" X_REQUEST_ID = "x-request-id" +REPO_API_REGEX = re.compile( + r""" + # staging or production endpoint + ^https://[^/]+ + ( + # on /api/repo_type/repo_id + /api/(models|datasets|spaces)/(.+) + | + # or /repo_id/resolve/revision/... + /(.+)/resolve/(.+) + ) + """, + flags=re.VERBOSE, +) + class UniqueRequestIdAdapter(HTTPAdapter): X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" @@ -317,3 +343,202 @@ def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str: url = url.replace(constants._HF_DEFAULT_ENDPOINT, endpoint) url = url.replace(constants._HF_DEFAULT_STAGING_ENDPOINT, endpoint) return url + + +def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) -> None: + """ + Internal version of `response.raise_for_status()` that will refine a + potential HTTPError. Raised exception will be an instance of `HfHubHTTPError`. + + This helper is meant to be the unique method to raise_for_status when making a call + to the Hugging Face Hub. + + + Example: + ```py + import requests + from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError + + response = get_session().post(...) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + print(str(e)) # formatted message + e.request_id, e.server_message # details returned by server + + # Complete the error message with additional information once it's raised + e.append_to_message("\n`create_commit` expects the repository to exist.") + raise + ``` + + Args: + response (`Response`): + Response from the server. + endpoint_name (`str`, *optional*): + Name of the endpoint that has been called. If provided, the error message + will be more complete. + + + + Raises when the request has failed: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it + doesn't exist, because `repo_type` is not set correctly, or because the repo + is `private` and you do not have access. + - [`~utils.GatedRepoError`] + If the repository exists but is gated and the user is not on the authorized + list. + - [`~utils.RevisionNotFoundError`] + If the repository exists but the revision couldn't be find. + - [`~utils.EntryNotFoundError`] + If the repository exists but the entry (e.g. the requested file) couldn't be + find. + - [`~utils.BadRequestError`] + If request failed with a HTTP 400 BadRequest error. + - [`~utils.HfHubHTTPError`] + If request failed for a reason not listed above. + + + """ + try: + response.raise_for_status() + except HTTPError as e: + error_code = response.headers.get("X-Error-Code") + error_message = response.headers.get("X-Error-Message") + + if error_code == "RevisionNotFound": + message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}." + raise _format(RevisionNotFoundError, message, response) from e + + elif error_code == "EntryNotFound": + message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}." + raise _format(EntryNotFoundError, message, response) from e + + elif error_code == "GatedRepo": + message = ( + f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}." + ) + raise _format(GatedRepoError, message, response) from e + + elif error_message == "Access to this resource is disabled.": + message = ( + f"{response.status_code} Client Error." + + "\n\n" + + f"Cannot access repository for url {response.url}." + + "\n" + + "Access to this resource is disabled." + ) + raise _format(DisabledRepoError, message, response) from e + + elif error_code == "RepoNotFound" or ( + response.status_code == 401 + and response.request is not None + and response.request.url is not None + and REPO_API_REGEX.search(response.request.url) is not None + ): + # 401 is misleading as it is returned for: + # - private and gated repos if user is not authenticated + # - missing repos + # => for now, we process them as `RepoNotFound` anyway. + # See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9 + message = ( + f"{response.status_code} Client Error." + + "\n\n" + + f"Repository Not Found for url: {response.url}." + + "\nPlease make sure you specified the correct `repo_id` and" + " `repo_type`.\nIf you are trying to access a private or gated repo," + " make sure you are authenticated." + ) + raise _format(RepositoryNotFoundError, message, response) from e + + elif response.status_code == 400: + message = ( + f"\n\nBad request for {endpoint_name} endpoint:" if endpoint_name is not None else "\n\nBad request:" + ) + raise _format(BadRequestError, message, response) from e + + elif response.status_code == 403: + message = ( + f"\n\n{response.status_code} Forbidden: {error_message}." + + f"\nCannot access content at: {response.url}." + + "\nMake sure your token has the correct permissions." + ) + raise _format(HfHubHTTPError, message, response) from e + + elif response.status_code == 416: + range_header = response.request.headers.get("Range") + message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}." + raise _format(HfHubHTTPError, message, response) from e + + # Convert `HTTPError` into a `HfHubHTTPError` to display request information + # as well (request id and/or server error message) + raise _format(HfHubHTTPError, "", response) from e + + +def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Response) -> HfHubHTTPError: + server_errors = [] + + # Retrieve server error from header + from_headers = response.headers.get("X-Error-Message") + if from_headers is not None: + server_errors.append(from_headers) + + # Retrieve server error from body + try: + # Case errors are returned in a JSON format + data = response.json() + + error = data.get("error") + if error is not None: + if isinstance(error, list): + # Case {'error': ['my error 1', 'my error 2']} + server_errors.extend(error) + else: + # Case {'error': 'my error'} + server_errors.append(error) + + errors = data.get("errors") + if errors is not None: + # Case {'errors': [{'message': 'my error 1'}, {'message': 'my error 2'}]} + for error in errors: + if "message" in error: + server_errors.append(error["message"]) + + except JSONDecodeError: + # Case error is directly returned as text + if response.text: + server_errors.append(response.text) + + # Strip all server messages + server_errors = [line.strip() for line in server_errors if line.strip()] + + # Deduplicate server messages (keep order) + # taken from https://stackoverflow.com/a/17016257 + server_errors = list(dict.fromkeys(server_errors)) + + # Format server error + server_message = "\n".join(server_errors) + + # Add server error to custom message + final_error_message = custom_message + if server_message and server_message.lower() not in custom_message.lower(): + if "\n\n" in custom_message: + final_error_message += "\n" + server_message + else: + final_error_message += "\n\n" + server_message + + # Add Request ID + request_id = str(response.headers.get(X_REQUEST_ID, "")) + if len(request_id) > 0 and request_id.lower() not in final_error_message.lower(): + request_id_message = f" (Request ID: {request_id})" + if "\n" in final_error_message: + newline_index = final_error_message.index("\n") + final_error_message = ( + final_error_message[:newline_index] + request_id_message + final_error_message[newline_index:] + ) + else: + final_error_message += request_id_message + + # Return + return error_type(final_error_message.strip(), response=response, server_message=server_message or None) diff --git a/tests/test_utils_errors.py b/tests/test_utils_errors.py index 5b6e443558..fc27fa418b 100644 --- a/tests/test_utils_errors.py +++ b/tests/test_utils_errors.py @@ -11,46 +11,46 @@ RepositoryNotFoundError, RevisionNotFoundError, ) -from huggingface_hub.utils._errors import REPO_API_REGEX, hf_raise_for_status +from huggingface_hub.utils._http import REPO_API_REGEX, X_REQUEST_ID, _format, hf_raise_for_status class TestErrorUtils(unittest.TestCase): def test_hf_raise_for_status_repo_not_found(self) -> None: response = Response() - response.headers = {"X-Error-Code": "RepoNotFound", "X-Request-Id": 123} + response.headers = {"X-Error-Code": "RepoNotFound", X_REQUEST_ID: 123} response.status_code = 404 with self.assertRaisesRegex(RepositoryNotFoundError, "Repository Not Found") as context: hf_raise_for_status(response) - self.assertEqual(context.exception.response.status_code, 404) - self.assertIn("Request ID: 123", str(context.exception)) + assert context.exception.response.status_code == 404 + assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_disabled_repo(self) -> None: response = Response() - response.headers = {"X-Error-Message": "Access to this resource is disabled.", "X-Request-Id": 123} + response.headers = {"X-Error-Message": "Access to this resource is disabled.", X_REQUEST_ID: 123} response.status_code = 403 with self.assertRaises(DisabledRepoError) as context: hf_raise_for_status(response) - self.assertEqual(context.exception.response.status_code, 403) - self.assertIn("Request ID: 123", str(context.exception)) + assert context.exception.response.status_code == 403 + assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_401_repo_url(self) -> None: response = Response() - response.headers = {"X-Request-Id": 123} + response.headers = {X_REQUEST_ID: 123} response.status_code = 401 response.request = PreparedRequest() response.request.url = "https://huggingface.co/api/models/username/reponame" with self.assertRaisesRegex(RepositoryNotFoundError, "Repository Not Found") as context: hf_raise_for_status(response) - self.assertEqual(context.exception.response.status_code, 401) - self.assertIn("Request ID: 123", str(context.exception)) + assert context.exception.response.status_code == 401 + assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_403_wrong_token_scope(self) -> None: response = Response() - response.headers = {"X-Request-Id": 123, "X-Error-Message": "specific error message"} + response.headers = {X_REQUEST_ID: 123, "X-Error-Message": "specific error message"} response.status_code = 403 response.request = PreparedRequest() response.request.url = "https://huggingface.co/api/repos/create" @@ -58,40 +58,40 @@ def test_hf_raise_for_status_403_wrong_token_scope(self) -> None: with self.assertRaisesRegex(HfHubHTTPError, expected_message_part) as context: hf_raise_for_status(response) - self.assertEqual(context.exception.response.status_code, 403) - self.assertIn("Request ID: 123", str(context.exception)) + assert context.exception.response.status_code == 403 + assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_401_not_repo_url(self) -> None: response = Response() - response.headers = {"X-Request-Id": 123} + response.headers = {X_REQUEST_ID: 123} response.status_code = 401 response.request = PreparedRequest() response.request.url = "https://huggingface.co/api/collections" with self.assertRaises(HfHubHTTPError) as context: hf_raise_for_status(response) - self.assertEqual(context.exception.response.status_code, 401) - self.assertIn("Request ID: 123", str(context.exception)) + assert context.exception.response.status_code == 401 + assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_revision_not_found(self) -> None: response = Response() - response.headers = {"X-Error-Code": "RevisionNotFound", "X-Request-Id": 123} + response.headers = {"X-Error-Code": "RevisionNotFound", X_REQUEST_ID: 123} response.status_code = 404 with self.assertRaisesRegex(RevisionNotFoundError, "Revision Not Found") as context: hf_raise_for_status(response) - self.assertEqual(context.exception.response.status_code, 404) - self.assertIn("Request ID: 123", str(context.exception)) + assert context.exception.response.status_code == 404 + assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_entry_not_found(self) -> None: response = Response() - response.headers = {"X-Error-Code": "EntryNotFound", "X-Request-Id": 123} + response.headers = {"X-Error-Code": "EntryNotFound", X_REQUEST_ID: 123} response.status_code = 404 with self.assertRaisesRegex(EntryNotFoundError, "Entry Not Found") as context: hf_raise_for_status(response) - self.assertEqual(context.exception.response.status_code, 404) - self.assertIn("Request ID: 123", str(context.exception)) + assert context.exception.response.status_code == 404 + assert "Request ID: 123" in str(context.exception) def test_hf_raise_for_status_bad_request_no_endpoint_name(self) -> None: """Test HTTPError converted to BadRequestError if error 400.""" @@ -99,7 +99,7 @@ def test_hf_raise_for_status_bad_request_no_endpoint_name(self) -> None: response.status_code = 400 with self.assertRaisesRegex(BadRequestError, "Bad request:") as context: hf_raise_for_status(response) - self.assertEqual(context.exception.response.status_code, 400) + assert context.exception.response.status_code == 400 def test_hf_raise_for_status_bad_request_with_endpoint_name(self) -> None: """Test endpoint name is added to BadRequestError message.""" @@ -107,21 +107,21 @@ def test_hf_raise_for_status_bad_request_with_endpoint_name(self) -> None: response.status_code = 400 with self.assertRaisesRegex(BadRequestError, "Bad request for preupload endpoint:") as context: hf_raise_for_status(response, endpoint_name="preupload") - self.assertEqual(context.exception.response.status_code, 400) + assert context.exception.response.status_code == 400 def test_hf_raise_for_status_fallback(self) -> None: """Test HTTPError is converted to HfHubHTTPError.""" response = Response() response.status_code = 404 response.headers = { - "X-Request-Id": "test-id", + X_REQUEST_ID: "test-id", } response.url = "test_URL" with self.assertRaisesRegex(HfHubHTTPError, "Request ID: test-id") as context: hf_raise_for_status(response) - self.assertEqual(context.exception.response.status_code, 404) - self.assertEqual(context.exception.response.url, "test_URL") + assert context.exception.response.status_code == 404 + assert context.exception.response.url == "test_URL" class TestHfHubHTTPError(unittest.TestCase): @@ -136,54 +136,48 @@ def setUp(self) -> None: def test_hf_hub_http_error_initialization(self) -> None: """Test HfHubHTTPError is initialized properly.""" error = HfHubHTTPError("this is a message", response=self.response) - self.assertEqual(str(error), "this is a message") - self.assertEqual(error.response, self.response) - self.assertIsNone(error.request_id) - self.assertIsNone(error.server_message) + assert str(error) == "this is a message" + assert error.response == self.response + assert error.request_id is None + assert error.server_message is None def test_hf_hub_http_error_init_with_request_id(self) -> None: """Test request id is added to the message.""" - self.response.headers = {"X-Request-Id": "test-id"} - error = HfHubHTTPError("this is a message", response=self.response) - self.assertEqual(str(error), "this is a message (Request ID: test-id)") - self.assertEqual(error.request_id, "test-id") + self.response.headers = {X_REQUEST_ID: "test-id"} + error = _format(HfHubHTTPError, "this is a message", response=self.response) + assert str(error) == "this is a message (Request ID: test-id)" + assert error.request_id == "test-id" def test_hf_hub_http_error_init_with_request_id_and_multiline_message(self) -> None: """Test request id is added to the end of the first line.""" - self.response.headers = {"X-Request-Id": "test-id"} - error = HfHubHTTPError("this is a message\nthis is more details", response=self.response) - self.assertEqual(str(error), "this is a message (Request ID: test-id)\nthis is more details") - - error = HfHubHTTPError("this is a message\n\nthis is more details", response=self.response) - self.assertEqual( - str(error), - "this is a message (Request ID: test-id)\n\nthis is more details", - ) + self.response.headers = {X_REQUEST_ID: "test-id"} + error = _format(HfHubHTTPError, "this is a message\nthis is more details", response=self.response) + assert str(error) == "this is a message (Request ID: test-id)\nthis is more details" + + error = _format(HfHubHTTPError, "this is a message\n\nthis is more details", response=self.response) + assert str(error) == "this is a message (Request ID: test-id)\n\nthis is more details" def test_hf_hub_http_error_init_with_request_id_already_in_message(self) -> None: """Test request id is not duplicated in error message (case insensitive)""" - self.response.headers = {"X-Request-Id": "test-id"} - error = HfHubHTTPError("this is a message on request TEST-ID", response=self.response) - self.assertEqual(str(error), "this is a message on request TEST-ID") - self.assertEqual(error.request_id, "test-id") + self.response.headers = {X_REQUEST_ID: "test-id"} + error = _format(HfHubHTTPError, "this is a message on request TEST-ID", response=self.response) + assert str(error) == "this is a message on request TEST-ID" + assert error.request_id == "test-id" def test_hf_hub_http_error_init_with_server_error(self) -> None: """Test server error is added to the error message.""" self.response._content = b'{"error": "This is a message returned by the server"}' - error = HfHubHTTPError("this is a message", response=self.response) - self.assertEqual(str(error), "this is a message\n\nThis is a message returned by the server") - self.assertEqual(error.server_message, "This is a message returned by the server") + error = _format(HfHubHTTPError, "this is a message", response=self.response) + assert str(error) == "this is a message\n\nThis is a message returned by the server" + assert error.server_message == "This is a message returned by the server" def test_hf_hub_http_error_init_with_server_error_and_multiline_message( self, ) -> None: """Test server error is added to the error message after the details.""" self.response._content = b'{"error": "This is a message returned by the server"}' - error = HfHubHTTPError("this is a message\n\nSome details.", response=self.response) - self.assertEqual( - str(error), - "this is a message\n\nSome details.\nThis is a message returned by the server", - ) + error = _format(HfHubHTTPError, "this is a message\n\nSome details.", response=self.response) + assert str(error) == "this is a message\n\nSome details.\nThis is a message returned by the server" def test_hf_hub_http_error_init_with_multiple_server_errors( self, @@ -196,11 +190,8 @@ def test_hf_hub_http_error_init_with_multiple_server_errors( b'{"httpStatusCode": 400, "errors": [{"message": "this is error 1", "type":' b' "error"}, {"message": "this is error 2", "type": "error"}]}' ) - error = HfHubHTTPError("this is a message\n\nSome details.", response=self.response) - self.assertEqual( - str(error), - "this is a message\n\nSome details.\nthis is error 1\nthis is error 2", - ) + error = _format(HfHubHTTPError, "this is a message\n\nSome details.", response=self.response) + assert str(error) == "this is a message\n\nSome details.\nthis is error 1\nthis is error 2" def test_hf_hub_http_error_init_with_server_error_already_in_message( self, @@ -210,42 +201,38 @@ def test_hf_hub_http_error_init_with_server_error_already_in_message( Case insensitive. """ self.response._content = b'{"error": "repo NOT found"}' - error = HfHubHTTPError( + error = _format( + HfHubHTTPError, "this is a message\n\nRepo Not Found. and more\nand more", response=self.response, ) - self.assertEqual( - str(error), - "this is a message\n\nRepo Not Found. and more\nand more", - ) + assert str(error) == "this is a message\n\nRepo Not Found. and more\nand more" def test_hf_hub_http_error_init_with_unparsable_server_error( self, ) -> None: - """Test error message is unchanged and exception is not raised..""" + """Server returned a text message (not as JSON) => should be added to the exception.""" self.response._content = b"this is not a json-formatted string" - error = HfHubHTTPError("this is a message", response=self.response) - self.assertEqual(str(error), "this is a message") - self.assertIsNone(error.server_message) # still None since not parsed + error = _format(HfHubHTTPError, "this is a message", response=self.response) + assert str(error) == "this is a message\n\nthis is not a json-formatted string" + assert error.server_message == "this is not a json-formatted string" def test_hf_hub_http_error_append_to_message(self) -> None: """Test add extra information to existing HfHubHTTPError.""" - error = HfHubHTTPError("this is a message", response=self.response) + error = _format(HfHubHTTPError, "this is a message", response=self.response) error.args = error.args + (1, 2, 3) # faking some extra args error.append_to_message("\nthis is an additional message") - self.assertEqual( - error.args, - ("this is a message\nthis is an additional message", 1, 2, 3), - ) - self.assertIsNone(error.server_message) # added message is not from server + assert error.args == ("this is a message\nthis is an additional message", 1, 2, 3) + + assert error.server_message is None # added message is not from server def test_hf_hub_http_error_init_with_error_message_in_header(self) -> None: """Test server error from header is added to the error message.""" self.response.headers = {"X-Error-Message": "Error message from headers."} - error = HfHubHTTPError("this is a message", response=self.response) - self.assertEqual(str(error), "this is a message\n\nError message from headers.") - self.assertEqual(error.server_message, "Error message from headers.") + error = _format(HfHubHTTPError, "this is a message", response=self.response) + assert str(error) == "this is a message\n\nError message from headers." + assert error.server_message == "Error message from headers." def test_hf_hub_http_error_init_with_error_message_from_header_and_body( self, @@ -253,15 +240,9 @@ def test_hf_hub_http_error_init_with_error_message_from_header_and_body( """Test server error from header and from body are added to the error message.""" self.response._content = b'{"error": "Error message from body."}' self.response.headers = {"X-Error-Message": "Error message from headers."} - error = HfHubHTTPError("this is a message", response=self.response) - self.assertEqual( - str(error), - "this is a message\n\nError message from headers.\nError message from body.", - ) - self.assertEqual( - error.server_message, - "Error message from headers.\nError message from body.", - ) + error = _format(HfHubHTTPError, "this is a message", response=self.response) + assert str(error) == "this is a message\n\nError message from headers.\nError message from body." + assert error.server_message == "Error message from headers.\nError message from body." def test_hf_hub_http_error_init_with_error_message_duplicated_in_header_and_body( self, @@ -272,12 +253,9 @@ def test_hf_hub_http_error_init_with_error_message_duplicated_in_header_and_body """ self.response._content = b'{"error": "Error message duplicated in headers and body."}' self.response.headers = {"X-Error-Message": "Error message duplicated in headers and body."} - error = HfHubHTTPError("this is a message", response=self.response) - self.assertEqual( - str(error), - "this is a message\n\nError message duplicated in headers and body.", - ) - self.assertEqual(error.server_message, "Error message duplicated in headers and body.") + error = _format(HfHubHTTPError, "this is a message", response=self.response) + assert str(error) == "this is a message\n\nError message duplicated in headers and body." + assert error.server_message == "Error message duplicated in headers and body." @pytest.mark.parametrize(