Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow up for #31973 #32025

Merged
merged 25 commits into from
Jul 25, 2024
138 changes: 77 additions & 61 deletions tests/generation/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import tempfile
import unittest
import warnings
from pathlib import Path

from huggingface_hub import HfFolder, delete_repo
from parameterized import parameterized
from requests.exceptions import HTTPError

from transformers import AutoConfig, GenerationConfig
from transformers.generation import GenerationMode
Expand Down Expand Up @@ -228,72 +228,88 @@ def setUpClass(cls):
cls._token = TOKEN
HfFolder.save_token(TOKEN)

@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, repo_id="test-generation-config")
except HTTPError:
pass

try:
delete_repo(token=cls._token, repo_id="valid_org/test-generation-config-org")
except HTTPError:
pass

def test_push_to_hub(self):
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub("test-generation-config", token=self._token)

new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))

@staticmethod
def _try_delete_repo(repo_id, token):
try:
# Reset repo
delete_repo(token=self._token, repo_id="test-generation-config")
delete_repo(repo_id=repo_id, token=token)
except: # noqa E722
pass

# Push to hub via save_pretrained
def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, repo_id="test-generation-config", push_to_hub=True, token=self._token)

new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
try:
tmp_repo = f"{USER}/test-generation-config-{Path(tmp_dir).name}"
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub(tmp_repo, token=self._token)

new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)

def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-generation-config-{Path(tmp_dir).name}"
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
# Push to hub via save_pretrained
config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)

new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)

def test_push_to_hub_in_organization(self):
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub("valid_org/test-generation-config-org", token=self._token)

new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))

try:
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-generation-config-org")
except: # noqa E722
pass

# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(
tmp_dir, repo_id="valid_org/test-generation-config-org", push_to_hub=True, token=self._token
)

new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
try:
tmp_repo = f"valid_org/test-generation-config-org-{Path(tmp_dir).name}"
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub(tmp_repo, token=self._token)

new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)

def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-generation-config-org-{Path(tmp_dir).name}"
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
# Push to hub via save_pretrained
config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)

new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
120 changes: 62 additions & 58 deletions tests/models/auto/test_processor_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
import unittest
from pathlib import Path
from shutil import copyfile
from uuid import uuid4

from huggingface_hub import HfFolder, Repository, create_repo, delete_repo
from requests.exceptions import HTTPError

import transformers
from transformers import (
Expand Down Expand Up @@ -374,69 +372,73 @@ def setUpClass(cls):
cls._token = TOKEN
HfFolder.save_token(TOKEN)

@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, repo_id="test-processor")
except HTTPError:
pass

try:
delete_repo(token=cls._token, repo_id="valid_org/test-processor-org")
except HTTPError:
pass

@staticmethod
def _try_delete_repo(repo_id, token):
try:
delete_repo(token=cls._token, repo_id="test-dynamic-processor")
except HTTPError:
# Reset repo
delete_repo(repo_id=repo_id, token=token)
except: # noqa E722
pass

def test_push_to_hub(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained(os.path.join(tmp_dir, "test-processor"), push_to_hub=True, token=self._token)

new_processor = Wav2Vec2Processor.from_pretrained(f"{USER}/test-processor")
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())

def test_push_to_hub_in_organization(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)

try:
tmp_repo = f"{USER}/test-processor-{Path(tmp_dir).name}"
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
# Push to hub via save_pretrained
processor.save_pretrained(tmp_repo, repo_id=tmp_repo, push_to_hub=True, token=self._token)

new_processor = Wav2Vec2Processor.from_pretrained(tmp_repo)
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)

def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained(
os.path.join(tmp_dir, "test-processor-org"),
push_to_hub=True,
token=self._token,
organization="valid_org",
)
try:
tmp_repo = f"valid_org/test-processor-org-{Path(tmp_dir).name}"
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)

# Push to hub via save_pretrained
processor.save_pretrained(
tmp_dir,
repo_id=tmp_repo,
push_to_hub=True,
token=self._token,
)

new_processor = Wav2Vec2Processor.from_pretrained("valid_org/test-processor-org")
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
new_processor = Wav2Vec2Processor.from_pretrained(tmp_repo)
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)

def test_push_to_hub_dynamic_processor(self):
Copy link
Collaborator Author

@ydshieh ydshieh Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For def test_push_to_hub_dynamic_processor(self):

Since @Rocketknight1 you changed it a few days ago, maybe you would like to review this (small) part too?

@amyeroberts I got your approval in general with a POC. I have applied the change to all places. But for this bock, maybe you can take a final look too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will merge this evening if no other comment saying something extra to do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as well!

CustomFeatureExtractor.register_for_auto_class()
CustomTokenizer.register_for_auto_class()
CustomProcessor.register_for_auto_class()
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-dynamic-processor-{Path(tmp_dir).name}"

feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
CustomFeatureExtractor.register_for_auto_class()
CustomTokenizer.register_for_auto_class()
CustomProcessor.register_for_auto_class()

with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = CustomTokenizer(vocab_file)
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)

processor = CustomProcessor(feature_extractor, tokenizer)
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = CustomTokenizer(vocab_file)

random_repo_id = f"{USER}/test-dynamic-processor-{uuid4()}"
try:
with tempfile.TemporaryDirectory() as tmp_dir:
create_repo(random_repo_id, token=self._token)
repo = Repository(tmp_dir, clone_from=random_repo_id, token=self._token)
processor = CustomProcessor(feature_extractor, tokenizer)

create_repo(tmp_repo, token=self._token)
repo = Repository(tmp_dir, clone_from=tmp_repo, token=self._token)
processor.save_pretrained(tmp_dir)

# This has added the proper auto_map field to the feature extractor config
Expand Down Expand Up @@ -466,8 +468,10 @@ def test_push_to_hub_dynamic_processor(self):

repo.push_to_hub()

new_processor = AutoProcessor.from_pretrained(random_repo_id, trust_remote_code=True)
# Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module
self.assertEqual(new_processor.__class__.__name__, "CustomProcessor")
finally:
delete_repo(repo_id=random_repo_id)
new_processor = AutoProcessor.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module
self.assertEqual(new_processor.__class__.__name__, "CustomProcessor")

finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
Loading
Loading