Skip to content

Commit

Permalink
fix 8
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Jul 24, 2024
1 parent df8b75e commit 114a765
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions tests/utils/test_image_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,27 @@ def test_push_to_hub_in_organization_via_save_pretrained(self):
self._try_delete_repo(repo_id=tmp_repo, token=self._token)

def test_push_to_hub_dynamic_image_processor(self):
CustomImageProcessor.register_for_auto_class()
image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-dynamic-image-processor-{Path(tmp_dir).name}"
CustomImageProcessor.register_for_auto_class()
image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)

image_processor.push_to_hub("test-dynamic-image-processor", token=self._token)
image_processor.push_to_hub(tmp_repo, token=self._token)

# This has added the proper auto_map field to the config
self.assertDictEqual(
image_processor.auto_map,
{"AutoImageProcessor": "custom_image_processing.CustomImageProcessor"},
)
# This has added the proper auto_map field to the config
self.assertDictEqual(
image_processor.auto_map,
{"AutoImageProcessor": "custom_image_processing.CustomImageProcessor"},
)

new_image_processor = AutoImageProcessor.from_pretrained(
f"{USER}/test-dynamic-image-processor", trust_remote_code=True
)
# Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
new_image_processor = AutoImageProcessor.from_pretrained(
tmp_repo, trust_remote_code=True
)
# Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
finally:
self._try_delete_repo(repo_id=tmp_repo, token=self._token)


class ImageProcessingUtilsTester(unittest.TestCase):
Expand Down

0 comments on commit 114a765

Please sign in to comment.