Skip to content

Commit

Permalink
Fix under generation issue for speech to speech translation models by…
Browse files Browse the repository at this point in the history
… adding optional generation args (#4662)

* OSS ckpts for Interspeech 2022 paper

* HF interface update

* local test

* local test

* revert local test

* address comments

* add Hk<>En models

* add Hk<>En models

* add Hk<>En models

* add hk->en

* add hk->en

* add hk->en

* add hk->en

* add hk->en

* debug

* debug

* debub

* fix undergeneration for S2UT

* fix typo

* fix typo

* fix bug
  • Loading branch information
sravyapopuri388 authored Aug 24, 2022
1 parent f826615 commit eda7037
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
3 changes: 3 additions & 0 deletions fairseq/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def from_pretrained(
model_path,
arg_overrides=kwargs,
)
if "generation_args" in kwargs and kwargs["generation_args"]:
for key in kwargs["generation_args"]:
setattr(args["generation"], key, kwargs["generation_args"][key])

return {
"args": args,
Expand Down
2 changes: 1 addition & 1 deletion fairseq/models/speech_to_text/hub_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, cfg, task, model):
self.task = task
self.model = model
self.model.eval()
self.generator = self.task.build_generator([self.model], self.cfg)
self.generator = self.task.build_generator([self.model], self.cfg.generation)

@classmethod
def get_model_input(cls, task, audio: Union[str, torch.Tensor]):
Expand Down
3 changes: 2 additions & 1 deletion fairseq/models/speech_to_text/xm_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,17 +510,18 @@ def from_pretrained(
data_name_or_path=".",
config_yaml="config.yaml",
task="speech_to_text",
generation_args=None,
**kwargs,
):
from fairseq import hub_utils

x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
config_yaml=config_yaml,
task=task,
generation_args=generation_args,
**kwargs,
)
return S2THubInterface(x["args"], x["task"], x["models"][0])
Expand Down
8 changes: 4 additions & 4 deletions fairseq/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def _generate(
# handle max length constraint
if step >= max_len:
lprobs[:, : self.eos] = -math.inf
lprobs[:, self.eos + 1 :] = -math.inf
lprobs[:, self.eos + 1:] = -math.inf

# handle prefix tokens (possibly with different lengths)
if (
Expand Down Expand Up @@ -604,7 +604,7 @@ def _prefix_tokens(
if eos_mask.any():
# validate that the first beam matches the prefix
first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
:, 0, 1 : step + 1
:, 0, 1: step + 1
]
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
Expand Down Expand Up @@ -649,12 +649,12 @@ def finalize_hypos(
# tokens is (batch * beam, max_len). So the index_select
# gets the newly EOS rows, then selects cols 1..{step + 2}
tokens_clone = tokens.index_select(0, bbsz_idx)[
:, 1 : step + 2
:, 1: step + 2
] # skip the first index, which is EOS

tokens_clone[:, step] = self.eos
attn_clone = (
attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
attn.index_select(0, bbsz_idx)[:, :, 1: step + 2]
if attn is not None
else None
)
Expand Down

0 comments on commit eda7037

Please sign in to comment.