From 772307be7649e1333a933cfaa229dc0dec2fd331 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Thu, 18 Jan 2024 17:01:49 +0000 Subject: [PATCH] Making CTC training example more general (#28582) * add w2v2bert compatibility * Update examples/pytorch/speech-recognition/run_speech_recognition_ctc.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../run_speech_recognition_ctc.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py index 3ca9a2c6f44..35db8631a35 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py @@ -132,10 +132,17 @@ class ModelArguments: ctc_loss_reduction: Optional[str] = field( default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} ) + ctc_zero_infinity: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly" + " occur when the inputs are too short to be aligned to the targets." + }, + ) add_adapter: Optional[bool] = field( default=False, metadata={ - "help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2BERT Encoder. Can be very" + "help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very" "useful to downsample the output length." }, ) @@ -316,11 +323,14 @@ class DataCollatorCTCWithPadding: padding: Union[bool, str] = "longest" pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None + feature_extractor_input_name: Optional[str] = "input_values" def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # split inputs and labels since they have to be of different lengths and need # different padding methods - input_features = [{"input_values": feature["input_values"]} for feature in features] + input_features = [ + {self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features + ] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( @@ -606,6 +616,7 @@ def remove_special_characters(batch): "gradient_checkpointing": training_args.gradient_checkpointing, "layerdrop": model_args.layerdrop, "ctc_loss_reduction": model_args.ctc_loss_reduction, + "ctc_zero_infinity": model_args.ctc_zero_infinity, "pad_token_id": tokenizer.pad_token_id, "vocab_size": len(tokenizer), "activation_dropout": model_args.activation_dropout, @@ -643,6 +654,7 @@ def remove_special_characters(batch): min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate audio_column_name = data_args.audio_column_name num_workers = data_args.preprocessing_num_workers + feature_extractor_input_name = feature_extractor.model_input_names[0] # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification phoneme_language = data_args.phoneme_language @@ -654,8 +666,9 @@ def prepare_dataset(batch): sample = batch[audio_column_name] inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) - batch["input_values"] = inputs.input_values[0] - batch["input_length"] = len(batch["input_values"]) + batch[feature_extractor_input_name] = getattr(inputs, feature_extractor_input_name)[0] + # take length of raw audio waveform + batch["input_length"] = len(sample["array"].squeeze()) # encode targets additional_kwargs = {} @@ -736,7 +749,9 @@ def compute_metrics(pred): processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir) # Instantiate custom data collator - data_collator = DataCollatorCTCWithPadding(processor=processor) + data_collator = DataCollatorCTCWithPadding( + processor=processor, feature_extractor_input_name=feature_extractor_input_name + ) # Initialize Trainer trainer = Trainer(