From c6495def1fbbaf2a0233110d50f976ed61620e83 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Mon, 1 Apr 2024 21:30:51 +0300 Subject: [PATCH] *2 multiplier to huber loss cause of 1/2 a^2 conv. The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index a2d5da83e..c380e3311 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4665,7 +4665,7 @@ def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str if loss_type == 'l2': loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) elif loss_type == 'huber' or loss_type == 'huber_scheduled': - loss = huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) + loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum":