Skip to content

Commit

Permalink
Handle num_items_in_batch in Mistral's forward
Browse files Browse the repository at this point in the history
This PR enables handling loss keyword arguments in the Mistral
forward() method. Specifically, if `num_items_in_batch` is passed,
the value is used to properly normalize the loss value.

This relates to the Gradient Accumulation fix (#34191)

Fixes #34575
  • Loading branch information
gheinrich committed Nov 2, 2024
1 parent 33868a0 commit a4faa09
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1095,8 +1096,12 @@ def forward(
shift_labels = shift_labels.view(-1)
# Ensure tensors are on the same device
shift_labels = shift_labels.to(shift_logits.device)
loss_fct = CrossEntropyLoss()
num_items_in_batch = loss_kwargs.pop("num_items_in_batch", None)
reduction = "sum" if num_items_in_batch is not None else "mean"
loss_fct = CrossEntropyLoss(reduction=reduction)
loss = loss_fct(shift_logits, shift_labels)
if reduction == "sum":
loss = loss / num_items_in_batch

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down

0 comments on commit a4faa09

Please sign in to comment.