Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] update whisper fine-tuning #29938

Merged
merged 5 commits into from
Apr 26, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,24 @@ def main():
model.freeze_encoder()
model.model.encoder.gradient_checkpointing = False

if data_args.language is not None:
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
# We only need to set the language and task ids in a multilingual setting
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
model.generation_config.update(
**{
"language": data_args.language,
"task": data_args.task,
}
)
elif data_args.language is not None:
raise ValueError(
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
"only be set for multilingual checkpoints."
)

if hasattr(model.generation_config, "forced_decoder_ids"):
# forced decoder ids are now handled entirely by the decoder input ids
model.generation_config.forced_decoder_ids = None

# 6. Resample speech dataset if necessary
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
Expand Down
Loading