Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OmDet-Turbo #31843

Merged
merged 62 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
51e7764
Add template with add-new-model-like
Jul 8, 2024
5c922a6
Add rough OmDetTurboEncoder and OmDetTurboDecoder
yonigozlan Jul 9, 2024
bf6775a
Add working OmDetTurbo convert to hf
yonigozlan Jul 10, 2024
e5f4fb4
Change OmDetTurbo encoder to RT-DETR encoder
yonigozlan Jul 12, 2024
9e5532c
Add swin timm backbone as default, add always partition fix for swin …
yonigozlan Jul 13, 2024
3bbdb5e
Add labels and tasks caching
yonigozlan Jul 15, 2024
61e3eda
Fix make fix-copies
yonigozlan Jul 15, 2024
85a7424
Format omdet_turbo
yonigozlan Jul 17, 2024
b778378
fix Tokenizer tests
yonigozlan Jul 18, 2024
f74f4c0
Fix style and quality
yonigozlan Jul 18, 2024
71b3499
Reformat omdet_turbo
yonigozlan Jul 19, 2024
45c9bfa
Fix quality, style, copies
yonigozlan Jul 19, 2024
605cd31
Standardize processor kwargs
yonigozlan Jul 20, 2024
225e210
Fix style
yonigozlan Jul 20, 2024
87be14c
Add output_hidden_states and ouput_attentions
yonigozlan Jul 21, 2024
d8595ab
Add personalize multi-head attention, improve docstrings
yonigozlan Jul 22, 2024
d0a99fb
Add integrated test and fix copy, style, quality
yonigozlan Jul 22, 2024
96487a0
Fix unprotected import
yonigozlan Jul 22, 2024
bc72859
Cleanup comments and fix unprotected imports
yonigozlan Jul 22, 2024
ad82fd6
Add fix different prompts in batch (key_padding_mask)
yonigozlan Jul 23, 2024
f104cfd
Add key_padding_mask to custom multi-head attention module
yonigozlan Jul 23, 2024
09c15cc
Replace attention_mask by key_padding_mask
yonigozlan Jul 23, 2024
5711351
Remove OmDetTurboModel and refactor
yonigozlan Jul 31, 2024
e23bc1f
Refactor processing of classes and abstract use of timm backbone
yonigozlan Aug 1, 2024
89ae1e5
Add testing, fix output attentions and hidden states, add cache for a…
yonigozlan Aug 4, 2024
32ec423
Fix copies, style, quality
yonigozlan Aug 4, 2024
54fff5f
Add documentation, conver key_padding_mask to attention_mask
yonigozlan Aug 5, 2024
46f9793
revert changes to backbone_utils
yonigozlan Aug 5, 2024
b45afb3
Fic docstrings rst
yonigozlan Aug 5, 2024
de18095
Fix unused argument in config
yonigozlan Aug 5, 2024
c11564e
Fix image link documentation
yonigozlan Aug 5, 2024
6635c35
Reorder config and cleanup
yonigozlan Aug 5, 2024
91d0295
Add tokenizer_init_kwargs in merge_kwargs of the processor
yonigozlan Aug 6, 2024
8647324
Change AutoTokenizer to CLIPTokenizer in convert
yonigozlan Aug 6, 2024
fa9cfb7
Fix init_weights
yonigozlan Aug 6, 2024
fd6d7b3
Add ProcessorMixin tests, Fix convert while waiting on uniform kwargs
yonigozlan Aug 8, 2024
b2b8716
change processor kwargs and make task input optional
yonigozlan Aug 9, 2024
a1aeca0
Fix omdet docs
yonigozlan Aug 12, 2024
87df421
Remove unnecessary tests for processor kwargs
yonigozlan Aug 13, 2024
e9210a4
Replace nested BatchEncoding output of the processor by a flattened B…
yonigozlan Aug 13, 2024
618053e
Make modifications from Pavel review
yonigozlan Aug 15, 2024
99c2d3b
Add changes Amy review
yonigozlan Aug 21, 2024
62949f1
Remove unused param
yonigozlan Aug 22, 2024
9447650
Remove normalize_before param, Modify processor call docstring
yonigozlan Aug 22, 2024
df903ba
Remove redundant decoder class, add gradient checkpointing for decoder
yonigozlan Aug 22, 2024
e89a1ff
Remove commented out code
yonigozlan Aug 27, 2024
e1d5126
Fix inference in fp16 and add fp16 integrated test
yonigozlan Sep 7, 2024
497e3f9
update omdet md doc
yonigozlan Sep 7, 2024
8a769e9
Add OmdetTurboModel
yonigozlan Sep 7, 2024
8f31d65
fix caching and nit
yonigozlan Sep 7, 2024
dbd908c
add OmDetTurboModel to tests
yonigozlan Sep 7, 2024
1c2d3aa
nit change repeated key test
yonigozlan Sep 9, 2024
5af6160
Improve inference speed in eager mode
yonigozlan Sep 11, 2024
b651def
fix copies
yonigozlan Sep 11, 2024
66ef0b9
Fix nit
yonigozlan Sep 18, 2024
0980201
remove OmdetTurboModel
yonigozlan Sep 20, 2024
79766ec
[run-slow] omdet_turbo
yonigozlan Sep 20, 2024
2267ceb
[run-slow] omdet_turbo
yonigozlan Sep 20, 2024
971974f
skip dataparallel test
yonigozlan Sep 23, 2024
d039fc0
[run-slow] omdet_turbo
yonigozlan Sep 23, 2024
9ff88d1
update weights to new path
yonigozlan Sep 24, 2024
94682a1
remove unnecessary config in class
yonigozlan Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,8 @@
title: MGP-STR
- local: model_doc/nougat
title: Nougat
- local: model_doc/omdet-turbo
title: OmDet-Turbo
- local: model_doc/oneformer
title: OneFormer
- local: model_doc/owlvit
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Nyströmformer](model_doc/nystromformer) | ✅ | ❌ | ❌ |
| [OLMo](model_doc/olmo) | ✅ | ❌ | ❌ |
| [OLMoE](model_doc/olmoe) | ✅ | ❌ | ❌ |
| [OmDet-Turbo](model_doc/omdet-turbo) | ✅ | ❌ | ❌ |
| [OneFormer](model_doc/oneformer) | ✅ | ❌ | ❌ |
| [OpenAI GPT](model_doc/openai-gpt) | ✅ | ✅ | ❌ |
| [OpenAI GPT-2](model_doc/gpt2) | ✅ | ✅ | ✅ |
Expand Down
164 changes: 164 additions & 0 deletions docs/source/en/model_doc/omdet-turbo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# OmDet-Turbo

