From a957b7911a758d54597914b4479fe6e81424d64f Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Fri, 21 Feb 2025 09:04:19 +0000 Subject: [PATCH] Add SigLIP 2 (#36323) * Docs * Inits * Auto classes * Add siglip base * Add base tests * Fix Siglip V1 for fix res version * Add image processor * Update conversion * Experimenting with vectorized embeddings * Fixup * Add modular Siglip2Processor * Add modular configuration * Rename num patches * Correct image and text features merging * Working conversion script * Refactoring conversion script * Remove unused code in conversion script * Shorten dict a bit * Refactoring conversion * Done conversion refactoring * Fixup * Modular siglip2 * Make model exportable and compilable without graph breaks * Remove position_ids from image_processor * REmove position ids from modeling file * Update modular * Type hint * Fixup * Set defaults to processor * Add integration test * Revert spatial shapes back to tensor * Change order * Fix most of the tests * Fix docstring * Remove interpolate_pos_encoding arg (not needed) * Update docs * Standardize processing * Fix attention_mask in vision head * Siglip v1: remove double transpose in FA2 * Update modular file * Update FA2 test * Update expected logits * Fix interpolation for siglip2 image processor * Skip init test * Skip dispatch on flash test * Fix modeling tests * Fixup * Add dummy objects * Fix some docstrings * Add siglip2 in index.md * Fix consistency * Add docs * Remove size and data format * Add image processor tests * Fix * Add fast image processor * Fix style * Fix * Docs * Set lowercase for tokenizer * Adjust head size for Siglip v1 * Update siglip2 for consistency with siglip1 * Update siglip2 conversion * Update pipeline * Update checkpoints in tests * Update checkpoint name * Fix pooling for image classification model * Fix FA2 test * Update processor * Fix check repo * Update docs * Fix typos * Fix docstring for fast image processor * Add siglip2 to FA2 docs * Fix fast ip tests * Fix constitency * Fix tokenizer class for siglip v1 * Fix missing header * Refactor scaling for clip, siglip, siglip2 * Remove unused imports * Make fast IP default for siglip2 * Update docs * Update checkpoints * Update modular * Update paper link * Fixup * Fix name in toctree * Fix test --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/siglip2.md | 276 +++ docs/source/en/perf_infer_gpu_one.md | 2 + src/transformers/__init__.py | 32 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 3 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 7 + src/transformers/models/clip/modeling_clip.py | 7 +- .../models/siglip/configuration_siglip.py | 4 + .../models/siglip/convert_siglip_to_hf.py | 319 +++- .../models/siglip/modeling_siglip.py | 23 +- .../models/siglip/processing_siglip.py | 4 +- src/transformers/models/siglip2/__init__.py | 30 + .../models/siglip2/configuration_siglip2.py | 277 +++ .../models/siglip2/convert_siglip2_to_hf.py | 438 +++++ .../siglip2/image_processing_siglip2.py | 343 ++++ .../siglip2/image_processing_siglip2_fast.py | 322 ++++ .../models/siglip2/modeling_siglip2.py | 1634 +++++++++++++++++ .../models/siglip2/modular_siglip2.py | 537 ++++++ .../models/siglip2/processing_siglip2.py | 171 ++ .../zero_shot_image_classification.py | 9 +- src/transformers/utils/dummy_pt_objects.py | 35 + .../utils/dummy_torchvision_objects.py | 7 + .../utils/dummy_vision_objects.py | 7 + tests/models/siglip2/__init__.py | 0 .../siglip2/test_image_processing_siglip2.py | 200 ++ tests/models/siglip2/test_modeling_siglip2.py | 989 ++++++++++ tests/test_modeling_common.py | 5 + utils/check_repo.py | 2 + 33 files changed, 5570 insertions(+), 122 deletions(-) create mode 100644 docs/source/en/model_doc/siglip2.md create mode 100644 src/transformers/models/siglip2/__init__.py create mode 100644 src/transformers/models/siglip2/configuration_siglip2.py create mode 100644 src/transformers/models/siglip2/convert_siglip2_to_hf.py create mode 100644 src/transformers/models/siglip2/image_processing_siglip2.py create mode 100644 src/transformers/models/siglip2/image_processing_siglip2_fast.py create mode 100644 src/transformers/models/siglip2/modeling_siglip2.py create mode 100644 src/transformers/models/siglip2/modular_siglip2.py create mode 100644 src/transformers/models/siglip2/processing_siglip2.py create mode 100644 tests/models/siglip2/__init__.py create mode 100644 tests/models/siglip2/test_image_processing_siglip2.py create mode 100644 tests/models/siglip2/test_modeling_siglip2.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 06a6f172fae1..7d7201da5027 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -965,6 +965,8 @@ title: Segment Anything - local: model_doc/siglip title: SigLIP + - local: model_doc/siglip2 + title: SigLIP2 - local: model_doc/smolvlm title: SmolVLM - local: model_doc/speech-encoder-decoder diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 6a168e9905ba..a6961b06a47b 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -317,6 +317,7 @@ Flax), PyTorch, and/or TensorFlow. | [SEW](model_doc/sew) | ✅ | ❌ | ❌ | | [SEW-D](model_doc/sew-d) | ✅ | ❌ | ❌ | | [SigLIP](model_doc/siglip) | ✅ | ❌ | ❌ | +| [SigLIP2](model_doc/siglip2) | ✅ | ❌ | ❌ | | [SmolVLM](model_doc/smolvlm) | ✅ | ❌ | ❌ | | [Speech Encoder decoder](model_doc/speech-encoder-decoder) | ✅ | ❌ | ✅ | | [Speech2Text](model_doc/speech_to_text) | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/siglip2.md b/docs/source/en/model_doc/siglip2.md new file mode 100644 index 000000000000..054e09189ad1 --- /dev/null +++ b/docs/source/en/model_doc/siglip2.md @@ -0,0 +1,276 @@ + + +# SigLIP2 + +## Overview + +The SigLIP2 model was proposed in [SigLIP 2: Multilingual Vision-Language Encoders with Improved Semantic Understanding, Localization, and Dense Features](https://huggingface.co/papers/2502.14786) by Michael Tschannen, Alexey Gritsenko, Xiao Wang, Muhammad Ferjad Naeem, Ibrahim Alabdulmohsin, +Nikhil Parthasarathy, Talfan Evans, Lucas Beyer, Ye Xia, Basil Mustafa, Olivier Hénaff, Jeremiah Harmsen, +Andreas Steiner and Xiaohua Zhai. + +The model comes in two variants + + 1) FixRes - model works with fixed resolution images (backward compatible with SigLIP v1) + 2) NaFlex - model works with variable image aspect ratios and resolutions (SigLIP2 in `transformers`) + +The abstract from the paper is the following: + +*We introduce SigLIP 2, a family of new multilingual vision-language encoders that build on the success +of the original SigLIP. In this second iteration, we extend the original image-text training objective with +several prior, independently developed techniques into a unified recipe—this includes decoder-based +pretraining, self-supervised losses (self-distillation, masked prediction) and online data curation. With +these changes, SigLIP 2 models outperform their SigLIP counterparts at all model scales in core capabilities, +including zero-shot classification (best SigLIP 2 ViT-g/16 achieves 85.0% ImageNet zero-shot +accuracy), image-text retrieval, and transfer performance when extracting visual representations for +Vision-Language Models (VLMs). Furthermore, the new training recipe leads to significant improvements +on localization and dense prediction tasks. We also train variants which support multiple resolutions +and preserve the input’s native aspect ratio. Finally, we train on a more diverse data-mixture that +includes de-biasing techniques, leading to much better multilingual understanding and improved fair- +ness. To provide users with the ability to trade-off inference cost with performance, we release model +checkpoints at four sizes (ViT-B/86M, L/303M, So400m/400M, and g/1B).* + +## Usage tips + +- Usage of SigLIP2 is similar to [SigLIP](siglip) and [CLIP](clip). The main difference from CLIP is the training loss, which does not require a global view of all the pairwise similarities of images and texts within a batch. One needs to apply the sigmoid activation function to the logits, rather than the softmax. +- Training is supported but does not use `torch.distributed` utilities which may limit the scalability of batch size. However, DDP and FDSP works on single-node multi-gpu setup. +- When using the standalone [`GemmaTokenizerFast`] make sure to pass `padding="max_length"` and `max_length=64` as that's how the model was trained. +- Model was trained with *lowercased* text, make sure you make the same preprocessing for your text labels. +- To get the same results as the pipeline, a prompt template of "this is a photo of {label}" should be used. +- The NaFlex variant supports processing images at higher resolutions by adjusting the `max_num_patches` parameter in the `Processor`. The default value is `max_num_patches=256`. Increasing `max_num_patches` to 1024 (4x) will approximately double processed image height and width, while preserving the aspect ratio. + + + +This model was contributed by [qubvel](https://huggingface.co/qubvel-hf). +The original code can be found [here](https://github.com/google-research/big_vision/tree/main). + +## Usage example + +There are 2 main ways to use SigLIP2: either using the pipeline API, which abstracts away all the complexity for you, or by using the `Siglip2Model` class yourself. + +### FixRes variant + +**Pipeline API** + +The pipeline allows to use the model in a few lines of code: + +```python +>>> from transformers import pipeline +>>> from PIL import Image +>>> import requests + +>>> # load pipe +>>> image_classifier = pipeline( +... task="zero-shot-image-classification", +... model="google/siglip2-base-patch16-224", +... ) + +>>> # load image +>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' +>>> image = Image.open(requests.get(url, stream=True).raw) + +>>> # inference +>>> candidate_labels = ["2 cats", "a plane", "a remote"] +>>> outputs = image_classifier(image, candidate_labels=candidate_labels) +>>> outputs = [{"score": round(output["score"], 4), "label": output["label"] } for output in outputs] +>>> print(outputs) +[{'score': 0.1499, 'label': '2 cats'}, {'score': 0.0008, 'label': 'a remote'}, {'score': 0.0, 'label': 'a plane'}] +``` + +**Using the model yourself** + +If you want to do the pre- and postprocessing yourself, here's how to do that: + +```python +>>> from PIL import Image +>>> import requests +>>> from transformers import AutoProcessor, AutoModel +>>> import torch + +>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") +>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + +>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" +>>> image = Image.open(requests.get(url, stream=True).raw) + +>>> candidate_labels = ["2 cats", "2 dogs"] +# follows the pipeline prompt template to get same results +>>> texts = [f"This is a photo of {label}." for label in candidate_labels] + +# IMPORTANT: we pass `padding=max_length` and `max_length=64` since the model was trained with this +>>> inputs = processor(text=texts, images=image, padding="max_length", max_length=64, return_tensors="pt") + +>>> with torch.no_grad(): +... outputs = model(**inputs) + +>>> logits_per_image = outputs.logits_per_image +>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities +>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'") +15.0% that image 0 is '2 cats' +``` + +### NaFlex variant + +NaFlex combines ideas from FlexiViT, i.e. supporting multiple, predefined sequence lengths +with a single ViT model, and NaViT, namely processing images at their native aspect ratio. +This enables processing different types of images at appropriate resolution, e.g. using a +larger resolution to process document images, while at the same time minimizing the impact +of aspect ratio distortion on certain inference tasks, e.g. on OCR. + +Given a patch size and target sequence length, NaFlex preprocesses the data by first resizing +the input image such that the height and width after resizing are multiples of the patch size, +while + + 1. keeping the aspect ratio distortion as small as possible + 2. producing a sequence length of at most the desired target sequence length (`max_num_patches`) + +The resulting distortion in width and height is at most `(patch_size - 1) / width` and +`(patch_size - 1) / height`, respectively, which tends to be small for common resolutions and aspect ratios. +After resizing, the image is split into a sequence of patches, and a mask with padding information is added. + +```python +>>> from PIL import Image +>>> import requests +>>> from transformers import AutoProcessor, AutoModel +>>> import torch + +>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-naflex") +>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-naflex") + +>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" +>>> image = Image.open(requests.get(url, stream=True).raw) + +>>> candidate_labels = ["2 cats", "2 dogs"] +# follows the pipeline prompt template to get same results +>>> texts = [f"This is a photo of {label}." for label in candidate_labels] + +# default value for `max_num_patches` is 256, but you can increase resulted image resolution providing +# higher values e.g. `max_num_patches=512` +>>> inputs = processor(text=texts, images=image, max_num_patches=256, return_tensors="pt") + +>>> with torch.no_grad(): +... outputs = model(**inputs) + +>>> logits_per_image = outputs.logits_per_image +>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities +>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'") +21.1% that image 0 is '2 cats' +``` + +## Resources + +A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with SigLIP2. + +- [Zero-shot image classification task guide](../tasks/zero_shot_image_classification) +- Demo notebook for SigLIP2 can be found [here](https://github.com/qubvel/transformers-notebooks/tree/master/notebooks/SigLIP2_inference.ipynb). 🌎 + +If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. + + +## Combining SigLIP2 and Flash Attention 2 + +First, make sure to install the latest version of Flash Attention 2. + +```bash +pip install -U flash-attn --no-build-isolation +``` + +Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``) + +To load and run a model using Flash Attention 2, refer to the snippet below: + +```python +>>> import torch +>>> import requests +>>> from PIL import Image +>>> from transformers import AutoProcessor, AutoModel +>>> device = "cuda" # the device to load the model onto + +>>> model = AutoModel.from_pretrained( +... "google/siglip2-so400m-patch14-384", +... attn_implementation="flash_attention_2", +... torch_dtype=torch.float16, +... device_map=device, +... ) +>>> processor = AutoProcessor.from_pretrained("google/siglip2-so400m-patch14-384") + +>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" +>>> image = Image.open(requests.get(url, stream=True).raw) + +>>> candidate_labels = ["2 cats", "2 dogs"] +# follows the pipeline prompt template to get same results +>>> texts = [f'This is a photo of {label}.' for label in candidate_labels] +# important: we pass `padding=max_length` since the model was trained with this +>>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt").to(device) + +>>> with torch.no_grad(): +... with torch.autocast(device): +... outputs = model(**inputs) + +>>> logits_per_image = outputs.logits_per_image +>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities +>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'") +19.8% that image 0 is '2 cats' +``` + +## Siglip2Config + +[[autodoc]] Siglip2Config + +## Siglip2TextConfig + +[[autodoc]] Siglip2TextConfig + +## Siglip2VisionConfig + +[[autodoc]] Siglip2VisionConfig + +## Siglip2ImageProcessor + +[[autodoc]] Siglip2ImageProcessor + - preprocess + +## Siglip2ImageProcessorFast + +[[autodoc]] Siglip2ImageProcessorFast + - preprocess + +## Siglip2Processor + +[[autodoc]] Siglip2Processor + +## Siglip2Model + +[[autodoc]] Siglip2Model + - forward + - get_text_features + - get_image_features + +## Siglip2TextModel + +[[autodoc]] Siglip2TextModel + - forward + +## Siglip2VisionModel + +[[autodoc]] Siglip2VisionModel + - forward + +## Siglip2ForImageClassification + +[[autodoc]] Siglip2ForImageClassification + - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index b8896114eccb..7f57a99c7d35 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -111,6 +111,7 @@ FlashAttention-2 is currently supported for the following architectures: * [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) * [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) * [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) +* [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) * [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel) @@ -310,6 +311,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) * [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) * [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) +* [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index dc427aad5727..ed2682901008 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -776,6 +776,12 @@ "SiglipTextConfig", "SiglipVisionConfig", ], + "models.siglip2": [ + "Siglip2Config", + "Siglip2Processor", + "Siglip2TextConfig", + "Siglip2VisionConfig", + ], "models.smolvlm": ["SmolVLMConfig"], "models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"], "models.speech_to_text": [ @@ -1289,6 +1295,7 @@ _import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"]) _import_structure["models.seggpt"].extend(["SegGptImageProcessor"]) _import_structure["models.siglip"].append("SiglipImageProcessor") + _import_structure["models.siglip2"].append("Siglip2ImageProcessor") _import_structure["models.smolvlm"].extend(["SmolVLMImageProcessor"]) _import_structure["models.superglue"].extend(["SuperGlueImageProcessor"]) _import_structure["models.superpoint"].extend(["SuperPointImageProcessor"]) @@ -1330,6 +1337,7 @@ _import_structure["models.qwen2_vl"].append("Qwen2VLImageProcessorFast") _import_structure["models.rt_detr"].append("RTDetrImageProcessorFast") _import_structure["models.siglip"].append("SiglipImageProcessorFast") + _import_structure["models.siglip2"].append("Siglip2ImageProcessorFast") _import_structure["models.vit"].append("ViTImageProcessorFast") try: @@ -3559,6 +3567,15 @@ "SiglipVisionModel", ] ) + _import_structure["models.siglip2"].extend( + [ + "Siglip2ForImageClassification", + "Siglip2Model", + "Siglip2PreTrainedModel", + "Siglip2TextModel", + "Siglip2VisionModel", + ] + ) _import_structure["models.smolvlm"].extend( [ "SmolVLMForConditionalGeneration", @@ -5942,6 +5959,12 @@ SiglipTextConfig, SiglipVisionConfig, ) + from .models.siglip2 import ( + Siglip2Config, + Siglip2Processor, + Siglip2TextConfig, + Siglip2VisionConfig, + ) from .models.smolvlm import SmolVLMConfig from .models.speech_encoder_decoder import SpeechEncoderDecoderConfig from .models.speech_to_text import ( @@ -6472,6 +6495,7 @@ from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor from .models.seggpt import SegGptImageProcessor from .models.siglip import SiglipImageProcessor + from .models.siglip2 import Siglip2ImageProcessor from .models.smolvlm import SmolVLMImageProcessor from .models.superglue import SuperGlueImageProcessor from .models.superpoint import SuperPointImageProcessor @@ -6509,6 +6533,7 @@ from .models.qwen2_vl import Qwen2VLImageProcessorFast from .models.rt_detr import RTDetrImageProcessorFast from .models.siglip import SiglipImageProcessorFast + from .models.siglip2 import Siglip2ImageProcessorFast from .models.vit import ViTImageProcessorFast try: @@ -8288,6 +8313,13 @@ SiglipTextModel, SiglipVisionModel, ) + from .models.siglip2 import ( + Siglip2ForImageClassification, + Siglip2Model, + Siglip2PreTrainedModel, + Siglip2TextModel, + Siglip2VisionModel, + ) from .models.smolvlm import ( SmolVLMForConditionalGeneration, SmolVLMModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 220f5dfa59c6..74dad4a2418b 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -245,6 +245,7 @@ sew, sew_d, siglip, + siglip2, smolvlm, speech_encoder_decoder, speech_to_text, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7e81e41006a6..8b2b514496d8 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -271,6 +271,7 @@ ("sew", "SEWConfig"), ("sew-d", "SEWDConfig"), ("siglip", "SiglipConfig"), + ("siglip2", "Siglip2Config"), ("siglip_vision_model", "SiglipVisionConfig"), ("smolvlm", "SmolVLMConfig"), ("smolvlm_vision", "SmolVLMVisionConfig"), @@ -617,6 +618,8 @@ ("sew", "SEW"), ("sew-d", "SEW-D"), ("siglip", "SigLIP"), + ("siglip2", "SigLIP2"), + ("siglip2_vision_model", "Siglip2VisionModel"), ("siglip_vision_model", "SiglipVisionModel"), ("smolvlm", "SmolVLM"), ("smolvlm_vision", "SmolVLMVisionTransformer"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index ef4d9b25d1d2..4942b8f39b7e 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -136,6 +136,7 @@ ("segformer", ("SegformerImageProcessor",)), ("seggpt", ("SegGptImageProcessor",)), ("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")), + ("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")), ("superglue", "SuperGlueImageProcessor"), ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")), ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 8fbe1b6c0d68..cf6518c41760 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -250,6 +250,7 @@ ("sew", "SEWModel"), ("sew-d", "SEWDModel"), ("siglip", "SiglipModel"), + ("siglip2", "Siglip2Model"), ("siglip_vision_model", "SiglipVisionModel"), ("smolvlm", "SmolVLMModel"), ("smolvlm_vision", "SmolVLMVisionTransformer"), @@ -721,6 +722,7 @@ ("resnet", "ResNetForImageClassification"), ("segformer", "SegformerForImageClassification"), ("siglip", "SiglipForImageClassification"), + ("siglip2", "Siglip2ForImageClassification"), ("swiftformer", "SwiftFormerForImageClassification"), ("swin", "SwinForImageClassification"), ("swinv2", "Swinv2ForImageClassification"), @@ -1403,6 +1405,7 @@ ("clip", "CLIPModel"), ("clipseg", "CLIPSegModel"), ("siglip", "SiglipModel"), + ("siglip2", "Siglip2Model"), ] ) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index f329d9e465e5..03b8c860f60b 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -99,6 +99,7 @@ ("sew", "Wav2Vec2Processor"), ("sew-d", "Wav2Vec2Processor"), ("siglip", "SiglipProcessor"), + ("siglip2", "Siglip2Processor"), ("speech_to_text", "Speech2TextProcessor"), ("speech_to_text_2", "Speech2Text2Processor"), ("speecht5", "SpeechT5Processor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 5ee4f612285f..61c2c2e23d2f 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -479,6 +479,13 @@ ), ), ("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)), + ( + "siglip2", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)), ("speech_to_text_2", ("Speech2Text2Tokenizer", None)), ("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)), diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 01c8f4dcbc9a..5e4ebd24690a 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -1394,10 +1394,9 @@ def forward( text_embeds = text_embeds / _get_vector_norm(text_embeds) # cosine similarity as logits - logit_scale = self.logit_scale.exp() - logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * logit_scale.to( - text_embeds.device - ) + logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) + logits_per_text = logits_per_text * self.logit_scale.exp().to(text_embeds.device) + logits_per_image = logits_per_text.t() loss = None diff --git a/src/transformers/models/siglip/configuration_siglip.py b/src/transformers/models/siglip/configuration_siglip.py index ad676046f348..f4a140cecc2c 100644 --- a/src/transformers/models/siglip/configuration_siglip.py +++ b/src/transformers/models/siglip/configuration_siglip.py @@ -59,6 +59,8 @@ class SiglipTextConfig(PretrainedConfig): The id of the beginning-of-sequence token in the vocabulary. eos_token_id (`int`, *optional*, defaults to 49407): The id of the end-of-sequence token in the vocabulary. + projection_size (`int`, *optional*, defaults to `hidden_size`): + The size of the projection head. Example: @@ -94,6 +96,7 @@ def __init__( pad_token_id=1, bos_token_id=49406, eos_token_id=49407, + projection_size=None, **kwargs, ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -107,6 +110,7 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act self.attention_dropout = attention_dropout + self.projection_size = projection_size if projection_size is not None else hidden_size class SiglipVisionConfig(PretrainedConfig): diff --git a/src/transformers/models/siglip/convert_siglip_to_hf.py b/src/transformers/models/siglip/convert_siglip_to_hf.py index 163f6f279792..8b0a8a250dd8 100644 --- a/src/transformers/models/siglip/convert_siglip_to_hf.py +++ b/src/transformers/models/siglip/convert_siglip_to_hf.py @@ -19,7 +19,8 @@ import argparse import collections -from pathlib import Path +import os +from typing import Tuple import numpy as np import requests @@ -28,7 +29,14 @@ from numpy import load from PIL import Image -from transformers import SiglipConfig, SiglipImageProcessor, SiglipModel, SiglipProcessor, SiglipTokenizer +from transformers import ( + GemmaTokenizerFast, + SiglipConfig, + SiglipImageProcessor, + SiglipModel, + SiglipProcessor, + SiglipTokenizer, +) from transformers.utils import logging @@ -36,6 +44,33 @@ logger = logging.get_logger(__name__) +MODEL_CONFIGS = { + "base": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_hidden_layers": 12, + "num_attention_heads": 12, + }, + "large": { + "hidden_size": 1024, + "intermediate_size": 4096, + "num_hidden_layers": 24, + "num_attention_heads": 16, + }, + "giant-opt": { + "hidden_size": 1536, + "intermediate_size": 6144, + "num_hidden_layers": 40, + "num_attention_heads": 16, + }, + "so400m": { + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + }, +} + model_name_to_checkpoint = { # base checkpoints "siglip-base-patch16-224": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz", @@ -49,56 +84,146 @@ "siglip-base-patch16-256-i18n": "/Users/nielsrogge/Documents/SigLIP/webli_i18n_b16_256_66117334.npz", # so400m checkpoints "siglip-so400m-patch14-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_so400m_384_58765454.npz", + # ----------------- v2 ----------------- + # base checkpoints + "siglip2-base-patch32-256": "gv-hf/siglip2/siglip2_b32_256.npz", + "siglip2-base-patch16-224": "gv-hf/siglip2/siglip2_b16_224.npz", + "siglip2-base-patch16-256": "gv-hf/siglip2/siglip2_b16_256.npz", + "siglip2-base-patch16-384": "gv-hf/siglip2/siglip2_b16_384.npz", + "siglip2-base-patch16-512": "gv-hf/siglip2/siglip2_b16_512.npz", + # large checkpoints + "siglip2-large-patch16-256": "gv-hf/siglip2/siglip2_l16_256.npz", + "siglip2-large-patch16-384": "gv-hf/siglip2/siglip2_l16_384.npz", + "siglip2-large-patch16-512": "gv-hf/siglip2/siglip2_l16_512.npz", + # giant opt checkpoints + "siglip2-giant-opt-patch16-256": "gv-hf/siglip2/siglip2_g-opt16_256.npz", + "siglip2-giant-opt-patch16-384": "gv-hf/siglip2/siglip2_g-opt16_384.npz", + # so400m checkpoints + "siglip2-so400m-patch14-224": "gv-hf/siglip2/siglip2_so400m14_224.npz", + "siglip2-so400m-patch14-384": "gv-hf/siglip2/siglip2_so400m14_384.npz", + "siglip2-so400m-patch16-256": "gv-hf/siglip2/siglip2_so400m16_256.npz", + "siglip2-so400m-patch16-384": "gv-hf/siglip2/siglip2_so400m16_384.npz", + "siglip2-so400m-patch16-512": "gv-hf/siglip2/siglip2_so400m16_512.npz", } -model_name_to_image_size = { - "siglip-base-patch16-224": 224, - "siglip-base-patch16-256": 256, - "siglip-base-patch16-384": 384, - "siglip-base-patch16-512": 512, - "siglip-large-patch16-256": 256, - "siglip-large-patch16-384": 384, - "siglip-base-patch16-256-i18n": 256, - "siglip-so400m-patch14-384": 384, -} +# ------------------------------------------------------------------------------------------------------ +# CONFIG +# ------------------------------------------------------------------------------------------------------ + + +def get_image_size_from_model_name(model_name: str) -> int: + if "-i18n" not in model_name: + size = model_name.split("-")[-1] + else: + size = model_name.split("-")[-2] + return int(size) + + +def get_patch_size_from_model_name(model_name: str) -> int: + patch_str = [x for x in model_name.split("-") if "patch" in x][0] + return int(patch_str[-2:]) + + +def get_vocab_size_from_model_name(model_name: str) -> int: + if "siglip2" in model_name: + vocab_size = 256000 + elif "-i18n" in model_name: + vocab_size = 250000 + else: + vocab_size = 32000 + return vocab_size + + +def get_vocab_file_from_model_name(model_name: str) -> str: + # get vocab file + if "i18n" in model_name: + vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model" + else: + vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model" + return vocab_file + + +def get_text_and_vision_vit_variants(model_name: str) -> Tuple[str, str]: + variant = model_name.split("-")[1] if "giant-opt" not in model_name else "giant-opt" + return { + "base": ("base", "base"), + "large": ("large", "large"), + "so400m": ("so400m", "so400m"), + # g-opt siglip2 is not symmetric + "giant-opt": ("so400m", "giant-opt"), + }[variant] def get_siglip_config(model_name): - config = SiglipConfig() - - vocab_size = 250000 if "i18n" in model_name else 32000 - image_size = model_name_to_image_size[model_name] - patch_size = 16 if "patch16" in model_name else 14 - - # size of the architecture - config.vision_config.image_size = image_size - config.vision_config.patch_size = patch_size - config.text_config.vocab_size = vocab_size - - if "base" in model_name: - pass - elif "large" in model_name: - config.text_config.hidden_size = 1024 - config.text_config.intermediate_size = 4096 - config.text_config.num_hidden_layers = 24 - config.text_config.num_attention_heads = 16 - config.vision_config.hidden_size = 1024 - config.vision_config.intermediate_size = 4096 - config.vision_config.num_hidden_layers = 24 - config.vision_config.num_attention_heads = 16 - elif "so400m" in model_name: - config.text_config.hidden_size = 1152 - config.text_config.intermediate_size = 4304 - config.text_config.num_hidden_layers = 27 - config.text_config.num_attention_heads = 16 - config.vision_config.hidden_size = 1152 - config.vision_config.intermediate_size = 4304 - config.vision_config.num_hidden_layers = 27 - config.vision_config.num_attention_heads = 16 + text_variant, vision_variant = get_text_and_vision_vit_variants(model_name) + text_config = MODEL_CONFIGS[text_variant].copy() + vision_config = MODEL_CONFIGS[vision_variant].copy() + + text_config["vocab_size"] = get_vocab_size_from_model_name(model_name) + vision_config["image_size"] = get_image_size_from_model_name(model_name) + vision_config["patch_size"] = get_patch_size_from_model_name(model_name) + + if text_config["hidden_size"] != vision_config["hidden_size"]: + text_config["projection_size"] = vision_config["hidden_size"] + + return SiglipConfig(text_config=text_config, vision_config=vision_config) + + +# ------------------------------------------------------------------------------------------------------ +# PROCESSING +# ------------------------------------------------------------------------------------------------------ + + +def get_tokenizer(model_name: str) -> GemmaTokenizerFast: + if "siglip2" in model_name: + tokenizer = GemmaTokenizerFast.from_pretrained( + "google/gemma-2-9b-it", + add_bos_token=False, + add_eos_token=True, + padding_side="right", + do_lower_case=True, + # important: make tokenizer NOT return attention_mask since original one doesn't require it + model_input_names=["input_ids"], + ) + else: + # for siglip v1 + vocab_file = get_vocab_file_from_model_name(model_name) + # important: make tokenizer not return attention_mask since original one doesn't require it + tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"]) + return tokenizer + + +def get_image_processor(model_name: str) -> SiglipImageProcessor: + image_size = get_image_size_from_model_name(model_name) + size = {"height": image_size, "width": image_size} + if "siglip2" in model_name: + image_processor = SiglipImageProcessor(size=size, resample=2) # bilinear resampling else: - raise ValueError("Model not supported") + image_processor = SiglipImageProcessor(size=size) + return image_processor - return config + +# ------------------------------------------------------------------------------------------------------ +# CONVERT FUNCTIONS +# ------------------------------------------------------------------------------------------------------ + + +def split_encoderblock_layers(state_dict: dict) -> dict: + """ + Split the encoderblock weight into layers. In some cases they are concatenated in + the original checkpoints. + """ + # Make shallow copy + state_dict = state_dict.copy() + # Split encoderblock weight into layers + keys = list(state_dict.keys()) + for key in keys: + if "/encoderblock/" in key: + weight = state_dict.pop(key) + for i, weight_i in enumerate(weight): + new_name = key.replace("encoderblock", f"encoderblock_{i}") + state_dict[new_name] = weight_i + return state_dict def create_rename_keys(config): @@ -258,23 +383,21 @@ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logit Copy/paste/tweak model's weights to our SigLIP structure. """ - # define default SigLIP configuration + # Define default SigLIP configuration config = get_siglip_config(model_name) - # get checkpoint + # Get checkpoint checkpoint = model_name_to_checkpoint[model_name] + if not os.path.exists(checkpoint): + org, repo_id, *filepath = checkpoint.split("/") + checkpoint = hf_hub_download(repo_id=f"{org}/{repo_id}", filename="/".join(filepath)) - # get vocab file - if "i18n" in model_name: - vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model" - else: - vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model" - - # load original state dict + # Load original state dict data = load(checkpoint) state_dict = flatten_nested_dict(data) + state_dict = split_encoderblock_layers(state_dict) - # remove and rename some keys + # Remove and rename some keys rename_keys = create_rename_keys(config) for src, dest in rename_keys: rename_key(state_dict, src, dest, config) @@ -282,64 +405,61 @@ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logit # qkv matrices of attention pooling head need special treatment read_in_q_k_v_head(state_dict, config) - # load HuggingFace model + # Load HuggingFace model model = SiglipModel(config).eval() model.load_state_dict(state_dict) - # create processor - # important: make tokenizer not return attention_mask since original one doesn't require it - image_size = config.vision_config.image_size - size = {"height": image_size, "width": image_size} - image_processor = SiglipImageProcessor(size=size) - tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"]) + # Create processor + image_processor = get_image_processor(model_name) + tokenizer = get_tokenizer(model_name) processor = SiglipProcessor(image_processor=image_processor, tokenizer=tokenizer) - # verify on dummy images and texts + # Verify forward pass on dummy images and texts url_1 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg" image_1 = Image.open(requests.get(url_1, stream=True).raw).convert("RGB") url_2 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg" image_2 = Image.open(requests.get(url_2, stream=True).raw).convert("RGB") texts = ["an apple", "a picture of an apple"] - inputs = processor(images=[image_1, image_2], text=texts, return_tensors="pt", padding="max_length") - - # verify input_ids against original ones - if image_size == 224: - filename = "siglip_pixel_values.pt" - elif image_size == 256: - filename = "siglip_pixel_values_256.pt" - elif image_size == 384: - filename = "siglip_pixel_values_384.pt" - elif image_size == 512: - filename = "siglip_pixel_values_512.pt" - else: - raise ValueError("Image size not supported") - - filepath = hf_hub_download(repo_id="nielsr/test-image", filename=filename, repo_type="dataset") - original_pixel_values = torch.load(filepath) - filepath = hf_hub_download(repo_id="nielsr/test-image", filename="siglip_input_ids.pt", repo_type="dataset") - original_input_ids = torch.load(filepath) + inputs = processor(images=[image_1, image_2], text=texts, padding="max_length", max_length=64, return_tensors="pt") + with torch.no_grad(): + outputs = model(**inputs) - if "i18n" not in model_name: - assert inputs.input_ids.tolist() == original_input_ids.tolist() + if verify_logits: + image_size = config.vision_config.image_size + + # verify input_ids against original ones + if image_size == 224: + filename = "siglip_pixel_values.pt" + elif image_size == 256: + filename = "siglip_pixel_values_256.pt" + elif image_size == 384: + filename = "siglip_pixel_values_384.pt" + elif image_size == 512: + filename = "siglip_pixel_values_512.pt" + else: + raise ValueError("Image size not supported") - print("Mean of original pixel values:", original_pixel_values.mean()) - print("Mean of new pixel values:", inputs.pixel_values.mean()) + filepath = hf_hub_download(repo_id="nielsr/test-image", filename=filename, repo_type="dataset") + original_pixel_values = torch.load(filepath) + filepath = hf_hub_download(repo_id="nielsr/test-image", filename="siglip_input_ids.pt", repo_type="dataset") + original_input_ids = torch.load(filepath) - # note: we're testing with original pixel values here since we don't have exact pixel values - with torch.no_grad(): - outputs = model(input_ids=inputs.input_ids, pixel_values=original_pixel_values) + if "i18n" not in model_name: + assert inputs.input_ids.tolist() == original_input_ids.tolist() - # with torch.no_grad(): - # outputs = model(input_ids=inputs.input_ids, pixel_values=inputs.pixel_values) + print("Mean of original pixel values:", original_pixel_values.mean()) + print("Mean of new pixel values:", inputs.pixel_values.mean()) - print(outputs.logits_per_image[:3, :3]) + # note: we're testing with original pixel values here since we don't have exact pixel values + with torch.no_grad(): + outputs = model(input_ids=original_input_ids, pixel_values=original_pixel_values) + print(outputs.logits_per_image[:3, :3]) - probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities - print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") - print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'") + probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities + print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'") - if verify_logits: if model_name == "siglip-base-patch16-224": expected_slice = torch.tensor( [[-2.9621, -2.1672], [-0.2713, 0.2910]], @@ -375,15 +495,16 @@ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logit print("Looks ok!") if pytorch_dump_folder_path is not None: - Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + pytorch_dump_folder_path = os.path.join(pytorch_dump_folder_path, model_name) + os.makedirs(pytorch_dump_folder_path, exist_ok=True) print(f"Saving model {model_name} to {pytorch_dump_folder_path}") model.save_pretrained(pytorch_dump_folder_path) print(f"Saving processor to {pytorch_dump_folder_path}") processor.save_pretrained(pytorch_dump_folder_path) if push_to_hub: - model.push_to_hub(f"nielsr/{model_name}") - processor.push_to_hub(f"nielsr/{model_name}") + model.push_to_hub(f"s0225/{model_name}", private=True) + processor.push_to_hub(f"s0225/{model_name}", private=True) if __name__ == "__main__": @@ -401,7 +522,7 @@ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logit ) parser.add_argument( "--verify_logits", - action="store_false", + action="store_true", help="Whether to verify logits against the original implementation.", ) parser.add_argument( diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index d8a317493a10..9c54ed9c03fc 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -471,15 +471,9 @@ def forward( # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) dropout_rate = self.dropout if self.training else 0.0 @@ -936,7 +930,7 @@ def __init__(self, config: SiglipTextConfig): self.encoder = SiglipEncoder(config) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - self.head = nn.Linear(embed_dim, embed_dim) + self.head = nn.Linear(embed_dim, config.projection_size) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @@ -1415,10 +1409,11 @@ def forward( text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits - logits_per_text = ( - torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp() - + self.logit_bias - ) + logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) + + logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) + logits_per_text = logits_per_text * logit_scale.exp() + logit_bias + logits_per_image = logits_per_text.t() loss = None diff --git a/src/transformers/models/siglip/processing_siglip.py b/src/transformers/models/siglip/processing_siglip.py index fd89287fc3f4..7a37cebabfe7 100644 --- a/src/transformers/models/siglip/processing_siglip.py +++ b/src/transformers/models/siglip/processing_siglip.py @@ -41,7 +41,7 @@ class SiglipProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "SiglipImageProcessor" - tokenizer_class = "SiglipTokenizer" + tokenizer_class = "AutoTokenizer" def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) @@ -113,7 +113,7 @@ def __call__( image_features = self.image_processor(images, return_tensors=return_tensors) if text is not None and images is not None: - encoding["pixel_values"] = image_features.pixel_values + encoding.update(image_features) return encoding elif text is not None: return encoding diff --git a/src/transformers/models/siglip2/__init__.py b/src/transformers/models/siglip2/__init__.py new file mode 100644 index 000000000000..fe5b732b9513 --- /dev/null +++ b/src/transformers/models/siglip2/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_siglip2 import * + from .image_processing_siglip2 import * + from .image_processing_siglip2_fast import * + from .modeling_siglip2 import * + from .processing_siglip2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/siglip2/configuration_siglip2.py b/src/transformers/models/siglip2/configuration_siglip2.py new file mode 100644 index 000000000000..6cb379c670ad --- /dev/null +++ b/src/transformers/models/siglip2/configuration_siglip2.py @@ -0,0 +1,277 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_siglip2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Siglip2TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Siglip2TextModel`]. It is used to instantiate a + Siglip2 text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip2 + [google/siglip2-base-patch16-224](https://huggingface.co/google/siglip2-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip2 text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`Siglip2Model`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + projection_size (`int`, *optional*, defaults to `hidden_size`): + The size of the projection head. + + Example: + + ```python + >>> from transformers import Siglip2TextConfig, Siglip2TextModel + + >>> # Initializing a Siglip2TextConfig with google/siglip2-base-patch16-224 style configuration + >>> configuration = Siglip2TextConfig() + + >>> # Initializing a Siglip2TextModel (with random weights) from the google/siglip2-base-patch16-224 style configuration + >>> model = Siglip2TextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip2_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip2 + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + projection_size=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.projection_size = projection_size if projection_size is not None else hidden_size + + +class Siglip2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a + Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2 + [google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + num_patches (`int`, *optional*, defaults to 256): + The number of patches in the image with the size of (`patch_size`, `patch_size`). + The image is resized to fill maximum of this number of patches, and to preserve + the aspect ratio. In case the resulted number of patches is lower, the image is + padded in "patch" dimension. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel + + >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration + >>> configuration = Siglip2VisionConfig() + + >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration + >>> model = Siglip2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip2_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + num_patches=256, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.num_patches = num_patches + + +class Siglip2Config(PretrainedConfig): + r""" + [`Siglip2Config`] is the configuration class to store the configuration of a [`Siglip2Model`]. It is used to + instantiate a Siglip2 model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip2 + [google/siglip2-base-patch16-224](https://huggingface.co/google/siglip2-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Siglip2TextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Siglip2VisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import Siglip2Config, Siglip2Model + + >>> # Initializing a Siglip2Config with google/siglip2-base-patch16-224 style configuration + >>> configuration = Siglip2Config() + + >>> # Initializing a Siglip2Model (with random weights) from the google/siglip2-base-patch16-224 style configuration + >>> model = Siglip2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a Siglip2Config from a Siglip2TextConfig and a Siglip2VisionConfig + >>> from transformers import Siglip2TextConfig, Siglip2VisionConfig + + >>> # Initializing a Siglip2Text and Siglip2Vision configuration + >>> config_text = Siglip2TextConfig() + >>> config_vision = Siglip2VisionConfig() + + >>> config = Siglip2Config.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "siglip2" + sub_configs = {"text_config": Siglip2TextConfig, "vision_config": Siglip2VisionConfig} + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `Siglip2TextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `Siglip2VisionConfig` with default values.") + + self.text_config = Siglip2TextConfig(**text_config) + self.vision_config = Siglip2VisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: Siglip2TextConfig, vision_config: Siglip2VisionConfig, **kwargs): + r""" + Instantiate a [`Siglip2Config`] (or a derived class) from siglip2 text model configuration and siglip2 vision + model configuration. + + Returns: + [`Siglip2Config`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + +__all__ = ["Siglip2Config", "Siglip2TextConfig", "Siglip2VisionConfig"] diff --git a/src/transformers/models/siglip2/convert_siglip2_to_hf.py b/src/transformers/models/siglip2/convert_siglip2_to_hf.py new file mode 100644 index 000000000000..819596498996 --- /dev/null +++ b/src/transformers/models/siglip2/convert_siglip2_to_hf.py @@ -0,0 +1,438 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Siglip2 checkpoints from the original repository. + +URL: https://github.com/google-research/big_vision/tree/main +""" + +import argparse +import collections +import os +import re + +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image, ImageDraw + +from transformers import GemmaTokenizerFast, Siglip2Config, Siglip2ImageProcessorFast, Siglip2Model, Siglip2Processor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +COMMON_CONFIG_PARAMS = { + "base": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_hidden_layers": 12, + "num_attention_heads": 12, + }, + "large": { + "hidden_size": 1024, + "intermediate_size": 4096, + "num_hidden_layers": 24, + "num_attention_heads": 16, + }, + "so400m": { + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + }, +} + +MODEL_NAME_TO_CHECKPOINT_PATH = { + # base checkpoints + "siglip2-base-patch16-naflex": "gv-hf/siglip2/siglip2_b16_naflex.npz", + "siglip2-so400m-patch16-naflex": "gv-hf/siglip2/siglip2_so400m16_naflex.npz", +} + +# fmt: off +EXPECTED_OUTPUTS = { + "siglip2-base-patch16-naflex": torch.tensor([ + [ 1.0195, -0.0280, -1.4468], + [ -4.5395, -6.2269, -1.5667], + [ 4.1757, 5.0358, 3.5159], + [ 9.4264, 10.1879, 6.3353], + [ 2.4409, 3.1058, 4.5491], + [-12.3230, -13.7355, -13.4632], + [ 1.1520, 1.1687, -1.9647], + ]), + "siglip2-so400m-patch16-naflex": torch.tensor([ + [ 0.9422, 0.5540, -2.4405], + [ -7.3522, -9.4931, -6.3499], + [ 5.7852, 6.7288, 7.7893], + [ 9.9881, 10.8136, 9.2121], + [ 5.3660, 5.7746, 8.4130], + [-12.7218, -14.2631, -13.6442], + [ 0.6384, 0.4278, -0.9022], + ]), +} +# fmt: on + +# fmt: off +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # Vision embeddings + r"params/img/embedding/kernel": r"vision_model.embeddings.patch_embedding.weight", + r"params/img/embedding/bias": r"vision_model.embeddings.patch_embedding.bias", + r"params/img/pos_embedding": r"vision_model.embeddings.position_embedding.weight", + # Vision encoder + r"params/img/Transformer/encoderblock_(\d+)/LayerNorm_0/scale": r"vision_model.encoder.layers.\1.layer_norm1.weight", + r"params/img/Transformer/encoderblock_(\d+)/LayerNorm_0/bias": r"vision_model.encoder.layers.\1.layer_norm1.bias", + r"params/img/Transformer/encoderblock_(\d+)/LayerNorm_1/scale": r"vision_model.encoder.layers.\1.layer_norm2.weight", + r"params/img/Transformer/encoderblock_(\d+)/LayerNorm_1/bias": r"vision_model.encoder.layers.\1.layer_norm2.bias", + r"params/img/Transformer/encoderblock_(\d+)/MlpBlock_0/Dense_0/kernel": r"vision_model.encoder.layers.\1.mlp.fc1.weight", + r"params/img/Transformer/encoderblock_(\d+)/MlpBlock_0/Dense_0/bias": r"vision_model.encoder.layers.\1.mlp.fc1.bias", + r"params/img/Transformer/encoderblock_(\d+)/MlpBlock_0/Dense_1/kernel": r"vision_model.encoder.layers.\1.mlp.fc2.weight", + r"params/img/Transformer/encoderblock_(\d+)/MlpBlock_0/Dense_1/bias": r"vision_model.encoder.layers.\1.mlp.fc2.bias", + r"params/img/Transformer/encoderblock_(\d+)/MultiHeadDotProductAttention_0/(q|k|v|out)[a-z]*/kernel": r"vision_model.encoder.layers.\1.self_attn.\2_proj.weight", + r"params/img/Transformer/encoderblock_(\d+)/MultiHeadDotProductAttention_0/(q|k|v|out)[a-z]*/bias": r"vision_model.encoder.layers.\1.self_attn.\2_proj.bias", + # Vision norm + r"params/img/Transformer/encoder_norm/scale": r"vision_model.post_layernorm.weight", + r"params/img/Transformer/encoder_norm/bias": r"vision_model.post_layernorm.bias", + # Vision head + r"params/img/MAPHead_0/probe": r"vision_model.head.probe", + r"params/img/MAPHead_0/LayerNorm_0/scale": r"vision_model.head.layernorm.weight", + r"params/img/MAPHead_0/LayerNorm_0/bias": r"vision_model.head.layernorm.bias", + r"params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel": r"vision_model.head.mlp.fc1.weight", + r"params/img/MAPHead_0/MlpBlock_0/Dense_0/bias": r"vision_model.head.mlp.fc1.bias", + r"params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel": r"vision_model.head.mlp.fc2.weight", + r"params/img/MAPHead_0/MlpBlock_0/Dense_1/bias": r"vision_model.head.mlp.fc2.bias", + r"params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel": r"vision_model.head.attention.out_proj.weight", + r"params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias": r"vision_model.head.attention.out_proj.bias", + r"params/img/MAPHead_0/MultiHeadDotProductAttention_0/qkv/kernel": r"vision_model.head.attention.in_proj_weight", + r"params/img/MAPHead_0/MultiHeadDotProductAttention_0/qkv/bias": r"vision_model.head.attention.in_proj_bias", + # Text embeddings + r"params/txt/Embed_0/embedding": r"text_model.embeddings.token_embedding.weight", + r"params/txt/pos_embedding": r"text_model.embeddings.position_embedding.weight", + # Text encoder + r"params/txt/Encoder_0/encoderblock_(\d+)/LayerNorm_0/scale": r"text_model.encoder.layers.\1.layer_norm1.weight", + r"params/txt/Encoder_0/encoderblock_(\d+)/LayerNorm_0/bias": r"text_model.encoder.layers.\1.layer_norm1.bias", + r"params/txt/Encoder_0/encoderblock_(\d+)/LayerNorm_1/scale": r"text_model.encoder.layers.\1.layer_norm2.weight", + r"params/txt/Encoder_0/encoderblock_(\d+)/LayerNorm_1/bias": r"text_model.encoder.layers.\1.layer_norm2.bias", + r"params/txt/Encoder_0/encoderblock_(\d+)/MlpBlock_0/Dense_0/kernel": r"text_model.encoder.layers.\1.mlp.fc1.weight", + r"params/txt/Encoder_0/encoderblock_(\d+)/MlpBlock_0/Dense_0/bias": r"text_model.encoder.layers.\1.mlp.fc1.bias", + r"params/txt/Encoder_0/encoderblock_(\d+)/MlpBlock_0/Dense_1/kernel": r"text_model.encoder.layers.\1.mlp.fc2.weight", + r"params/txt/Encoder_0/encoderblock_(\d+)/MlpBlock_0/Dense_1/bias": r"text_model.encoder.layers.\1.mlp.fc2.bias", + r"params/txt/Encoder_0/encoderblock_(\d+)/MultiHeadDotProductAttention_0/(q|k|v|out)[a-z]*/kernel": r"text_model.encoder.layers.\1.self_attn.\2_proj.weight", + r"params/txt/Encoder_0/encoderblock_(\d+)/MultiHeadDotProductAttention_0/(q|k|v|out)[a-z]*/bias": r"text_model.encoder.layers.\1.self_attn.\2_proj.bias", + # Text encoder norm and head + r"params/txt/Encoder_0/encoder_norm/scale": r"text_model.final_layer_norm.weight", + r"params/txt/Encoder_0/encoder_norm/bias": r"text_model.final_layer_norm.bias", + r"params/txt/head/kernel": r"text_model.head.weight", + r"params/txt/head/bias": r"text_model.head.bias", + # learned temperature and bias + r"params/t": r"logit_scale", + r"params/b": r"logit_bias", +} +# fmt: on + + +# -------------------------------------------------------------------------------------------- +# Model objects: configuration, tokenizer, image processor +# -------------------------------------------------------------------------------------------- + + +def get_siglip2_config(model_name: str) -> Siglip2Config: + """ + Create a configuration for the Siglip2 model based on the model name. + """ + + _, variant, patch, _ = model_name.split("-") + patch_size = int(patch[-2:]) + num_patches = 256 + + common_options = COMMON_CONFIG_PARAMS[variant] + vision_config = { + "patch_size": patch_size, + "num_patches": num_patches, + **common_options, + } + text_config = { + "vocab_size": 256_000, + **common_options, + } + config = Siglip2Config( + vision_config=vision_config, + text_config=text_config, + ) + return config + + +def get_siglip2_tokenizer() -> GemmaTokenizerFast: + # Load pretrained tokenizer + gemma_checkpoint = "google/gemma-7b" + tokenizer = GemmaTokenizerFast.from_pretrained( + gemma_checkpoint, + add_bos_token=False, + add_eos_token=True, + padding_side="right", + do_lower_case=True, + # important: make tokenizer NOT return attention_mask since original one doesn't require it + model_input_names=["input_ids"], + ) + return tokenizer + + +def get_siglip2_image_processor(patch_size: int, max_num_patches: int) -> Siglip2ImageProcessorFast: + image_processor = Siglip2ImageProcessorFast( + patch_size=patch_size, + max_num_patches=max_num_patches, + do_resize=True, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + do_rescale=True, + rescale_factor=1 / 255, + resample=Image.Resampling.BILINEAR, + ) + return image_processor + + +# -------------------------------------------------------------------------------------------- +# Helper functions for state dict conversion +# -------------------------------------------------------------------------------------------- + + +def flatten_nested_dict(params: dict, parent_key: str = "", sep: str = "/") -> dict: + """ + Flatten a nested original checkpoint dictionary into a flat dictionary. + """ + items = [] + for k, v in params.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, collections.abc.MutableMapping): + items.extend(flatten_nested_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def split_encoderblock_layers(state_dict: dict) -> dict: + """ + Split the encoderblock weight into layers. In some cases they are concatenated in + the original checkpoints. + """ + # Make shallow copy + state_dict = state_dict.copy() + # Split encoderblock weight into layers + keys = list(state_dict.keys()) + for key in keys: + if "/encoderblock/" in key: + weight = state_dict.pop(key) + for i, weight_i in enumerate(weight): + new_name = key.replace("encoderblock", f"encoderblock_{i}") + state_dict[new_name] = weight_i + return state_dict + + +def merge_qkv_for_head(state_dict: dict, config: Siglip2Config) -> dict: + """ + Merge the q/k/v weights and biases for the attention head. + """ + # Make shallow copy + state_dict = state_dict.copy() + # Read and process q/k/v weights and biases + qkv_weights, qkv_biases = [], [] + for name in ["query", "key", "value"]: + prefix = f"params/img/MAPHead_0/MultiHeadDotProductAttention_0/{name}" + weight = state_dict.pop(f"{prefix}/kernel").reshape(-1, config.vision_config.hidden_size) + bias = state_dict.pop(f"{prefix}/bias").reshape(-1) + qkv_weights.append(weight) + qkv_biases.append(bias) + + # Combine into single tensors + state_dict["params/img/MAPHead_0/MultiHeadDotProductAttention_0/qkv/kernel"] = np.concatenate(qkv_weights, axis=1) + state_dict["params/img/MAPHead_0/MultiHeadDotProductAttention_0/qkv/bias"] = np.concatenate(qkv_biases, axis=0) + return state_dict + + +def convert_old_keys_to_new_keys(state_dict_keys: list) -> dict: + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # an empty line + continue + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +# -------------------------------------------------------------------------------------------- +# Helper functions for model verification +# -------------------------------------------------------------------------------------------- + + +def create_image(width, height): + """ + Helper function to create an image with a blue circle on a red background. + """ + image = Image.new("RGB", (width, height), color="red") + draw = ImageDraw.Draw(image) + center_x = image.width // 2 + center_y = image.height // 2 + radius = min(center_x, center_y) // 8 * 7 + draw.ellipse( + (center_x - radius, center_y - radius, center_x + radius, center_y + radius), + fill="blue", + outline="green", + width=image.width // 20, + ) + return image + + +def prepare_inputs(): + """ + Prepare inputs for the model. + """ + text = [ + "circle", + "ellipsoid", + "blue circle on red background", + "blue circle with green border on red background", + "green circle on red background", + "a dog", + "a blue dog with a green border on a red background", + ] + img224 = create_image(224, 224) + img1024 = create_image(1024, 1024) + img224_1024 = create_image(1024, 224) + + images = [img224, img1024, img224_1024] + return text, images + + +# -------------------------------------------------------------------------------------------- +# Convert model +# -------------------------------------------------------------------------------------------- + + +@torch.no_grad() +def convert_siglip2_checkpoint(model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False): + """ + Copy/paste/tweak model's weights to our Siglip2 structure. + """ + + # Define Siglip2 configuration + config = get_siglip2_config(model_name) + + checkpoint = MODEL_NAME_TO_CHECKPOINT_PATH[model_name] + if not os.path.exists(checkpoint): + org, repo_id, *filepath = checkpoint.split("/") + checkpoint = hf_hub_download(repo_id=f"{org}/{repo_id}", filename="/".join(filepath)) + + print(f"Loading checkpoint from {checkpoint}...") + data = np.load(checkpoint) + state_dict = flatten_nested_dict(data) + state_dict = split_encoderblock_layers(state_dict) + state_dict = merge_qkv_for_head(state_dict, config) + + # Rename and transform weights + print("Renaming and transforming weights...") + + original_keys = list(state_dict.keys()) + hf_keys = convert_old_keys_to_new_keys(original_keys) + + new_state_dict = {} + for original_key in original_keys: + new_key = hf_keys[original_key] + parameter = state_dict.pop(original_key) + + hidden_size = config.vision_config.hidden_size if "vision" in new_key else config.text_config.hidden_size + + if any(k in new_key for k in ("out_proj", "q_proj", "k_proj", "v_proj", "position_embedding")): + parameter = parameter.reshape(-1, hidden_size) + + # Transpose every weight except for position_embedding and token_embedding + if new_key.endswith("weight") and "position_embedding" not in new_key and "token_embedding" not in new_key: + parameter = parameter.T + + # Reshape every bias + if new_key.endswith("bias"): + parameter = parameter.reshape(-1) + + new_state_dict[new_key] = torch.from_numpy(parameter) + + # load HuggingFace model + print("Loading HuggingFace model...") + model = Siglip2Model(config).eval() + model.load_state_dict(new_state_dict) + + # Create processor + print("Creating processor...") + # TODO: update with more checkpoints + tokenizer = get_siglip2_tokenizer() + image_processor = get_siglip2_image_processor(config.vision_config.patch_size, max_num_patches=256) + processor = Siglip2Processor(image_processor=image_processor, tokenizer=tokenizer) + + # Verify logits + if verify_logits: + print(f"Verifying logits for {model_name}...") + text, images = prepare_inputs() + inputs = processor(text=text, images=images, padding="max_length", max_length=64, return_tensors="pt") + outputs = model(**inputs) + torch.testing.assert_close(outputs.logits_per_text, EXPECTED_OUTPUTS[model_name], atol=1e-3, rtol=1e-3) + + # Save model + if pytorch_dump_folder_path is not None: + dst_dir = os.path.join(pytorch_dump_folder_path, model_name) + print(f"Saving model {model_name} to {dst_dir}...") + model.save_pretrained(dst_dir) + print(f"Saving processor to {dst_dir}...") + processor.save_pretrained(dst_dir) + + if push_to_hub: + print(f"Pushing model and processor for {model_name} to the HuggingFace Hub...") + model.push_to_hub(f"qubvel-hf/{model_name}", private=True) + processor.push_to_hub(f"qubvel-hf/{model_name}", private=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="siglip2-base-patch16-naflex", + type=str, + choices=MODEL_NAME_TO_CHECKPOINT_PATH.keys(), + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default="checkpoints/", + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--verify_logits", + action="store_true", + help="Whether to verify logits against the original implementation.", + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_siglip2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub) diff --git a/src/transformers/models/siglip2/image_processing_siglip2.py b/src/transformers/models/siglip2/image_processing_siglip2.py new file mode 100644 index 000000000000..6278010319b9 --- /dev/null +++ b/src/transformers/models/siglip2/image_processing_siglip2.py @@ -0,0 +1,343 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for SigLIP2.""" + +import math +from functools import lru_cache +from typing import List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + from PIL import Image + + +@lru_cache(maxsize=256) +def get_image_size_for_max_num_patches( + image_height: int, image_width: int, patch_size: int, max_num_patches: int, eps: float = 1e-5 +) -> Tuple[int, int]: + """ + Determine image size based on max number of patches, ensure dimensions are divisible by patch size and image is at least 1 patch. + + Args: + image_height (`int`): + Original image height. + image_width (`int`): + Original image width. + patch_size (`int`): + Patch size for processing. + max_num_patches (`int`): + Maximum number of patches. + eps (`float`): + Small threshold for binary search. + + Returns: + Tuple: (target_height, target_width) + """ + + def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int: + scaled_size = size * scale + scaled_size = math.ceil(scaled_size / patch_size) * patch_size # make divisible by patch_size + scaled_size = max(patch_size, scaled_size) # ensure at least 1 patch + return int(scaled_size) + + # Binary search for optimal scale + scale_min, scale_max = eps / 10, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size) + target_width = get_scaled_image_size(scale, image_width, patch_size) + num_patches = (target_height / patch_size) * (target_width / patch_size) + + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size) + target_width = get_scaled_image_size(scale, image_width, patch_size) + return target_height, target_width + + +def convert_image_to_patches(image: np.ndarray, patch_size: int) -> np.ndarray: + """ + Convert 3D array image of shape (image_height, image_width, num_channels) into 2D array of patches of shape + (num_patches_height * num_patches_width, patch_size * patch_size * num_channels). + """ + image_height, image_width, num_channels = image.shape + num_patches_height = image_height // patch_size + num_patches_width = image_width // patch_size + patched_image = image.reshape(num_patches_height, patch_size, num_patches_width, patch_size, num_channels) + patched_image = patched_image.transpose(0, 2, 1, 3, 4) + patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1) + return patched_image + + +def pad_along_first_dim(array: np.ndarray, target_length: int, pad_value: int = 0) -> Tuple[np.ndarray, np.ndarray]: + """ + Pad the array along the first dimension. + """ + current_length = array.shape[0] + padding_length = target_length - current_length + mask = np.ones((target_length,), dtype=np.int32) + if padding_length > 0: + paddings = [(0, padding_length)] + [(0, 0)] * (array.ndim - 1) + array = np.pad(array, paddings, mode="constant", constant_values=pad_value) + mask[-padding_length:] = 0 + return array, mask + + +class Siglip2ImageProcessor(BaseImageProcessor): + r""" + Constructs a SigLIP2 image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's dimensions to fit `max_num_patches` according to given `patch_size`. + Can be overridden by `do_resize` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch the image will be split to. + max_num_patches (`int`, *optional*, defaults to 256): + The image will be resized to have at most this number of patches, + and then padded in "patch" dimension to match this number exactly. + """ + + model_input_names = ["pixel_values", "pixel_attention_mask", "spatial_shapes"] + + def __init__( + self, + do_resize: bool = True, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + patch_size: int = 16, + max_num_patches: int = 256, + **kwargs, + ): + super().__init__(**kwargs) + + image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5] + image_std = image_std if image_std is not None else [0.5, 0.5, 0.5] + + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + self.patch_size = patch_size + self.max_num_patches = max_num_patches + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: Optional[bool] = None, + patch_size: Optional[int] = None, + max_num_patches: Optional[int] = None, + ) -> "Image.Image": + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + Patch size for processing, same as the patch size used in the model. + max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`): + Maximum number of patches per image, the image will be resized to have at most this number of patches. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + patch_size = patch_size if patch_size is not None else self.patch_size + max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches + + # Explicitly specify data format to be channels last for image preprocessing. + # Image processor does not support different output formats, because it returns patches. + data_format = ChannelDimension.LAST + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + pixel_masks = [] + pixel_values = [] + spatial_shapes = [] + + for image in images: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + if do_resize: + height, width = get_image_size_for_max_num_patches( + image_height=image.shape[0], + image_width=image.shape[1], + patch_size=patch_size, + max_num_patches=max_num_patches, + ) + image = resize(image=image, size=(height, width), resample=resample, input_data_format=data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=data_format) + + patches = convert_image_to_patches(image, patch_size) + patches, mask = pad_along_first_dim(patches, max_num_patches) + num_patches_height = image.shape[0] // patch_size + num_patches_width = image.shape[1] // patch_size + + spatial_shapes.append((num_patches_height, num_patches_width)) + pixel_values.append(patches) + pixel_masks.append(mask) + + batch_feature = BatchFeature( + data={ + "pixel_values": pixel_values, + "pixel_attention_mask": pixel_masks, + "spatial_shapes": spatial_shapes, + }, + tensor_type=return_tensors, + ) + + return batch_feature + + +__all__ = ["Siglip2ImageProcessor"] diff --git a/src/transformers/models/siglip2/image_processing_siglip2_fast.py b/src/transformers/models/siglip2/image_processing_siglip2_fast.py new file mode 100644 index 000000000000..3cb2015e3695 --- /dev/null +++ b/src/transformers/models/siglip2/image_processing_siglip2_fast.py @@ -0,0 +1,322 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for SigLIP2.""" + +import math +from functools import lru_cache +from typing import List, Optional, Tuple, Union + +import torch + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + TensorType, +) +from ...utils import ( + filter_out_non_signature_kwargs, + is_torch_available, +) + + +if is_torch_available(): + import torch + + +@lru_cache(maxsize=256) +# Copied from transformers.models.siglip2.image_processing_siglip2.get_image_size_for_max_num_patches +def get_image_size_for_max_num_patches( + image_height: int, image_width: int, patch_size: int, max_num_patches: int, eps: float = 1e-5 +) -> Tuple[int, int]: + """ + Determine image size based on max number of patches, ensure dimensions are divisible by patch size and image is at least 1 patch. + + Args: + image_height (`int`): + Original image height. + image_width (`int`): + Original image width. + patch_size (`int`): + Patch size for processing. + max_num_patches (`int`): + Maximum number of patches. + eps (`float`): + Small threshold for binary search. + + Returns: + Tuple: (target_height, target_width) + """ + + def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int: + scaled_size = size * scale + scaled_size = math.ceil(scaled_size / patch_size) * patch_size # make divisible by patch_size + scaled_size = max(patch_size, scaled_size) # ensure at least 1 patch + return int(scaled_size) + + # Binary search for optimal scale + scale_min, scale_max = eps / 10, 100.0 + while (scale_max - scale_min) >= eps: + scale = (scale_min + scale_max) / 2 + target_height = get_scaled_image_size(scale, image_height, patch_size) + target_width = get_scaled_image_size(scale, image_width, patch_size) + num_patches = (target_height / patch_size) * (target_width / patch_size) + + if num_patches <= max_num_patches: + scale_min = scale + else: + scale_max = scale + + scale = scale_min + target_height = get_scaled_image_size(scale, image_height, patch_size) + target_width = get_scaled_image_size(scale, image_width, patch_size) + return target_height, target_width + + +def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor": + """ + Convert 3D tensor image of shape (num_channels, image_height, image_width) into 2D tensor of patches of shape + (num_patches_height * num_patches_width, patch_size * patch_size * num_channels). + """ + num_channels, image_height, image_width = image.shape + num_patches_height = image_height // patch_size + num_patches_width = image_width // patch_size + patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size) + patched_image = patched_image.permute(1, 3, 2, 4, 0) + patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1) + return patched_image + + +def pad_along_first_dim( + tensor: "torch.Tensor", target_length: int, pad_value: int = 0 +) -> Tuple["torch.Tensor", "torch.Tensor"]: + """ + Pad the tensor along the first dimension. + """ + current_length = tensor.shape[0] + padding_length = target_length - current_length + mask = torch.ones((target_length,), dtype=torch.int32) + if padding_length > 0: + padding = [0, 0] * (tensor.ndim - 1) + [0, padding_length] + tensor = torch.nn.functional.pad(tensor, padding, mode="constant", value=pad_value) + mask[-padding_length:] = 0 + return tensor, mask + + +class Siglip2ImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a fast SigLIP2 image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's dimensions to fit `max_num_patches` according to given `patch_size`. + Can be overridden by `do_resize` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch the image will be split to. + max_num_patches (`int`, *optional*, defaults to 256): + The image will be resized to have at most this number of patches, + and then padded in "patch" dimension to match this number exactly. + """ + + def __init__( + self, + do_resize: bool = True, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + patch_size: int = 16, + max_num_patches: int = 256, + **kwargs, + ): + super().__init__(**kwargs) + + image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5] + image_std = image_std if image_std is not None else [0.5, 0.5, 0.5] + + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + self.patch_size = patch_size + self.max_num_patches = max_num_patches + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: Optional[bool] = None, + patch_size: Optional[int] = None, + max_num_patches: Optional[int] = None, + device: Union["torch.device", str] = "cpu", + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + Patch size for processing, same as the patch size used in the model. + max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`): + Maximum number of patches per image, the image will be resized to have at most this number of patches. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + patch_size = patch_size if patch_size is not None else self.patch_size + max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches + + image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean + image_std = tuple(image_std) if isinstance(image_std, list) else image_std + + image_mean, image_std, interpolation = self._prepare_process_arguments( + do_normalize=do_normalize, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + image_mean=image_mean, + image_std=image_std, + resample=resample, + ) + + images = self._prepare_input_images( + images=images, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + device=device, + ) + + pixel_masks = [] + pixel_values = [] + spatial_shapes = [] + + for image in images: + if do_resize: + height, width = get_image_size_for_max_num_patches( + image_height=image.shape[1], + image_width=image.shape[2], + patch_size=patch_size, + max_num_patches=max_num_patches, + ) + side_dict = SizeDict(height=height, width=width) + image = self.resize(image=image, size=side_dict, interpolation=interpolation) + + image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) + + # (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels) + patches = convert_image_to_patches(image, patch_size) + patches, mask = pad_along_first_dim(patches, max_num_patches) + + num_patches_height = image.shape[1] // patch_size + num_patches_width = image.shape[2] // patch_size + + spatial_shapes.append((num_patches_height, num_patches_width)) + pixel_values.append(patches) + pixel_masks.append(mask) + + pixel_values = torch.stack(pixel_values) + pixel_masks = torch.stack(pixel_masks) + spatial_shapes = torch.tensor(spatial_shapes) + + batch_feature = BatchFeature( + data={ + "pixel_values": pixel_values, + "pixel_attention_mask": pixel_masks, + "spatial_shapes": spatial_shapes, + }, + tensor_type=return_tensors, + ) + return batch_feature + + +__all__ = ["Siglip2ImageProcessorFast"] diff --git a/src/transformers/models/siglip2/modeling_siglip2.py b/src/transformers/models/siglip2/modeling_siglip2.py new file mode 100644 index 000000000000..4785ea9f0177 --- /dev/null +++ b/src/transformers/models/siglip2/modeling_siglip2.py @@ -0,0 +1,1634 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_siglip2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn.init import _calculate_fan_in_and_fan_out + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Siglip2Config" + + +@dataclass +class Siglip2VisionOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Siglip2TextOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class Siglip2Output(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`Siglip2TextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`Siglip2VisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`Siglip2TextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`Siglip2VisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class Siglip2VisionEmbeddings(nn.Module): + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, + ) + + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + + @staticmethod + def resize_positional_embeddings( + positional_embeddings: torch.Tensor, + spatial_shapes: torch.LongTensor, + max_length: int, + ) -> torch.Tensor: + """ + Resize positional embeddings to image-specific size and pad to a fixed size. + + Args: + positional_embeddings (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + batch_size = spatial_shapes.shape[0] + embed_dim = positional_embeddings.shape[-1] + source_dtype = positional_embeddings.dtype + + resulted_positional_embeddings = torch.empty( + (batch_size, max_length, embed_dim), + device=positional_embeddings.device, + dtype=source_dtype, + ) + + # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation + positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) + + # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU + if positional_embeddings.device.type == "cpu": + positional_embeddings = positional_embeddings.to(torch.float32) + + for i in range(batch_size): + # (1, dim, height, width) -> (1, dim, target_height, target_width) + height, width = spatial_shapes[i] + resized_embeddings = F.interpolate( + positional_embeddings, + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # (1, dim, target_height, target_width) -> (target_height * target_width, dim) + resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) + + # Cast to original dtype + resized_embeddings = resized_embeddings.to(source_dtype) + + resulted_positional_embeddings[i, : height * width] = resized_embeddings + resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] + + return resulted_positional_embeddings + + def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) + spatial_shapes (`List[Tuple[int, int]]`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + + # Get positional resized and padded positional embeddings + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, -1 + ) + resized_positional_embeddings = self.resize_positional_embeddings( + positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1] + ) + + # Add positional embeddings to patch embeddings + embeddings = patch_embeds + resized_positional_embeddings + return embeddings + + +class Siglip2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class Siglip2FlashAttention2(Siglip2Attention): + """ + Siglip2Attention flash attention module. This module inherits from `Siglip2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + is_causal = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class Siglip2SdpaAttention(Siglip2Attention): + """ + Siglip2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Siglip2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + is_causal = False + + # Adapted from Siglip2Attention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Siglip2Model is using Siglip2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if self.is_causal and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + +class Siglip2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +SIGLIP2_ATTENTION_CLASSES = { + "eager": Siglip2Attention, + "flash_attention_2": Siglip2FlashAttention2, + "sdpa": Siglip2SdpaAttention, +} + + +class Siglip2EncoderLayer(nn.Module): + def __init__(self, config: Siglip2Config): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SIGLIP2_ATTENTION_CLASSES[config._attn_implementation](config=config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Siglip2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Siglip2EncoderLayer`]. + + Args: + config: Siglip2Config + """ + + def __init__(self, config: Siglip2Config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +SIGLIP2_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class Siglip2VisionTransformer(nn.Module): + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = Siglip2VisionEmbeddings(config) + self.encoder = Siglip2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = Siglip2MultiheadAttentionPoolingHead(config) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.embeddings(pixel_values, spatial_shapes) + + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + else: + encoder_attention_mask = attention_mask + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None + if not return_dict: + return (last_hidden_state, pooler_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class Siglip2TextEmbeddings(nn.Module): + def __init__(self, config: Siglip2TextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + f"Sequence length must be less than max_position_embeddings (got `sequence length`: " + f"{seq_length} and max_position_embeddings: {max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +SIGLIP2_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class Siglip2TextTransformer(nn.Module): + def __init__(self, config: Siglip2TextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = Siglip2TextEmbeddings(config) + self.encoder = Siglip2Encoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, config.projection_size) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + @add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2TextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: Siglip2's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +SIGLIP2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Siglip2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SIGLIP2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class Siglip2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Siglip2Config + base_model_prefix = "siglip2" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "Siglip2TextEmbeddings", + "Siglip2EncoderLayer", + "Siglip2VisionEmbeddings", + "Siglip2EncoderLayer", + "Siglip2MultiheadAttentionPoolingHead", + ] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Siglip2VisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, Siglip2Config) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, Siglip2Attention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, Siglip2MLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, Siglip2MultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, Siglip2Model): + logit_scale_init = torch.log(torch.tensor(1.0)) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, Siglip2ForImageClassification): + nn.init.normal_( + module.classifier.weight, + std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@add_start_docstrings( + """The text model from Siglip2 without any head or projection on top.""", + SIGLIP2_START_DOCSTRING, +) +class Siglip2TextModel(Siglip2PreTrainedModel): + config_class = Siglip2TextConfig + + def __init__(self, config: Siglip2TextConfig): + super().__init__(config) + self.text_model = Siglip2TextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2TextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, Siglip2TextModel + + >>> model = Siglip2TextModel.from_pretrained("google/siglip2-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class Siglip2MultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) + self.num_heads = config.num_attention_heads + + def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + if attention_mask is not None: + target_len, source_len = probe.shape[1], hidden_state.shape[1] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len) + attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1) + attention_mask = attention_mask.reshape(-1, target_len, source_len) + + hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@add_start_docstrings( + """The vision model from Siglip2 without any head or projection on top.""", + SIGLIP2_START_DOCSTRING, +) +class Siglip2VisionModel(Siglip2PreTrainedModel): + config_class = Siglip2VisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: Siglip2VisionConfig): + super().__init__(config) + + self.vision_model = Siglip2VisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Siglip2VisionModel + + >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(SIGLIP2_START_DOCSTRING) +class Siglip2Model(Siglip2PreTrainedModel): + config_class = Siglip2Config + + def __init__(self, config: Siglip2Config): + super().__init__(config) + + if not isinstance(config.text_config, Siglip2TextConfig): + raise TypeError( + "config.text_config is expected to be of type Siglip2TextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, Siglip2VisionConfig): + raise TypeError( + "config.vision_config is expected to be of type Siglip2VisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + # First, initialize the text and vision models with proper attention implementation + text_model = Siglip2TextModel._from_config(text_config) + vision_model = Siglip2VisionModel._from_config(vision_config) + + # Second, get the text and vision submodules (for backward compatibility) + self.text_model = text_model.text_model + self.vision_model = vision_model.vision_model + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`Siglip2TextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`Siglip2VisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use Siglip2Model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Siglip2Output, config_class=Siglip2Config) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Siglip2Output]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + text_embeds = text_outputs[1] + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) + + logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) + logits_per_text = logits_per_text * logit_scale.exp() + logit_bias + + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return Siglip2Output( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +@add_start_docstrings( + """ + Siglip2 vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of + the patch tokens) e.g. for ImageNet. + """, + SIGLIP2_START_DOCSTRING, +) +class Siglip2ForImageClassification(Siglip2PreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: Siglip2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + + # Create the vision model with proper attention + # and take only vision_model submodule (for backward compatibility) + vision_model = Siglip2VisionModel._from_config(config.vision_config) + self.vision_model = vision_model.vision_model + + # Classifier head + self.classifier = ( + nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SIGLIP2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # note: we are loading a `Siglip2Model` from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. + >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224") + >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the two classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: LABEL_1 + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vision_model( + pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # average pool the patch tokens + if pixel_attention_mask is not None: + pool_mask = pixel_attention_mask[..., None].to(sequence_output.device) + sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1) + else: + sequence_output = torch.mean(sequence_output, dim=1) + + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Siglip2Model", + "Siglip2PreTrainedModel", + "Siglip2TextModel", + "Siglip2VisionModel", + "Siglip2ForImageClassification", +] diff --git a/src/transformers/models/siglip2/modular_siglip2.py b/src/transformers/models/siglip2/modular_siglip2.py new file mode 100644 index 000000000000..6fac0030511f --- /dev/null +++ b/src/transformers/models/siglip2/modular_siglip2.py @@ -0,0 +1,537 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.models.siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig +from transformers.models.siglip.modeling_siglip import ( + BaseModelOutputWithPooling, + ImageClassifierOutput, + SiglipForImageClassification, + SiglipModel, + SiglipMultiheadAttentionPoolingHead, + SiglipOutput, + SiglipPreTrainedModel, + SiglipTextModel, + SiglipTextModelOutput, + SiglipVisionModel, + SiglipVisionModelOutput, + SiglipVisionTransformer, +) + +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask + + +class Siglip2TextConfig(SiglipTextConfig): + pass + + +class Siglip2VisionConfig(SiglipVisionConfig): + r""" + This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a + Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2 + [google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + num_patches (`int`, *optional*, defaults to 256): + The number of patches in the image with the size of (`patch_size`, `patch_size`). + The image is resized to fill maximum of this number of patches, and to preserve + the aspect ratio. In case the resulted number of patches is lower, the image is + padded in "patch" dimension. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + Example: + + ```python + >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel + + >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration + >>> configuration = Siglip2VisionConfig() + + >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration + >>> model = Siglip2VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + num_patches=256, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_patches = num_patches + del self.image_size + + +class Siglip2Config(SiglipConfig): + pass + + +class Siglip2VisionOutput(SiglipVisionModelOutput): + pass + + +class Siglip2TextOutput(SiglipTextModelOutput): + pass + + +class Siglip2Output(SiglipOutput): + pass + + +class Siglip2VisionEmbeddings(nn.Module): + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, + ) + + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + + @staticmethod + def resize_positional_embeddings( + positional_embeddings: torch.Tensor, + spatial_shapes: torch.LongTensor, + max_length: int, + ) -> torch.Tensor: + """ + Resize positional embeddings to image-specific size and pad to a fixed size. + + Args: + positional_embeddings (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + batch_size = spatial_shapes.shape[0] + embed_dim = positional_embeddings.shape[-1] + source_dtype = positional_embeddings.dtype + + resulted_positional_embeddings = torch.empty( + (batch_size, max_length, embed_dim), + device=positional_embeddings.device, + dtype=source_dtype, + ) + + # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation + positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) + + # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU + if positional_embeddings.device.type == "cpu": + positional_embeddings = positional_embeddings.to(torch.float32) + + for i in range(batch_size): + # (1, dim, height, width) -> (1, dim, target_height, target_width) + height, width = spatial_shapes[i] + resized_embeddings = F.interpolate( + positional_embeddings, + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # (1, dim, target_height, target_width) -> (target_height * target_width, dim) + resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1) + + # Cast to original dtype + resized_embeddings = resized_embeddings.to(source_dtype) + + resulted_positional_embeddings[i, : height * width] = resized_embeddings + resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] + + return resulted_positional_embeddings + + def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) + spatial_shapes (`List[Tuple[int, int]]`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + + # Get positional resized and padded positional embeddings + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, -1 + ) + resized_positional_embeddings = self.resize_positional_embeddings( + positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1] + ) + + # Add positional embeddings to patch embeddings + embeddings = patch_embeds + resized_positional_embeddings + return embeddings + + +class Siglip2VisionTransformer(SiglipVisionTransformer): + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + # Update: add `spatial_shapes` and `attention_mask` + def forward( + self, + pixel_values: torch.FloatTensor, + attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.embeddings(pixel_values, spatial_shapes) + + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + else: + encoder_attention_mask = attention_mask + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None + if not return_dict: + return (last_hidden_state, pooler_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class Siglip2PreTrainedModel(SiglipPreTrainedModel): + pass + + +class Siglip2TextModel(SiglipTextModel): + pass + + +class Siglip2MultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead): + def __init__(self, config: Siglip2VisionConfig): + super().__init__(config) + self.num_heads = config.num_attention_heads + + def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + if attention_mask is not None: + target_len, source_len = probe.shape[1], hidden_state.shape[1] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len) + attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1) + attention_mask = attention_mask.reshape(-1, target_len, source_len) + + hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class Siglip2VisionModel(SiglipVisionModel): + # Update: add `spatial_shapes` and `pixel_attention_mask` + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class Siglip2Model(SiglipModel): + # Update: add `spatial_shapes` and `pixel_attention_mask` + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + # Use Siglip2Model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] + + return pooled_output + + # Update: add `spatial_shapes` and `pixel_attention_mask` + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Siglip2Output]: + # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + text_embeds = text_outputs[1] + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) + + logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) + logits_per_text = logits_per_text * logit_scale.exp() + logit_bias + + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return Siglip2Output( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +class Siglip2ForImageClassification(SiglipForImageClassification): + # Update: add `spatial_shapes` and `pixel_attention_mask` + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_attention_mask: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.LongTensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vision_model( + pixel_values, + attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # average pool the patch tokens + if pixel_attention_mask is not None: + pool_mask = pixel_attention_mask[..., None].to(sequence_output.device) + sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1) + else: + sequence_output = torch.mean(sequence_output, dim=1) + + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Siglip2Config", + "Siglip2TextConfig", + "Siglip2VisionConfig", + "Siglip2Model", + "Siglip2PreTrainedModel", + "Siglip2TextModel", + "Siglip2VisionModel", + "Siglip2ForImageClassification", +] diff --git a/src/transformers/models/siglip2/processing_siglip2.py b/src/transformers/models/siglip2/processing_siglip2.py new file mode 100644 index 000000000000..4f4ec33f2f19 --- /dev/null +++ b/src/transformers/models/siglip2/processing_siglip2.py @@ -0,0 +1,171 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for SigLIP2. +""" + +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class Siglip2ImagesKwargs(ImagesKwargs, total=False): + max_num_patches: Optional[int] + patch_size: Optional[int] + + +class Siglip2ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Siglip2ImagesKwargs + + _defaults = { + "text_kwargs": { + "padding": "max_length", + "truncation": True, + "max_length": 64, + }, + "images_kwargs": { + "max_num_patches": 256, + "patch_size": 16, + }, + } + + +class Siglip2Processor(ProcessorMixin): + r""" + Constructs a Siglip2 processor which wraps a Siglip2 image processor and a Gemma tokenizer into a single processor. + + [`Siglip2Processor`] offers all the functionalities of [`Siglip2ImageProcessor`] and [`GemmaTokenizerFast`]. See the + [`~Siglip2Processor.__call__`] and [`~Siglip2Processor.decode`] for more information. + + Args: + image_processor ([`Siglip2ImageProcessor`]): + The image processor is a required input. + tokenizer ([`GemmaTokenizerFast`]): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images: Optional[Union[ImageInput, List[ImageInput], List[List[ImageInput]]]] = None, + text: Optional[Union[TextInput, "PreTokenizedInput", List[TextInput], List["PreTokenizedInput"]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Siglip2ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to GemmaTokenizerFast's [`~GemmaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` argument to + Siglip2ImageProcessor's [`~Siglip2ImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `max_length`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*, defaults to 64): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*, defaults to `True`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'pt'`): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_attention_mask** -- Attention mask for the pixel values. Returned when `images` is not `None`. + - **spatial_shapes** -- The number of horizontal and vertical patches per image. + Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Siglip2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + if images is not None: + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + + if text is not None and images is not None: + encoding.update(image_features) + return encoding + elif text is not None: + return encoding + else: + return_tensors = output_kwargs["common_kwargs"]["return_tensors"] + return BatchFeature(data=dict(**image_features), tensor_type=return_tensors) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Siglip2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Siglip2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["Siglip2Processor"] diff --git a/src/transformers/pipelines/zero_shot_image_classification.py b/src/transformers/pipelines/zero_shot_image_classification.py index c53b515dcccd..7e49ba0efaaf 100644 --- a/src/transformers/pipelines/zero_shot_image_classification.py +++ b/src/transformers/pipelines/zero_shot_image_classification.py @@ -145,8 +145,11 @@ def preprocess( inputs = inputs.to(self.torch_dtype) inputs["candidate_labels"] = candidate_labels sequences = [hypothesis_template.format(x) for x in candidate_labels] - padding = "max_length" if self.model.config.model_type == "siglip" else True - text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=padding, **tokenizer_kwargs) + tokenizer_default_kwargs = {"padding": True} + if "siglip" in self.model.config.model_type: + tokenizer_default_kwargs.update(padding="max_length", max_length=64, truncation=True) + tokenizer_default_kwargs.update(tokenizer_kwargs) + text_inputs = self.tokenizer(sequences, return_tensors=self.framework, **tokenizer_default_kwargs) inputs["text_inputs"] = [text_inputs] return inputs @@ -170,7 +173,7 @@ def _forward(self, model_inputs): def postprocess(self, model_outputs): candidate_labels = model_outputs.pop("candidate_labels") logits = model_outputs["logits"][0] - if self.framework == "pt" and self.model.config.model_type == "siglip": + if self.framework == "pt" and "siglip" in self.model.config.model_type: probs = torch.sigmoid(logits).squeeze(-1) scores = probs.tolist() if not isinstance(scores, list): diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e04a785f2c94..d409238588d0 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8849,6 +8849,41 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Siglip2ForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Siglip2Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Siglip2PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Siglip2TextModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Siglip2VisionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class SmolVLMForConditionalGeneration(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index 87b60fbc0463..f393a8f1265d 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -107,6 +107,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) +class Siglip2ImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class ViTImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index d0e59a7d5c07..92906e005f90 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -639,6 +639,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class Siglip2ImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class SmolVLMImageProcessor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/siglip2/__init__.py b/tests/models/siglip2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/siglip2/test_image_processing_siglip2.py b/tests/models/siglip2/test_image_processing_siglip2.py new file mode 100644 index 000000000000..dd96db9c5671 --- /dev/null +++ b/tests/models/siglip2/test_image_processing_siglip2.py @@ -0,0 +1,200 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import requests + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import Siglip2ImageProcessor + + +if is_torchvision_available(): + from transformers import Siglip2ImageProcessorFast + + +class Siglip2ImageProcessingTester: + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=None, + do_rescale=True, + rescale_factor=1 / 255, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + resample=None, + patch_size=16, + max_num_patches=256, + ): + size = size if size is not None else {"height": 18, "width": 18} + resample = resample if resample is not None else Image.Resampling.BILINEAR + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.resample = resample + self.patch_size = patch_size + self.max_num_patches = max_num_patches + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "do_rescale": self.do_rescale, + "rescale_factor": self.rescale_factor, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "resample": self.resample, + "patch_size": self.patch_size, + "max_num_patches": self.max_num_patches, + } + + def expected_output_image_shape(self, images): + return self.max_num_patches, self.patch_size * self.patch_size * self.num_channels + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest with CLIP->Siglip2 +class Siglip2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = Siglip2ImageProcessor if is_vision_available() else None + fast_image_processing_class = Siglip2ImageProcessorFast if is_torchvision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = Siglip2ImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + # Ignore copy + def test_image_processor_properties(self): + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "resample")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "patch_size")) + self.assertTrue(hasattr(image_processing, "max_num_patches")) + + # Ignore copy + def test_image_processor_from_dict_with_kwargs(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.max_num_patches, 256) + self.assertEqual(image_processor.patch_size, 16) + + image_processor = self.image_processing_class.from_dict( + self.image_processor_dict, patch_size=32, max_num_patches=512 + ) + self.assertEqual(image_processor.patch_size, 32) + self.assertEqual(image_processor.max_num_patches, 512) + + @unittest.skip(reason="not supported") + # Ignore copy + def test_call_numpy_4_channels(self): + pass + + # increase mean tolerance to 1e-3 -> 2e-3 + # Ignore copy + def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + dummy_image = Image.open( + requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw + ) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-1) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 2e-3 + ) + + # increase mean tolerance to 1e-3 -> 2e-3 + # Ignore copy + def test_slow_fast_equivalence_batched(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: + self.skipTest( + reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" + ) + + dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_images, return_tensors="pt") + + torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-1) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 2e-3 + ) diff --git a/tests/models/siglip2/test_modeling_siglip2.py b/tests/models/siglip2/test_modeling_siglip2.py new file mode 100644 index 000000000000..dea49ececa9b --- /dev/null +++ b/tests/models/siglip2/test_modeling_siglip2.py @@ -0,0 +1,989 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Siglip2 model.""" + +import inspect +import tempfile +import unittest +from typing import Tuple + +import numpy as np +from parameterized import parameterized +from pytest import mark + +from transformers import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig +from transformers.testing_utils import ( + require_flash_attn, + require_torch, + require_torch_gpu, + require_torch_sdpa, + require_vision, + slow, + torch_device, +) +from transformers.utils import ( + is_torch_available, + is_torch_bf16_available_on_device, + is_torch_fp16_available_on_device, + is_torch_sdpa_available, + is_vision_available, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + floats_tensor, + ids_tensor, + is_flaky, + random_attention_mask, +) +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import Siglip2ForImageClassification, Siglip2Model, Siglip2TextModel, Siglip2VisionModel + +if is_torch_sdpa_available(): + from torch.nn.attention import SDPBackend, sdpa_kernel + +if is_vision_available(): + from PIL import Image, ImageDraw + + from transformers import Siglip2Processor + + +class Siglip2ModelTesterMixin(ModelTesterMixin): + def test_sdpa_can_dispatch_composite_models(self): + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # Load the model with SDPA + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + # Load model with eager attention + model_eager = model_class.from_pretrained( + tmpdirname, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + # SigLip has one shared cls attr for all models, so we assign both submodels heer + vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager" + + if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"): + self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn) + self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn) + self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.text_model.config._attn_implementation == "eager") + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + + def test_eager_matches_sdpa_inference( + self, + torch_dtype: str, + use_attention_mask_options: Tuple[bool, ...] = (True, False), + logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"), + ): + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Convert to torch dtype + dtypes = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + torch_dtype = dtypes[torch_dtype] + + atols = { + torch.float32: 1e-5, + torch.bfloat16: 3e-2, + torch.float16: 5e-3, + } + rtols = { + torch.float32: 1e-4, + torch.bfloat16: 3e-2, + torch.float16: 5e-3, + } + + atol = atols[torch_dtype] + rtol = rtols[torch_dtype] + + def get_mean_reldiff(msg, current_case, x, ref, atol, rtol): + return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # Load the model with SDPA + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + model_sdpa = model_sdpa.eval().to(torch_device) + + # Load model with eager attention + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time, + # but it would be nicer to have an efficient way to use parameterized.expand + cases = [ + (use_mask, output_attentions, sdpa_backend, batch_size) + for use_mask in use_attention_mask_options + for output_attentions in [True, False] + for sdpa_backend in [ + SDPBackend.MATH, + [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH], + [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], + [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], + ] + for batch_size in [1, 5] + ] + fail_cases = [] + + for use_mask, output_attentions, sdpa_backend, batch_size in cases: + processed_inputs = inputs_dict.copy() + + # convert to torch_dtype + if "pixel_values" in processed_inputs: + processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype) + + # slice for different batch sizes + for key in processed_inputs.keys(): + if isinstance(processed_inputs[key], (torch.Tensor, list, tuple)): + processed_inputs[key] = processed_inputs[key][:batch_size] + + # set attention mask with left padding + if not use_mask: + processed_inputs.pop("attention_mask", None) + else: + dummy_attention_mask = processed_inputs["attention_mask"] + dummy_attention_mask[:] = 1 + dummy_attention_mask[:, :1] = 0 + processed_inputs["attention_mask"] = dummy_attention_mask + + processed_inputs["output_attentions"] = output_attentions + processed_inputs["output_hidden_states"] = True + + current_case = ( + f"padding_side=left, use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}" + ) + + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + + with torch.no_grad(): + try: + with sdpa_kernel(sdpa_backend): + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) + except Exception as e: + fail_cases.append(f"{current_case}: {e}") + continue + + for key in logit_keys: + eager_logits = outputs_eager[key] + sdpa_logits = outputs_sdpa[key] + + if use_mask: + eager_logits = eager_logits[:, 1:] + sdpa_logits = sdpa_logits[:, 1:] + + is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol) + if not is_close: + fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol)) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence(self): + dtype = torch.float16 + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + # Prepare inputs + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if "pixel_values" in inputs_dict: + inputs_dict["pixel_values"] = inputs_dict["pixel_values"].to(dtype) + + # Separate masks + attention_masks = {} + if "attention_mask" in inputs_dict: + # attention_masks["attention_mask"] = inputs_dict.pop("attention_mask") + inputs_dict["attention_mask"] = None + if "pixel_attention_mask" in inputs_dict: + attention_masks["pixel_attention_mask"] = inputs_dict.pop("pixel_attention_mask") + inputs_dict["pixel_attention_mask"] = None + + # Save and load model with flash attention 2 and eager attentions + with tempfile.TemporaryDirectory() as tmp_dir: + model = model_class(config) + model.save_pretrained(tmp_dir) + + model = model_class.from_pretrained(tmp_dir, torch_dtype=dtype) + model_fa = model_class.from_pretrained( + tmp_dir, torch_dtype=dtype, attn_implementation="flash_attention_2" + ) + + model_fa.to(torch_device) + model.to(torch_device) + + # Run forward pass without attention masks + with torch.no_grad(): + outputs = model(**inputs_dict, output_hidden_states=True) + outputs_fa = model_fa(**inputs_dict, output_hidden_states=True) + + # Choose which key to compare + key = [k for k in ["logits", "logits_per_image", "last_hidden_state"] if k in outputs][0] + + torch.testing.assert_close(outputs[key], outputs_fa[key], atol=4e-2, rtol=4e-2) + + # Run forward pass with attention masks + inputs_dict.update(attention_masks) + with torch.no_grad(): + outputs = model(**inputs_dict, output_hidden_states=True) + outputs_fa = model_fa(**inputs_dict, output_hidden_states=True) + + output_tensor = outputs[key] + output_tensor_fa = outputs_fa[key] + + # Mask out padded tokens, they are different for SDPA and Flash Attention 2 + if key == "last_hidden_state" and "pixel_attention_mask" in inputs_dict: + output_tensor = output_tensor * inputs_dict["pixel_attention_mask"][..., None] + output_tensor_fa = output_tensor_fa * inputs_dict["pixel_attention_mask"][..., None] + elif key == "last_hidden_state" and inputs_dict.get("attention_mask", None) is not None: + output_tensor = output_tensor * inputs_dict["attention_mask"][..., None] + output_tensor_fa = output_tensor_fa * inputs_dict["attention_mask"][..., None] + + torch.testing.assert_close(output_tensor, output_tensor_fa, atol=4e-2, rtol=4e-2) + + # Check with inference + dropout + model.train() + _ = model_fa(**inputs_dict, output_hidden_states=True) + + @unittest.skip(reason="Siglip2 has default right padding (tested in test_flash_attn_2_inference_equivalence)") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + + @unittest.skip(reason="SDPA can't dispatch on flash with not None `attention_mask`") + def test_sdpa_can_dispatch_on_flash(self): + pass + + +class Siglip2VisionModelTester: + def __init__( + self, + parent, + batch_size=12, + num_patches=16, + image_num_patches=24, + patch_size=2, + num_channels=3, + is_training=True, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + dropout=0.1, + attention_dropout=0.1, + initializer_range=0.02, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.num_patches = num_patches + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.scope = scope + self.seq_length = image_num_patches + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [self.batch_size, self.seq_length, self.num_channels * self.patch_size * self.patch_size] + ) + pixel_attention_mask = torch.zeros(self.batch_size, self.seq_length, device=torch_device, dtype=torch.long) + + spatial_shapes = [ + (height, width) + for height in range(1, self.seq_length) + for width in range(1, self.seq_length) + if height * width <= self.seq_length + ] * self.batch_size + spatial_shapes = spatial_shapes[: self.batch_size] + spatial_shapes = torch.tensor(spatial_shapes, device=torch_device, dtype=torch.long) + + for i, (height, width) in enumerate(spatial_shapes): + pixel_attention_mask[i, : height * width] = 1 + + config = self.get_config() + + return config, pixel_values, pixel_attention_mask, spatial_shapes + + def get_config(self): + return Siglip2VisionConfig( + num_patches=self.num_patches, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, pixel_values, pixel_attention_mask, spatial_shapes): + model = Siglip2VisionModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values, pixel_attention_mask, spatial_shapes) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_attention_mask, spatial_shapes = self.prepare_config_and_inputs() + inputs_dict = { + "pixel_values": pixel_values, + "pixel_attention_mask": pixel_attention_mask, + "spatial_shapes": spatial_shapes, + } + return config, inputs_dict + + +@require_torch +class Siglip2VisionModelTest(Siglip2ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as SIGLIP2 does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (Siglip2VisionModel,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + # MP works but offload doesn't work when the MultiheadAttention is offloaded + # TODO: One potential solution would be to add to set preload_module_classes = ["Siglip2MultiheadAttentionPoolingHead"] + # in the dispatch_model function + test_cpu_offload = False + test_disk_offload_safetensors = False + test_disk_offload_bin = False + + def setUp(self): + self.model_tester = Siglip2VisionModelTester(self) + self.config_tester = ConfigTester( + self, config_class=Siglip2VisionConfig, has_text_modality=False, hidden_size=37 + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="SIGLIP2 does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="Siglip2VisionModel does not support standalone training") + def test_training(self): + pass + + @unittest.skip(reason="Siglip2VisionModel does not support standalone training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="Siglip2VisionModel does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="Siglip2VisionModel does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Siglip2VisionModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="Siglip2VisionModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation") + def test_initialization(self): + pass + + @slow + def test_model_from_pretrained(self): + model_name = "google/siglip2-base-patch16-naflex" + model = Siglip2VisionModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + @is_flaky() + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + super().test_eager_matches_sdpa_inference( + torch_dtype=torch_dtype, + logit_keys=("pooler_output", "last_hidden_state"), + use_attention_mask_options=(False,), + ) + + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + + +class Siglip2TextModelTester: + def __init__( + self, + parent, + batch_size=12, + seq_length=7, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + dropout=0.1, + attention_dropout=0.1, + max_position_embeddings=512, + initializer_range=0.02, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + if input_mask is not None: + batch_size, seq_length = input_mask.shape + rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,)) + for batch_idx, start_index in enumerate(rnd_start_indices): + input_mask[batch_idx, :start_index] = 1 + input_mask[batch_idx, start_index:] = 0 + + config = self.get_config() + + return config, input_ids, input_mask + + def get_config(self): + return Siglip2TextConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + max_position_embeddings=self.max_position_embeddings, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, input_ids, input_mask): + model = Siglip2TextModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, input_mask = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class Siglip2TextModelTest(Siglip2ModelTesterMixin, unittest.TestCase): + all_model_classes = (Siglip2TextModel,) if is_torch_available() else () + fx_compatible = False + test_resize_embeddings = False + test_pruning = False + test_head_masking = False + model_split_percents = [0.5, 0.8, 0.9] + + def setUp(self): + self.model_tester = Siglip2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=Siglip2TextConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="Siglip2TextModel does not support standalone training") + def test_training(self): + pass + + @unittest.skip(reason="Siglip2TextModel does not support standalone training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="Siglip2TextModel does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="Siglip2TextModel does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Siglip2 does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Siglip2TextModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="Siglip2TextModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation") + def test_initialization(self): + pass + + @slow + def test_model_from_pretrained(self): + model_name = "google/siglip2-base-patch16-naflex" + model = Siglip2TextModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + @is_flaky() + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + super().test_eager_matches_sdpa_inference( + torch_dtype=torch_dtype, + logit_keys=("pooler_output", "last_hidden_state"), + use_attention_mask_options=(False, True), + ) + + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + + +class Siglip2ModelTester: + def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): + if text_kwargs is None: + text_kwargs = {} + if vision_kwargs is None: + vision_kwargs = {} + + self.parent = parent + self.text_model_tester = Siglip2TextModelTester(parent, **text_kwargs) + self.vision_model_tester = Siglip2VisionModelTester(parent, **vision_kwargs) + self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test + self.is_training = is_training + + def prepare_config_and_inputs(self): + text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs() + vision_config, pixel_values, pixel_attention_mask, spatial_shapes = ( + self.vision_model_tester.prepare_config_and_inputs() + ) + + config = self.get_config() + + return config, input_ids, attention_mask, pixel_values, pixel_attention_mask, spatial_shapes + + def get_config(self): + return Siglip2Config.from_text_vision_configs( + self.text_model_tester.get_config(), + self.vision_model_tester.get_config(), + ) + + def create_and_check_model( + self, config, input_ids, attention_mask, pixel_values, pixel_attention_mask, spatial_shapes + ): + model = Siglip2Model(config).to(torch_device).eval() + with torch.no_grad(): + result = model(input_ids, pixel_values, pixel_attention_mask, spatial_shapes, attention_mask) + self.parent.assertEqual( + result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size) + ) + self.parent.assertEqual( + result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask, pixel_values, pixel_attention_mask, spatial_shapes = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "pixel_values": pixel_values, + "pixel_attention_mask": pixel_attention_mask, + "spatial_shapes": spatial_shapes, + "attention_mask": attention_mask, + "position_ids": None, + "return_loss": False, + } + return config, inputs_dict + + +@require_torch +class Siglip2ModelTest(Siglip2ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Siglip2Model,) if is_torch_available() else () + pipeline_model_mapping = {"feature-extraction": Siglip2Model} if is_torch_available() else {} + fx_compatible = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + # MP works but offload doesn't work when the MultiheadAttention is offloaded + # TODO: One potential solution would be to add to set preload_module_classes = ["Siglip2MultiheadAttentionPoolingHead"] + # in the dispatch_model function + test_cpu_offload = False + test_disk_offload_safetensors = False + test_disk_offload_bin = False + _is_composite = True + + def setUp(self): + self.model_tester = Siglip2ModelTester(self) + self.config_tester = ConfigTester(self, config_class=Siglip2Config, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Inputs_embeds is tested in individual model tests") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Siglip2Model does not have input/output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation") + def test_initialization(self): + pass + + def test_load_vision_text_config(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + # Save Siglip2Config and check if we can load Siglip2VisionConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + vision_config = Siglip2VisionConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict()) + + # Save Siglip2Config and check if we can load Siglip2TextConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + text_config = Siglip2TextConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) + + @slow + def test_model_from_pretrained(self): + model_name = "google/siglip2-base-patch16-naflex" + model = Siglip2Model.from_pretrained(model_name) + self.assertIsNotNone(model) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest("Siglip2 does not support right padding") + + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + @is_flaky() + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + super().test_eager_matches_sdpa_inference( + torch_dtype=torch_dtype, + logit_keys=("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"), + use_attention_mask_options=(False, True), + ) + + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + + +class Siglip2ForImageClassificationModelTester(Siglip2ModelTester): + def __init__(self, parent): + super().__init__(parent) + self.batch_size = self.vision_model_tester.batch_size + self.num_hidden_layers = self.vision_model_tester.num_hidden_layers + self.hidden_size = self.vision_model_tester.hidden_size + self.seq_length = self.vision_model_tester.seq_length + + def prepare_config_and_inputs(self): + _, pixel_values, pixel_attention_mask, spatial_shapes = self.vision_model_tester.prepare_config_and_inputs() + config = self.get_config() + + return config, pixel_values, pixel_attention_mask, spatial_shapes + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, pixel_attention_mask, spatial_shapes = config_and_inputs + inputs_dict = { + "pixel_values": pixel_values, + "pixel_attention_mask": pixel_attention_mask, + "spatial_shapes": spatial_shapes, + } + return config, inputs_dict + + +@require_torch +class Siglip2ForImageClassificationModelTest(Siglip2ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Siglip2ForImageClassification,) if is_torch_available() else () + pipeline_model_mapping = {"image-classification": Siglip2ForImageClassification} if is_torch_available() else {} + fx_compatible = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + # MP works but offload doesn't work when the MultiheadAttention is offloaded + # TODO: One potential solution would be to add to set preload_module_classes = ["Siglip2MultiheadAttentionPoolingHead"] + # in the dispatch_model function + test_cpu_offload = False + test_disk_offload_safetensors = False + test_disk_offload_bin = False + _is_composite = True + + def setUp(self): + self.model_tester = Siglip2ForImageClassificationModelTester(self) + + @unittest.skip(reason="Siglip2ForImageClassification does not support inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Siglip2ForImageClassification does not support inputs_embeds") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Siglip2ForImageClassification does not support gradient checkpointing yet") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="Siglip2ForImageClassification does not support gradient checkpointing yet") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="Siglip2ForImageClassification does not support gradient checkpointing yet") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation") + def test_initialization(self): + pass + + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + @is_flaky() + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + super().test_eager_matches_sdpa_inference( + torch_dtype=torch_dtype, logit_keys=("logits",), use_attention_mask_options=(False,) + ) + + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + + +# Draw a circle on an images with different aspect ratios +def prepare_images(): + shapes = [(224, 224), (1024, 1024), (224, 1024)] + images = [] + for height, width in shapes: + image = Image.new("RGB", (width, height), color="red") + draw = ImageDraw.Draw(image) + center_x = image.width // 2 + center_y = image.height // 2 + radius = min(center_x, center_y) // 8 * 7 + draw.ellipse( + (center_x - radius, center_y - radius, center_x + radius, center_y + radius), + fill="blue", + outline="green", + width=image.width // 20, + ) + images.append(image) + return images + + +@require_vision +@require_torch +class Siglip2ModelIntegrationTest(unittest.TestCase): + @slow + def test_inference(self): + model_name = "google/siglip2-base-patch16-naflex" + model = Siglip2Model.from_pretrained(model_name).to(torch_device) + processor = Siglip2Processor.from_pretrained(model_name) + + images = prepare_images() + text = [ + "circle", + "ellipsoid", + "blue circle on red background", + "blue circle with green border on red background", + "green circle on red background", + "a dog", + "a blue dog with a green border on a red background", + ] + + inputs = processor(text=text, images=images, return_tensors="pt") + inputs = inputs.to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + logits_per_image = outputs.logits_per_image + logits_per_text = outputs.logits_per_text + + # verify the logits shape + self.assertEqual( + logits_per_image.shape, + torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])), + ) + self.assertEqual( + logits_per_text.shape, + torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])), + ) + + # verify the logits values + # fmt: off + expected_logits_per_text = torch.tensor( + [ + [ 1.0195, -0.0280, -1.4468], + [ -4.5395, -6.2269, -1.5667], + [ 4.1757, 5.0358, 3.5159], + [ 9.4264, 10.1879, 6.3353], + [ 2.4409, 3.1058, 4.5491], + [-12.3230, -13.7355, -13.4632], + [ 1.1520, 1.1687, -1.9647], + ] + ).to(torch_device) + # fmt: on + + torch.testing.assert_close(outputs.logits_per_text, expected_logits_per_text, rtol=1e-3, atol=1e-3) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2cb48b7a3cae..d0adee987d52 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1175,6 +1175,10 @@ def _create_and_check_torchscript(self, config, inputs_dict): traced_model = torch.jit.trace( model, (pixel_values, prompt_pixel_values, prompt_masks), check_trace=False ) # when traced model is checked, an error is produced due to name mangling + elif "Siglip2" in model_class.__name__: + outputs = model(**inputs) + example_inputs = [t for t in inputs.values() if isinstance(t, torch.Tensor)] + traced_model = torch.jit.trace(model, example_inputs, check_trace=False) else: main_input = inputs[main_input_name] @@ -3035,6 +3039,7 @@ def test_mismatched_shapes_have_properly_initialized_weights(self): "wav2vec2.masked_spec_embed", "Wav2Vec2ForSequenceClassification", "CLIPForImageClassification", + "Siglip2ForImageClassification", "RegNetForImageClassification", "ResNetForImageClassification", "UniSpeechSatForSequenceClassification", diff --git a/utils/check_repo.py b/utils/check_repo.py index f8769cbb1b13..3b3dddf9cf63 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -334,6 +334,8 @@ "SegGptForImageSegmentation", "SiglipVisionModel", "SiglipTextModel", + "Siglip2VisionModel", + "Siglip2TextModel", "ChameleonVQVAE", # no autoclass for VQ-VAE models "VitPoseForPoseEstimation", "CLIPTextModel",