diff --git a/docs/source/en/main_classes/trainer.mdx b/docs/source/en/main_classes/trainer.mdx index 67ab6aba42ef..409a6c6d33af 100644 --- a/docs/source/en/main_classes/trainer.mdx +++ b/docs/source/en/main_classes/trainer.mdx @@ -61,7 +61,7 @@ class CustomTrainer(Trainer): outputs = model(**inputs) logits = outputs.get("logits") # compute custom loss (suppose one has 3 labels with different weights) - loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0])) + loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device)) loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) return (loss, outputs) if return_outputs else loss ```