## Overview

The OmDet-Turbo model was proposed in [Real-time Transformer-based Open-Vocabulary Detection with Efficient Fusion Head](https://arxiv.org/abs/2403.06892) by Tiancheng Zhao, Peng Liu, Xuan He, Lu Zhang, Kyusong Lee. OmDet-Turbo incorporates components from RT-DETR and introduces a swift multimodal fusion module to achieve real-time open-vocabulary object detection capabilities while maintaining high accuracy. The base model achieves performance of up to 100.2 FPS and 53.4 AP on COCO zero-shot.

The abstract from the paper is the following:

*End-to-end transformer-based detectors (DETRs) have shown exceptional performance in both closed-set and open-vocabulary object detection (OVD) tasks through the integration of language modalities. However, their demanding computational requirements have hindered their practical application in real-time object detection (OD) scenarios. In this paper, we scrutinize the limitations of two leading models in the OVDEval benchmark, OmDet and Grounding-DINO, and introduce OmDet-Turbo. This novel transformer-based real-time OVD model features an innovative Efficient Fusion Head (EFH) module designed to alleviate the bottlenecks observed in OmDet and Grounding-DINO. Notably, OmDet-Turbo-Base achieves a 100.2 frames per second (FPS) with TensorRT and language cache techniques applied. Notably, in zero-shot scenarios on COCO and LVIS datasets, OmDet-Turbo achieves performance levels nearly on par with current state-of-the-art supervised models. Furthermore, it establishes new state-of-the-art benchmarks on ODinW and OVDEval, boasting an AP of 30.1 and an NMS-AP of 26.86, respectively. The practicality of OmDet-Turbo in industrial applications is underscored by its exceptional performance on benchmark datasets and superior inference speed, positioning it as a compelling choice for real-time object detection tasks.*

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/omdet_turbo_architecture.jpeg" alt="drawing" width="600"/>

<small> OmDet-Turbo architecture overview. Taken from the <a href="https://arxiv.org/abs/2403.06892">original paper</a>. </small>

This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan).
The original code can be found [here](https://github.com/om-ai-lab/OmDet).

## Usage tips

One unique property of OmDet-Turbo compared to other zero-shot object detection models, such as [Grounding DINO](grounding-dino), is the decoupled classes and prompt embedding structure that allows caching of text embeddings. This means that the model needs both classes and task as inputs, where classes is a list of objects we want to detect and task is the grounded text used to guide open-vocabulary detection. This approach limits the scope of the open-vocabulary detection and makes the decoding process faster.

[`OmDetTurboProcessor`] is used to prepare the classes, task and image triplet. The task input is optional, and when not provided, it will default to `"Detect [class1], [class2], [class3], ..."`. To process the results from the model, one can use `post_process_grounded_object_detection` from [`OmDetTurboProcessor`]. Notably, this function takes in the input classes, as unlike other zero-shot object detection models, the decoupling of classes and task embeddings means that no decoding of the predicted class embeddings is needed in the post-processing step, and the predicted classes can be matched to the inputted ones directly.

## Usage example

### Single image inference

Here's how to load the model and prepare the inputs to perform zero-shot object detection on a single image:

```python
import requests
from PIL import Image

from transformers import AutoProcessor, OmDetTurboForObjectDetection

processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-tiny")
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-tiny")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
classes = ["cat", "remote"]
inputs = processor(image, text=classes, return_tensors="pt")

outputs = model(**inputs)

# convert outputs (bounding boxes and class logits)
results = processor.post_process_grounded_object_detection(
outputs,
classes=classes,
target_sizes=[image.size[::-1]],
score_threshold=0.3,
nms_threshold=0.3,
)[0]
for score, class_name, box in zip(
results["scores"], results["classes"], results["boxes"]
):
box = [round(i, 1) for i in box.tolist()]
print(
f"Detected {class_name} with confidence "
f"{round(score.item(), 2)} at location {box}"
)
```

### Multi image inference

OmDet-Turbo can perform batched multi-image inference, with support for different text prompts and classes in the same batch:

```python
>>> import torch
>>> import requests
>>> from io import BytesIO
>>> from PIL import Image
>>> from transformers import AutoProcessor, OmDetTurboForObjectDetection

>>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
>>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")

>>> url1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image1 = Image.open(BytesIO(requests.get(url1).content)).convert("RGB")
>>> classes1 = ["cat", "remote"]
>>> task1 = "Detect {}.".format(", ".join(classes1))

>>> url2 = "http://images.cocodataset.org/train2017/000000257813.jpg"
>>> image2 = Image.open(BytesIO(requests.get(url2).content)).convert("RGB")
>>> classes2 = ["boat"]
>>> task2 = "Detect everything that looks like a boat."

>>> url3 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
>>> image3 = Image.open(BytesIO(requests.get(url3).content)).convert("RGB")
>>> classes3 = ["statue", "trees"]
>>> task3 = "Focus on the foreground, detect statue and trees."

>>> inputs = processor(
... images=[image1, image2, image3],
... text=[classes1, classes2, classes3],
... task=[task1, task2, task3],
... return_tensors="pt",
... )

>>> with torch.no_grad():
... outputs = model(**inputs)

>>> # convert outputs (bounding boxes and class logits)
>>> results = processor.post_process_grounded_object_detection(
... outputs,
... classes=[classes1, classes2, classes3],
... target_sizes=[image1.size[::-1], image2.size[::-1], image3.size[::-1]],
... score_threshold=0.2,
... nms_threshold=0.3,
... )

>>> for i, result in enumerate(results):
... for score, class_name, box in zip(
... result["scores"], result["classes"], result["boxes"]
... ):
... box = [round(i, 1) for i in box.tolist()]
... print(
... f"Detected {class_name} with confidence "
... f"{round(score.item(), 2)} at location {box} in image {i}"
... )
Detected remote with confidence 0.77 at location [39.9, 70.4, 176.7, 118.0] in image 0
Detected cat with confidence 0.72 at location [11.6, 54.2, 314.8, 474.0] in image 0
Detected remote with confidence 0.56 at location [333.4, 75.8, 370.7, 187.0] in image 0
Detected cat with confidence 0.55 at location [345.2, 24.0, 639.8, 371.7] in image 0
Detected boat with confidence 0.32 at location [146.9, 219.8, 209.6, 250.7] in image 1
Detected boat with confidence 0.3 at location [319.1, 223.2, 403.2, 238.4] in image 1
Detected boat with confidence 0.27 at location [37.7, 220.3, 84.0, 235.9] in image 1
Detected boat with confidence 0.22 at location [407.9, 207.0, 441.7, 220.2] in image 1
Detected statue with confidence 0.73 at location [544.7, 210.2, 651.9, 502.8] in image 2
Detected trees with confidence 0.25 at location [3.9, 584.3, 391.4, 785.6] in image 2
Detected trees with confidence 0.25 at location [1.4, 621.2, 118.2, 787.8] in image 2
Detected statue with confidence 0.2 at location [428.1, 205.5, 767.3, 759.5] in image 2

```

## OmDetTurboConfig

[[autodoc]] OmDetTurboConfig

## OmDetTurboProcessor

[[autodoc]] OmDetTurboProcessor
- post_process_grounded_object_detection

## OmDetTurboForObjectDetection

[[autodoc]] OmDetTurboForObjectDetection
- forward
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,10 @@
"models.nystromformer": ["NystromformerConfig"],
"models.olmo": ["OlmoConfig"],
"models.olmoe": ["OlmoeConfig"],
"models.omdet_turbo": [
"OmDetTurboConfig",
"OmDetTurboProcessor",
],
"models.oneformer": [
"OneFormerConfig",
"OneFormerProcessor",
Expand Down Expand Up @@ -2844,6 +2848,12 @@
"OlmoePreTrainedModel",
]
)
_import_structure["models.omdet_turbo"].extend(
[
"OmDetTurboForObjectDetection",
"OmDetTurboPreTrainedModel",
]
)
_import_structure["models.oneformer"].extend(
[
"OneFormerForUniversalSegmentation",
Expand Down Expand Up @@ -5385,6 +5395,10 @@
)
from .models.olmo import OlmoConfig
from .models.olmoe import OlmoeConfig
from .models.omdet_turbo import (
OmDetTurboConfig,
OmDetTurboProcessor,
)
from .models.oneformer import (
OneFormerConfig,
OneFormerProcessor,
Expand Down Expand Up @@ -7351,6 +7365,10 @@
OlmoeModel,
OlmoePreTrainedModel,
)
from .models.omdet_turbo import (
OmDetTurboForObjectDetection,
OmDetTurboPreTrainedModel,
)
from .models.oneformer import (
OneFormerForUniversalSegmentation,
OneFormerModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
nystromformer,
olmo,
olmoe,
omdet_turbo,
oneformer,
openai,
opt,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
("clap", "ClapConfig"),
("clip", "CLIPConfig"),
("clip_text_model", "CLIPTextConfig"),
("clip_vision_model", "CLIPVisionConfig"),
("clipseg", "CLIPSegConfig"),
("clvp", "ClvpConfig"),
Expand Down Expand Up @@ -189,6 +190,7 @@
("nystromformer", "NystromformerConfig"),
("olmo", "OlmoConfig"),
("olmoe", "OlmoeConfig"),
("omdet-turbo", "OmDetTurboConfig"),
("oneformer", "OneFormerConfig"),
("open-llama", "OpenLlamaConfig"),
("openai-gpt", "OpenAIGPTConfig"),
Expand Down Expand Up @@ -346,6 +348,7 @@
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
("clap", "CLAP"),
("clip", "CLIP"),
("clip_text_model", "CLIPTextModel"),
("clip_vision_model", "CLIPVisionModel"),
("clipseg", "CLIPSeg"),
("clvp", "CLVP"),
Expand Down Expand Up @@ -493,6 +496,7 @@
("nystromformer", "Nyströmformer"),
("olmo", "OLMo"),
("olmoe", "OLMoE"),
("omdet-turbo", "OmDet-Turbo"),
("oneformer", "OneFormer"),
("open-llama", "OpenLlama"),
("openai-gpt", "OpenAI GPT"),
Expand Down Expand Up @@ -661,6 +665,7 @@
("xclip", "x_clip"),
("clip_vision_model", "clip"),
("qwen2_audio_encoder", "qwen2_audio"),
("clip_text_model", "clip"),
("siglip_vision_model", "siglip"),
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
("clap", "ClapModel"),
("clip", "CLIPModel"),
("clip_text_model", "CLIPTextModel"),
("clip_vision_model", "CLIPVisionModel"),
("clipseg", "CLIPSegModel"),
("clvp", "ClvpModelForConditionalGeneration"),
Expand Down Expand Up @@ -179,6 +180,7 @@
("nystromformer", "NystromformerModel"),
("olmo", "OlmoModel"),
("olmoe", "OlmoeModel"),
("omdet-turbo", "OmDetTurboForObjectDetection"),
("oneformer", "OneFormerModel"),
("open-llama", "OpenLlamaModel"),
("openai-gpt", "OpenAIGPTModel"),
Expand Down Expand Up @@ -809,6 +811,7 @@
[
# Model for Zero Shot Object Detection mapping
("grounding-dino", "GroundingDinoForObjectDetection"),
("omdet-turbo", "OmDetTurboForObjectDetection"),
("owlv2", "Owlv2ForObjectDetection"),
("owlvit", "OwlViTForObjectDetection"),
]
Expand Down Expand Up @@ -1323,6 +1326,7 @@
("albert", "AlbertModel"),
("bert", "BertModel"),
("big_bird", "BigBirdModel"),
("clip_text_model", "CLIPTextModel"),
("data2vec-text", "Data2VecTextModel"),
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@
),
("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
(
"omdet-turbo",
("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None),
),
("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
(
"openai-gpt",
Expand Down
56 changes: 56 additions & 0 deletions src/transformers/models/omdet_turbo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available


_import_structure = {
"configuration_omdet_turbo": ["OmDetTurboConfig"],
"processing_omdet_turbo": ["OmDetTurboProcessor"],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_omdet_turbo"] = [
"OmDetTurboForObjectDetection",
"OmDetTurboPreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_omdet_turbo import (
OmDetTurboConfig,
)
from .processing_omdet_turbo import OmDetTurboProcessor

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_omdet_turbo import (
OmDetTurboForObjectDetection,
OmDetTurboPreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Loading
Loading