Skip to content

Commit

Permalink
Clean up jpeg tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Aug 10, 2023
1 parent f2b6f43 commit 1f5f875
Showing 1 changed file with 0 additions and 73 deletions.
73 changes: 0 additions & 73 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,77 +422,6 @@ def test_encode_jpeg_errors():
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))


def _collect_if(cond):
# TODO: remove this once test_encode_jpeg_reference and test_write_jpeg_reference
# are removed
def _inner(test_func):
if cond:
return test_func
else:
return pytest.mark.dont_collect(test_func)

return _inner


@_collect_if(cond=False)
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
def test_encode_jpeg_reference(img_path):
# This test is *wrong*.
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it
# starts encoding the torchvision version from an image that comes from
# decode_jpeg, which can yield different results from pil.decode (see
# test_decode... which uses a high tolerance).
# Instead, we should start encoding from the exact same decoded image, for a
# valid comparison. This is done in test_encode_jpeg, but unfortunately
# these more correct tests fail on windows (probably because of a difference
# in libjpeg) between torchvision and PIL.
# FIXME: make the correct tests pass on windows and remove this.
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, "jpeg_write")
expected_file = os.path.join(write_folder, f"{filename}_pil.jpg")
img = decode_jpeg(read_file(img_path))

with open(expected_file, "rb") as f:
pil_bytes = f.read()
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
for src_img in [img, img.contiguous()]:
# PIL sets jpeg quality to 75 by default
jpeg_bytes = encode_jpeg(src_img, quality=75)
assert_equal(jpeg_bytes, pil_bytes)


@_collect_if(cond=False)
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
def test_write_jpeg_reference(img_path, tmpdir):
# FIXME: Remove this eventually, see test_encode_jpeg_reference
data = read_file(img_path)
img = decode_jpeg(data)

basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(tmpdir, f"{filename}_torch.jpg")
pil_jpeg = os.path.join(basedir, "jpeg_write", f"{filename}_pil.jpg")

write_jpeg(img, torch_jpeg, quality=75)

with open(torch_jpeg, "rb") as f:
torch_bytes = f.read()

with open(pil_jpeg, "rb") as f:
pil_bytes = f.read()

assert_equal(torch_bytes, pil_bytes)


# TODO: Remove the skip. See https://github.com/pytorch/vision/issues/5162.
@pytest.mark.skip("this test fails because PIL uses libjpeg-turbo")
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
Expand All @@ -511,8 +440,6 @@ def test_encode_jpeg(img_path):
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


# TODO: Remove the skip. See https://github.com/pytorch/vision/issues/5162.
@pytest.mark.skip("this test fails because PIL uses libjpeg-turbo")
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
Expand Down

0 comments on commit 1f5f875

Please sign in to comment.