Skip to content

Commit

Permalink
✌️ Remove double compute of sum in SFTTrainer (#3001)
Browse files Browse the repository at this point in the history
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
  • Loading branch information
lexasub and qgallouedec authored Mar 4, 2025
1 parent 402187b commit ea1d9be
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
total_tokens = self.accelerator.gather_for_metrics(total_tokens)

# Compute the mean token accuracy and log it
accuracy = (correct_tokens.sum() / total_tokens.sum()).item() if total_tokens.sum() > 0 else 0.0
total_sum = total_tokens.sum()
accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0
self._metrics[mode]["mean_token_accuracy"].append(accuracy)

return (loss, outputs) if return_outputs else loss
Expand Down

0 comments on commit ea1d9be

Please sign in to comment.