From e44b878c0252ac1c841afcd68dd873c7fe307289 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Fri, 28 Jun 2024 01:07:33 +0800 Subject: [PATCH] Fix float out of range in owlvit and owlv2 when using FP16 or lower precision (#31657) --- src/transformers/models/owlv2/modeling_owlv2.py | 2 +- src/transformers/models/owlvit/modeling_owlvit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 05c5cd4595b5..638a9d966e0c 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -1276,7 +1276,7 @@ def forward( if query_mask.ndim > 1: query_mask = torch.unsqueeze(query_mask, dim=-2) - pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) + pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits) pred_logits = pred_logits.to(torch.float32) return (pred_logits, image_class_embeds) diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index ee6d8aa423d1..32e2012b2146 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -1257,7 +1257,7 @@ def forward( if query_mask.ndim > 1: query_mask = torch.unsqueeze(query_mask, dim=-2) - pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) + pred_logits = torch.where(query_mask == 0, torch.finfo(pred_logits.dtype).min, pred_logits) pred_logits = pred_logits.to(torch.float32) return (pred_logits, image_class_embeds)