diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 05c5cd4595b5..638a9d966e0c 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -1276,7 +1276,7 @@ def forward( if query_mask.ndim > 1: query_mask = torch.unsqueeze(query_mask, dim=-2) - pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) + pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits) pred_logits = pred_logits.to(torch.float32) return (pred_logits, image_class_embeds) diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index ee6d8aa423d1..32e2012b2146 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -1257,7 +1257,7 @@ def forward( if query_mask.ndim > 1: query_mask = torch.unsqueeze(query_mask, dim=-2) - pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) + pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits) pred_logits = pred_logits.to(torch.float32) return (pred_logits, image_class_embeds)