Skip to content

Commit

Permalink
fix gemma-2-27b text generation pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
skaulintel committed Mar 6, 2025
1 parent cd7d2c7 commit 25d8cb7
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions optimum/habana/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
num_logits_to_keep: int = 0,
token_idx: Optional[torch.Tensor] = None,
trim_logits: Optional[bool] = False,
attn_softmax_bf16: Optional[bool] = False,
Expand Down Expand Up @@ -956,12 +956,7 @@ def forward(
else:
hidden_states = hidden_states[:, -1, :]

slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

loss = None
if labels is not None:
Expand Down

0 comments on commit 25d8cb7

Please sign in to comment.