diff --git a/examples/text-to-speech/README.md b/examples/text-to-speech/README.md index 5b98b30493..a1e089f55e 100644 --- a/examples/text-to-speech/README.md +++ b/examples/text-to-speech/README.md @@ -36,4 +36,5 @@ python3 run_pipeline.py \ ``` Models that have been validated: - [microsoft/speecht5_tts](https://huggingface.co/microsoft/speecht5_tts) + - [facebook/hf-seamless-m4t-medium](https://huggingface.co/facebook/hf-seamless-m4t-medium) - [facebook/mms-tts-eng](https://huggingface.co/facebook/mms-tts-eng) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d39933a903..b10eb11cf6 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -238,14 +238,20 @@ def _prepare_decoder_input_ids_for_generation( if token_idx is None: decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1) else: - max_length = max_new_tokens + 2 if max_new_tokens is not None else self.generation_config.max_length + decoder_input_ids_len = decoder_input_ids.shape[-1] + max_length = ( + max_new_tokens + decoder_input_ids_len + 1 + if max_new_tokens is not None + else self.generation_config.max_length + ) if max_length != decoder_start_token_id.shape[-1]: decoder_start_token_id = torch.nn.functional.pad( decoder_start_token_id, (0, max_length - decoder_start_token_id.shape[-1]), value=pad_token_id, ) - decoder_input_ids = decoder_start_token_id.index_copy(1, token_idx, decoder_input_ids) + decoder_start_token_id[:, 1 : 1 + decoder_input_ids_len, ...] = decoder_input_ids + decoder_input_ids = decoder_start_token_id token_idx.add_(1) if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] @@ -1109,11 +1115,14 @@ def generate( ) else: assert generation_config.bucket_size <= 0, "Untested path for bucket>0" - token_idx = 1 + if model_kwargs.get("decoder_input_ids", None) is None: + token_idx = 1 + else: + token_idx = model_kwargs["decoder_input_ids"].shape[-1] model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device) if model_kwargs.get("decoder_attention_mask", None) is None and generation_config.use_cache: max_length = ( - generation_config.max_new_tokens + 1 + generation_config.max_new_tokens + token_idx if generation_config.max_new_tokens is not None else generation_config.max_length )