Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Extractors] Fix kwargs to pre-trained #30260

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,17 @@ def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrain
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

feature_extractor = cls(**feature_extractor_dict)

# Update feature_extractor with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(feature_extractor, key):
setattr(feature_extractor, key, value)
if key in feature_extractor_dict:
feature_extractor_dict[key] = value
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)

feature_extractor = cls(**feature_extractor_dict)

logger.info(f"Feature extractor {feature_extractor}")
if return_unused_kwargs:
return feature_extractor, kwargs
Expand Down
14 changes: 14 additions & 0 deletions tests/models/whisper/test_feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@ def test_feat_extract_to_json_file(self):
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)

def test_feat_extract_from_pretrained_kwargs(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)

with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(
tmpdirname, feature_size=2 * self.feat_extract_dict["feature_size"]
)

mel_1 = feat_extract_first.mel_filters
mel_2 = feat_extract_second.mel_filters
self.assertTrue(2 * mel_1.shape[1] == mel_2.shape[1])

def test_call(self):
# Tests that all call wrap to encode_plus and batch_encode_plus
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
Expand Down
Loading