Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up jpeg tests #7820

Merged
merged 2 commits into from
Aug 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 0 additions & 73 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,77 +422,6 @@ def test_encode_jpeg_errors():
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))


def _collect_if(cond):
# TODO: remove this once test_encode_jpeg_reference and test_write_jpeg_reference
# are removed
def _inner(test_func):
if cond:
return test_func
else:
return pytest.mark.dont_collect(test_func)

return _inner


@_collect_if(cond=False)
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
def test_encode_jpeg_reference(img_path):
# This test is *wrong*.
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it
# starts encoding the torchvision version from an image that comes from
# decode_jpeg, which can yield different results from pil.decode (see
# test_decode... which uses a high tolerance).
# Instead, we should start encoding from the exact same decoded image, for a
# valid comparison. This is done in test_encode_jpeg, but unfortunately
# these more correct tests fail on windows (probably because of a difference
# in libjpeg) between torchvision and PIL.
# FIXME: make the correct tests pass on windows and remove this.
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, "jpeg_write")
expected_file = os.path.join(write_folder, f"{filename}_pil.jpg")
img = decode_jpeg(read_file(img_path))

with open(expected_file, "rb") as f:
pil_bytes = f.read()
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
for src_img in [img, img.contiguous()]:
# PIL sets jpeg quality to 75 by default
jpeg_bytes = encode_jpeg(src_img, quality=75)
assert_equal(jpeg_bytes, pil_bytes)


@_collect_if(cond=False)
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
def test_write_jpeg_reference(img_path, tmpdir):
# FIXME: Remove this eventually, see test_encode_jpeg_reference
data = read_file(img_path)
img = decode_jpeg(data)

basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(tmpdir, f"{filename}_torch.jpg")
pil_jpeg = os.path.join(basedir, "jpeg_write", f"{filename}_pil.jpg")

write_jpeg(img, torch_jpeg, quality=75)

with open(torch_jpeg, "rb") as f:
torch_bytes = f.read()

with open(pil_jpeg, "rb") as f:
pil_bytes = f.read()

assert_equal(torch_bytes, pil_bytes)


# TODO: Remove the skip. See https://github.com/pytorch/vision/issues/5162.
@pytest.mark.skip("this test fails because PIL uses libjpeg-turbo")
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
Expand All @@ -511,8 +440,6 @@ def test_encode_jpeg(img_path):
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


# TODO: Remove the skip. See https://github.com/pytorch/vision/issues/5162.
@pytest.mark.skip("this test fails because PIL uses libjpeg-turbo")
@pytest.mark.parametrize(
"img_path",
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
Expand Down