Skip to content

Commit

Permalink
Musicgen special tokens in tensors (#31420)
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
zucchini-nlp authored Jun 17, 2024
1 parent eed9ed6 commit 9af1b6a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,6 +1666,8 @@ def generate(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = input_ids.shape[0] // self.num_codebooks
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)

# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
Expand Down Expand Up @@ -2738,6 +2740,8 @@ def generate(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)

# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,8 @@ def generate(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = input_ids.shape[0] // self.num_codebooks
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)

# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
Expand Down Expand Up @@ -2588,6 +2590,8 @@ def generate(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)

# 4. Define other model kwargs
model_kwargs["use_cache"] = generation_config.use_cache
Expand Down

0 comments on commit 9af1b6a

Please sign in to comment.