diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7eff2a383026..8c48749bc0c9 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -852,6 +852,8 @@ title: MGP-STR - local: model_doc/nougat title: Nougat + - local: model_doc/omdet-turbo + title: OmDet-Turbo - local: model_doc/oneformer title: OneFormer - local: model_doc/owlvit diff --git a/docs/source/en/index.md b/docs/source/en/index.md index c18426de4c03..478184fdd344 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -235,6 +235,7 @@ Flax), PyTorch, and/or TensorFlow. | [Nyströmformer](model_doc/nystromformer) | ✅ | ❌ | ❌ | | [OLMo](model_doc/olmo) | ✅ | ❌ | ❌ | | [OLMoE](model_doc/olmoe) | ✅ | ❌ | ❌ | +| [OmDet-Turbo](model_doc/omdet-turbo) | ✅ | ❌ | ❌ | | [OneFormer](model_doc/oneformer) | ✅ | ❌ | ❌ | | [OpenAI GPT](model_doc/openai-gpt) | ✅ | ✅ | ❌ | | [OpenAI GPT-2](model_doc/gpt2) | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/omdet-turbo.md b/docs/source/en/model_doc/omdet-turbo.md new file mode 100644 index 000000000000..190ac3e31eea --- /dev/null +++ b/docs/source/en/model_doc/omdet-turbo.md @@ -0,0 +1,164 @@ + + +# OmDet-Turbo + +## Overview + +The OmDet-Turbo model was proposed in [Real-time Transformer-based Open-Vocabulary Detection with Efficient Fusion Head](https://arxiv.org/abs/2403.06892) by Tiancheng Zhao, Peng Liu, Xuan He, Lu Zhang, Kyusong Lee. OmDet-Turbo incorporates components from RT-DETR and introduces a swift multimodal fusion module to achieve real-time open-vocabulary object detection capabilities while maintaining high accuracy. The base model achieves performance of up to 100.2 FPS and 53.4 AP on COCO zero-shot. + +The abstract from the paper is the following: + +*End-to-end transformer-based detectors (DETRs) have shown exceptional performance in both closed-set and open-vocabulary object detection (OVD) tasks through the integration of language modalities. However, their demanding computational requirements have hindered their practical application in real-time object detection (OD) scenarios. In this paper, we scrutinize the limitations of two leading models in the OVDEval benchmark, OmDet and Grounding-DINO, and introduce OmDet-Turbo. This novel transformer-based real-time OVD model features an innovative Efficient Fusion Head (EFH) module designed to alleviate the bottlenecks observed in OmDet and Grounding-DINO. Notably, OmDet-Turbo-Base achieves a 100.2 frames per second (FPS) with TensorRT and language cache techniques applied. Notably, in zero-shot scenarios on COCO and LVIS datasets, OmDet-Turbo achieves performance levels nearly on par with current state-of-the-art supervised models. Furthermore, it establishes new state-of-the-art benchmarks on ODinW and OVDEval, boasting an AP of 30.1 and an NMS-AP of 26.86, respectively. The practicality of OmDet-Turbo in industrial applications is underscored by its exceptional performance on benchmark datasets and superior inference speed, positioning it as a compelling choice for real-time object detection tasks.* + +drawing + + OmDet-Turbo architecture overview. Taken from the original paper. + +This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan). +The original code can be found [here](https://github.com/om-ai-lab/OmDet). + +## Usage tips + +One unique property of OmDet-Turbo compared to other zero-shot object detection models, such as [Grounding DINO](grounding-dino), is the decoupled classes and prompt embedding structure that allows caching of text embeddings. This means that the model needs both classes and task as inputs, where classes is a list of objects we want to detect and task is the grounded text used to guide open-vocabulary detection. This approach limits the scope of the open-vocabulary detection and makes the decoding process faster. + +[`OmDetTurboProcessor`] is used to prepare the classes, task and image triplet. The task input is optional, and when not provided, it will default to `"Detect [class1], [class2], [class3], ..."`. To process the results from the model, one can use `post_process_grounded_object_detection` from [`OmDetTurboProcessor`]. Notably, this function takes in the input classes, as unlike other zero-shot object detection models, the decoupling of classes and task embeddings means that no decoding of the predicted class embeddings is needed in the post-processing step, and the predicted classes can be matched to the inputted ones directly. + +## Usage example + +### Single image inference + +Here's how to load the model and prepare the inputs to perform zero-shot object detection on a single image: + +```python +import requests +from PIL import Image + +from transformers import AutoProcessor, OmDetTurboForObjectDetection + +processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-tiny") +model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-tiny") + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image = Image.open(requests.get(url, stream=True).raw) +classes = ["cat", "remote"] +inputs = processor(image, text=classes, return_tensors="pt") + +outputs = model(**inputs) + +# convert outputs (bounding boxes and class logits) +results = processor.post_process_grounded_object_detection( + outputs, + classes=classes, + target_sizes=[image.size[::-1]], + score_threshold=0.3, + nms_threshold=0.3, +)[0] +for score, class_name, box in zip( + results["scores"], results["classes"], results["boxes"] +): + box = [round(i, 1) for i in box.tolist()] + print( + f"Detected {class_name} with confidence " + f"{round(score.item(), 2)} at location {box}" + ) +``` + +### Multi image inference + +OmDet-Turbo can perform batched multi-image inference, with support for different text prompts and classes in the same batch: + +```python +>>> import torch +>>> import requests +>>> from io import BytesIO +>>> from PIL import Image +>>> from transformers import AutoProcessor, OmDetTurboForObjectDetection + +>>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") +>>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") + +>>> url1 = "http://images.cocodataset.org/val2017/000000039769.jpg" +>>> image1 = Image.open(BytesIO(requests.get(url1).content)).convert("RGB") +>>> classes1 = ["cat", "remote"] +>>> task1 = "Detect {}.".format(", ".join(classes1)) + +>>> url2 = "http://images.cocodataset.org/train2017/000000257813.jpg" +>>> image2 = Image.open(BytesIO(requests.get(url2).content)).convert("RGB") +>>> classes2 = ["boat"] +>>> task2 = "Detect everything that looks like a boat." + +>>> url3 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" +>>> image3 = Image.open(BytesIO(requests.get(url3).content)).convert("RGB") +>>> classes3 = ["statue", "trees"] +>>> task3 = "Focus on the foreground, detect statue and trees." + +>>> inputs = processor( +... images=[image1, image2, image3], +... text=[classes1, classes2, classes3], +... task=[task1, task2, task3], +... return_tensors="pt", +... ) + +>>> with torch.no_grad(): +... outputs = model(**inputs) + +>>> # convert outputs (bounding boxes and class logits) +>>> results = processor.post_process_grounded_object_detection( +... outputs, +... classes=[classes1, classes2, classes3], +... target_sizes=[image1.size[::-1], image2.size[::-1], image3.size[::-1]], +... score_threshold=0.2, +... nms_threshold=0.3, +... ) + +>>> for i, result in enumerate(results): +... for score, class_name, box in zip( +... result["scores"], result["classes"], result["boxes"] +... ): +... box = [round(i, 1) for i in box.tolist()] +... print( +... f"Detected {class_name} with confidence " +... f"{round(score.item(), 2)} at location {box} in image {i}" +... ) +Detected remote with confidence 0.77 at location [39.9, 70.4, 176.7, 118.0] in image 0 +Detected cat with confidence 0.72 at location [11.6, 54.2, 314.8, 474.0] in image 0 +Detected remote with confidence 0.56 at location [333.4, 75.8, 370.7, 187.0] in image 0 +Detected cat with confidence 0.55 at location [345.2, 24.0, 639.8, 371.7] in image 0 +Detected boat with confidence 0.32 at location [146.9, 219.8, 209.6, 250.7] in image 1 +Detected boat with confidence 0.3 at location [319.1, 223.2, 403.2, 238.4] in image 1 +Detected boat with confidence 0.27 at location [37.7, 220.3, 84.0, 235.9] in image 1 +Detected boat with confidence 0.22 at location [407.9, 207.0, 441.7, 220.2] in image 1 +Detected statue with confidence 0.73 at location [544.7, 210.2, 651.9, 502.8] in image 2 +Detected trees with confidence 0.25 at location [3.9, 584.3, 391.4, 785.6] in image 2 +Detected trees with confidence 0.25 at location [1.4, 621.2, 118.2, 787.8] in image 2 +Detected statue with confidence 0.2 at location [428.1, 205.5, 767.3, 759.5] in image 2 + +``` + +## OmDetTurboConfig + +[[autodoc]] OmDetTurboConfig + +## OmDetTurboProcessor + +[[autodoc]] OmDetTurboProcessor + - post_process_grounded_object_detection + +## OmDetTurboForObjectDetection + +[[autodoc]] OmDetTurboForObjectDetection + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 36775d8454ab..c3260ad0ae60 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -606,6 +606,10 @@ "models.nystromformer": ["NystromformerConfig"], "models.olmo": ["OlmoConfig"], "models.olmoe": ["OlmoeConfig"], + "models.omdet_turbo": [ + "OmDetTurboConfig", + "OmDetTurboProcessor", + ], "models.oneformer": [ "OneFormerConfig", "OneFormerProcessor", @@ -2844,6 +2848,12 @@ "OlmoePreTrainedModel", ] ) + _import_structure["models.omdet_turbo"].extend( + [ + "OmDetTurboForObjectDetection", + "OmDetTurboPreTrainedModel", + ] + ) _import_structure["models.oneformer"].extend( [ "OneFormerForUniversalSegmentation", @@ -5385,6 +5395,10 @@ ) from .models.olmo import OlmoConfig from .models.olmoe import OlmoeConfig + from .models.omdet_turbo import ( + OmDetTurboConfig, + OmDetTurboProcessor, + ) from .models.oneformer import ( OneFormerConfig, OneFormerProcessor, @@ -7351,6 +7365,10 @@ OlmoeModel, OlmoePreTrainedModel, ) + from .models.omdet_turbo import ( + OmDetTurboForObjectDetection, + OmDetTurboPreTrainedModel, + ) from .models.oneformer import ( OneFormerForUniversalSegmentation, OneFormerModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 2022048cd455..0819277194b3 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -171,6 +171,7 @@ nystromformer, olmo, olmoe, + omdet_turbo, oneformer, openai, opt, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 2cd7d550d90b..54df20a07b1f 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -60,6 +60,7 @@ ("chinese_clip_vision_model", "ChineseCLIPVisionConfig"), ("clap", "ClapConfig"), ("clip", "CLIPConfig"), + ("clip_text_model", "CLIPTextConfig"), ("clip_vision_model", "CLIPVisionConfig"), ("clipseg", "CLIPSegConfig"), ("clvp", "ClvpConfig"), @@ -189,6 +190,7 @@ ("nystromformer", "NystromformerConfig"), ("olmo", "OlmoConfig"), ("olmoe", "OlmoeConfig"), + ("omdet-turbo", "OmDetTurboConfig"), ("oneformer", "OneFormerConfig"), ("open-llama", "OpenLlamaConfig"), ("openai-gpt", "OpenAIGPTConfig"), @@ -346,6 +348,7 @@ ("chinese_clip_vision_model", "ChineseCLIPVisionModel"), ("clap", "CLAP"), ("clip", "CLIP"), + ("clip_text_model", "CLIPTextModel"), ("clip_vision_model", "CLIPVisionModel"), ("clipseg", "CLIPSeg"), ("clvp", "CLVP"), @@ -493,6 +496,7 @@ ("nystromformer", "Nyströmformer"), ("olmo", "OLMo"), ("olmoe", "OLMoE"), + ("omdet-turbo", "OmDet-Turbo"), ("oneformer", "OneFormer"), ("open-llama", "OpenLlama"), ("openai-gpt", "OpenAI GPT"), @@ -661,6 +665,7 @@ ("xclip", "x_clip"), ("clip_vision_model", "clip"), ("qwen2_audio_encoder", "qwen2_audio"), + ("clip_text_model", "clip"), ("siglip_vision_model", "siglip"), ("chinese_clip_vision_model", "chinese_clip"), ("rt_detr_resnet", "rt_detr"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e0d15f1e2365..8b9e3fe5df95 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -60,6 +60,7 @@ ("chinese_clip_vision_model", "ChineseCLIPVisionModel"), ("clap", "ClapModel"), ("clip", "CLIPModel"), + ("clip_text_model", "CLIPTextModel"), ("clip_vision_model", "CLIPVisionModel"), ("clipseg", "CLIPSegModel"), ("clvp", "ClvpModelForConditionalGeneration"), @@ -179,6 +180,7 @@ ("nystromformer", "NystromformerModel"), ("olmo", "OlmoModel"), ("olmoe", "OlmoeModel"), + ("omdet-turbo", "OmDetTurboForObjectDetection"), ("oneformer", "OneFormerModel"), ("open-llama", "OpenLlamaModel"), ("openai-gpt", "OpenAIGPTModel"), @@ -809,6 +811,7 @@ [ # Model for Zero Shot Object Detection mapping ("grounding-dino", "GroundingDinoForObjectDetection"), + ("omdet-turbo", "OmDetTurboForObjectDetection"), ("owlv2", "Owlv2ForObjectDetection"), ("owlvit", "OwlViTForObjectDetection"), ] @@ -1323,6 +1326,7 @@ ("albert", "AlbertModel"), ("bert", "BertModel"), ("big_bird", "BigBirdModel"), + ("clip_text_model", "CLIPTextModel"), ("data2vec-text", "Data2VecTextModel"), ("deberta", "DebertaModel"), ("deberta-v2", "DebertaV2Model"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index e735579108d8..8a7b8c2330d3 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -344,6 +344,10 @@ ), ("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ( + "omdet-turbo", + ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None), + ), ("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ( "openai-gpt", diff --git a/src/transformers/models/omdet_turbo/__init__.py b/src/transformers/models/omdet_turbo/__init__.py new file mode 100644 index 000000000000..34eb6386298f --- /dev/null +++ b/src/transformers/models/omdet_turbo/__init__.py @@ -0,0 +1,56 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_omdet_turbo": ["OmDetTurboConfig"], + "processing_omdet_turbo": ["OmDetTurboProcessor"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_omdet_turbo"] = [ + "OmDetTurboForObjectDetection", + "OmDetTurboPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_omdet_turbo import ( + OmDetTurboConfig, + ) + from .processing_omdet_turbo import OmDetTurboProcessor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_omdet_turbo import ( + OmDetTurboForObjectDetection, + OmDetTurboPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py b/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py new file mode 100644 index 000000000000..cb5e69db5f90 --- /dev/null +++ b/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py @@ -0,0 +1,290 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OmDet-Turbo model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class OmDetTurboConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`OmDetTurboForObjectDetection`]. + It is used to instantiate a OmDet-Turbo model according to the specified arguments, defining the model architecture + Instantiating a configuration with the defaults will yield a similar configuration to that of the OmDet-Turbo + [omlab/omdet-turbo-swin-tiny-hf](https://huggingface.co/omlab/omdet-turbo-swin-tiny-hf) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`PretrainedConfig`, *optional*): + The configuration of the text backbone. + backbone_config (`PretrainedConfig`, *optional*): + The configuration of the vision backbone. + use_timm_backbone (`bool`, *optional*, defaults to `True`): + Whether to use the timm for the vision backbone. + backbone (`str`, *optional*, defaults to `"swin_tiny_patch4_window7_224"`): + The name of the pretrained vision backbone to use. If `use_pretrained_backbone=False` a randomly initialized + backbone with the same architecture `backbone` is used. + backbone_kwargs (`dict`, *optional*): + Additional kwargs for the vision backbone. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use a pretrained vision backbone. + apply_layernorm_after_vision_backbone (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization on the feature maps of the vision backbone output. + image_size (`int`, *optional*, defaults to 640): + The size (resolution) of each image. + disable_custom_kernels (`bool`, *optional*, defaults to `False`): + Whether to disable custom kernels. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value for layer normalization. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value for batch normalization. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + text_projection_in_dim (`int`, *optional*, defaults to 512): + The input dimension for the text projection. + text_projection_out_dim (`int`, *optional*, defaults to 512): + The output dimension for the text projection. + task_encoder_hidden_dim (`int`, *optional*, defaults to 1024): + The feedforward dimension for the task encoder. + class_embed_dim (`int`, *optional*, defaults to 512): + The dimension of the classes embeddings. + class_distance_type (`str`, *optional*, defaults to `"cosine"`): + The type of of distance to compare predicted classes to projected classes embeddings. + Can be `"cosine"` or `"dot"`. + num_queries (`int`, *optional*, defaults to 900): + The number of queries. + csp_activation (`str`, *optional*, defaults to `"silu"`): + The activation function of the Cross Stage Partial (CSP) networks of the encoder. + conv_norm_activation (`str`, *optional*, defaults to `"gelu"`): + The activation function of the ConvNormLayer layers of the encoder. + encoder_feedforward_activation (`str`, *optional*, defaults to `"relu"`): + The activation function for the feedforward network of the encoder. + encoder_feedforward_dropout (`float`, *optional*, defaults to 0.0): + The dropout rate following the activation of the encoder feedforward network. + encoder_dropout (`float`, *optional*, defaults to 0.0): + The dropout rate of the encoder multi-head attention module. + hidden_expansion (`int`, *optional*, defaults to 1): + The hidden expansion of the CSP networks in the encoder. + vision_features_channels (`tuple(int)`, *optional*, defaults to `[256, 256, 256]`): + The projected vision features channels used as inputs for the decoder. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + The hidden dimension of the encoder. + encoder_in_channels (`List(int)`, *optional*, defaults to `[192, 384, 768]`): + The input channels for the encoder. + encoder_projection_indices (`List(int)`, *optional*, defaults to `[2]`): + The indices of the input features projected by each layers. + encoder_attention_heads (`int`, *optional*, defaults to 8): + The number of attention heads for the encoder. + encoder_dim_feedforward (`int`, *optional*, defaults to 2048): + The feedforward dimension for the encoder. + encoder_layers (`int`, *optional*, defaults to 1): + The number of layers in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The positional encoding temperature in the encoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of feature levels for the multi-scale deformable attention module of the decoder. + decoder_hidden_dim (`int`, *optional*, defaults to 256): + The hidden dimension of the decoder. + decoder_num_heads (`int`, *optional*, defaults to 8): + The number of heads for the decoder. + decoder_num_layers (`int`, *optional*, defaults to 6): + The number of layers for the decoder. + decoder_activation (`str`, *optional*, defaults to `"relu"`): + The activation function for the decoder. + decoder_dim_feedforward (`int`, *optional*, defaults to 2048): + The feedforward dimension for the decoder. + decoder_num_points (`int`, *optional*, defaults to 4): + The number of points sampled in the decoder multi-scale deformable attention module. + decoder_dropout (`float`, *optional*, defaults to 0.0): + The dropout rate for the decoder. + eval_size (`Tuple[int, int]`, *optional*): + Height and width used to computes the effective height and width of the position embeddings after taking + into account the stride (see RTDetr). + learn_initial_query (`bool`, *optional*, defaults to `False`): + Whether to learn the initial query. + cache_size (`int`, *optional*, defaults to 100): + The cache size for the classes and prompts caches. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is used as an encoder-decoder model or not. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from the architecture. The values in kwargs will be saved as part of the configuration + and can be used to control the model outputs. + + Examples: + + ```python + >>> from transformers import OmDetTurboConfig, OmDetTurboForObjectDetection + + >>> # Initializing a OmDet-Turbo omlab/omdet-turbo-tiny style configuration + >>> configuration = OmDetTurboConfig() + + >>> # Initializing a model (with random weights) from the omlab/omdet-turbo-tiny style configuration + >>> model = OmDetTurboForObjectDetection(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "omdet-turbo" + attribute_map = { + "encoder_hidden_dim": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + text_config=None, + backbone_config=None, + use_timm_backbone=True, + backbone="swin_tiny_patch4_window7_224", + backbone_kwargs=None, + use_pretrained_backbone=False, + apply_layernorm_after_vision_backbone=True, + image_size=640, + disable_custom_kernels=False, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + init_std=0.02, + text_projection_in_dim=512, + text_projection_out_dim=512, + task_encoder_hidden_dim=1024, + class_embed_dim=512, + class_distance_type="cosine", + num_queries=900, + csp_activation="silu", + conv_norm_activation="gelu", + encoder_feedforward_activation="relu", + encoder_feedforward_dropout=0.0, + encoder_dropout=0.0, + hidden_expansion=1, + vision_features_channels=[256, 256, 256], + encoder_hidden_dim=256, + encoder_in_channels=[192, 384, 768], + encoder_projection_indices=[2], + encoder_attention_heads=8, + encoder_dim_feedforward=2048, + encoder_layers=1, + positional_encoding_temperature=10000, + num_feature_levels=3, + decoder_hidden_dim=256, + decoder_num_heads=8, + decoder_num_layers=6, + decoder_activation="relu", + decoder_dim_feedforward=2048, + decoder_num_points=4, + decoder_dropout=0.0, + eval_size=None, + learn_initial_query=False, + cache_size=100, + is_encoder_decoder=True, + **kwargs, + ): + if use_timm_backbone: + if backbone_config is None: + backbone_kwargs = { + "out_indices": [1, 2, 3], + "img_size": image_size, + "always_partition": True, + } + elif backbone_config is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `swin` vision config.") + backbone_config = CONFIG_MAPPING["swin"]( + window_size=7, + image_size=image_size, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + out_indices=[2, 3, 4], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + if text_config is None: + logger.info( + "`text_config` is `None`. Initializing the config with the default `clip_text_model` text config." + ) + text_config = CONFIG_MAPPING["clip_text_model"]() + elif isinstance(text_config, dict): + text_model_type = text_config.get("model_type") + text_config = CONFIG_MAPPING[text_model_type](**text_config) + + if class_distance_type not in ["cosine", "dot"]: + raise ValueError( + f"Invalid `class_distance_type`. It should be either `cosine` or `dot`, but got {class_distance_type}." + ) + + self.text_config = text_config + self.backbone_config = backbone_config + self.use_timm_backbone = use_timm_backbone + self.backbone = backbone + self.backbone_kwargs = backbone_kwargs + self.use_pretrained_backbone = use_pretrained_backbone + self.apply_layernorm_after_vision_backbone = apply_layernorm_after_vision_backbone + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.init_std = init_std + self.text_projection_in_dim = text_projection_in_dim + self.text_projection_out_dim = text_projection_out_dim + self.task_encoder_hidden_dim = task_encoder_hidden_dim + self.class_embed_dim = class_embed_dim + self.class_distance_type = class_distance_type + self.num_queries = num_queries + self.csp_activation = csp_activation + self.conv_norm_activation = conv_norm_activation + self.encoder_feedforward_activation = encoder_feedforward_activation + self.encoder_feedforward_dropout = encoder_feedforward_dropout + self.encoder_dropout = encoder_dropout + self.hidden_expansion = hidden_expansion + self.vision_features_channels = vision_features_channels + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.encoder_projection_indices = encoder_projection_indices + self.encoder_attention_heads = encoder_attention_heads + self.encoder_dim_feedforward = encoder_dim_feedforward + self.encoder_layers = encoder_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.num_feature_levels = num_feature_levels + self.decoder_hidden_dim = decoder_hidden_dim + self.decoder_num_heads = decoder_num_heads + self.decoder_num_layers = decoder_num_layers + self.decoder_activation = decoder_activation + self.decoder_dim_feedforward = decoder_dim_feedforward + self.decoder_num_points = decoder_num_points + self.decoder_dropout = decoder_dropout + self.eval_size = eval_size + self.learn_initial_query = learn_initial_query + self.cache_size = cache_size + self.is_encoder_decoder = is_encoder_decoder + + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) diff --git a/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py b/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py new file mode 100644 index 000000000000..2e515e983408 --- /dev/null +++ b/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py @@ -0,0 +1,349 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OmDet-Turbo checkpoints from the original repository. + +URL: https://github.com/om-ai-lab/OmDet""" + +import argparse + +import requests +import torch +from PIL import Image + +from transformers import ( + CLIPTokenizer, + DetrImageProcessor, + OmDetTurboConfig, + OmDetTurboForObjectDetection, + OmDetTurboProcessor, +) + + +IMAGE_MEAN = [123.675, 116.28, 103.53] +IMAGE_STD = [58.395, 57.12, 57.375] + + +def get_omdet_turbo_config(model_name, use_timm_backbone): + if "tiny" in model_name: + window_size = 7 + embed_dim = 96 + depths = (2, 2, 6, 2) + num_heads = (3, 6, 12, 24) + image_size = 640 + else: + raise ValueError("Model not supported, only supports tiny variant.") + + config = OmDetTurboConfig( + backbone_window_size=window_size, + backbone_image_size=image_size, + backbone_embed_dim=embed_dim, + backbone_depths=depths, + backbone_num_heads=num_heads, + backbone_out_indices=(1, 2, 3), + text_config={"model_type": "clip_text_model"}, + use_timm_backbone=use_timm_backbone, + backbone="swin_tiny_patch4_window7_224" if use_timm_backbone else None, + apply_layernorm_after_vision_backbone=True if use_timm_backbone else False, + use_pretrained_backbone=False, + ) + + return config + + +def create_rename_keys_vision(state_dict, config): + rename_keys = [] + # fmt: off + ########################################## VISION BACKBONE - START + for layer_name in state_dict.keys(): + if layer_name.startswith("backbone") and not layer_name.startswith("backbone.norm"): + if config.use_timm_backbone: + layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone._backbone") + layer_name_replace = layer_name_replace.replace(".layers.", ".layers_") + if "downsample" in layer_name: + # get layer number + layer_num = int(layer_name.split(".")[2]) + layer_name_replace = layer_name_replace.replace(f"{layer_num}.downsample", f"{layer_num+1}.downsample") + else: + layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone") + layer_name_replace = layer_name_replace.replace("patch_embed.proj", "embeddings.patch_embeddings.projection") + layer_name_replace = layer_name_replace.replace("patch_embed.norm", "embeddings.norm") + if layer_name.startswith("backbone.layers"): + layer_name_replace = layer_name_replace.replace("norm1", "layernorm_before") + layer_name_replace = layer_name_replace.replace("norm2", "layernorm_after") + layer_name_replace = layer_name_replace.replace("attn.proj", "attention.output.dense") + layer_name_replace = layer_name_replace.replace("mlp.fc1", "intermediate.dense") + layer_name_replace = layer_name_replace.replace("mlp.fc2", "output.dense") + layer_name_replace = layer_name_replace.replace(".layers.", ".encoder.layers.") + layer_name_replace = layer_name_replace.replace(".attn.", ".attention.self.") + elif layer_name.startswith("backbone.norm"): + layer_num = int(layer_name.split("norm")[1].split(".")[0]) + if config.use_timm_backbone: + layer_name_replace = layer_name.replace("backbone", "vision_backbone") + layer_name_replace = layer_name_replace.replace(f"norm{layer_num}", f"layer_norms.{layer_num-1}") + else: + layer_name_replace = layer_name.replace(f"backbone.norm{layer_num}", f"vision_backbone.vision_backbone.hidden_states_norms.stage{layer_num+1}") + else: + continue + rename_keys.append((layer_name, layer_name_replace)) + ########################################## VISION BACKBONE - END + + ########################################## ENCODER - START + for layer_name, params in state_dict.items(): + if "neck" in layer_name: + layer_name_replace = layer_name.replace("neck", "encoder") + layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers") + if "fpn_blocks" in layer_name or "pan_blocks" in layer_name or "lateral_convs" in layer_name or "downsample_convs" in layer_name: + layer_name_replace = layer_name_replace.replace(".m.", ".bottlenecks.") + layer_name_replace = layer_name_replace.replace(".cv", ".conv") + layer_name_replace = layer_name_replace.replace(".bn", ".norm") + if "encoder_layer" in layer_name: + layer_name_replace = layer_name_replace.replace("encoder_layer", "encoder.0.layers.0") + layer_name_replace = layer_name_replace.replace(".linear", ".fc") + layer_name_replace = layer_name_replace.replace("norm1", "self_attn_layer_norm") + layer_name_replace = layer_name_replace.replace("norm2", "final_layer_norm") + rename_keys.append((layer_name, layer_name_replace)) + ########################################## ENCODER - END + + ########################################## DECODER - START + for layer_name, params in state_dict.items(): + if layer_name.startswith("decoder"): + layer_name_replace = layer_name.replace("decoder.decoder.layers", "decoder.layers") + layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers") + layer_name_replace = layer_name_replace.replace("query_pos_head", "query_position_head") + layer_name_replace = layer_name_replace.replace("enc_bbox_head", "encoder_bbox_head") + layer_name_replace = layer_name_replace.replace("enc_output", "encoder_vision_features") + layer_name_replace = layer_name_replace.replace("dec_score_head", "decoder_class_head") + layer_name_replace = layer_name_replace.replace("dec_bbox_head", "decoder_bbox_head") + layer_name_replace = layer_name_replace.replace("enc_score_head", "encoder_class_head") + rename_keys.append((layer_name, layer_name_replace)) + ########################################## DECODER - END + # fmt: on + return rename_keys + + +def create_rename_keys_language(state_dict): + rename_keys = [] + # fmt: off + for layer_name in state_dict.keys(): + if layer_name.startswith("language_backbone") and not layer_name.startswith("language_backbone.text_projection"): + layer_name_replace = layer_name.replace("language_backbone", "language_backbone.model.text_model") + layer_name_replace = layer_name_replace.replace("transformer.resblocks", "encoder.layers") + layer_name_replace = layer_name_replace.replace("token_embedding", "embeddings.token_embedding") + layer_name_replace = layer_name_replace.replace("positional_embedding", "embeddings.position_embedding.weight") + layer_name_replace = layer_name_replace.replace(".attn", ".self_attn") + layer_name_replace = layer_name_replace.replace(".mlp.c_fc", ".mlp.fc1") + layer_name_replace = layer_name_replace.replace(".mlp.c_proj", ".mlp.fc2") + layer_name_replace = layer_name_replace.replace("ln_final", "final_layer_norm") + layer_name_replace = layer_name_replace.replace(".ln_", ".layer_norm") + rename_keys.append((layer_name, layer_name_replace)) + # fmt: on + return rename_keys + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v_vision(state_dict, config): + state_dict_keys = list(state_dict.keys()) + for layer_name_vision in state_dict_keys: + if layer_name_vision.startswith("vision_backbone") and "qkv" in layer_name_vision: + layer_num = int(layer_name_vision.split(".")[4]) + hidden_size = config.backbone_config.embed_dim * 2**layer_num + if "weight" in layer_name_vision: + in_proj_weight = state_dict.pop(layer_name_vision) + state_dict[layer_name_vision.replace("qkv.weight", "key.weight")] = in_proj_weight[:hidden_size, :] + state_dict[layer_name_vision.replace("qkv.weight", "query.weight")] = in_proj_weight[ + hidden_size : hidden_size * 2, : + ] + state_dict[layer_name_vision.replace("qkv.weight", "value.weight")] = in_proj_weight[-hidden_size:, :] + elif "bias" in layer_name_vision: + in_proj_bias = state_dict.pop(layer_name_vision) + state_dict[layer_name_vision.replace("qkv.bias", "key.bias")] = in_proj_bias[:hidden_size] + state_dict[layer_name_vision.replace("qkv.bias", "query.bias")] = in_proj_bias[ + hidden_size : hidden_size * 2 + ] + state_dict[layer_name_vision.replace("qkv.bias", "value.bias")] = in_proj_bias[-hidden_size:] + + +def read_in_q_k_v_text(state_dict, config): + state_dict_keys = list(state_dict.keys()) + hidden_size = config.text_config.projection_dim + for layer_name_text in state_dict_keys: + if layer_name_text.startswith("language_backbone") and "in_proj" in layer_name_text: + if "weight" in layer_name_text: + in_proj_weight = state_dict.pop(layer_name_text) + state_dict[layer_name_text.replace("in_proj_weight", "q_proj.weight")] = in_proj_weight[ + :hidden_size, : + ] + state_dict[layer_name_text.replace("in_proj_weight", "k_proj.weight")] = in_proj_weight[ + hidden_size : hidden_size * 2, : + ] + state_dict[layer_name_text.replace("in_proj_weight", "v_proj.weight")] = in_proj_weight[ + -hidden_size:, : + ] + elif "bias" in layer_name_text: + in_proj_bias = state_dict.pop(layer_name_text) + state_dict[layer_name_text.replace("in_proj_bias", "q_proj.bias")] = in_proj_bias[:hidden_size] + state_dict[layer_name_text.replace("in_proj_bias", "k_proj.bias")] = in_proj_bias[ + hidden_size : hidden_size * 2 + ] + state_dict[layer_name_text.replace("in_proj_bias", "v_proj.bias")] = in_proj_bias[-hidden_size:] + + +def read_in_q_k_v_encoder(state_dict, config): + embed_dim = config.encoder_hidden_dim + # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias) + in_proj_weight = state_dict.pop("encoder.encoder.0.layers.0.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop("encoder.encoder.0.layers.0.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict["encoder.encoder.0.layers.0.self_attn.query.weight"] = in_proj_weight[:embed_dim, :] + state_dict["encoder.encoder.0.layers.0.self_attn.query.bias"] = in_proj_bias[:embed_dim] + state_dict["encoder.encoder.0.layers.0.self_attn.key.weight"] = in_proj_weight[embed_dim : embed_dim * 2, :] + state_dict["encoder.encoder.0.layers.0.self_attn.key.bias"] = in_proj_bias[embed_dim : embed_dim * 2] + state_dict["encoder.encoder.0.layers.0.self_attn.value.weight"] = in_proj_weight[-embed_dim:, :] + state_dict["encoder.encoder.0.layers.0.self_attn.value.bias"] = in_proj_bias[-embed_dim:] + + +def read_in_q_k_v_decoder(state_dict, config): + for layer_num in range(config.decoder_num_layers): + embed_dim = config.decoder_hidden_dim + # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"decoder.layers.{layer_num}.self_attn.in_proj_weight") + in_proj_bias = state_dict.pop(f"decoder.layers.{layer_num}.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"decoder.layers.{layer_num}.self_attn.query.weight"] = in_proj_weight[:embed_dim, :] + state_dict[f"decoder.layers.{layer_num}.self_attn.query.bias"] = in_proj_bias[:embed_dim] + state_dict[f"decoder.layers.{layer_num}.self_attn.key.weight"] = in_proj_weight[embed_dim : embed_dim * 2, :] + state_dict[f"decoder.layers.{layer_num}.self_attn.key.bias"] = in_proj_bias[embed_dim : embed_dim * 2] + state_dict[f"decoder.layers.{layer_num}.self_attn.value.weight"] = in_proj_weight[-embed_dim:, :] + state_dict[f"decoder.layers.{layer_num}.self_attn.value.bias"] = in_proj_bias[-embed_dim:] + + +def run_test(model, processor): + # We will verify our results on an image of cute cats + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + + classes = ["cat", "remote"] + task = "Detect {}.".format(", ".join(classes)) + inputs = processor(image, text=classes, task=task, return_tensors="pt") + + # Running forward + with torch.no_grad(): + outputs = model(**inputs) + + predicted_slice = outputs[1][0, :3, :3] + print(predicted_slice) + expected_slice = torch.tensor([[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]) + + assert torch.allclose(predicted_slice, expected_slice, atol=1e-4) + print("Looks ok!") + + +@torch.no_grad() +def convert_omdet_turbo_checkpoint(args): + model_name = args.model_name + pytorch_dump_folder_path = args.pytorch_dump_folder_path + push_to_hub = args.push_to_hub + use_timm_backbone = args.use_timm_backbone + + checkpoint_mapping = { + "omdet-turbo-tiny": [ + "https://huggingface.co/omlab/OmDet-Turbo_tiny_SWIN_T/resolve/main/OmDet-Turbo_tiny_SWIN_T.pth", + "https://huggingface.co/omlab/OmDet-Turbo_tiny_SWIN_T/resolve/main/ViT-B-16.pt", + ], + } + # Define default OmDetTurbo configuation + config = get_omdet_turbo_config(model_name, use_timm_backbone) + + # Load original checkpoint + checkpoint_url = checkpoint_mapping[model_name] + original_state_dict_vision = torch.hub.load_state_dict_from_url(checkpoint_url[0], map_location="cpu")["model"] + original_state_dict_vision = {k.replace("module.", ""): v for k, v in original_state_dict_vision.items()} + + # Rename keys + new_state_dict = original_state_dict_vision.copy() + rename_keys_vision = create_rename_keys_vision(new_state_dict, config) + + rename_keys_language = create_rename_keys_language(new_state_dict) + + for src, dest in rename_keys_vision: + rename_key(new_state_dict, src, dest) + + for src, dest in rename_keys_language: + rename_key(new_state_dict, src, dest) + + if not use_timm_backbone: + read_in_q_k_v_vision(new_state_dict, config) + read_in_q_k_v_text(new_state_dict, config) + read_in_q_k_v_encoder(new_state_dict, config) + read_in_q_k_v_decoder(new_state_dict, config) + # add "model" prefix to all keys + new_state_dict = {f"model.{k}": v for k, v in new_state_dict.items()} + + # Load HF model + model = OmDetTurboForObjectDetection(config) + model.eval() + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + + image_processor = DetrImageProcessor( + size={"height": config.backbone_image_size, "width": config.backbone_image_size}, + do_rescale=False, + image_mean=IMAGE_MEAN, + image_std=IMAGE_STD, + do_pad=False, + ) + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") + processor = OmDetTurboProcessor(image_processor=image_processor, tokenizer=tokenizer) + + # end-to-end consistency test + run_test(model, processor) + + if pytorch_dump_folder_path is not None: + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model.push_to_hub(f"omlab/{model_name}") + processor.push_to_hub(f"omlab/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="omdet-turbo-tiny", + type=str, + choices=["omdet-turbo-tiny"], + help="Name of the OmDetTurbo model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + parser.add_argument( + "--use_timm_backbone", action="store_true", help="Whether or not to use timm backbone for vision backbone." + ) + + args = parser.parse_args() + convert_omdet_turbo_checkpoint(args) diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py new file mode 100644 index 000000000000..bb6c8838ff8c --- /dev/null +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -0,0 +1,1810 @@ +# coding=utf-8 +# Copyright 2024 Om Research Lab and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OmDet-Turbo model.""" + +import math +import os +import warnings +from collections import OrderedDict +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from ...activations import ACT2CLS, ACT2FN +from ...file_utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_cuda_available, + replace_return_docstrings, +) +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_utils import PreTrainedModel +from ...utils import is_ninja_available, logging +from ...utils.backbone_utils import load_backbone +from ..auto import AutoModel +from .configuration_omdet_turbo import OmDetTurboConfig + + +MultiScaleDeformableAttention = None + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "OmDetTurboConfig" + + +@dataclass +class OmDetTurboEncoderOutput(ModelOutput): + """ + Base class for outputs of the OmDetTurboHybridEncoder. + + Args: + last_hidden_state (`torch.FloatTensor`): + Last hidden states of the encoder. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + extracted_states (`Tuple[torch.FloatTensor]`): + The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + extracted_states: Tuple[torch.FloatTensor] = None + + +@dataclass +class OmDetTurboDecoderOutput(ModelOutput): + """ + Base class for outputs of the OmDetTurboDecoder. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder. + decoder_coords (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + The predicted coordinates of the objects. + decoder_classes (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes)`): + The predicted classes of the objects. + encoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + The predicted coordinates of the objects from the encoder. + encoder_class_logits (`Tuple[torch.FloatTensor]`) of shape `(batch_size, num_queries, num_classes)`: + The predicted class of the objects from the encoder. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + The initial reference points. + intermediate_reference_points (`Tuple[Tuple[torch.FloatTensor]]`): + The intermediate reference points. + hidden_states (`Optional[Tuple[torch.FloatTensor]]`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention, cross-attention and multi-scale deformable attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_coords: torch.FloatTensor = None + decoder_classes: torch.FloatTensor = None + encoder_coord_logits: torch.FloatTensor = None + encoder_class_logits: Tuple[torch.FloatTensor] = None + init_reference_points: torch.FloatTensor = None + intermediate_reference_points: Tuple[Tuple[torch.FloatTensor]] = None + + +@dataclass +class OmDetTurboObjectDetectionOutput(ModelOutput): + """ + Output type of [`OmDetTurboObjectDetectionOutput`]. + + Args: + loss (`torch.FloatTensor`): + The loss value. + decoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + The predicted coordinates logits of the objects. + decoder_class_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes)`): + The predicted class of the objects. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + The initial reference points. + intermediate_reference_points (`Tuple[Tuple[torch.FloatTensor]]`): + The intermediate reference points. + encoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + The predicted coordinates of the objects from the encoder. + encoder_class_logits (`Tuple[torch.FloatTensor]`): + The predicted class of the objects from the encoder. + encoder_extracted_states (`torch.FloatTensor`): + The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder. + decoder_hidden_states (`Optional[Tuple[torch.FloatTensor]]`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + decoder_attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`): + Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention, cross-attention and multi-scale deformable attention heads. + encoder_hidden_states (`Optional[Tuple[torch.FloatTensor]]`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + encoder_attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`): + Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention, cross-attention and multi-scale deformable attention heads. + """ + + loss: torch.FloatTensor = None + decoder_coord_logits: torch.FloatTensor = None + decoder_class_logits: torch.FloatTensor = None + init_reference_points: torch.FloatTensor = None + intermediate_reference_points: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + encoder_coord_logits: torch.FloatTensor = None + encoder_class_logits: Tuple[torch.FloatTensor] = None + encoder_extracted_states: torch.FloatTensor = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +# Copied from models.deformable_detr.load_cuda_kernels +def load_cuda_kernels(): + from torch.utils.cpp_extension import load + + global MultiScaleDeformableAttention + + root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr" + src_files = [ + root / filename + for filename in [ + "vision.cpp", + os.path.join("cpu", "ms_deform_attn_cpu.cpp"), + os.path.join("cuda", "ms_deform_attn_cuda.cu"), + ] + ] + + MultiScaleDeformableAttention = load( + "MultiScaleDeformableAttention", + src_files, + with_cuda=True, + extra_include_paths=[str(root)], + extra_cflags=["-DWITH_CUDA=1"], + extra_cuda_cflags=[ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ], + ) + + +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention +def multi_scale_deformable_attention( + value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + # Ignore copy + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +class OmDetTurboLRUCache: + def __init__(self, capacity: int): + self.cache = OrderedDict() + self.capacity = capacity + self.current_load = 0 + + def has(self, key) -> bool: + return key in self.cache + + def get(self, key): + """ + Get the value of the key if the key exists in the cache, otherwise return None. + Move the key to the end of the cache to show that it was recently used. + """ + if key not in self.cache: + return None + self.cache.move_to_end(key) + return self.cache[key] + + def put(self, key, value) -> None: + """ + Add the key-value pair to the cache. + Move the key to the end of the cache to show that it was recently used. + If the cache is full, remove the first key (least recently used). + """ + if key not in self.cache: + self.current_load += 1 + if self.current_load > self.capacity: + self.cache.popitem(last=False) + self.current_load -= 1 + + self.cache[key] = value + self.cache.move_to_end(key) + + +class OmDetTurboLanguageBackbone(nn.Module): + def __init__(self, config: OmDetTurboConfig): + super().__init__() + self.model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation) + self.text_projection = nn.Parameter(torch.zeros(config.text_projection_in_dim, config.text_projection_out_dim)) + + def forward(self, hidden_states, mask=None, encode_type="task"): + text_outputs = self.model(hidden_states) + pooled_output = text_outputs[0] + if encode_type == "task": + if mask is None: + raise ValueError("mask is required for task encoding") + max_len = (mask != 0).sum(1).max().item() + truncated_mask = mask[:, :max_len] + truncated_output = pooled_output[:, :max_len, :] + return truncated_output.transpose(0, 1), truncated_mask + elif encode_type == "class": + max_pooled_output = pooled_output[torch.arange(pooled_output.shape[0]), hidden_states.argmax(dim=-1)] + projected_output = max_pooled_output @ self.text_projection + return projected_output + else: + raise ValueError(f"encode_type {encode_type} is not supported") + + +class OmDetTurboVisionBackbone(nn.Module): + def __init__(self, config: OmDetTurboConfig): + super().__init__() + self.apply_layernorm_after_vision_backbone = config.apply_layernorm_after_vision_backbone + self.vision_backbone = load_backbone(config) + self.layer_norms = nn.ModuleList( + [nn.LayerNorm(in_channel_dim, eps=config.layer_norm_eps) for in_channel_dim in config.encoder_in_channels] + ) + + def forward(self, pixel_values): + outputs = self.vision_backbone(pixel_values).feature_maps + if self.apply_layernorm_after_vision_backbone: + outputs = [ + layer_norm(output).permute(0, 3, 1, 2).contiguous() + for layer_norm, output in zip(self.layer_norms, outputs) + ] + + return outputs + + +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction +class MultiScaleDeformableAttentionFunction(Function): + @staticmethod + def forward( + context, + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + im2col_step, + ): + context.im2col_step = im2col_step + output = MultiScaleDeformableAttention.ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + context.im2col_step, + ) + context.save_for_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights + ) + return output + + @staticmethod + @once_differentiable + def backward(context, grad_output): + ( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) = context.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + context.im2col_step, + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OmDetTurbo, Deformable DETR->OmDet-Turbo +class OmDetTurboMultiscaleDeformableAttention(nn.Module): + """ + Multiscale deformable attention as proposed in Deformable DETR. + """ + + def __init__(self, config: OmDetTurboConfig, num_heads: int, n_points: int): + super().__init__() + + kernel_loaded = MultiScaleDeformableAttention is not None + if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded: + try: + load_cuda_kernels() + except Exception as e: + logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") + + if config.d_model % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" + ) + dim_per_head = config.d_model // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in OmDetTurboMultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 64 + + self.d_model = config.d_model + self.n_levels = config.num_feature_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) + self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) + self.value_proj = nn.Linear(config.d_model, config.d_model) + self.output_proj = nn.Linear(config.d_model, config.d_model) + + self.disable_custom_kernels = config.disable_custom_kernels + + self._reset_parameters() + + def _reset_parameters(self): + nn.init.constant_(self.sampling_offsets.weight.data, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(self.attention_weights.weight.data, 0.0) + nn.init.constant_(self.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(self.value_proj.weight.data) + nn.init.constant_(self.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(self.output_proj.weight.data) + nn.init.constant_(self.output_proj.bias.data, 0.0) + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + output_attentions: bool = False, + ): + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + # Ignore copy + total_elements = sum([shape[0] * shape[1] for shape in spatial_shapes_list]) + if total_elements != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + # we invert the attention_mask + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points + ) + # batch_size, num_queries, n_heads, n_levels, n_points, 2 + num_coordinates = reference_points.shape[-1] + if num_coordinates == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif num_coordinates == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + if self.disable_custom_kernels: + # PyTorch implementation + output = multi_scale_deformable_attention( + value, spatial_shapes_list, sampling_locations, attention_weights + ) + else: + try: + # custom kernel + output = MultiScaleDeformableAttentionFunction.apply( + value, + spatial_shapes, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + except Exception: + # PyTorch implementation + output = multi_scale_deformable_attention( + value, spatial_shapes_list, sampling_locations, attention_weights + ) + output = self.output_proj(output) + + return output, attention_weights + + +# Copied from transformers.models.rt_detr.modeling_rt_detr.RTDetrConvNormLayer with RTDetr->OmDetTurbo +class OmDetTurboConvNormLayer(nn.Module): + def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +# Copied from transformers.models.rt_detr.modeling_rt_detr.RTDetrRepVggBlock with RTDetr->OmDetTurbo, activation_function->csp_activation +class OmDetTurboRepVggBlock(nn.Module): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: OmDetTurboConfig): + super().__init__() + + activation = config.csp_activation + hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion) + self.conv1 = OmDetTurboConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1) + self.conv2 = OmDetTurboConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, x): + y = self.conv1(x) + self.conv2(x) + return self.activation(y) + + +# Copied from transformers.models.rt_detr.modeling_rt_detr.RTDetrCSPRepLayer with RTDetr->OmDetTurbo, activation_function->csp_activation +class OmDetTurboCSPRepLayer(nn.Module): + """ + Cross Stage Partial (CSP) network layer with RepVGG blocks. + """ + + def __init__(self, config: OmDetTurboConfig): + super().__init__() + + in_channels = config.encoder_hidden_dim * 2 + out_channels = config.encoder_hidden_dim + num_blocks = 3 + activation = config.csp_activation + + hidden_channels = int(out_channels * config.hidden_expansion) + self.conv1 = OmDetTurboConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.conv2 = OmDetTurboConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.bottlenecks = nn.Sequential(*[OmDetTurboRepVggBlock(config) for _ in range(num_blocks)]) + if hidden_channels != out_channels: + self.conv3 = OmDetTurboConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation) + else: + self.conv3 = nn.Identity() + + def forward(self, hidden_state): + device = hidden_state.device + hidden_state_1 = self.conv1(hidden_state) + hidden_state_1 = self.bottlenecks(hidden_state_1).to(device) + hidden_state_2 = self.conv2(hidden_state).to(device) + return self.conv3(hidden_state_1 + hidden_state_2) + + +class OmDetTurboMultiheadAttention(nn.Module): + """Equivalent implementation of nn.MultiheadAttention with `batch_first=True`.""" + + def __init__(self, config, hidden_size, num_attention_heads, dropout): + super().__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"The hidden size ({hidden_size}) is not a multiple of the number of attention " + f"heads ({num_attention_heads})" + ) + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + self.out_proj = nn.Linear(hidden_size, hidden_size) + self.dropout = nn.Dropout(dropout) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + queries: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + query_layer = self.transpose_for_scores(self.query(queries)) + key_layer = self.transpose_for_scores(self.key(keys)) + value_layer = self.transpose_for_scores(self.value(values)) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + context_layer = self.out_proj(context_layer) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class OmDetTurboEncoderLayer(nn.Module): + def __init__(self, config: OmDetTurboConfig): + super().__init__() + self.self_attn = OmDetTurboMultiheadAttention( + config, + hidden_size=config.encoder_hidden_dim, + num_attention_heads=config.num_attention_heads, + dropout=config.encoder_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.encoder_dropout) + self.activation_fn = ACT2FN[config.encoder_feedforward_activation] + self.encoder_feedforward_dropout = nn.Dropout(config.encoder_feedforward_dropout) + self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_dim_feedforward) + self.fc2 = nn.Linear(config.encoder_dim_feedforward, config.encoder_hidden_dim) + self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + + @staticmethod + def with_pos_embed(tensor, pos_embed): + return tensor if pos_embed is None else tensor + pos_embed + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor = None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + position_embeddings (`torch.FloatTensor`, *optional*): + Object queries (also called content embeddings), to be added to the hidden states. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + query = key = self.with_pos_embed(hidden_states, position_embeddings) + + hidden_states = self.self_attn( + queries=query, + keys=key, + values=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states, attentions = hidden_states if output_attentions else (hidden_states[0], None) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.encoder_feedforward_dropout(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + if output_attentions: + return hidden_states, attentions + + return (hidden_states,) + + +class OmDetTurboEncoder(nn.Module): + def __init__(self, config: OmDetTurboConfig): + super().__init__() + + self.layers = nn.ModuleList([OmDetTurboEncoderLayer(config) for _ in range(config.encoder_layers)]) + + def forward( + self, src, src_mask=None, pos_embed=None, output_attentions: bool = False + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + hidden_states = src + attention = () if output_attentions else None + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=src_mask, + position_embeddings=pos_embed, + output_attentions=output_attentions, + ) + if output_attentions: + attention = attention + (hidden_states[1],) + hidden_states = hidden_states[0] + + return hidden_states, attention + + +class OmDetTurboHybridEncoder(nn.Module): + """ + Encoder consisting of channel projection layers, a set of `OmDetTurboEncoder`, a top-down Feature Pyramid Network + (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://arxiv.org/abs/2304.08069 + + Args: + config: OmDetTurboConfig + """ + + def __init__(self, config: OmDetTurboConfig): + super().__init__() + self.config = config + self.in_channels = config.encoder_in_channels + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encoder_projection_indices = config.encoder_projection_indices + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + + self.channel_projection_layers = nn.ModuleList() + for in_channel in self.in_channels: + self.channel_projection_layers.append( + nn.Sequential( + nn.Conv2d(in_channel, self.encoder_hidden_dim, kernel_size=(1, 1), bias=False), + nn.BatchNorm2d(self.encoder_hidden_dim), + ) + ) + + # encoder transformer + self.encoder = nn.ModuleList([OmDetTurboEncoder(config) for _ in range(len(self.encoder_projection_indices))]) + # top-down fpn + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1, 0, -1): + self.lateral_convs.append( + OmDetTurboConvNormLayer( + config, + in_channels=self.encoder_hidden_dim, + out_channels=self.encoder_hidden_dim, + kernel_size=1, + stride=1, + activation=config.conv_norm_activation, + ) + ) + self.fpn_blocks.append(OmDetTurboCSPRepLayer(config)) + + # bottom-up pan + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1): + self.downsample_convs.append( + OmDetTurboConvNormLayer( + config, + in_channels=self.encoder_hidden_dim, + out_channels=self.encoder_hidden_dim, + kernel_size=3, + stride=2, + activation=config.conv_norm_activation, + ) + ) + self.pan_blocks.append(OmDetTurboCSPRepLayer(config)) + + @staticmethod + def build_2d_sincos_position_embedding( + width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32 + ): + grid_w = torch.arange(int(width), dtype=dtype, device=device) + grid_h = torch.arange(int(height), dtype=dtype, device=device) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + if embed_dim % 4 != 0: + raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] + + def forward( + self, + inputs_embeddings=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layers) that is passed to the encoder. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeddings + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + # get projection features + projected_features = [self.channel_projection_layers[i](feature) for i, feature in enumerate(hidden_states)] + # encoder + for encoder_layer_index, feature_to_project_index in enumerate(self.encoder_projection_indices): + if output_hidden_states: + encoder_states = encoder_states + (projected_features[feature_to_project_index],) + height, width = projected_features[feature_to_project_index].shape[2:] + # flatten [batch, channel, height, width] to [batch, height*width, channel] + src_flatten = projected_features[feature_to_project_index].flatten(2).permute(0, 2, 1) + if self.training or self.eval_size is None: + pos_embed = self.build_2d_sincos_position_embedding( + width, + height, + self.encoder_hidden_dim, + self.positional_encoding_temperature, + device=src_flatten.device, + dtype=src_flatten.dtype, + ).to(src_flatten.device, src_flatten.dtype) + else: + pos_embed = None + layer_outputs = self.encoder[encoder_layer_index]( + src_flatten, + pos_embed=pos_embed, + output_attentions=output_attentions, + ) + projected_features[feature_to_project_index] = ( + layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous() + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (projected_features[feature_to_project_index],) + + # Feature Pyramid Network (FPN) + fpn_feature_maps = [projected_features[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_high = fpn_feature_maps[0] + feat_low = projected_features[idx - 1] + feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high) + fpn_feature_maps[0] = feat_high + upsample_feat = F.interpolate(feat_high, scale_factor=2.0, mode="nearest") + fps_map = self.fpn_blocks[len(self.in_channels) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1)) + fpn_feature_maps.insert(0, fps_map) + + # Path Aggregation Network (PAN) + fpn_states = [fpn_feature_maps[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = fpn_states[-1] + feat_high = fpn_feature_maps[idx + 1] + downsample_feat = self.downsample_convs[idx](feat_low) + hidden_states = self.pan_blocks[idx]( + torch.concat([downsample_feat, feat_high.to(downsample_feat.device)], dim=1) + ) + fpn_states.append(hidden_states) + if not return_dict: + return (fpn_states[-1], encoder_states, all_attentions, fpn_states) + return OmDetTurboEncoderOutput( + last_hidden_state=fpn_states[-1], + hidden_states=encoder_states, + attentions=all_attentions, + extracted_states=fpn_states, + ) + + +class OmDetTurboMLPWithDropout(nn.Module): + def __init__(self, config): + super().__init__() + self.linear1 = nn.Linear(config.class_embed_dim, config.task_encoder_hidden_dim) + self.activation = ACT2FN[config.decoder_activation] + self.dropout = nn.Dropout(config.decoder_dropout) + self.linear2 = nn.Linear(config.task_encoder_hidden_dim, config.class_embed_dim) + + def forward(self, x): + return self.linear2(self.dropout(self.activation(self.linear1(x)))) + + +class OmDetTurboMLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + hidden_layers_dims = [hidden_dim] * (num_layers - 1) + layers_dims = [input_dim] + hidden_layers_dims + [output_dim] + self.layers = nn.ModuleList( + [nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(layers_dims[:-1], layers_dims[1:])] + ) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class OmDetTurboResidualLayer(nn.Module): + """ + A residual connection followed by a layer norm. + """ + + def __init__(self, config): + super().__init__() + self.norm1 = nn.LayerNorm(config.class_embed_dim, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.decoder_dropout) + + def forward(self, x, y): + return self.norm1(x + self.dropout(y)) + + +class OmDetTurboTaskEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.mlp = OmDetTurboMLPWithDropout(config) + self.res1 = OmDetTurboResidualLayer(config) + + def forward(self, x): + mlp_out = self.mlp(x) + x = self.res1(x, mlp_out) + return x + + +class OmDetTurboDeformableTransformerDecoderLayer(nn.Module): + """ + A single layer of the Deformable Transformer Decoder. + """ + + def __init__(self, config): + super().__init__() + # self attention + self.self_attn = OmDetTurboMultiheadAttention( + config, + hidden_size=config.decoder_hidden_dim, + num_attention_heads=config.decoder_num_heads, + dropout=config.decoder_dropout, + ) + self.dropout1 = nn.Dropout(config.decoder_dropout) + self.norm1 = nn.LayerNorm(config.decoder_hidden_dim, eps=config.layer_norm_eps) + + # cross attention + self.cross_attn = OmDetTurboMultiscaleDeformableAttention( + config, num_heads=config.decoder_num_heads, n_points=config.decoder_num_points + ) + self.dropout2 = nn.Dropout(config.decoder_dropout) + self.norm2 = nn.LayerNorm(config.decoder_hidden_dim, eps=config.layer_norm_eps) + + # feed forward network + self.linear1 = nn.Linear(config.decoder_hidden_dim, config.decoder_dim_feedforward) + self.act = ACT2FN[config.decoder_activation] + self.dropout3 = nn.Dropout(config.decoder_dropout) + self.linear2 = nn.Linear(config.decoder_dim_feedforward, config.decoder_hidden_dim) + self.dropout4 = nn.Dropout(config.decoder_dropout) + self.norm3 = nn.LayerNorm(config.decoder_hidden_dim, eps=config.layer_norm_eps) + + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward( + self, + decoder_embeddings, + task_features, + reference_points, + vision_features, + vision_shapes, + vision_shapes_list, + level_start_index=None, + attention_mask=None, + padding_mask=None, + query_position=None, + output_attentions=None, + output_hidden_states=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states + + origin_embedding_len = decoder_embeddings.shape[1] + + # self attention + query = key = self.with_pos_embed(decoder_embeddings, query_position) + # combine task_features with query, key, value + task_features = task_features.transpose(0, 1) + query = torch.cat((query, task_features), dim=1) + key = torch.cat((key, task_features), dim=1) + decoder_embeddings = torch.cat((decoder_embeddings, task_features), dim=1) + + outputs = self.self_attn( + query, + key, + decoder_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + context, self_attention = outputs if output_attentions else (outputs[0], None) + decoder_embeddings = decoder_embeddings + self.dropout1(context) + decoder_embeddings = self.norm1(decoder_embeddings) + + task_features = decoder_embeddings[:, origin_embedding_len:, :].transpose(0, 1) + decoder_embeddings = decoder_embeddings[:, :origin_embedding_len, :] + + # cross attention + hidden_states = self.with_pos_embed(decoder_embeddings, query_position) + reference_points = reference_points.unsqueeze(2) + outputs, cross_attention = self.cross_attn( + hidden_states=hidden_states, + attention_mask=padding_mask, + encoder_hidden_states=vision_features, + reference_points=reference_points, + spatial_shapes=vision_shapes, + spatial_shapes_list=vision_shapes_list, + level_start_index=level_start_index, + ) + decoder_embeddings = decoder_embeddings + self.dropout2(outputs) + residual = self.norm2(decoder_embeddings) + + # feed forward network + decoder_embeddings = self.linear2(self.dropout3(self.act(self.linear1(residual)))) + decoder_embeddings = residual + self.dropout4(decoder_embeddings) + decoder_embeddings = self.norm3(decoder_embeddings) + + return ( + decoder_embeddings, + task_features, + self_attention if output_attentions else None, + cross_attention if output_attentions else None, + ) + + +class OmDetTurboPreTrainedModel(PreTrainedModel): + config_class = OmDetTurboConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + + def _init_weights(self, module): + def linear_init_(module_to_init): + bound = 1 / math.sqrt(module_to_init.weight.shape[0]) + nn.init.uniform_(module_to_init.weight, -bound, bound) + if hasattr(module_to_init, "bias") and module_to_init.bias is not None: + nn.init.uniform_(module_to_init.bias, -bound, bound) + + if isinstance(module, OmDetTurboEncoderLayer): + linear_init_(module.fc1) + linear_init_(module.fc2) + elif isinstance(module, OmDetTurboDecoder): + nn.init.constant_(module.encoder_bbox_head.layers[-1].weight, 0.0) + nn.init.constant_(module.encoder_bbox_head.layers[-1].bias, 0.0) + for mlp in module.decoder_bbox_head: + nn.init.constant_(mlp.layers[-1].weight, 0.0) + nn.init.constant_(mlp.layers[-1].bias, 0.0) + linear_init_(module.encoder_vision_features[0]) + nn.init.xavier_uniform_(module.encoder_vision_features[0].weight) + if module.learn_initial_query: + nn.init.xavier_uniform_(module.tgt_embed.weight) + nn.init.xavier_uniform_(module.query_position_head.layers[0].weight) + nn.init.xavier_uniform_(module.query_position_head.layers[1].weight) + for layer in module.channel_projection_layers: + nn.init.xavier_uniform_(layer[0].weight) + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, OmDetTurboDecoder): + module.gradient_checkpointing = value + + @staticmethod + def _get_cache_key_at_index(input_ids, attention_mask, index): + input_ids = input_ids[index] + input_mask = attention_mask[index] + cache_key = tuple(input_ids[input_mask != 0].tolist()) + return cache_key + + def get_cached_class_embeddings(self, classes_input_ids, classes_attention_mask): + not_cached_index = [] + not_cached_classes = [] + total_embeddings = [] + for idx, _ in enumerate(classes_input_ids): + cache_key = self._get_cache_key_at_index(classes_input_ids, classes_attention_mask, idx) + if self.language_cache_class.has(cache_key): + total_embeddings.append(self.language_cache_class.get(cache_key)) + else: + total_embeddings.append(None) + not_cached_index.append(idx) + not_cached_classes.append(cache_key) + + if not_cached_classes: + not_cached_classes_ids = torch.stack([classes_input_ids[idx] for idx in not_cached_index]) + embeddings = self.language_backbone(not_cached_classes_ids, encode_type="class") + for idx, emb in enumerate(embeddings): + idx_to_put = not_cached_index[idx] + total_embeddings[idx_to_put] = emb + self.language_cache_class.put(not_cached_classes[idx], emb) + + total_class_embs = torch.stack(total_embeddings).to(self.device) + return total_class_embs + + def get_cached_task_embeddings(self, tasks_input_ids, tasks_attention_mask): + not_cached_index = [] + not_cached_tasks = [] + total_task_features = [] + total_task_masks = [] + for idx, _ in enumerate(tasks_input_ids): + cache_key = self._get_cache_key_at_index(tasks_input_ids, tasks_attention_mask, idx) + if self.language_cache_prompt.has(cache_key): + task_feature, task_mask = self.language_cache_prompt.get(cache_key) + total_task_features.append(task_feature) + total_task_masks.append(task_mask) + else: + total_task_features.append(None) + total_task_masks.append(None) + not_cached_index.append(idx) + not_cached_tasks.append(cache_key) + + if not_cached_tasks: + not_cached_index_ids = torch.stack([tasks_input_ids[idx] for idx in not_cached_index]) + not_cached_mask = torch.stack([tasks_attention_mask[idx] for idx in not_cached_index]) + embeddings, masks = self.language_backbone(not_cached_index_ids, mask=not_cached_mask, encode_type="task") + + for idx in range(embeddings.shape[1]): + emb = embeddings[:, [idx], :] + idx_to_put = not_cached_index[idx] + cur_mask = torch.unsqueeze(masks[idx], dim=0).to(self.device) + total_task_features[idx_to_put] = emb + total_task_masks[idx_to_put] = cur_mask + self.language_cache_prompt.put(not_cached_tasks[idx], (emb, cur_mask)) + + # pad before concat if needed + max_len = max([task.shape[0] for task in total_task_features]) + for idx, task in enumerate(total_task_features): + if task.shape[0] < max_len: + pad_size = max_len - task.shape[0] + total_task_features[idx] = F.pad(task, (0, 0, 0, 0, 0, pad_size)) + total_task_masks[idx] = F.pad(total_task_masks[idx], (0, pad_size)) + + total_task_features = torch.cat(total_task_features, dim=1).to(self.device) + total_task_masks = torch.cat(total_task_masks, dim=0).to(self.device) + + return total_task_features, total_task_masks + + def get_language_embedding( + self, + classes_input_ids, + classes_attention_mask, + tasks_input_ids, + tasks_attention_mask, + classes_structure, + ): + batched_classes_embeddings = self.get_cached_class_embeddings(classes_input_ids, classes_attention_mask) + # regroup class embeddings using saved structure + max_class_size = torch.max(classes_structure) + class_embeddings_regrouped = [] + start = 0 + for size in classes_structure: + pad_size = max_class_size - size + class_embeddings_regrouped.append( + F.pad(batched_classes_embeddings[start : start + size], (0, 0, 0, pad_size)).unsqueeze(1) + ) + start += size + class_embeddings = torch.cat(class_embeddings_regrouped, dim=1) + + task_embeddings, task_mask = self.get_cached_task_embeddings(tasks_input_ids, tasks_attention_mask) + + return class_embeddings, task_embeddings, task_mask + + +OMDET_TURBO_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OmDetTurboConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +OMDET_TURBO_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. + + Pixel values can be obtained using [`AutoImageProcessor`]. See [`DetrImageProcessor.__call__`] for + details. + + classes_input_ids (`torch.LongTensor` of shape `(total_classes (>= batch_size), sequence_length)`): + Indices of input classes sequence tokens in the vocabulary of the language model. + Several classes can be provided for each tasks, thus the tokenized classes are flattened + and the structure of the classes is provided in the `classes_structure` argument. + + Indices can be obtained using [`OmDetTurboProcessor`]. See [`OmDetTurboProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + + classes_attention_mask (`torch.BoolTensor` of shape `(total_classes (>= batch_size), num_classes, sequence_length)`): + Attention mask for the classes. This is a binary mask that indicates which tokens should be attended to, + and which should not. + + tasks_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input tasks sequence tokens in the vocabulary of the language model. + + Indices can be obtained using [`OmDetTurboProcessor`]. See [`OmDetTurboProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + + tasks_attention_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Attention mask for the tasks. This is a binary mask that indicates which tokens should be attended to, + and which should not. + + classes_structure (torch.LongTensor of shape `(batch_size)`): + Structure of the classes. This tensor indicates the number of classes for each task. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + """ + + +def _cosine_similarity_scaled(a, b, logit_scale): + a = a / a.norm(dim=2, keepdim=True).clamp_min(1e-12) + b = b / b.norm(dim=1, keepdim=True).clamp_min(1e-12) + logit_scale = logit_scale.exp() + logits_per_image = logit_scale * torch.bmm(a, b) + return logits_per_image + + +def get_class_similarity(class_distance_type, cls_feature, class_proj): + logit_scale = torch.tensor(1 / 0.07).log() + if class_distance_type == "cosine": + class_logits = _cosine_similarity_scaled(cls_feature, class_proj, logit_scale) + elif class_distance_type == "dot": + class_logits = torch.bmm(cls_feature, class_proj) + else: + raise Exception("Unknown class_distance_type {}".format(class_distance_type)) + return class_logits + + +def _inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +class OmDetTurboDecoder(OmDetTurboPreTrainedModel): + def __init__(self, config: OmDetTurboConfig): + self.config = config + super().__init__(config) + self.gradient_checkpointing = False + + hidden_dim = config.decoder_hidden_dim + self.num_queries = config.num_queries + self.class_distance_type = config.class_distance_type + self.learn_initial_query = config.learn_initial_query + + # backbone feature projection + self.channel_projection_layers = nn.ModuleList( + nn.Sequential(nn.Conv2d(x, hidden_dim, 1, bias=False), nn.BatchNorm2d(hidden_dim)) + for x in config.vision_features_channels + ) + self.task_encoder = OmDetTurboTaskEncoder(config) + if config.class_embed_dim != hidden_dim: + self.task_project = nn.Linear(config.class_embed_dim, hidden_dim) + + # Transformer module + self.layers = nn.ModuleList( + [OmDetTurboDeformableTransformerDecoderLayer(config) for _ in range(config.decoder_num_layers)] + ) + self.decoder_num_layers = config.decoder_num_layers + # decoder embedding + if self.learn_initial_query: + self.tgt_embed = nn.Embedding(self.num_queries, hidden_dim) + self.query_position_head = OmDetTurboMLP( + input_dim=4, hidden_dim=2 * hidden_dim, output_dim=hidden_dim, num_layers=2 + ) + + # encoder head + self.encoder_vision_features = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim, eps=config.layer_norm_eps) + ) + self.encoder_class_head = nn.Linear(config.class_embed_dim, hidden_dim) + self.encoder_bbox_head = OmDetTurboMLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=4, num_layers=3) + + # decoder head + self.decoder_class_head = nn.ModuleList( + [nn.Linear(config.class_embed_dim, hidden_dim) for _ in range(config.decoder_num_layers)] + ) + self.decoder_bbox_head = nn.ModuleList( + [OmDetTurboMLP(hidden_dim, hidden_dim, 4, num_layers=3) for _ in range(config.decoder_num_layers)] + ) + + # Initialize weights and apply final processing + self.post_init() + + @lru_cache(maxsize=32) + def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32): + # We always generate anchors in float32 to preserve equivalence between + # dynamic and static anchor inference + # Ignore copy + if spatial_shapes is None: + raise ValueError("spatial_shapes must be provided") + + anchors = [] + for level, (height, width) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid( + torch.arange(end=height, dtype=dtype, device=device), + torch.arange(end=width, dtype=dtype, device=device), + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], -1) + valid_wh = torch.tensor([width, height], dtype=dtype, device=device) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_wh + wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**level) + anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) + # define the valid range for anchor coordinates + eps = 1e-2 + anchors = torch.concat(anchors, 1) + valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.inf) + + return anchors, valid_mask + + def _get_encoder_input(self, vision_features): + # get projection features + vision_features = [self.channel_projection_layers[i](feat) for i, feat in enumerate(vision_features)] + # get encoder inputs + new_vision_features = [] + new_vision_shapes_list = [] + for feat in vision_features: + height, width = feat.shape[2:] + # [batch_size, channels, height, width] -> [batch_size, height*width, channels] + new_vision_features.append(feat.flatten(2).permute(0, 2, 1)) + # [num_feature_levels, 2] + new_vision_shapes_list.append((height, width)) + + # [batch_size, height*width, channels] + new_vision_features = torch.cat(new_vision_features, 1) + new_vision_shapes = torch.tensor(new_vision_shapes_list, dtype=torch.int64).to(vision_features[0].device) + level_start_index = torch.cat((new_vision_shapes.new_zeros((1,)), new_vision_shapes.prod(1).cumsum(0)[:-1])) + + return new_vision_features, new_vision_shapes, new_vision_shapes_list, level_start_index + + def _get_decoder_input( + self, vision_features, vision_shapes, class_features, denoise_embeddings=None, denoise_bboxes=None + ): + batch_size = len(vision_features) + # prepare input for decoder + anchors, valid_mask = self.generate_anchors( + vision_shapes, device=vision_features.device, dtype=vision_features.dtype + ) + predicted_class_features = self.encoder_vision_features( + torch.where( + valid_mask, vision_features, torch.tensor(0.0, dtype=vision_features.dtype).to(vision_features.device) + ) + ) + + original_class_projected = self.encoder_class_head(class_features).permute(1, 2, 0) + encoder_class_similarity = get_class_similarity( + self.class_distance_type, predicted_class_features, original_class_projected + ) + + # dynamic anchors + static content + # (batch_size, height*width, 4) + encoder_outputs_bboxes = self.encoder_bbox_head(predicted_class_features) + anchors + + # query selection + # (batch_size, num_queries) + topk_ind = torch.topk(encoder_class_similarity.max(-1).values, self.num_queries, dim=1).indices.view(-1) + # (batch_size, num_queries) + batch_ind = ( + torch.arange(end=batch_size, dtype=topk_ind.dtype, device=topk_ind.device) + .unsqueeze(-1) + .repeat(1, self.num_queries) + .view(-1) + ) + + reference_points = encoder_outputs_bboxes[batch_ind, topk_ind].view(batch_size, self.num_queries, -1) + encoder_bboxes = reference_points.sigmoid() + if denoise_bboxes is not None: + reference_points = torch.cat([denoise_bboxes, reference_points], 1) + if self.training: + reference_points = reference_points.detach() + encoder_class_similarity = encoder_class_similarity[batch_ind, topk_ind].view(batch_size, self.num_queries, -1) + + if self.learn_initial_query: + embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1) + else: + embeddings = predicted_class_features[batch_ind, topk_ind].view(batch_size, self.num_queries, -1) + if self.training: + embeddings = embeddings.detach() + if denoise_embeddings is not None: + embeddings = torch.cat([denoise_embeddings, embeddings], 1) + + return embeddings, reference_points, encoder_bboxes, encoder_class_similarity, anchors + + def forward( + self, + vision_features, + class_features, + task_features, + task_mask, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + """ + Args: + vision_features (`torch.FloatTensor`): The sequence of vision features. shape depends on the vision + backbone. + class_features (`torch.FloatTensor`): The sequence of class features of shape + `(class_sequence_length, batch_size, class_embed_dim)`. + task_features (`torch.FloatTensor`): The sequence of task features of shape + `(task_sequence_length, batch_size, decoder_hidden_dim)`. + task_mask (`torch.LongTensor`): The mask for the task features of shape `(batch_size, task_sequence_length)`. + output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain + tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_features, vision_shapes, vision_shapes_list, level_start_index = self._get_encoder_input( + vision_features + ) + + # todo add denoising for training + denoise_embeddings, denoise_bboxes, key_padding_mask = None, None, None + batch_size = task_mask.shape[0] + + # compose attn_mask for vision_emb and task_emb fusion + task_features = self.task_encoder(task_features) + if self.task_project is not None: + task_features = self.task_project(task_features) + src_key_mask = (task_mask == 0).detach() + attn_mask_len = self.num_queries + fusion_size = attn_mask_len + task_features.shape[0] + key_padding_mask = torch.zeros([batch_size, fusion_size], dtype=torch.bool).to(task_features.device) + key_padding_mask[:, attn_mask_len:] = src_key_mask + attention_mask = _prepare_4d_attention_mask(~key_padding_mask, dtype=vision_features.dtype) + decoder_embeddings, reference_points, encoder_bboxes, encoder_class_similarity, init_reference_points = ( + self._get_decoder_input( + vision_features, tuple(vision_shapes_list), class_features, denoise_embeddings, denoise_bboxes + ) + ) + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + all_self_attns = () if output_attentions else None + all_cross_attns = () if output_attentions else None + predicted_class_features = decoder_embeddings + + if output_hidden_states: + all_hidden_states = all_hidden_states + (predicted_class_features,) + decoder_bboxes = [] + decoder_classes = [] + last_refined_bbox = None + reference_points = reference_points.sigmoid() + for i, layer in enumerate(self.layers): + if self.gradient_checkpointing and self.training: + predicted_class_features, task_features, self_attention, cross_attention = ( + self._gradient_checkpointing_func( + layer.__call__, + predicted_class_features, + task_features, + reference_points, + vision_features, + vision_shapes, + vision_shapes_list, + level_start_index=level_start_index, + attention_mask=attention_mask, + query_position=self.query_position_head(reference_points), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + ) + else: + predicted_class_features, task_features, self_attention, cross_attention = layer( + predicted_class_features, + task_features, + reference_points, + vision_features, + vision_shapes, + vision_shapes_list, + level_start_index=level_start_index, + attention_mask=attention_mask, + query_position=self.query_position_head(reference_points), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + if output_attentions: + all_self_attns = all_self_attns + (self_attention,) + all_cross_attns = all_cross_attns + (cross_attention,) + if output_hidden_states: + all_hidden_states = all_hidden_states + (predicted_class_features,) + + refined_bbox = torch.sigmoid( + self.decoder_bbox_head[i](predicted_class_features) + _inverse_sigmoid(reference_points) + ) + original_class_projected = self.decoder_class_head[i](class_features).permute(1, 2, 0) + if self.training: + decoder_classes.append( + get_class_similarity( + class_distance_type=self.class_distance_type, + cls_feature=predicted_class_features, + class_proj=original_class_projected, + ) + ) + if i == 0: + decoder_bboxes.append(refined_bbox) + else: + decoder_bboxes.append( + torch.sigmoid( + self.decoder_bbox_head[i](predicted_class_features) + _inverse_sigmoid(last_refined_bbox) + ) + ) + elif i == self.decoder_num_layers - 1: + decoder_classes.append( + get_class_similarity(self.class_distance_type, predicted_class_features, original_class_projected) + ) + decoder_bboxes.append(refined_bbox) + break + last_refined_bbox = refined_bbox + reference_points = refined_bbox.detach() if self.training else refined_bbox + if output_attentions: + all_attns += (all_self_attns, all_cross_attns) + + last_hidden_state = predicted_class_features + decoder_bboxes = torch.stack(decoder_bboxes) + decoder_classes = torch.stack(decoder_classes) + + if not return_dict: + return ( + last_hidden_state, + all_hidden_states, + all_attns, + decoder_bboxes, + decoder_classes, + encoder_bboxes, + encoder_class_similarity, + init_reference_points, + reference_points, + ) + + return OmDetTurboDecoderOutput( + last_hidden_state=last_hidden_state, + hidden_states=all_hidden_states, + attentions=all_attns, + decoder_coords=decoder_bboxes, + decoder_classes=decoder_classes, + encoder_coord_logits=encoder_bboxes, + encoder_class_logits=encoder_class_similarity, + init_reference_points=init_reference_points, + intermediate_reference_points=reference_points, + ) + + +@add_start_docstrings( + """ + OmDetTurbo Model (consisting of a vision and a text backbone, and encoder-decoder architecture) outputting + bounding boxes and classes scores for tasks such as COCO detection. + """, + OMDET_TURBO_START_DOCSTRING, +) +class OmDetTurboForObjectDetection(OmDetTurboPreTrainedModel): + def __init__(self, config: OmDetTurboConfig): + super().__init__(config) + self.vision_backbone = OmDetTurboVisionBackbone(config) + self.language_backbone = OmDetTurboLanguageBackbone(config) + self.encoder = OmDetTurboHybridEncoder(config) + self.decoder = OmDetTurboDecoder(config) + self.num_queries = config.num_queries + + self.language_cache_class = OmDetTurboLRUCache(config.cache_size) + self.language_cache_prompt = OmDetTurboLRUCache(config.cache_size) + self.vocab_size = config.text_config.vocab_size + self.post_init() + + def get_input_embeddings(self): + return self.language_backbone.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_backbone.model.set_input_embeddings(value) + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_backbone.model.resize_token_embeddings( + new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of + ) + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + @add_start_docstrings_to_model_forward(OMDET_TURBO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=OmDetTurboObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Tensor, + classes_input_ids: Tensor, + classes_attention_mask: Tensor, + tasks_input_ids: Tensor, + tasks_attention_mask: Tensor, + classes_structure: Tensor, + labels: Optional[Tensor] = None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ) -> Union[Tuple[torch.FloatTensor], OmDetTurboObjectDetectionOutput]: + r""" + Returns: + + Examples: + + ```python + >>> import requests + >>> from PIL import Image + + >>> from transformers import AutoProcessor, OmDetTurboForObjectDetection + + >>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-tiny") + >>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-tiny") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> classes = ["cat", "remote"] + >>> task = "Detect {}.".format(", ".join(classes)) + >>> inputs = processor(image, text=classes, task=task, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> # convert outputs (bounding boxes and class logits) + >>> results = processor.post_process_grounded_object_detection( + ... outputs, + ... classes=classes, + ... target_sizes=[image.size[::-1]], + ... score_threshold=0.3, + ... nms_threshold=0.3, + >>> )[0] + >>> for score, class_name, box in zip(results["scores"], results["classes"], results["boxes"]): + ... box = [round(i, 1) for i in box.tolist()] + ... print( + ... f"Detected {class_name} with confidence " + ... f"{round(score.item(), 2)} at location {box}" + ... ) + Detected remote with confidence 0.76 at location [39.9, 71.3, 176.5, 117.9] + Detected cat with confidence 0.72 at location [345.1, 22.5, 639.7, 371.9] + Detected cat with confidence 0.65 at location [12.7, 53.8, 315.5, 475.3] + Detected remote with confidence 0.57 at location [333.4, 75.6, 370.7, 187.0] + ```""" + if labels is not None: + raise NotImplementedError("Training is not implemented yet") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + loss = None + image_features = self.vision_backbone(pixel_values) + encoder_outputs = self.encoder( + image_features, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + class_features, task_features, task_mask = self.get_language_embedding( + classes_input_ids, + classes_attention_mask, + tasks_input_ids, + tasks_attention_mask, + classes_structure, + ) + encoder_extracted_states = encoder_outputs.extracted_states if return_dict else encoder_outputs[-1] + decoder_outputs = self.decoder( + encoder_extracted_states, + class_features, + task_features, + task_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return tuple( + output + for output in [ + loss, + decoder_outputs[3][-1], + decoder_outputs[4][-1], + decoder_outputs[7], + decoder_outputs[8], + decoder_outputs[5], + decoder_outputs[6], + encoder_outputs[-1], + decoder_outputs[1], + decoder_outputs[2], + encoder_outputs[1], + encoder_outputs[2], + ] + if output is not None + ) + + return OmDetTurboObjectDetectionOutput( + loss=loss, + decoder_coord_logits=decoder_outputs.decoder_coords[-1], + decoder_class_logits=decoder_outputs.decoder_classes[-1], + init_reference_points=decoder_outputs.init_reference_points, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + encoder_coord_logits=decoder_outputs.encoder_coord_logits, + encoder_class_logits=decoder_outputs.encoder_class_logits, + encoder_extracted_states=encoder_outputs.extracted_states, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) diff --git a/src/transformers/models/omdet_turbo/processing_omdet_turbo.py b/src/transformers/models/omdet_turbo/processing_omdet_turbo.py new file mode 100644 index 000000000000..909281b0c686 --- /dev/null +++ b/src/transformers/models/omdet_turbo/processing_omdet_turbo.py @@ -0,0 +1,362 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for OmDet-Turbo. +""" + +import sys +from typing import List, Optional, Tuple, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_transforms import center_to_corners_format +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import ( + TensorType, + is_torch_available, + is_torchvision_available, +) + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class OmDetTurboTextKwargs(TextKwargs, total=False): + task: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] + + +class OmDetTurboProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: OmDetTurboTextKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": "max_length", + "truncation": True, + "max_length": 77, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + "task": None, + }, + "images_kwargs": {}, + } + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + from torchvision.ops.boxes import batched_nms + + +def clip_boxes(box, box_size: Tuple[int, int]): + """ + Clip the boxes by limiting x coordinates to the range [0, width] + and y coordinates to the range [0, height]. + + Args: + box (Tensor): The box to be clipped. + box_size (height, width): The clipping box's size. + """ + assert torch.isfinite(box).all(), "Box tensor contains infinite or NaN!" + height, width = box_size + x1 = box[:, 0].clamp(min=0, max=width) + y1 = box[:, 1].clamp(min=0, max=height) + x2 = box[:, 2].clamp(min=0, max=width) + y2 = box[:, 3].clamp(min=0, max=height) + box = torch.stack((x1, y1, x2, y2), dim=-1) + + return box + + +def compute_score(boxes): + """ + Compute logit scores per class for each box (proposal) and an array of class indices + corresponding to each proposal, flattened across the proposal_num. + The indices in `classes` will later be used to filter and match the predicted classes + with the input class names. + """ + num_classes = boxes.shape[2] + proposal_num = boxes.shape[1] + scores = torch.sigmoid(boxes) + classes = torch.arange(num_classes, device=boxes.device).unsqueeze(0).repeat(proposal_num, 1).flatten(0, 1) + return scores, classes + + +def _post_process_boxes_for_image( + boxes: TensorType, + scores: TensorType, + predicted_classes: TensorType, + classes: List[str], + image_size: Tuple[int, int], + num_classes: int, + score_threshold: float, + nms_threshold: float, + max_num_det: int = None, +) -> dict: + """ + Filter predicted results using given thresholds and NMS. + Args: + boxes (torch.Tensor): A Tensor of predicted class-specific or class-agnostic + boxes for the image. Shape : (num_queries, max_num_classes_in_batch * 4) if doing + class-specific regression, or (num_queries, 4) if doing class-agnostic + regression. + scores (torch.Tensor): A Tensor of predicted class scores for the image. + Shape : (num_queries, max_num_classes_in_batch + 1) + predicted_classes (torch.Tensor): A Tensor of predicted classes for the image. + Shape : (num_queries * (max_num_classes_in_batch + 1),) + classes (List[str]): The input classes names. + image_size (tuple): A tuple of (height, width) for the image. + num_classes (int): The number of classes given for this image. + score_threshold (float): Only return detections with a confidence score exceeding this + threshold. + nms_threshold (float): The threshold to use for box non-maximum suppression. Value in [0, 1]. + max_num_det (int, optional): The maximum number of detections to return. Default is None. + Returns: + dict: A dictionary the following keys: + "boxes" (Tensor): A tensor of shape (num_filtered_objects, 4), containing the predicted boxes in (x1, y1, x2, y2) format. + "scores" (Tensor): A tensor of shape (num_filtered_objects,), containing the predicted confidence scores for each detection. + "classes" (List[str]): A list of strings, where each string is the predicted class for the + corresponding detection + """ + proposal_num = len(boxes) if max_num_det is None else max_num_det + scores_per_image, topk_indices = scores.flatten(0, 1).topk(proposal_num, sorted=False) + classes_per_image = predicted_classes[topk_indices] + box_pred_per_image = boxes.view(-1, 1, 4).repeat(1, num_classes, 1).view(-1, 4) + box_pred_per_image = box_pred_per_image[topk_indices] + + # Score filtering + box_pred_per_image = center_to_corners_format(box_pred_per_image) + box_pred_per_image = box_pred_per_image * torch.tensor(image_size[::-1]).repeat(2).to(box_pred_per_image.device) + filter_mask = scores_per_image > score_threshold # R x K + score_keep = filter_mask.nonzero(as_tuple=False).view(-1) + box_pred_per_image = box_pred_per_image[score_keep] + scores_per_image = scores_per_image[score_keep] + classes_per_image = classes_per_image[score_keep] + + filter_classes_mask = classes_per_image < len(classes) + classes_keep = filter_classes_mask.nonzero(as_tuple=False).view(-1) + box_pred_per_image = box_pred_per_image[classes_keep] + scores_per_image = scores_per_image[classes_keep] + classes_per_image = classes_per_image[classes_keep] + + # NMS + keep = batched_nms(box_pred_per_image, scores_per_image, classes_per_image, nms_threshold) + box_pred_per_image = box_pred_per_image[keep] + scores_per_image = scores_per_image[keep] + classes_per_image = classes_per_image[keep] + classes_per_image = [classes[i] for i in classes_per_image] + + # create an instance + result = {} + result["boxes"] = clip_boxes(box_pred_per_image, image_size) + result["scores"] = scores_per_image + result["classes"] = classes_per_image + + return result + + +class OmDetTurboProcessor(ProcessorMixin): + r""" + Constructs a OmDet-Turbo processor which wraps a Deformable DETR image processor and an AutoTokenizer into a + single processor. + + [`OmDetTurboProcessor`] offers all the functionalities of [`DetrImageProcessor`] and + [`AutoTokenizer`]. See the docstring of [`~OmDetTurboProcessor.__call__`] and [`~OmDetTurboProcessor.decode`] + for more information. + + Args: + image_processor (`DetrImageProcessor`): + An instance of [`DetrImageProcessor`]. The image processor is a required input. + tokenizer (`AutoTokenizer`): + An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "DetrImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images: ImageInput = None, + text: Union[List[str], List[List[str]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[OmDetTurboProcessorKwargs], + ) -> BatchFeature: + """ + This method uses [*DetrImageProcessor.__call__] method to prepare image(s) for the model, and + [CLIPTokenizerFast.__call__] to prepare text for the model. + + Please refer to the docstring of the above two methods for more information. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. + text (`Union[str, List[str], List[List[str]]]`): + The classes used to limit the scope of the open vocabulary detection. Expects a list of strings or a list + of list of strings. Batched classes can be of different lengths. + Examples: ["cat", "dog", "bird"], [["cat", "dog", "bird"], ["hat", "person"], ["car"]] + Kwargs: + task (`Union[str, List[str], TextInput, PreTokenizedInput]`): + The grounded text used to guide open vocabulary detection. Expects a single string or a list of strings. + Examples: "Detect a cat, a dog, and a bird.",[ "Detect everything.", "Detect trees and flowers."] + When not provided, the default task is "Detect [class1], [class2], [class3]" etc. + ... + """ + if images is None or text is None: + raise ValueError("You have to specify both `images` and `text`") + + output_kwargs = self._merge_kwargs( + OmDetTurboProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if isinstance(text, str): + text = text.strip(" ").split(",") + + if not (len(text) and isinstance(text[0], (list, tuple))): + text = [text] + + task = output_kwargs["text_kwargs"].pop("task", None) + if task is None: + task = ["Detect {}.".format(", ".join(text_single)) for text_single in text] + elif not isinstance(task, (list, tuple)): + task = [task] + + encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"]) + tasks_encoding = self.tokenizer(text=task, **output_kwargs["text_kwargs"]) + + classes = text + + classes_structure = torch.tensor([len(class_single) for class_single in classes], dtype=torch.long) + classes_flattened = [class_single for class_batch in classes for class_single in class_batch] + classes_encoding = self.tokenizer(text=classes_flattened, **output_kwargs["text_kwargs"]) + + encoding = BatchFeature() + encoding.update({f"tasks_{key}": value for key, value in tasks_encoding.items()}) + encoding.update({f"classes_{key}": value for key, value in classes_encoding.items()}) + encoding.update({"classes_structure": classes_structure}) + encoding.update(encoding_image_processor) + + return encoding + + # Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_grounded_object_detection( + self, + outputs, + classes: Union[List[str], List[List[str]]], + score_threshold: float = 0.3, + nms_threshold: float = 0.5, + target_sizes: Optional[Union[TensorType, List[Tuple]]] = None, + max_num_det: Optional[int] = None, + ): + """ + Converts the raw output of [`OmDetTurboForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format and get the associated text class. + + Args: + outputs ([`OmDetTurboObjectDetectionOutput`]): + Raw outputs of the model. + classes (Union[List[str], List[List[str]]]): The input classes names. + score_threshold (float, defaults to 0.3): Only return detections with a confidence score exceeding this + threshold. + nms_threshold (float, defaults to 0.5): The threshold to use for box non-maximum suppression. Value in [0, 1]. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*, defaults to None): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + max_num_det (int, *optional*, defaults to None): The maximum number of detections to return. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, classes and boxes for an image + in the batch as predicted by the model. + """ + if isinstance(classes[0], str): + classes = [classes] + + boxes_logits = outputs.decoder_coord_logits + scores_logits = outputs.decoder_class_logits + + # Inputs consistency check + if target_sizes is None: + height = ( + self.image_processor.size["height"] + if "height" in self.image_processor.size + else self.image_processor.size["shortest_edge"] + ) + width = ( + self.image_processor.size["width"] + if "width" in self.image_processor.size + else self.image_processor.size["longest_edge"] + ) + target_sizes = ((height, width),) * len(boxes_logits) + elif len(target_sizes[0]) != 2: + raise ValueError( + "Each element of target_sizes must contain the size (height, width) of each image of the batch" + ) + if len(target_sizes) != len(boxes_logits): + raise ValueError("Make sure that you pass in as many target sizes as output sequences") + if len(classes) != len(boxes_logits): + raise ValueError("Make sure that you pass in as many classes group as output sequences") + + # Convert target_sizes to list for easier handling + if isinstance(target_sizes, torch.Tensor): + target_sizes = target_sizes.tolist() + + scores, predicted_classes = compute_score(scores_logits) + num_classes = scores_logits.shape[2] + results = [] + for scores_img, box_per_img, image_size, class_names in zip(scores, boxes_logits, target_sizes, classes): + results.append( + _post_process_boxes_for_image( + box_per_img, + scores_img, + predicted_classes, + class_names, + image_size, + num_classes, + score_threshold=score_threshold, + nms_threshold=nms_threshold, + max_num_det=max_num_det, + ) + ) + + return results diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 2db7b38b5803..ef10b91ea558 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6552,6 +6552,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class OmDetTurboForObjectDetection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class OmDetTurboPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class OneFormerForUniversalSegmentation(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/omdet_turbo/__init__.py b/tests/models/omdet_turbo/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/omdet_turbo/test_modeling_omdet_turbo.py b/tests/models/omdet_turbo/test_modeling_omdet_turbo.py new file mode 100644 index 000000000000..ed85c4c00078 --- /dev/null +++ b/tests/models/omdet_turbo/test_modeling_omdet_turbo.py @@ -0,0 +1,904 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch OmDet-Turbo model.""" + +import copy +import unittest +from io import BytesIO + +import requests + +from transformers import OmDetTurboConfig, is_torch_available, is_vision_available +from transformers.feature_extraction_utils import BatchFeature +from transformers.file_utils import cached_property +from transformers.testing_utils import ( + require_timm, + require_torch, + require_torch_gpu, + require_vision, + slow, + torch_device, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + import torch.nn.functional as F + + from transformers import OmDetTurboForObjectDetection + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoProcessor + + +class OmDetTurboModelTester: + def __init__( + self, + parent, + batch_size=6, + is_training=False, + num_channels=3, + max_text_len=7, + num_classes=3, + use_timm_backbone=False, + backbone=None, + apply_layernorm_after_vision_backbone=False, + image_size=224, + text_projection_in_dim=16, + text_projection_out_dim=16, + class_embed_dim=16, + hidden_size=8, + num_hidden_layers=2, + num_attention_heads=2, + num_queries=20, + encoder_in_channels=(16, 32, 64), + encoder_dim_feedforward=32, + num_projection_layers=1, + decoder_n_points=4, + num_feature_levels=3, + ): + super().__init__() + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.num_channels = num_channels + self.max_text_len = max_text_len + self.num_classes = num_classes + self.use_timm_backbone = use_timm_backbone + self.backbone = backbone + self.apply_layernorm_after_vision_backbone = apply_layernorm_after_vision_backbone + self.image_size = image_size + self.text_projection_in_dim = text_projection_in_dim + self.text_projection_out_dim = text_projection_out_dim + self.class_embed_dim = class_embed_dim + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_queries = num_queries + self.encoder_in_channels = encoder_in_channels + self.encoder_dim_feedforward = encoder_dim_feedforward + self.num_projection_layers = num_projection_layers + self.decoder_n_points = decoder_n_points + self.num_feature_levels = num_feature_levels + + self.encoder_seq_length_vision = self.image_size // 32 + self.decoder_seq_length = self.num_queries + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + input_ids_tasks = ids_tensor([self.batch_size, self.max_text_len], self.num_classes) + input_ids_tasks = input_ids_tasks.to(torch_device) + input_ids_classes = torch.cat( + [ids_tensor([self.num_classes, self.max_text_len], self.num_classes) for _ in range(self.batch_size)] + ) + input_ids_classes = input_ids_classes.to(torch_device) + attention_mask_tasks = torch.ones_like(input_ids_tasks, device=torch_device) + attention_mask_classes = torch.ones_like(input_ids_classes, device=torch_device) + classes_structure = torch.ones(self.batch_size, dtype=torch.long, device=torch_device) * self.num_classes + encoding = BatchFeature() + encoding.update( + { + "pixel_values": pixel_values, + "classes_input_ids": input_ids_classes, + "classes_attention_mask": attention_mask_classes, + "tasks_input_ids": input_ids_tasks, + "tasks_attention_mask": attention_mask_tasks, + "classes_structure": classes_structure, + } + ) + config = self.get_config() + return config, encoding + + def get_config(self): + text_backbone = { + "hidden_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "intermediate_size": 16, + "max_position_embeddings": 8, + "model_type": "clip_text_model", + } + backbone_config = { + "embed_dim": self.hidden_size, + "depths": (1, 1, 1, 1), + "num_heads": (1, 1, 1, 1), + "window_size": 7, + "image_size": self.image_size, + "out_indices": (2, 3, 4), + "model_type": "swin", + } + + return OmDetTurboConfig( + text_config=text_backbone, + backbone_config=backbone_config, + use_timm_backbone=self.use_timm_backbone, + backbone=self.backbone, + apply_layernorm_after_vision_backbone=self.apply_layernorm_after_vision_backbone, + decoder_num_layers=self.num_hidden_layers, + image_size=self.image_size, + encoder_in_channels=self.encoder_in_channels, + num_queries=self.num_queries, + encoder_layers=self.num_hidden_layers, + encoder_projection_indices=[2] * self.num_projection_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_num_heads=self.num_attention_heads, + decoder_num_points=self.decoder_n_points, + num_feature_levels=self.num_feature_levels, + encoder_dim_feedforward=self.encoder_dim_feedforward, + task_encoder_hidden_dim=self.encoder_dim_feedforward, + decoder_dim_feedforward=self.encoder_dim_feedforward, + class_embed_dim=self.class_embed_dim, + text_projection_in_dim=self.text_projection_in_dim, + text_projection_out_dim=self.text_projection_out_dim, + encoder_hidden_dim=self.hidden_size, + decoder_hidden_dim=self.hidden_size, + vision_features_channels=[self.hidden_size, self.hidden_size, self.hidden_size], + ) + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def create_and_check_object_detection_head_model(self, config, inputs_dict): + model = OmDetTurboForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(**inputs_dict) + + self.parent.assertEqual(result.decoder_coord_logits.shape, (self.batch_size, self.num_queries, 4)) + self.parent.assertEqual( + result.decoder_class_logits.shape, (self.batch_size, self.num_queries, self.num_classes) + ) + + +@require_torch +class OmDetTurboModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (OmDetTurboForObjectDetection,) if is_torch_available() else () + is_encoder_decoder = True + test_pruning = False + test_head_masking = False + pipeline_model_mapping = ( + {"zero-shot-object-detection": OmDetTurboForObjectDetection} if is_torch_available() else {} + ) + + # special case for head models + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + return inputs_dict + + def setUp(self): + self.model_tester = OmDetTurboModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=OmDetTurboConfig, + has_text_modality=False, + common_properties=["d_model", "encoder_attention_heads", "decoder_num_heads"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_object_detection_head_model(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_object_detection_head_model(config, inputs_dict) + + @unittest.skip( + reason="Unsupported as classes_input_ids are classes input are flattened by the processor: https://github.com/huggingface/transformers/issues/33669" + ) + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip(reason="OmDet-Turbo does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="OmDet-Turbo does not have 'input_ids' and 'attention_mask'") + def test_torchscript_output_attentions(self): + pass + + @unittest.skip(reason="OmDet-Turbo does not have 'input_ids' and 'attention_mask'") + def test_torchscript_output_hidden_states(self): + pass + + @unittest.skip(reason="OmDet-Turbo does not have 'input_ids' and 'attention_mask'") + def test_torchscript_simple(self): + pass + + @unittest.skip(reason="OmDet-Turbo does not have 'input_ids' and 'attention_mask'") + def test_torchscript_output_hidden_state(self): + pass + + def test_resize_tokens_embeddings(self): + # rewrite as OmDet-Turbo does not have "input_ids" and "decoder_input_ids" + ( + original_config, + inputs_dict, + ) = self.model_tester.prepare_config_and_inputs_for_common() + if not self.test_resize_embeddings: + self.skipTest(reason="test_resize_embeddings is set to `False`") + + for model_class in self.all_model_classes: + config = copy.deepcopy(original_config) + model = model_class(config) + model.to(torch_device) + model_embed_pre_resize = model.get_input_embeddings() + type_model_embed_pre_resize = type(model_embed_pre_resize) + + if self.model_tester.is_training is False: + model.eval() + + model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size + # Retrieve the embeddings and clone theme + model_embed = model.resize_token_embeddings(model_vocab_size) + cloned_embeddings = model_embed.weight.clone() + + # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size + model_embed = model.resize_token_embeddings(model_vocab_size + 10) + new_model_vocab_size = ( + model.config.text_config.vocab_size + if hasattr(model.config, "text_config") + else model.config.vocab_size + ) + self.assertEqual(new_model_vocab_size, model_vocab_size + 10) + # Check that it actually resizes the embeddings matrix + self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) + # Check to make sure the type of embeddings returned post resizing is same as type of input + type_model_embed_post_resize = type(model_embed) + self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size + model_embed = model.resize_token_embeddings(model_vocab_size - 15) + new_model_vocab_size = ( + model.config.text_config.vocab_size + if hasattr(model.config, "text_config") + else model.config.vocab_size + ) + self.assertEqual(new_model_vocab_size, model_vocab_size - 15) + # Check that it actually resizes the embeddings matrix + self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15) + + # Check that the model can still do a forward pass successfully (every parameter should be resized) + # Input ids should be clamped to the maximum size of the vocabulary + inputs_dict["tasks_input_ids"].clamp_(max=model_vocab_size - 15 - 1) + + # make sure that classes_input_ids are resized as well + if "classes_input_ids" in inputs_dict: + inputs_dict["classes_input_ids"].clamp_(max=model_vocab_size - 15 - 1) + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Check that adding and removing tokens has not modified the first part of the embedding matrix. + models_equal = True + for p1, p2 in zip(cloned_embeddings, model_embed.weight): + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + config = copy.deepcopy(original_config) + model = model_class(config) + model.to(torch_device) + + model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size + model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1) + new_model_vocab_size = ( + model.config.text_config.vocab_size + if hasattr(model.config, "text_config") + else model.config.vocab_size + ) + self.assertTrue(new_model_vocab_size + 10, model_vocab_size) + + model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64) + new_model_vocab_size = ( + model.config.text_config.vocab_size + if hasattr(model.config, "text_config") + else model.config.vocab_size + ) + self.assertTrue(model_embed.weight.shape[0] // 64, 0) + + self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size) + self.assertTrue(new_model_vocab_size, model.vocab_size) + + model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64) + self.assertTrue(model_embed.weight.shape[0] // 64, 0) + + # Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size + target_dimension = 128 + model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64) + self.assertTrue(model_embed.weight.shape[0], target_dimension) + + with self.assertRaisesRegex( + ValueError, + "Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer", + ): + model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3) + + # Overwrite as `init_reference_points` is not batch dependent and contains `inf` values + def test_batching_equivalence(self): + """ + Tests that the model supports batching and that the output is nearly the same for the same input in + different batch sizes. + (Why "nearly the same" not "exactly the same"? Batching uses different matmul shapes, which often leads to + different results: https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535) + """ + + def get_tensor_equivalence_function(batched_input): + # models operating on continuous spaces have higher abs difference than LMs + # instead, we can rely on cos distance for image/speech models, similar to `diffusers` + if "input_ids" not in batched_input: + return lambda tensor1, tensor2: ( + 1.0 - F.cosine_similarity(tensor1.float().flatten(), tensor2.float().flatten(), dim=0, eps=1e-38) + ) + return lambda tensor1, tensor2: torch.max(torch.abs(tensor1 - tensor2)) + + def recursive_check(batched_object, single_row_object, model_name, key): + if isinstance(batched_object, (list, tuple)): + for batched_object_value, single_row_object_value in zip(batched_object, single_row_object): + recursive_check(batched_object_value, single_row_object_value, model_name, key) + elif isinstance(batched_object, dict): + for batched_object_value, single_row_object_value in zip( + batched_object.values(), single_row_object.values() + ): + recursive_check(batched_object_value, single_row_object_value, model_name, key) + # do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects + elif batched_object is None or not isinstance(batched_object, torch.Tensor): + return + elif batched_object.dim() == 0: + return + elif key != "init_reference_points": + # init + # indexing the first element does not always work + # e.g. models that output similarity scores of size (N, M) would need to index [0, 0] + slice_ids = [slice(0, index) for index in single_row_object.shape] + batched_row = batched_object[slice_ids] + self.assertFalse( + torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}" + ) + self.assertFalse( + torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}" + ) + self.assertFalse( + torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}" + ) + self.assertFalse( + torch.isinf(single_row_object).any(), + f"Single row output has `inf` in {model_name} for key={key}", + ) + self.assertTrue( + (equivalence(batched_row, single_row_object)) <= 1e-03, + msg=( + f"Batched and Single row outputs are not equal in {model_name} for key={key}. " + f"Difference={equivalence(batched_row, single_row_object)}." + ), + ) + + config, batched_input = self.model_tester.prepare_config_and_inputs_for_common() + equivalence = get_tensor_equivalence_function(batched_input) + + for model_class in self.all_model_classes: + config.output_hidden_states = True + + model_name = model_class.__name__ + if hasattr(self.model_tester, "prepare_config_and_inputs_for_model_class"): + config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class) + batched_input_prepared = self._prepare_for_class(batched_input, model_class) + model = model_class(config).to(torch_device).eval() + batch_size = self.model_tester.batch_size + single_row_input = {} + for key, value in batched_input_prepared.items(): + single_batch_shape = value.shape[0] // batch_size + single_row_input[key] = value[:single_batch_shape] + + with torch.no_grad(): + model_batched_output = model(**batched_input_prepared) + model_row_output = model(**single_row_input) + + if isinstance(model_batched_output, torch.Tensor): + model_batched_output = {"model_output": model_batched_output} + model_row_output = {"model_output": model_row_output} + + for key in model_batched_output: + # DETR starts from zero-init queries to decoder, leading to cos_similarity = `nan` + if hasattr(self, "zero_init_hidden_state") and "decoder_hidden_states" in key: + model_batched_output[key] = model_batched_output[key][1:] + model_row_output[key] = model_row_output[key][1:] + if key in ("decoder_class_logits", "decoder_classes", "encoder_class_logits"): + # check if all elements are close to 0, if so skip the test as the test strugles with comparing + # tensors with all elements close to 0 + if torch.allclose( + model_batched_output[key], torch.zeros_like(model_batched_output[key]), atol=1e-6 + ) and torch.allclose(model_row_output[key], torch.zeros_like(model_row_output[key]), atol=1e-6): + continue + + recursive_check(model_batched_output[key], model_row_output[key], model_name, key) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions[-1] + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions[-1] + self.assertEqual( + len(attentions), self.model_tester.num_hidden_layers * self.model_tester.num_projection_layers + ) + # Rest of the shape seems to depend on backbone output shapes and image size + self.assertListEqual( + list(attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + self.model_tester.encoder_seq_length_vision**2, + self.model_tester.encoder_seq_length_vision**2, + ], + ) + # decoder attentions + decoder_attentions = outputs.decoder_attentions[0] + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + self.model_tester.num_queries + self.model_tester.max_text_len, + self.model_tester.num_queries + self.model_tester.max_text_len, + ], + ) + + # cross attentions + cross_attentions = outputs.decoder_attentions[-1] + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + self.model_tester.num_feature_levels, + self.model_tester.decoder_n_points, + ], + ) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + self_attentions = outputs.encoder_attentions[-1] + + self.assertEqual( + len(self_attentions), self.model_tester.num_hidden_layers * self.model_tester.num_projection_layers + ) + self.assertListEqual( + list(attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + self.model_tester.encoder_seq_length_vision**2, + self.model_tester.encoder_seq_length_vision**2, + ], + ) + + # overwrite since encoder_hidden_states are 3-dim and not 2-dim + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_projection_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + seq_len = self.model_tester.encoder_seq_length_vision + + self.assertListEqual(list(hidden_states[0].shape[-3:]), [self.model_tester.hidden_size, seq_len, seq_len]) + + hidden_states = outputs.decoder_hidden_states + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.decoder_seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # removed retain_grad and grad on decoder_hidden_states, as queries don't require grad + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_attentions = outputs.encoder_attentions[0][0] + encoder_hidden_states.retain_grad() + encoder_attentions.retain_grad() + + cross_attentions = outputs.decoder_attentions[-1][0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + if ( + "embeddings" in name + or ".fc" in name + or "decoder.channel_projection_layers" in name + or "query_position_head" in name + or "decoder.encoder_vision_features" in name + ): + continue + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} seems not properly initialized", + ) + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + return image + + +def prepare_text(): + classes = ["cat", "remote"] + task = "Detect {}.".format(", ".join(classes)) + return classes, task + + +def prepare_img_batched(): + url1 = "http://images.cocodataset.org/val2017/000000039769.jpg" + url2 = "http://images.cocodataset.org/train2017/000000257813.jpg" + url3 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + + return [Image.open(BytesIO(requests.get(url).content)).convert("RGB") for url in [url1, url2, url3]] + + +def prepare_text_batched(): + classes1 = ["cat", "remote"] + classes2 = ["boat"] + classes3 = ["statue", "trees", "torch"] + + task1 = "Detect {}.".format(", ".join(classes1)) + task2 = "Detect all the boat in the image." + task3 = "Focus on the foreground, detect statue, torch and trees." + return [classes1, classes2, classes3], [task1, task2, task3] + + +@require_timm +@require_vision +@slow +class OmDetTurboModelIntegrationTests(unittest.TestCase): + @cached_property + def default_processor(self): + return AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") if is_vision_available() else None + + def test_inference_object_detection_head(self): + model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(torch_device) + + processor = self.default_processor + image = prepare_img() + classes, task = prepare_text() + encoding = processor(images=image, text=classes, task=task, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**encoding) + + expected_shape_coord_logits = torch.Size((1, model.config.num_queries, 4)) + expected_shape_class_logits = torch.Size((1, model.config.num_queries, 2)) + self.assertEqual(outputs.decoder_coord_logits.shape, expected_shape_coord_logits) + self.assertEqual(outputs.decoder_class_logits.shape, expected_shape_class_logits) + + expected_class_logits = torch.tensor([[[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]]).to( + torch_device + ) + expected_coord_logits = torch.tensor( + [[[0.2550, 0.5501, 0.4738, 0.8745], [0.7695, 0.4121, 0.4603, 0.7244], [0.7691, 0.4117, 0.4603, 0.7214]]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.decoder_class_logits[:3, :3], expected_class_logits, atol=1e-1)) + self.assertTrue(torch.allclose(outputs.decoder_coord_logits[:3, :3], expected_coord_logits, atol=1e-3)) + + # verify grounded postprocessing + results = processor.post_process_grounded_object_detection( + outputs, classes=[classes], target_sizes=[image.size[::-1]] + )[0] + expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device) + expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(torch_device) + + self.assertEqual(len(results["scores"]), 4) + self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2)) + self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2)) + + expected_classes = ["remote", "cat", "remote", "cat"] + self.assertListEqual(results["classes"], expected_classes) + + def test_inference_object_detection_head_fp16(self): + model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to( + torch_device, dtype=torch.float16 + ) + + processor = self.default_processor + image = prepare_img() + classes, task = prepare_text() + encoding = processor(images=image, text=classes, task=task, return_tensors="pt").to( + torch_device, dtype=torch.float16 + ) + + with torch.no_grad(): + outputs = model(**encoding) + + expected_shape_coord_logits = torch.Size((1, model.config.num_queries, 4)) + expected_shape_class_logits = torch.Size((1, model.config.num_queries, 2)) + self.assertEqual(outputs.decoder_coord_logits.shape, expected_shape_coord_logits) + self.assertEqual(outputs.decoder_class_logits.shape, expected_shape_class_logits) + + expected_class_logits = torch.tensor([[[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]]).to( + torch_device, dtype=torch.float16 + ) + expected_coord_logits = torch.tensor( + [[[0.2550, 0.5501, 0.4738, 0.8745], [0.7695, 0.4121, 0.4603, 0.7244], [0.7691, 0.4117, 0.4603, 0.7214]]] + ).to(torch_device, dtype=torch.float16) + + self.assertTrue(torch.allclose(outputs.decoder_class_logits[:3, :3], expected_class_logits, atol=1e-1)) + self.assertTrue(torch.allclose(outputs.decoder_coord_logits[:3, :3], expected_coord_logits, atol=1e-3)) + + # verify grounded postprocessing + results = processor.post_process_grounded_object_detection( + outputs, classes=[classes], target_sizes=[image.size[::-1]] + )[0] + expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device, dtype=torch.float16) + expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to( + torch_device, dtype=torch.float16 + ) + + self.assertEqual(len(results["scores"]), 4) + self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2)) + self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-1)) + + expected_classes = ["remote", "cat", "remote", "cat"] + self.assertListEqual(results["classes"], expected_classes) + + def test_inference_object_detection_head_no_task(self): + model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(torch_device) + + processor = self.default_processor + image = prepare_img() + classes, _ = prepare_text() + encoding = processor(images=image, text=classes, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**encoding) + + expected_shape_coord_logits = torch.Size((1, model.config.num_queries, 4)) + expected_shape_class_logits = torch.Size((1, model.config.num_queries, 2)) + self.assertEqual(outputs.decoder_coord_logits.shape, expected_shape_coord_logits) + self.assertEqual(outputs.decoder_class_logits.shape, expected_shape_class_logits) + + expected_class_logits = torch.tensor([[[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]]).to( + torch_device + ) + expected_coord_logits = torch.tensor( + [[[0.2550, 0.5501, 0.4738, 0.8745], [0.7695, 0.4121, 0.4603, 0.7244], [0.7691, 0.4117, 0.4603, 0.7214]]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.decoder_class_logits[:3, :3], expected_class_logits, atol=1e-1)) + self.assertTrue(torch.allclose(outputs.decoder_coord_logits[:3, :3], expected_coord_logits, atol=1e-3)) + + # verify grounded postprocessing + results = processor.post_process_grounded_object_detection( + outputs, classes=[classes], target_sizes=[image.size[::-1]] + )[0] + expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device) + expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(torch_device) + + self.assertEqual(len(results["scores"]), 4) + self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2)) + self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2)) + + expected_classes = ["remote", "cat", "remote", "cat"] + self.assertListEqual(results["classes"], expected_classes) + + def test_inference_object_detection_head_batched(self): + torch_device = "cpu" + model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(torch_device) + + processor = self.default_processor + images_batched = prepare_img_batched() + classes_batched, tasks_batched = prepare_text_batched() + encoding = processor(images=images_batched, text=classes_batched, task=tasks_batched, return_tensors="pt").to( + torch_device + ) + + with torch.no_grad(): + outputs = model(**encoding) + + expected_shape_coord_logits = torch.Size((len(images_batched), model.config.num_queries, 4)) + expected_shape_class_logits = torch.Size((len(images_batched), model.config.num_queries, 3)) + self.assertEqual(outputs.decoder_coord_logits.shape, expected_shape_coord_logits) + self.assertEqual(outputs.decoder_class_logits.shape, expected_shape_class_logits) + + expected_class_logits = torch.tensor( + [[[0.9427, -2.5958, -7.7601]], [[-2.3408, -9.3516, -9.3516]], [[1.0740, -2.3315, -1.1885]]] + ).to(torch_device) + + expected_coord_logits = torch.tensor( + [[[0.2550, 0.5501, 0.4738]], [[0.2535, 0.6006, 0.0353]], [[0.3742, 0.3337, 0.0666]]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.decoder_class_logits[:, :1, :3], expected_class_logits, atol=1e-1)) + self.assertTrue(torch.allclose(outputs.decoder_coord_logits[:, :1, :3], expected_coord_logits, atol=1e-3)) + + # verify grounded postprocessing + results = processor.post_process_grounded_object_detection( + outputs, + classes=classes_batched, + target_sizes=[image.size[::-1] for image in images_batched], + score_threshold=0.2, + ) + expected_scores = torch.tensor([0.7675, 0.3016, 0.7454]).to(torch_device) + expected_slice_boxes = torch.tensor( + [ + [39.8870, 70.3522, 176.7424, 118.0354], + [146.5446, 219.7132, 209.6983, 251.0456], + [545.3470, 209.9055, 651.9860, 502.1882], + ] + ).to(torch_device) + + self.assertListEqual([len(result["scores"]) for result in results], [4, 4, 6]) + self.assertTrue( + torch.allclose(torch.stack([result["scores"][0] for result in results]), expected_scores, atol=1e-2) + ) + self.assertTrue( + torch.allclose(torch.stack([result["boxes"][0, :] for result in results]), expected_slice_boxes, atol=1e-2) + ) + + expected_classes = [ + ["remote", "cat", "remote", "cat"], + ["boat", "boat", "boat", "boat"], + ["statue", "trees", "trees", "torch", "statue", "statue"], + ] + self.assertListEqual([result["classes"] for result in results], expected_classes) + + @require_torch_gpu + def test_inference_object_detection_head_equivalence_cpu_gpu(self): + processor = self.default_processor + image = prepare_img() + classes, task = prepare_text() + encoding = processor(images=image, text=classes, task=task, return_tensors="pt") + # 1. run model on CPU + model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") + + with torch.no_grad(): + cpu_outputs = model(**encoding) + + # 2. run model on GPU + model.to("cuda") + encoding = encoding.to("cuda") + with torch.no_grad(): + gpu_outputs = model(**encoding) + + # 3. assert equivalence + expected_class_logits = torch.tensor([[[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]]) + expected_coord_logits = torch.tensor( + [[[0.2550, 0.5501, 0.4738, 0.8745], [0.7695, 0.4121, 0.4603, 0.7244], [0.7691, 0.4117, 0.4603, 0.7214]]] + ) + + self.assertTrue(torch.allclose(cpu_outputs.decoder_class_logits[:3, :3], expected_class_logits, atol=1e-1)) + self.assertTrue(torch.allclose(cpu_outputs.decoder_coord_logits[:3, :3], expected_coord_logits, atol=1e-3)) + + # verify grounded postprocessing + results_cpu = processor.post_process_grounded_object_detection( + cpu_outputs, classes=[classes], target_sizes=[image.size[::-1]] + )[0] + result_gpu = processor.post_process_grounded_object_detection( + gpu_outputs, classes=[classes], target_sizes=[image.size[::-1]] + )[0] + + self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-2)) + self.assertTrue(torch.allclose(results_cpu["boxes"][0, :], result_gpu["boxes"][0, :].cpu(), atol=1e-2)) diff --git a/tests/models/omdet_turbo/test_processor_omdet_turbo.py b/tests/models/omdet_turbo/test_processor_omdet_turbo.py new file mode 100644 index 000000000000..e6e2a1f50c52 --- /dev/null +++ b/tests/models/omdet_turbo/test_processor_omdet_turbo.py @@ -0,0 +1,363 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import shutil +import tempfile +import unittest + +import numpy as np +import pytest + +from transformers import AutoProcessor, CLIPTokenizerFast, OmDetTurboProcessor +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +IMAGE_MEAN = [123.675, 116.28, 103.53] +IMAGE_STD = [58.395, 57.12, 57.375] + +if is_torch_available(): + import torch + + from transformers.models.omdet_turbo.modeling_omdet_turbo import OmDetTurboObjectDetectionOutput + +if is_vision_available(): + from PIL import Image + + from transformers import DetrImageProcessor + + +@require_torch +@require_vision +class OmDetTurboProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = OmDetTurboProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = DetrImageProcessor() + tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32") + + processor = OmDetTurboProcessor(image_processor, tokenizer) + processor.save_pretrained(self.tmpdirname) + + self.input_keys = [ + "tasks_input_ids", + "tasks_attention_mask", + "classes_input_ids", + "classes_attention_mask", + "classes_structure", + "pixel_values", + "pixel_mask", + ] + + self.batch_size = 5 + self.num_queries = 5 + self.embed_dim = 3 + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + + def get_fake_omdet_turbo_output(self): + torch.manual_seed(42) + return OmDetTurboObjectDetectionOutput( + decoder_coord_logits=torch.rand(self.batch_size, self.num_queries, 4), + decoder_class_logits=torch.rand(self.batch_size, self.num_queries, self.embed_dim), + ) + + def get_fake_omdet_turbo_classes(self): + return [[f"class{i}_{j}" for i in range(self.num_queries)] for j in range(self.batch_size)] + + def test_post_process_grounded_object_detection(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor) + + omdet_turbo_output = self.get_fake_omdet_turbo_output() + omdet_turbo_classes = self.get_fake_omdet_turbo_classes() + + post_processed = processor.post_process_grounded_object_detection( + omdet_turbo_output, omdet_turbo_classes, target_sizes=[(400, 30) for _ in range(self.batch_size)] + ) + + self.assertEqual(len(post_processed), self.batch_size) + self.assertEqual(list(post_processed[0].keys()), ["boxes", "scores", "classes"]) + self.assertEqual(post_processed[0]["boxes"].shape, (self.num_queries, 4)) + self.assertEqual(post_processed[0]["scores"].shape, (self.num_queries,)) + expected_scores = torch.tensor([0.7310, 0.6579, 0.6513, 0.6444, 0.6252]) + self.assertTrue(torch.allclose(post_processed[0]["scores"], expected_scores, atol=1e-4)) + + expected_box_slice = torch.tensor([14.9657, 141.2052, 30.0000, 312.9670]) + self.assertTrue(torch.allclose(post_processed[0]["boxes"][0], expected_box_slice, atol=1e-4)) + + def test_save_load_pretrained_additional_features(self): + processor = OmDetTurboProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor()) + processor.save_pretrained(self.tmpdirname) + + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0) + + processor = OmDetTurboProcessor.from_pretrained( + self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0 + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, CLIPTokenizerFast) + + self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.image_processor, DetrImageProcessor) + + def test_image_processor(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor).image_processor + + image_input = self.prepare_image_inputs() + + input_image_proc = image_processor(image_input, return_tensors="np") + input_processor = processor(images=image_input, return_tensors="np") + + for key in input_image_proc.keys(): + self.assertAlmostEqual(input_image_proc[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_tokenizer(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor).tokenizer + + input_str = "lower newer" + + encoded_processor = processor(text=input_str, padding="max_length", truncation=True, max_length=77) + + encoded_tok = tokenizer(input_str, padding="max_length", truncation=True, max_length=77) + + for key in encoded_tok.keys(): + self.assertListEqual(encoded_tok[key], encoded_processor[key]) + + def test_processor(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor) + + input_tasks = "task" + input_classes = ["class1", "class2"] + image_input = self.prepare_image_inputs() + + input_processor = processor(images=image_input, text=input_classes, task=input_tasks, return_tensors="pt") + + for key in self.input_keys: + assert torch.is_tensor(input_processor[key]) + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() + + def test_tokenizer_decode(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor) + + predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] + + decoded_processor = processor.batch_decode(predicted_ids) + decoded_tok = tokenizer.batch_decode(predicted_ids) + + self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor) + + input_tasks = "task" + input_classes = ["class1", "class2"] + image_input = self.prepare_image_inputs() + inputs = processor(images=image_input, text=input_classes, task=input_tasks, return_tensors="pt") + + self.assertListEqual(list(inputs.keys()), self.input_keys) + + @require_vision + @require_torch + def test_tokenizer_defaults_preserved_by_kwargs(self): + # Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes. + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=117) + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor(images=image_input, text=[input_str], task=input_str, return_tensors="pt") + + self.assertEqual(len(inputs["tasks_input_ids"][0]), 117) + self.assertEqual(len(inputs["classes_input_ids"][0]), 117) + + @require_vision + @require_torch + def test_kwargs_overrides_default_tokenizer_kwargs(self): + # Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes. + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=117) + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor(images=image_input, text=[input_str], task=input_str, return_tensors="pt", max_length=112) + + self.assertEqual(len(inputs["tasks_input_ids"][0]), 112) + self.assertEqual(len(inputs["classes_input_ids"][0]), 112) + + @require_torch + @require_vision + def test_unstructured_kwargs(self): + # Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes. + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor( + images=image_input, + text=[input_str], + task=input_str, + return_tensors="pt", + size={"height": 214, "width": 214}, + padding="max_length", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + self.assertEqual(len(inputs["tasks_input_ids"][0]), 76) + self.assertEqual(len(inputs["classes_input_ids"][0]), 76) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + # Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes. + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + images=image_input, + text=[input_str], + task=input_str, + return_tensors="pt", + size={"height": 214, "width": 214}, + padding="longest", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["tasks_input_ids"][0]), 6) + self.assertEqual(len(inputs["classes_input_ids"][0]), 6) + + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + # Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes. + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"size": {"height": 214, "width": 214}}, + "text_kwargs": {"padding": "max_length", "max_length": 76, "task": input_str}, + } + + inputs = processor(images=image_input, text=[input_str], **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["tasks_input_ids"][0]), 76) + self.assertEqual(len(inputs["classes_input_ids"][0]), 76) + + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + # Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes. + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"size": {"height": 214, "width": 214}}, + "text_kwargs": {"padding": "max_length", "max_length": 76, "task": input_str}, + } + + inputs = processor(images=image_input, text=[input_str], **all_kwargs) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["tasks_input_ids"][0]), 76) + self.assertEqual(len(inputs["classes_input_ids"][0]), 76) diff --git a/utils/check_table.py b/utils/check_table.py index 02541e87ddba..587681844955 100644 --- a/utils/check_table.py +++ b/utils/check_table.py @@ -173,7 +173,13 @@ def _center_text(text: str, width: int) -> str: "XLS-R": "Wav2Vec2", "XLSR-Wav2Vec2": "Wav2Vec2", } -MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel", "SiglipVisionModel", "ChineseCLIPVisionModel", "Qwen2AudioEncoder"] +MODEL_NAMES_TO_IGNORE = [ + "ChineseCLIPVisionModel", + "CLIPTextModel", + "CLIPVisionModel", + "Qwen2AudioEncoder", + "SiglipVisionModel", +] def get_model_table_from_auto_modules() -> str: