From 5b8f201d9582ccd5179bec576ac6aefa5132d349 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Sun, 14 Jul 2024 17:34:34 +0530 Subject: [PATCH] Removed a wrong key-word argument in sigmoid_focal_loss() function call. --- src/transformers/models/rt_detr/modeling_rt_detr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 850b8dc2f627..e61521d88800 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -2163,7 +2163,7 @@ def loss_labels_focal(self, outputs, targets, indices, num_boxes, log=True): target_classes[idx] = target_classes_original target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] - loss = sigmoid_focal_loss(src_logits, target, self.alpha, self.gamma, reduction="none") + loss = sigmoid_focal_loss(src_logits, target, self.alpha, self.gamma) loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes return {"loss_focal": loss}