diff --git a/test/common_utils.py b/test/common_utils.py index b8b02828683..af8f5783263 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -503,12 +503,13 @@ def make_image( device="cpu", memory_format=torch.contiguous_format, ): + dtype = dtype or torch.uint8 max_value = get_max_value(dtype) data = torch.testing.make_tensor( (*batch_dims, get_num_channels(color_space), *size), low=0, high=max_value, - dtype=dtype or torch.uint8, + dtype=dtype, device=device, memory_format=memory_format, )