Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Aug 1, 2023
1 parent 68fd513 commit e19cf4a
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2._utils import _get_fill
from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2.utils import query_size

DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
Expand Down
12 changes: 6 additions & 6 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
self,
*,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__()
self.interpolation = _check_interpolation(interpolation)
Expand Down Expand Up @@ -80,7 +80,7 @@ def _apply_image_or_video_transform(
transform_id: str,
magnitude: float,
interpolation: Union[InterpolationMode, int],
fill: Dict[Type, datapoints._FillTypeJIT],
fill: Dict[Union[Type, str], datapoints._FillTypeJIT],
) -> Union[datapoints._ImageType, datapoints._VideoType]:
fill_ = _get_fill(fill, type(image))

Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
Expand Down Expand Up @@ -394,7 +394,7 @@ def __init__(
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
Expand Down Expand Up @@ -467,7 +467,7 @@ def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
Expand Down Expand Up @@ -550,7 +550,7 @@ def __init__(
alpha: float = 1.0,
all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
Expand Down
14 changes: 7 additions & 7 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
Expand Down Expand Up @@ -543,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform):

def __init__(
self,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
Expand Down Expand Up @@ -621,7 +621,7 @@ def __init__(
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
Expand Down Expand Up @@ -703,7 +703,7 @@ def __init__(
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -841,7 +841,7 @@ def __init__(
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
Expand Down Expand Up @@ -960,7 +960,7 @@ def __init__(
distortion_scale: float = 0.5,
p: float = 0.5,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None:
super().__init__(p=p)

Expand Down Expand Up @@ -1062,7 +1062,7 @@ def __init__(
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size:
return arg


def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None:
def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> None:
if isinstance(fill, dict):
for value in fill.values():
_check_fill_arg(value)
Expand All @@ -49,7 +49,7 @@ def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
return fill # type: ignore[return-value]


def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]:
def _setup_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> Dict[Union[Type, str], _FillTypeJIT]:
_check_fill_arg(fill)

if isinstance(fill, dict):
Expand Down

0 comments on commit e19cf4a

Please sign in to comment.