diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index d0f4d03139..b0c2da15bf 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -83,6 +83,9 @@ def from_pretrained( model_path, arg_overrides=kwargs, ) + if "generation_args" in kwargs and kwargs["generation_args"]: + for key in kwargs["generation_args"]: + setattr(args["generation"], key, kwargs["generation_args"][key]) return { "args": args, diff --git a/fairseq/models/speech_to_text/hub_interface.py b/fairseq/models/speech_to_text/hub_interface.py index 173bb73064..c7e1f5ee06 100644 --- a/fairseq/models/speech_to_text/hub_interface.py +++ b/fairseq/models/speech_to_text/hub_interface.py @@ -30,7 +30,7 @@ def __init__(self, cfg, task, model): self.task = task self.model = model self.model.eval() - self.generator = self.task.build_generator([self.model], self.cfg) + self.generator = self.task.build_generator([self.model], self.cfg.generation) @classmethod def get_model_input(cls, task, audio: Union[str, torch.Tensor]): diff --git a/fairseq/models/speech_to_text/xm_transformer.py b/fairseq/models/speech_to_text/xm_transformer.py index e67b7c6715..c82dea9ba4 100644 --- a/fairseq/models/speech_to_text/xm_transformer.py +++ b/fairseq/models/speech_to_text/xm_transformer.py @@ -510,10 +510,10 @@ def from_pretrained( data_name_or_path=".", config_yaml="config.yaml", task="speech_to_text", + generation_args=None, **kwargs, ): from fairseq import hub_utils - x = hub_utils.from_pretrained( model_name_or_path, checkpoint_file, @@ -521,6 +521,7 @@ def from_pretrained( archive_map=cls.hub_models(), config_yaml=config_yaml, task=task, + generation_args=generation_args, **kwargs, ) return S2THubInterface(x["args"], x["task"], x["models"][0]) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 13f99078c7..5176f5d267 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -374,7 +374,7 @@ def _generate( # handle max length constraint if step >= max_len: lprobs[:, : self.eos] = -math.inf - lprobs[:, self.eos + 1 :] = -math.inf + lprobs[:, self.eos + 1:] = -math.inf # handle prefix tokens (possibly with different lengths) if ( @@ -604,7 +604,7 @@ def _prefix_tokens( if eos_mask.any(): # validate that the first beam matches the prefix first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ - :, 0, 1 : step + 1 + :, 0, 1: step + 1 ] eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] @@ -649,12 +649,12 @@ def finalize_hypos( # tokens is (batch * beam, max_len). So the index_select # gets the newly EOS rows, then selects cols 1..{step + 2} tokens_clone = tokens.index_select(0, bbsz_idx)[ - :, 1 : step + 2 + :, 1: step + 2 ] # skip the first index, which is EOS tokens_clone[:, step] = self.eos attn_clone = ( - attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] + attn.index_select(0, bbsz_idx)[:, :, 1: step + 2] if attn is not None else None )