Skip to content

Commit

Permalink
Update speech recognition example
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss committed Jan 19, 2024
1 parent 60a0d29 commit eed2d38
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 21 deletions.
28 changes: 25 additions & 3 deletions examples/speech-recognition/run_speech_recognition_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,20 @@ class ModelArguments:
ctc_loss_reduction: Optional[str] = field(
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
)
ctc_zero_infinity: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly"
" occur when the inputs are too short to be aligned to the targets."
},
)
add_adapter: Optional[bool] = field(
default=False,
metadata={
"help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very "
"useful to downsample the output length."
},
)


@dataclass
Expand Down Expand Up @@ -315,11 +329,14 @@ class DataCollatorCTCWithPadding:
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None
feature_extractor_input_name: Optional[str] = "input_values"

def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features]
input_features = [
{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features
]
label_features = [{"input_ids": feature["labels"]} for feature in features]

batch = self.processor.pad(
Expand Down Expand Up @@ -612,9 +629,11 @@ def remove_special_characters(batch):
"gradient_checkpointing": training_args.gradient_checkpointing,
"layerdrop": model_args.layerdrop,
"ctc_loss_reduction": model_args.ctc_loss_reduction,
"ctc_zero_infinity": model_args.ctc_zero_infinity,
"pad_token_id": tokenizer.pad_token_id,
"vocab_size": len(tokenizer),
"activation_dropout": model_args.activation_dropout,
"add_adapter": model_args.add_adapter,
}
)

Expand Down Expand Up @@ -653,6 +672,7 @@ def remove_special_characters(batch):
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
audio_column_name = data_args.audio_column_name
num_workers = data_args.preprocessing_num_workers
feature_extractor_input_name = feature_extractor.model_input_names[0]

# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
phoneme_language = data_args.phoneme_language
Expand All @@ -664,8 +684,9 @@ def prepare_dataset(batch):
sample = batch[audio_column_name]

inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
batch["input_values"] = inputs.input_values[0]
batch["input_length"] = len(batch["input_values"])
batch[feature_extractor_input_name] = getattr(inputs, feature_extractor_input_name)[0]
# take length of raw audio waveform
batch["input_length"] = len(sample["array"].squeeze())

# encode targets
additional_kwargs = {}
Expand Down Expand Up @@ -748,6 +769,7 @@ def compute_metrics(pred):
# Instantiate custom data collator
data_collator = DataCollatorCTCWithPadding(
processor=processor,
feature_extractor_input_name=feature_extractor_input_name,
pad_to_multiple_of=int(max_input_length),
pad_to_multiple_of_labels=500,
)
Expand Down
39 changes: 21 additions & 18 deletions tests/example_diff/run_speech_recognition_ctc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,37 @@
>
> require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
>
141d147
145c152
< "help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very"
---
> "help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very "
155d161
<
251c257
265c271
< "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
---
> "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
390c396
407c413
< parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
---
> parser = HfArgumentParser((ModelArguments, DataTrainingArguments, GaudiTrainingArguments))
433a440,445
450a457,462
> gaudi_config = GaudiConfig.from_pretrained(
> training_args.gaudi_config_name,
> cache_dir=model_args.cache_dir,
> use_auth_token=True if data_args.use_auth_token else None,
> )
>
434a447
451a464
> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
436,437c449,451
453,454c466,468
< f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
< f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
---
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
450,456c464,469
467,473c481,486
< if training_args.do_train:
< raw_datasets["train"] = load_dataset(
< data_args.dataset_name,
Expand All @@ -71,7 +75,7 @@
> split=data_args.train_split_name,
> token=data_args.token,
> )
458,463c471,476
475,480c488,493
< if data_args.audio_column_name not in raw_datasets["train"].column_names:
< raise ValueError(
< f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
Expand All @@ -85,7 +89,7 @@
> " Make sure to set `--audio_column_name` to the correct audio column - one of"
> f" {', '.join(raw_datasets['train'].column_names)}."
> )
465,470c478,483
482,487c495,500
< if data_args.text_column_name not in raw_datasets["train"].column_names:
< raise ValueError(
< f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
Expand All @@ -99,33 +103,32 @@
> "Make sure to set `--text_column_name` to the correct text column - one of "
> f"{', '.join(raw_datasets['train'].column_names)}."
> )
472,473c485,486
489,490c502,503
< if data_args.max_train_samples is not None:
< raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
---
> if data_args.max_train_samples is not None:
> raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
491c504
508c521
< f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
---
> f'[{"".join(data_args.chars_to_ignore).replace(" ", "")}]' if data_args.chars_to_ignore is not None else None
628a642,646
647a661,665
> raise RuntimeError(
> f"The dataset sampling rate ({dataset_sampling_rate}) is different from the feature extractor one"
> f" ({feature_extractor.sampling_rate}).Data resampling should be done. The Datasets library does not"
> " support it on HPUs yet."
> )
731c749,753
< data_collator = DataCollatorCTCWithPadding(processor=processor)
753c771,774
< processor=processor, feature_extractor_input_name=feature_extractor_input_name
---
> data_collator = DataCollatorCTCWithPadding(
> processor=processor,
> feature_extractor_input_name=feature_extractor_input_name,
> pad_to_multiple_of=int(max_input_length),
> pad_to_multiple_of_labels=500,
> )
734c756
757c778
< trainer = Trainer(
---
> trainer = GaudiTrainer(
735a758
758a780
> gaudi_config=gaudi_config,

0 comments on commit eed2d38

Please sign in to comment.