Skip to content

Commit

Permalink
transformers.image_transforms.normalize wrong types (huggingface#35773)
Browse files Browse the repository at this point in the history
transformers.image_transforms.normalize documents and checks for the wrong type for std and mean arguments

Co-authored-by: Louis Groux <louis.cal.groux@gmail.com>
  • Loading branch information
2 people authored and elvircrn committed Feb 13, 2025
1 parent 1cd739a commit f17a8d0
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import warnings
from math import ceil
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -357,8 +357,8 @@ def resize(

def normalize(
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
mean: Union[float, Sequence[float]],
std: Union[float, Sequence[float]],
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
Expand All @@ -370,9 +370,9 @@ def normalize(
Args:
image (`np.ndarray`):
The image to normalize.
mean (`float` or `Iterable[float]`):
mean (`float` or `Sequence[float]`):
The mean to use for normalization.
std (`float` or `Iterable[float]`):
std (`float` or `Sequence[float]`):
The standard deviation to use for normalization.
data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If unset, will use the inferred format from the input.
Expand All @@ -393,14 +393,14 @@ def normalize(
if not np.issubdtype(image.dtype, np.floating):
image = image.astype(np.float32)

if isinstance(mean, Iterable):
if isinstance(mean, Sequence):
if len(mean) != num_channels:
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
else:
mean = [mean] * num_channels
mean = np.array(mean, dtype=image.dtype)

if isinstance(std, Iterable):
if isinstance(std, Sequence):
if len(std) != num_channels:
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
else:
Expand Down

0 comments on commit f17a8d0

Please sign in to comment.