diff --git a/gallery/plot_custom_datapoints.py b/gallery/plot_custom_datapoints.py index ea757283e86..0a62a991a75 100644 --- a/gallery/plot_custom_datapoints.py +++ b/gallery/plot_custom_datapoints.py @@ -49,7 +49,7 @@ class MyDatapoint(datapoints.Datapoint): from torchvision.transforms.v2 import functional as F -@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint) +@F.register_kernel(functional="hflip", datapoint_cls=MyDatapoint) def hflip_my_datapoint(my_dp, *args, **kwargs): print("Flipping!") out = my_dp.flip(-1) @@ -64,9 +64,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs): # .. note:: # # In our call to ``register_kernel`` above we used a string -# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We +# ``functional="hflip"`` to refer to the functional we want to hook into. We # could also have used the functional *itself*, i.e. -# ``@register_kernel(dispatcher=F.hflip, ...)``. +# ``@register_kernel(functional=F.hflip, ...)``. # # The functionals that you can be hooked into are the ones in # ``torchvision.transforms.v2.functional`` and they are documented in diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 1e78c5ed6c5..fa1ed05b84b 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -163,25 +163,25 @@ def check_kernel( _check_kernel_batched_vs_unbatched(kernel, input, *args, **kwargs, **_to_tolerances(check_batched_vs_unbatched)) -def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs): - """Checks if the dispatcher can be scripted and the scripted version can be called without error.""" +def _check_functional_scripted_smoke(functional, input, *args, **kwargs): + """Checks if the functional can be scripted and the scripted version can be called without error.""" if not isinstance(input, datapoints.Image): return - dispatcher_scripted = _script(dispatcher) + functional_scripted = _script(functional) with ignore_jit_no_profile_information_warning(): - dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) + functional_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) -def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwargs): +def check_functional(functional, input, *args, check_scripted_smoke=True, **kwargs): unknown_input = object() with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): - dispatcher(unknown_input, *args, **kwargs) + functional(unknown_input, *args, **kwargs) with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy: - output = dispatcher(input, *args, **kwargs) + output = functional(input, *args, **kwargs) - spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}") + spy.assert_any_call(f"{functional.__module__}.{functional.__name__}") assert isinstance(output, type(input)) @@ -189,41 +189,41 @@ def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwar assert output.format == input.format if check_scripted_smoke: - _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs) + _check_functional_scripted_smoke(functional, input, *args, **kwargs) -def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): - """Checks if the signature of the dispatcher matches the kernel signature.""" - dispatcher_params = list(inspect.signature(dispatcher).parameters.values())[1:] +def check_functional_kernel_signature_match(functional, *, kernel, input_type): + """Checks if the signature of the functional matches the kernel signature.""" + functional_params = list(inspect.signature(functional).parameters.values())[1:] kernel_params = list(inspect.signature(kernel).parameters.values())[1:] if issubclass(input_type, datapoints.Datapoint): - # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be + # We filter out metadata that is implicitly passed to the functional through the input datapoint, but has to be # explicitly passed to the kernel. explicit_metadata = { datapoints.BoundingBoxes: {"format", "canvas_size"}, } kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())] - dispatcher_params = iter(dispatcher_params) - for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params): + functional_params = iter(functional_params) + for functional_param, kernel_param in zip(functional_params, kernel_params): try: - # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out - # dispatcher parameters that have no kernel equivalent while keeping the order intact. - while dispatcher_param.name != kernel_param.name: - dispatcher_param = next(dispatcher_params) + # In general, the functional parameters are a superset of the kernel parameters. Thus, we filter out + # functional parameters that have no kernel equivalent while keeping the order intact. + while functional_param.name != kernel_param.name: + functional_param = next(functional_params) except StopIteration: raise AssertionError( f"Parameter `{kernel_param.name}` of kernel `{kernel.__name__}` " - f"has no corresponding parameter on the dispatcher `{dispatcher.__name__}`." + f"has no corresponding parameter on the functional `{functional.__name__}`." ) from None if issubclass(input_type, PIL.Image.Image): # PIL kernels often have more correct annotations, since they are not limited by JIT. Thus, we don't check # them in the first place. - dispatcher_param._annotation = kernel_param._annotation = inspect.Parameter.empty + functional_param._annotation = kernel_param._annotation = inspect.Parameter.empty - assert dispatcher_param == kernel_param + assert functional_param == kernel_param def _check_transform_v1_compatibility(transform, input): @@ -482,8 +482,8 @@ def test_kernel_video(self): "make_input", [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, size, make_input): - check_dispatcher( + def test_functional(self, size, make_input): + check_functional( F.resize, make_input(self.INPUT_SIZE), size=size, @@ -502,8 +502,8 @@ def test_dispatcher(self, size, make_input): (F.resize_video, datapoints.Video), ], ) - def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type) @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -608,7 +608,7 @@ def test_pil_interpolation_compat_smoke(self, interpolation, make_input): interpolation=interpolation, ) - def test_dispatcher_pil_antialias_warning(self): + def test_functional_pil_antialias_warning(self): with pytest.warns(UserWarning, match="Anti-alias option is always applied for PIL Image input"): F.resize(make_image_pil(self.INPUT_SIZE), size=self.OUTPUT_SIZES[0], antialias=False) @@ -763,8 +763,8 @@ def test_kernel_video(self): "make_input", [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, make_input): - check_dispatcher(F.horizontal_flip, make_input()) + def test_functional(self, make_input): + check_functional(F.horizontal_flip, make_input()) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -777,8 +777,8 @@ def test_dispatcher(self, make_input): (F.horizontal_flip_video, datapoints.Video), ], ) - def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", @@ -939,8 +939,8 @@ def test_kernel_video(self): "make_input", [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, make_input): - check_dispatcher(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS) + def test_functional(self, make_input): + check_functional(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -953,8 +953,8 @@ def test_dispatcher(self, make_input): (F.affine_video, datapoints.Video), ], ) - def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", @@ -1228,8 +1228,8 @@ def test_kernel_video(self): "make_input", [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, make_input): - check_dispatcher(F.vertical_flip, make_input()) + def test_functional(self, make_input): + check_functional(F.vertical_flip, make_input()) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -1242,8 +1242,8 @@ def test_dispatcher(self, make_input): (F.vertical_flip_video, datapoints.Video), ], ) - def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", @@ -1378,8 +1378,8 @@ def test_kernel_video(self): "make_input", [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, make_input): - check_dispatcher(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS) + def test_functional(self, make_input): + check_functional(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -1392,8 +1392,8 @@ def test_dispatcher(self, make_input): (F.rotate_video, datapoints.Video), ], ) - def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", @@ -1643,8 +1643,8 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("scale", (True, False)) - def test_dispatcher(self, make_input, input_dtype, output_dtype, device, scale): - check_dispatcher( + def test_functional(self, make_input, input_dtype, output_dtype, device, scale): + check_functional( F.to_dtype, make_input(dtype=input_dtype, device=device), dtype=output_dtype, @@ -1810,8 +1810,8 @@ def test_kernel(self, kernel, make_input, dtype, device): check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) - def test_dispatcher(self, make_input): - check_dispatcher(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) + def test_functional(self, make_input): + check_functional(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -1822,8 +1822,8 @@ def test_dispatcher(self, make_input): (F.adjust_brightness_video, datapoints.Video), ], ) - def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type) @pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS) def test_image_correctness(self, brightness_factor): @@ -2042,7 +2042,7 @@ def test_get_num_frames(self, kernel, make_input): assert kernel(input) == F.get_num_frames(input) == num_frames @pytest.mark.parametrize( - ("dispatcher", "make_input"), + ("functional", "make_input"), [ (F.get_dimensions, make_bounding_box), (F.get_dimensions, make_detection_mask), @@ -2057,22 +2057,22 @@ def test_get_num_frames(self, kernel, make_input): (F.get_num_frames, make_segmentation_mask), ], ) - def test_unsupported_types(self, dispatcher, make_input): + def test_unsupported_types(self, functional, make_input): input = make_input() with pytest.raises(TypeError, match=re.escape(str(type(input)))): - dispatcher(input) + functional(input) class TestRegisterKernel: - @pytest.mark.parametrize("dispatcher", (F.resize, "resize")) - def test_register_kernel(self, dispatcher): + @pytest.mark.parametrize("functional", (F.resize, "resize")) + def test_register_kernel(self, functional): class CustomDatapoint(datapoints.Datapoint): pass kernel_was_called = False - @F.register_kernel(dispatcher, CustomDatapoint) + @F.register_kernel(functional, CustomDatapoint) def new_resize(dp, *args, **kwargs): nonlocal kernel_was_called kernel_was_called = True @@ -2090,10 +2090,10 @@ def new_resize(dp, *args, **kwargs): t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224) def test_errors(self): - with pytest.raises(ValueError, match="Could not find dispatcher with name"): + with pytest.raises(ValueError, match="Could not find functional with name"): F.register_kernel("bad_name", datapoints.Image) - with pytest.raises(ValueError, match="Kernels can only be registered on dispatchers"): + with pytest.raises(ValueError, match="Kernels can only be registered on functionals"): F.register_kernel(datapoints.Image, F.resize) with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"): @@ -2115,7 +2115,7 @@ def resize_custom_datapoint(): class TestGetKernel: - # We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination + # We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination # would also be fine KERNELS = { torch.Tensor: F.resize_image_tensor, @@ -2139,7 +2139,7 @@ class MyPILImage(PIL.Image.Image): def test_exact_match(self): # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the - # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher + # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional # here, register the kernels without wrapper, and check the exact matching afterwards. def resize_with_pure_kernels(): pass @@ -2151,7 +2151,7 @@ def resize_with_pure_kernels(): def test_builtin_datapoint_subclass(self): # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the - # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher + # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional # here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched # to the kernel of the corresponding superclass def resize_with_pure_kernels(): @@ -2217,8 +2217,8 @@ def test_kernel(self, kernel, make_input, dtype, device): check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION) @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) - def test_dispatcher(self, make_input): - check_dispatcher(F.permute_channels, make_input(), permutation=self._DEFAULT_PERMUTATION) + def test_functional(self, make_input): + check_functional(F.permute_channels, make_input(), permutation=self._DEFAULT_PERMUTATION) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -2229,8 +2229,8 @@ def test_dispatcher(self, make_input): (F.permute_channels_video, datapoints.Video), ], ) - def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type) def reference_image_correctness(self, image, permutation): channel_images = image.split(1, dim=-3) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 9be7a40e8ca..844e0321e0c 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -91,13 +91,13 @@ def __init__( self._log_ratio = torch.log(torch.tensor(self.ratio)) - def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): warnings.warn( f"{type(self).__name__}() is currently passing through inputs of type " f"datapoints.{type(inpt).__name__}. This will likely change in the future." ) - return super()._call_kernel(dispatcher, inpt, *args, **kwargs) + return super()._call_kernel(functional, inpt, *args, **kwargs) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: img_c, img_h, img_w = query_chw(flat_inputs) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index b209140614e..b28fad6eabc 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -358,13 +358,13 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): warnings.warn( f"{type(self).__name__}() is currently passing through inputs of type " f"datapoints.{type(inpt).__name__}. This will likely change in the future." ) - return super()._call_kernel(dispatcher, inpt, *args, **kwargs) + return super()._call_kernel(functional, inpt, *args, **kwargs) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.five_crop, inpt, self.size) @@ -405,13 +405,13 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.vertical_flip = vertical_flip - def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): warnings.warn( f"{type(self).__name__}() is currently passing through inputs of type " f"datapoints.{type(inpt).__name__}. This will likely change in the future." ) - return super()._call_kernel(dispatcher, inpt, *args, **kwargs) + return super()._call_kernel(functional, inpt, *args, **kwargs) def _check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index 5a310ddbd4c..d4ee8af556d 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -30,8 +30,8 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict() - def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: - kernel = _get_kernel(dispatcher, type(inpt), allow_passthrough=True) + def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + kernel = _get_kernel(functional, type(inpt), allow_passthrough=True) return kernel(inpt, *args, **kwargs) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 82891b8cc8b..aed1133020f 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -203,7 +203,7 @@ def convert_format_bounding_boxes( new_format: Optional[BoundingBoxFormat] = None, inplace: bool = False, ) -> torch.Tensor: - # This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor + # This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for simple tensor # inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the # default error that would be thrown if `new_format` had no default value. diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index d9609a52e1e..95145beee4d 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -12,7 +12,7 @@ def is_simple_tensor(inpt: Any) -> bool: return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint) -# {dispatcher: {input_type: type_specific_kernel}} +# {functional: {input_type: type_specific_kernel}} _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {} @@ -27,10 +27,10 @@ def wrapper(inpt, *args, **kwargs): return wrapper -def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True): - registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) +def _register_kernel_internal(functional, input_type, *, datapoint_wrapper=True): + registry = _KERNEL_REGISTRY.setdefault(functional, {}) if input_type in registry: - raise ValueError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.") + raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.") def decorator(kernel): registry[input_type] = ( @@ -43,14 +43,14 @@ def decorator(kernel): return decorator -def _name_to_dispatcher(name): +def _name_to_functional(name): import torchvision.transforms.v2.functional # noqa try: return getattr(torchvision.transforms.v2.functional, name) except AttributeError: raise ValueError( - f"Could not find dispatcher with name '{name}' in torchvision.transforms.v2.functional." + f"Could not find functional with name '{name}' in torchvision.transforms.v2.functional." ) from None @@ -59,21 +59,21 @@ def _name_to_dispatcher(name): } -def register_kernel(dispatcher, datapoint_cls): - """Decorate a kernel to register it for a dispatcher and a (custom) datapoint type. +def register_kernel(functional, datapoint_cls): + """Decorate a kernel to register it for a functional and a (custom) datapoint type. See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage details. """ - if isinstance(dispatcher, str): - dispatcher = _name_to_dispatcher(name=dispatcher) + if isinstance(functional, str): + functional = _name_to_functional(name=functional) elif not ( - callable(dispatcher) - and getattr(dispatcher, "__module__", "").startswith("torchvision.transforms.v2.functional") + callable(functional) + and getattr(functional, "__module__", "").startswith("torchvision.transforms.v2.functional") ): raise ValueError( - f"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, " - f"but got {dispatcher}." + f"Kernels can only be registered on functionals from the torchvision.transforms.v2.functional namespace, " + f"but got {functional}." ) if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)): @@ -85,13 +85,13 @@ def register_kernel(dispatcher, datapoint_cls): if datapoint_cls in _BUILTIN_DATAPOINT_TYPES: raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}") - return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) + return _register_kernel_internal(functional, datapoint_cls, datapoint_wrapper=False) -def _get_kernel(dispatcher, input_type, *, allow_passthrough=False): - registry = _KERNEL_REGISTRY.get(dispatcher) +def _get_kernel(functional, input_type, *, allow_passthrough=False): + registry = _KERNEL_REGISTRY.get(functional) if not registry: - raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.") + raise ValueError(f"No kernel registered for functional {functional.__name__}.") # In case we have an exact type match, we take a shortcut. if input_type in registry: @@ -113,17 +113,17 @@ def _get_kernel(dispatcher, input_type, *, allow_passthrough=False): return lambda inpt, *args, **kwargs: inpt raise TypeError( - f"Dispatcher F.{dispatcher.__name__} supports inputs of type {registry.keys()}, " + f"Functional F.{functional.__name__} supports inputs of type {registry.keys()}, " f"but got {input_type} instead." ) # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop -# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool -def _register_five_ten_crop_kernel_internal(dispatcher, input_type): - registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) +# We could get rid of this by letting _register_kernel_internal take arbitrary functionals rather than wrap_kernel: bool +def _register_five_ten_crop_kernel_internal(functional, input_type): + registry = _KERNEL_REGISTRY.setdefault(functional, {}) if input_type in registry: - raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.") + raise TypeError(f"Functional '{functional}' already has a kernel registered for type '{input_type}'.") def wrap(kernel): @functools.wraps(kernel)