Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OmDet Turbo processor standardization #34937

Merged
merged 13 commits into from
Jan 17, 2025
87 changes: 45 additions & 42 deletions docs/source/en/model_doc/omdet-turbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,37 +44,40 @@ One unique property of OmDet-Turbo compared to other zero-shot object detection
Here's how to load the model and prepare the inputs to perform zero-shot object detection on a single image:

```python
import requests
from PIL import Image

from transformers import AutoProcessor, OmDetTurboForObjectDetection

processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
classes = ["cat", "remote"]
inputs = processor(image, text=classes, return_tensors="pt")

outputs = model(**inputs)

# convert outputs (bounding boxes and class logits)
results = processor.post_process_grounded_object_detection(
outputs,
classes=classes,
target_sizes=[image.size[::-1]],
score_threshold=0.3,
nms_threshold=0.3,
)[0]
for score, class_name, box in zip(
results["scores"], results["classes"], results["boxes"]
):
box = [round(i, 1) for i in box.tolist()]
print(
f"Detected {class_name} with confidence "
f"{round(score.item(), 2)} at location {box}"
)
>>> import torch
>>> import requests
>>> from PIL import Image

>>> from transformers import AutoProcessor, OmDetTurboForObjectDetection

>>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
>>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text_labels = ["cat", "remote"]
>>> inputs = processor(image, text=text_labels, return_tensors="pt")

>>> with torch.no_grad():
... outputs = model(**inputs)

>>> # convert outputs (bounding boxes and class logits)
>>> results = processor.post_process_grounded_object_detection(
... outputs,
... target_sizes=[(image.height, image.width)],
... text_labels=text_labels,
... threshold=0.3,
... nms_threshold=0.3,
... )
>>> result = results[0]
>>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"]
>>> for box, score, text_label in zip(boxes, scores, text_labels):
... box = [round(i, 2) for i in box.tolist()]
... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
Detected remote with confidence 0.768 at location [39.89, 70.35, 176.74, 118.04]
Detected cat with confidence 0.72 at location [11.6, 54.19, 314.8, 473.95]
Detected remote with confidence 0.563 at location [333.38, 75.77, 370.7, 187.03]
Detected cat with confidence 0.552 at location [345.15, 23.95, 639.75, 371.67]
```

### Multi image inference
Expand All @@ -93,22 +96,22 @@ OmDet-Turbo can perform batched multi-image inference, with support for differen

>>> url1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image1 = Image.open(BytesIO(requests.get(url1).content)).convert("RGB")
>>> classes1 = ["cat", "remote"]
>>> task1 = "Detect {}.".format(", ".join(classes1))
>>> text_labels1 = ["cat", "remote"]
>>> task1 = "Detect {}.".format(", ".join(text_labels1))

>>> url2 = "http://images.cocodataset.org/train2017/000000257813.jpg"
>>> image2 = Image.open(BytesIO(requests.get(url2).content)).convert("RGB")
>>> classes2 = ["boat"]
>>> text_labels2 = ["boat"]
>>> task2 = "Detect everything that looks like a boat."

>>> url3 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
>>> image3 = Image.open(BytesIO(requests.get(url3).content)).convert("RGB")
>>> classes3 = ["statue", "trees"]
>>> text_labels3 = ["statue", "trees"]
>>> task3 = "Focus on the foreground, detect statue and trees."

>>> inputs = processor(
... images=[image1, image2, image3],
... text=[classes1, classes2, classes3],
... text=[text_labels1, text_labels2, text_labels3],
... task=[task1, task2, task3],
... return_tensors="pt",
... )
Expand All @@ -119,19 +122,19 @@ OmDet-Turbo can perform batched multi-image inference, with support for differen
>>> # convert outputs (bounding boxes and class logits)
>>> results = processor.post_process_grounded_object_detection(
... outputs,
... classes=[classes1, classes2, classes3],
... target_sizes=[image1.size[::-1], image2.size[::-1], image3.size[::-1]],
... score_threshold=0.2,
... text_labels=[text_labels1, text_labels2, text_labels3],
... target_sizes=[(image.height, image.width) for image in [image1, image2, image3]],
... threshold=0.2,
... nms_threshold=0.3,
... )

>>> for i, result in enumerate(results):
... for score, class_name, box in zip(
... result["scores"], result["classes"], result["boxes"]
... for score, text_label, box in zip(
... result["scores"], result["text_labels"], result["boxes"]
... ):
... box = [round(i, 1) for i in box.tolist()]
... print(
... f"Detected {class_name} with confidence "
... f"Detected {text_label} with confidence "
... f"{round(score.item(), 2)} at location {box} in image {i}"
... )
Detected remote with confidence 0.77 at location [39.9, 70.4, 176.7, 118.0] in image 0
Expand Down
33 changes: 19 additions & 14 deletions src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,24 @@ class OmDetTurboObjectDetectionOutput(ModelOutput):
The predicted class of the objects from the encoder.
encoder_extracted_states (`torch.FloatTensor`):
The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder.
decoder_hidden_states (`Optional[Tuple[torch.FloatTensor]]`):
decoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs.
decoder_attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`):
decoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*):
Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
encoder_hidden_states (`Optional[Tuple[torch.FloatTensor]]`):
encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs.
encoder_attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`):
encoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*):
Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
classes_structure (`torch.LongTensor`, *optional*):
The number of queried classes for each image.
"""

loss: torch.FloatTensor = None
Expand All @@ -173,6 +175,7 @@ class OmDetTurboObjectDetectionOutput(ModelOutput):
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
classes_structure: Optional[torch.LongTensor] = None


# Copied from models.deformable_detr.load_cuda_kernels
Expand Down Expand Up @@ -1667,16 +1670,16 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
@replace_return_docstrings(output_type=OmDetTurboObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
classes_input_ids: Tensor,
classes_attention_mask: Tensor,
tasks_input_ids: Tensor,
tasks_attention_mask: Tensor,
classes_structure: Tensor,
labels: Optional[Tensor] = None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
pixel_values: torch.FloatTensor,
classes_input_ids: torch.LongTensor,
classes_attention_mask: torch.LongTensor,
tasks_input_ids: torch.LongTensor,
tasks_attention_mask: torch.LongTensor,
classes_structure: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], OmDetTurboObjectDetectionOutput]:
r"""
Returns:
Expand Down Expand Up @@ -1770,6 +1773,7 @@ def forward(
decoder_outputs[2],
encoder_outputs[1],
encoder_outputs[2],
classes_structure,
]
if output is not None
)
Expand All @@ -1787,4 +1791,5 @@ def forward(
decoder_attentions=decoder_outputs.attentions,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
classes_structure=classes_structure,
)
Loading