Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making CTC training example more general #28582

Merged
merged 2 commits into from
Jan 18, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,17 @@ class ModelArguments:
ctc_loss_reduction: Optional[str] = field(
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
)
ctc_zero_infinity: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly"
"occur when the inputs are too short to be aligned to the targets."
ylacombe marked this conversation as resolved.
Show resolved Hide resolved
},
)
add_adapter: Optional[bool] = field(
default=False,
metadata={
"help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2BERT Encoder. Can be very"
"help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very"
"useful to downsample the output length."
},
)
Expand Down Expand Up @@ -316,11 +323,14 @@ class DataCollatorCTCWithPadding:
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None
feature_extractor_input_name: Optional[str] = "input_values"

def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features]
input_features = [
{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features
]
label_features = [{"input_ids": feature["labels"]} for feature in features]

batch = self.processor.pad(
Expand Down Expand Up @@ -606,6 +616,7 @@ def remove_special_characters(batch):
"gradient_checkpointing": training_args.gradient_checkpointing,
"layerdrop": model_args.layerdrop,
"ctc_loss_reduction": model_args.ctc_loss_reduction,
"ctc_zero_infinity": model_args.ctc_zero_infinity,
"pad_token_id": tokenizer.pad_token_id,
"vocab_size": len(tokenizer),
"activation_dropout": model_args.activation_dropout,
Expand Down Expand Up @@ -643,6 +654,7 @@ def remove_special_characters(batch):
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
audio_column_name = data_args.audio_column_name
num_workers = data_args.preprocessing_num_workers
feature_extractor_input_name = feature_extractor.model_input_names[0]

# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
phoneme_language = data_args.phoneme_language
Expand All @@ -654,8 +666,9 @@ def prepare_dataset(batch):
sample = batch[audio_column_name]

inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
batch["input_values"] = inputs.input_values[0]
batch["input_length"] = len(batch["input_values"])
batch[feature_extractor_input_name] = getattr(inputs, feature_extractor_input_name)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

# take length of raw audio waveform
batch["input_length"] = len(sample["array"].squeeze())

# encode targets
additional_kwargs = {}
Expand Down Expand Up @@ -736,7 +749,9 @@ def compute_metrics(pred):
processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)

# Instantiate custom data collator
data_collator = DataCollatorCTCWithPadding(processor=processor)
data_collator = DataCollatorCTCWithPadding(
processor=processor, feature_extractor_input_name=feature_extractor_input_name
)

# Initialize Trainer
trainer = Trainer(
Expand Down
Loading