-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
register tensor and PIL kernel the same way as datapoints #7797
Changes from 7 commits
94cf81e
1eaf82e
390021f
fe73f50
8c0a8ea
5599729
f61f595
7f67e52
dca577e
fe64e4c
8557cdf
88904db
29d0cd4
4068251
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,7 @@ | |
from torchvision.transforms._functional_tensor import _max_value as get_max_value | ||
from torchvision.transforms.functional import pil_modes_mapping | ||
from torchvision.transforms.v2 import functional as F | ||
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY | ||
from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
|
@@ -173,59 +173,32 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs): | |
dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) | ||
|
||
|
||
def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs): | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is | ||
preserved in doing so. For bounding boxes also checks that the format is preserved. | ||
""" | ||
input_type = type(input) | ||
|
||
if isinstance(input, datapoints.Datapoint): | ||
wrapped_kernel = _KERNEL_REGISTRY[dispatcher][input_type] | ||
|
||
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the | ||
# proper kernel was wrapped | ||
if hasattr(wrapped_kernel, "__wrapped__"): | ||
assert wrapped_kernel.__wrapped__ is kernel | ||
|
||
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__) | ||
with mock.patch.dict(_KERNEL_REGISTRY[dispatcher], values={input_type: spy}): | ||
output = dispatcher(input, *args, **kwargs) | ||
|
||
spy.assert_called_once() | ||
else: | ||
with mock.patch(f"{dispatcher.__module__}.{kernel.__name__}", wraps=kernel) as spy: | ||
output = dispatcher(input, *args, **kwargs) | ||
|
||
spy.assert_called_once() | ||
|
||
assert isinstance(output, input_type) | ||
|
||
if isinstance(input, datapoints.BoundingBoxes): | ||
assert output.format == input.format | ||
|
||
|
||
def check_dispatcher( | ||
dispatcher, | ||
# TODO: remove this parameter | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We no longer need this parameter. However, we previously parametrized over it together with the function to create the input. Thus, removing it is chore that I'll deal with after release. It doesn't have any effect on the runtime, because the number of tests stays exactly the same after this parameter is removed. |
||
kernel, | ||
input, | ||
*args, | ||
check_scripted_smoke=True, | ||
check_dispatch=True, | ||
**kwargs, | ||
): | ||
unknown_input = object() | ||
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): | ||
dispatcher(unknown_input, *args, **kwargs) | ||
|
||
with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy: | ||
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): | ||
dispatcher(unknown_input, *args, **kwargs) | ||
output = dispatcher(input, *args, **kwargs) | ||
|
||
spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}") | ||
|
||
assert isinstance(output, type(input)) | ||
|
||
if isinstance(input, datapoints.BoundingBoxes): | ||
assert output.format == input.format | ||
|
||
if check_scripted_smoke: | ||
_check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs) | ||
|
||
if check_dispatch: | ||
_check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs) | ||
|
||
|
||
def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): | ||
"""Checks if the signature of the dispatcher matches the kernel signature.""" | ||
|
@@ -412,18 +385,20 @@ def transform(bbox): | |
|
||
|
||
@pytest.mark.parametrize( | ||
("dispatcher", "registered_datapoint_clss"), | ||
("dispatcher", "registered_input_types"), | ||
[(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()], | ||
) | ||
def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss): | ||
def test_exhaustive_kernel_registration(dispatcher, registered_input_types): | ||
missing = { | ||
torch.Tensor, | ||
PIL.Image.Image, | ||
datapoints.Image, | ||
datapoints.BoundingBoxes, | ||
datapoints.Mask, | ||
datapoints.Video, | ||
} - registered_datapoint_clss | ||
} - registered_input_types | ||
if missing: | ||
names = sorted(f"datapoints.{cls.__name__}" for cls in missing) | ||
names = sorted(str(t) for t in missing) | ||
raise AssertionError( | ||
"\n".join( | ||
[ | ||
|
@@ -1753,11 +1728,6 @@ def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device, | |
F.to_dtype, | ||
kernel, | ||
make_input(dtype=input_dtype, device=device), | ||
# TODO: we could leave check_dispatch to True but it currently fails | ||
# in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints. | ||
# We should be able to put this back if we change the dispatch | ||
# mechanism e.g. via https://github.com/pytorch/vision/pull/7733 | ||
check_dispatch=False, | ||
dtype=output_dtype, | ||
scale=scale, | ||
) | ||
|
@@ -2185,7 +2155,9 @@ def test_unsupported_types(self, dispatcher, make_input): | |
|
||
class TestRegisterKernel: | ||
@pytest.mark.parametrize("dispatcher", (F.resize, "resize")) | ||
def test_register_kernel(self, dispatcher): | ||
def test_register_kernel(self, mocker, dispatcher): | ||
mocker.patch.dict(_KERNEL_REGISTRY, values={F.resize: _KERNEL_REGISTRY[F.resize]}, clear=True) | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
class CustomDatapoint(datapoints.Datapoint): | ||
pass | ||
|
||
|
@@ -2208,9 +2180,96 @@ def new_resize(dp, *args, **kwargs): | |
t(torch.rand(3, 10, 10)).shape == (3, 224, 224) | ||
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224) | ||
|
||
def test_bad_disaptcher_name(self): | ||
class CustomDatapoint(datapoints.Datapoint): | ||
pass | ||
def test_errors(self, mocker): | ||
mocker.patch.dict(_KERNEL_REGISTRY, clear=True) | ||
|
||
with pytest.raises(ValueError, match="Could not find dispatcher with name"): | ||
F.register_kernel("bad_name", CustomDatapoint) | ||
F.register_kernel("bad_name", datapoints.Image) | ||
|
||
with pytest.raises(ValueError, match="Kernels can only be registered on dispatchers"): | ||
F.register_kernel(datapoints.Image, F.resize) | ||
|
||
with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"): | ||
F.register_kernel(F.resize, object) | ||
|
||
F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
with pytest.raises(ValueError, match="already has a kernel registered for type"): | ||
F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) | ||
|
||
|
||
class TestGetKernel: | ||
def make_and_register_kernel(self, dispatcher, input_type): | ||
return _register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)(object()) | ||
|
||
@pytest.fixture | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def dispatcher_and_kernels(self, mocker): | ||
mocker.patch.dict(_KERNEL_REGISTRY, clear=True) | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
dispatcher = object() | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
kernels = { | ||
cls: self.make_and_register_kernel(dispatcher, cls) | ||
for cls in [ | ||
torch.Tensor, | ||
PIL.Image.Image, | ||
datapoints.Image, | ||
datapoints.BoundingBoxes, | ||
datapoints.Mask, | ||
datapoints.Video, | ||
] | ||
} | ||
|
||
yield dispatcher, kernels | ||
|
||
def test_unsupported_types(self, dispatcher_and_kernels): | ||
dispatcher, _ = dispatcher_and_kernels | ||
|
||
class MyTensor(torch.Tensor): | ||
pass | ||
|
||
class MyPILImage(PIL.Image.Image): | ||
pass | ||
|
||
for input_type in [str, int, object, MyTensor, MyPILImage]: | ||
with pytest.raises(TypeError, match=re.escape(str(input_type))): | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_get_kernel(dispatcher, input_type) | ||
|
||
def test_exact_match(self, dispatcher_and_kernels): | ||
dispatcher, kernels = dispatcher_and_kernels | ||
|
||
for input_type, kernel in kernels.items(): | ||
assert _get_kernel(dispatcher, input_type) is kernel | ||
|
||
def test_builtin_datapoint_subclass(self, dispatcher_and_kernels): | ||
dispatcher, kernels = dispatcher_and_kernels | ||
|
||
class MyImage(datapoints.Image): | ||
pass | ||
|
||
class MyBoundingBoxes(datapoints.BoundingBoxes): | ||
pass | ||
|
||
class MyMask(datapoints.Mask): | ||
pass | ||
|
||
class MyVideo(datapoints.Video): | ||
pass | ||
|
||
assert _get_kernel(dispatcher, MyImage) is kernels[datapoints.Image] | ||
assert _get_kernel(dispatcher, MyBoundingBoxes) is kernels[datapoints.BoundingBoxes] | ||
assert _get_kernel(dispatcher, MyMask) is kernels[datapoints.Mask] | ||
assert _get_kernel(dispatcher, MyVideo) is kernels[datapoints.Video] | ||
|
||
def test_datapoint_subclass(self, dispatcher_and_kernels): | ||
dispatcher, _ = dispatcher_and_kernels | ||
|
||
class MyDatapoint(datapoints.Datapoint): | ||
pass | ||
|
||
# Note that this will be an error in the future | ||
assert _get_kernel(dispatcher, MyDatapoint) is _noop | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
kernel = self.make_and_register_kernel(dispatcher, MyDatapoint) | ||
|
||
assert _get_kernel(dispatcher, MyDatapoint) is kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this test is part of the legacy framework, I was to lazy to handle this more elegantly.