Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] update whisper fine-tuning #29938

Merged
merged 5 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions examples/pytorch/speech-recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ python run_speech_recognition_seq2seq.py \
--dataset_name="mozilla-foundation/common_voice_11_0" \
--dataset_config_name="hi" \
--language="hindi" \
--task="transcribe" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--max_steps="5000" \
Expand All @@ -384,12 +385,10 @@ python run_speech_recognition_seq2seq.py \
--save_steps="1000" \
--generation_max_length="225" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--freeze_feature_encoder="False" \
--gradient_checkpointing \
--group_by_length \
--fp16 \
--overwrite_output_dir \
--do_train \
Expand All @@ -399,7 +398,8 @@ python run_speech_recognition_seq2seq.py \
```
On a single V100, training should take approximately 8 hours, with a final cross-entropy loss of **1e-4** and word error rate of **32.6%**.

If training on a different language, you should be sure to change the `language` argument. The `language` argument should be omitted for English speech recognition.
If training on a different language, you should be sure to change the `language` argument. The `language` and `task`
arguments should be omitted for English speech recognition.

#### Multi GPU Whisper Training
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 2 GPU devices in half-precision:
Expand All @@ -410,6 +410,7 @@ torchrun \
--dataset_name="mozilla-foundation/common_voice_11_0" \
--dataset_config_name="hi" \
--language="hindi" \
--task="transcribe" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--max_steps="5000" \
Expand All @@ -425,12 +426,10 @@ torchrun \
--save_steps="1000" \
--generation_max_length="225" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--freeze_feature_encoder="False" \
--gradient_checkpointing \
--group_by_length \
--fp16 \
--overwrite_output_dir \
--do_train \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,16 @@ class ModelArguments:
)
forced_decoder_ids: List[List[int]] = field(
default=None,
metadata={
metadata={"help": "Deprecated. Please use the `language` and `task` arguments instead."},
)
suppress_tokens: List[int] = field(
default=None, metadata={
"help": (
"A list of pairs of integers which indicates a mapping from generation indices to token indices "
"that will be forced before sampling. For example, [[0, 123]] means the first generated token "
"will always be a token of index 123."
"Deprecated. The use of `suppress_tokens` should not be required for the majority of fine-tuning examples."
"Should you need to use `suppress_tokens`, please manually update them in the fine-tuning script directly."
)
},
)
suppress_tokens: List[int] = field(
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
)
apply_spec_augment: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -400,8 +399,6 @@ def main():
trust_remote_code=model_args.trust_remote_code,
)

config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})

# SpecAugment for whisper models
if getattr(config, "model_type", None) == "whisper":
config.update({"apply_spec_augment": model_args.apply_spec_augment})
Expand Down Expand Up @@ -440,9 +437,35 @@ def main():
model.freeze_encoder()
model.model.encoder.gradient_checkpointing = False

if data_args.language is not None:
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
# We only need to set the language and task ids in a multilingual setting
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
model.generation_config.update(
**{
"language": data_args.language,
"task": data_args.task,
}
)
elif data_args.language is not None:
raise ValueError(
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
"only be set for multilingual checkpoints."
)

# TODO (Sanchit): deprecate these arguments in v4.41
if model_args.forced_decoder_ids is not None:
logger.warning(
"The use of `forced_decoder_ids` is deprecated and will be removed in v4.41."
"Please use the `language` and `task` arguments instead"
)
model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids

if model_args.suppress_tokens is not None:
logger.warning(
"The use of `suppress_tokens` is deprecated and will be removed in v4.41."
"Should you need `suppress_tokens`, please manually set them in the fine-tuning script."
)
model.generation_config.suppress_tokens = model_args.suppress_tokens

# 6. Resample speech dataset if necessary
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
Expand Down
Loading