From a142f161313199bcfa67afe1990d1f0f39a973bb Mon Sep 17 00:00:00 2001 From: CalOmnie Date: Mon, 20 Jan 2025 16:00:46 +0100 Subject: [PATCH] transformers.image_transforms.normalize wrong types (#35773) transformers.image_transforms.normalize documents and checks for the wrong type for std and mean arguments Co-authored-by: Louis Groux --- src/transformers/image_transforms.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index e7d3a5abb7a8..0a3d5e4fa300 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -15,7 +15,7 @@ import warnings from math import ceil -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Sequence, Tuple, Union import numpy as np @@ -357,8 +357,8 @@ def resize( def normalize( image: np.ndarray, - mean: Union[float, Iterable[float]], - std: Union[float, Iterable[float]], + mean: Union[float, Sequence[float]], + std: Union[float, Sequence[float]], data_format: Optional[ChannelDimension] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: @@ -370,9 +370,9 @@ def normalize( Args: image (`np.ndarray`): The image to normalize. - mean (`float` or `Iterable[float]`): + mean (`float` or `Sequence[float]`): The mean to use for normalization. - std (`float` or `Iterable[float]`): + std (`float` or `Sequence[float]`): The standard deviation to use for normalization. data_format (`ChannelDimension`, *optional*): The channel dimension format of the output image. If unset, will use the inferred format from the input. @@ -393,14 +393,14 @@ def normalize( if not np.issubdtype(image.dtype, np.floating): image = image.astype(np.float32) - if isinstance(mean, Iterable): + if isinstance(mean, Sequence): if len(mean) != num_channels: raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}") else: mean = [mean] * num_channels mean = np.array(mean, dtype=image.dtype) - if isinstance(std, Iterable): + if isinstance(std, Sequence): if len(std) != num_channels: raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}") else: