diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 69d9442c9c6..f5ea69279a1 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -29,8 +29,8 @@ from torchvision.transforms import functional as legacy_F from torchvision.transforms.v2 import functional as prototype_F -from torchvision.transforms.v2.functional import to_image_pil from torchvision.transforms.v2._utils import _get_fill +from torchvision.transforms.v2.functional import to_image_pil from torchvision.transforms.v2.utils import query_size DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)]) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 4f7fc959bbf..a4023ca2108 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -14,7 +14,7 @@ class FixedSizeCrop(Transform): def __init__( self, size: Union[int, Sequence[int]], - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, padding_mode: str = "constant", ) -> None: super().__init__() diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index fefa1c97a5b..146c8c236ef 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -20,7 +20,7 @@ def __init__( self, *, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ) -> None: super().__init__() self.interpolation = _check_interpolation(interpolation) @@ -80,7 +80,7 @@ def _apply_image_or_video_transform( transform_id: str, magnitude: float, interpolation: Union[InterpolationMode, int], - fill: Dict[Type, datapoints._FillTypeJIT], + fill: Dict[Union[Type, str], datapoints._FillTypeJIT], ) -> Union[datapoints._ImageType, datapoints._VideoType]: fill_ = _get_fill(fill, type(image)) @@ -214,7 +214,7 @@ def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.policy = policy @@ -394,7 +394,7 @@ def __init__( magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops @@ -467,7 +467,7 @@ def __init__( self, num_magnitude_bins: int = 31, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ): super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins @@ -550,7 +550,7 @@ def __init__( alpha: float = 1.0, all_ops: bool = True, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 53d2a236282..c7a1e39286f 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -488,7 +488,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: def __init__( self, padding: Union[int, Sequence[int]], - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -543,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform): def __init__( self, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, side_range: Sequence[float] = (1.0, 4.0), p: float = 0.5, ) -> None: @@ -621,7 +621,7 @@ def __init__( interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, ) -> None: super().__init__() self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) @@ -703,7 +703,7 @@ def __init__( scale: Optional[Sequence[float]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() @@ -841,7 +841,7 @@ def __init__( size: Union[int, Sequence[int]], padding: Optional[Union[int, Sequence[int]]] = None, pad_if_needed: bool = False, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -960,7 +960,7 @@ def __init__( distortion_scale: float = 0.5, p: float = 0.5, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, ) -> None: super().__init__(p=p) @@ -1062,7 +1062,7 @@ def __init__( alpha: Union[float, Sequence[float]] = 50.0, sigma: Union[float, Sequence[float]] = 5.0, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, ) -> None: super().__init__() self.alpha = _setup_float_or_seq(alpha, "alpha", 2) diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index a6a09bbec2b..859586be110 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -27,7 +27,7 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: return arg -def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None: +def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> None: if isinstance(fill, dict): for value in fill.values(): _check_fill_arg(value) @@ -49,7 +49,7 @@ def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT: return fill # type: ignore[return-value] -def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]: +def _setup_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> Dict[Union[Type, str], _FillTypeJIT]: _check_fill_arg(fill) if isinstance(fill, dict):