diff --git a/test/test_utils.py b/test/test_utils.py index 32b3db596310..521774c03776 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,7 +9,7 @@ import torch import torchvision.transforms.functional as F import torchvision.utils as utils -from common_utils import assert_equal +from common_utils import assert_equal, cpu_and_gpu from PIL import __version__ as PILLOW_VERSION, Image, ImageColor @@ -203,12 +203,13 @@ def test_draw_no_boxes(): ], ) @pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1)) -def test_draw_segmentation_masks(colors, alpha): +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_draw_segmentation_masks(colors, alpha, device): """This test makes sure that masks draw their corresponding color where they should""" num_masks, h, w = 2, 100, 100 dtype = torch.uint8 - img = torch.randint(0, 256, size=(3, h, w), dtype=dtype) - masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool) + img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device) + masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool, device=device) # For testing we enforce that there's no overlap between the masks. The # current behaviour is that the last mask's color will take priority when @@ -234,7 +235,7 @@ def test_draw_segmentation_masks(colors, alpha): for mask, color in zip(masks, colors): if isinstance(color, str): color = ImageColor.getrgb(color) - color = torch.tensor(color, dtype=dtype) + color = torch.tensor(color, dtype=dtype, device=device) if alpha == 1: assert (out[:, mask] == color[:, None]).all()