Skip to content

Commit

Permalink
*2 multiplier to huber loss cause of 1/2 a^2 conv.
Browse files Browse the repository at this point in the history
The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another
  • Loading branch information
kabachuha committed Apr 1, 2024
1 parent a58f290 commit c6495de
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4665,7 +4665,7 @@ def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str
if loss_type == 'l2':
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == 'huber' or loss_type == 'huber_scheduled':
loss = huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
Expand Down

0 comments on commit c6495de

Please sign in to comment.