Skip to content

Commit

Permalink
Add SigLIP 2 (#36323)
Browse files Browse the repository at this point in the history
* Docs

* Inits

* Auto classes

* Add siglip base

* Add base tests

* Fix Siglip V1 for fix res version

* Add image processor

* Update conversion

* Experimenting with vectorized embeddings

* Fixup

* Add modular Siglip2Processor

* Add modular configuration

* Rename num patches

* Correct image and text features merging

* Working conversion script

* Refactoring conversion script

* Remove unused code in conversion script

* Shorten dict a bit

* Refactoring conversion

* Done conversion refactoring

* Fixup

* Modular siglip2

* Make model exportable and compilable without graph breaks

* Remove position_ids from image_processor

* REmove position ids from modeling file

* Update modular

* Type hint

* Fixup

* Set defaults to processor

* Add integration test

* Revert spatial shapes back to tensor

* Change order

* Fix most of the tests

* Fix docstring

* Remove interpolate_pos_encoding arg (not needed)

* Update docs

* Standardize processing

* Fix attention_mask in vision head

* Siglip v1: remove double transpose in FA2

* Update modular file

* Update FA2 test

* Update expected logits

* Fix interpolation for siglip2 image processor

* Skip init test

* Skip dispatch on flash test

* Fix modeling tests

* Fixup

* Add dummy objects

* Fix some docstrings

* Add siglip2 in index.md

* Fix consistency

* Add docs

* Remove size and data format

* Add image processor tests

* Fix

* Add fast image processor

* Fix style

* Fix

* Docs

* Set lowercase for tokenizer

* Adjust head size for Siglip v1

* Update siglip2 for consistency with siglip1

* Update siglip2 conversion

* Update pipeline

* Update checkpoints in tests

* Update checkpoint name

* Fix pooling for image classification model

* Fix FA2 test

* Update processor

* Fix check repo

* Update docs

* Fix typos

* Fix docstring for fast image processor

* Add siglip2 to FA2 docs

* Fix fast ip tests

* Fix constitency

* Fix tokenizer class for siglip v1

* Fix missing header

* Refactor scaling for clip, siglip, siglip2

* Remove unused imports

* Make fast IP default for siglip2

* Update docs

* Update checkpoints

* Update modular

* Update paper link

* Fixup

* Fix name in toctree

* Fix test
  • Loading branch information
qubvel authored Feb 21, 2025
1 parent 14552cb commit a957b79
Show file tree
Hide file tree
Showing 33 changed files with 5,570 additions and 122 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,8 @@
title: Segment Anything
- local: model_doc/siglip
title: SigLIP
- local: model_doc/siglip2
title: SigLIP2
- local: model_doc/smolvlm
title: SmolVLM
- local: model_doc/speech-encoder-decoder
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ Flax), PyTorch, and/or TensorFlow.
| [SEW](model_doc/sew) ||||
| [SEW-D](model_doc/sew-d) ||||
| [SigLIP](model_doc/siglip) ||||
| [SigLIP2](model_doc/siglip2) ||||
| [SmolVLM](model_doc/smolvlm) ||||
| [Speech Encoder decoder](model_doc/speech-encoder-decoder) ||||
| [Speech2Text](model_doc/speech_to_text) ||||
Expand Down
276 changes: 276 additions & 0 deletions docs/source/en/model_doc/siglip2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# SigLIP2

## Overview

The SigLIP2 model was proposed in [SigLIP 2: Multilingual Vision-Language Encoders with Improved Semantic Understanding, Localization, and Dense Features](https://huggingface.co/papers/2502.14786) by Michael Tschannen, Alexey Gritsenko, Xiao Wang, Muhammad Ferjad Naeem, Ibrahim Alabdulmohsin,
Nikhil Parthasarathy, Talfan Evans, Lucas Beyer, Ye Xia, Basil Mustafa, Olivier Hénaff, Jeremiah Harmsen,
Andreas Steiner and Xiaohua Zhai.

The model comes in two variants

1) FixRes - model works with fixed resolution images (backward compatible with SigLIP v1)
2) NaFlex - model works with variable image aspect ratios and resolutions (SigLIP2 in `transformers`)

The abstract from the paper is the following:

*We introduce SigLIP 2, a family of new multilingual vision-language encoders that build on the success
of the original SigLIP. In this second iteration, we extend the original image-text training objective with
several prior, independently developed techniques into a unified recipe—this includes decoder-based
pretraining, self-supervised losses (self-distillation, masked prediction) and online data curation. With
these changes, SigLIP 2 models outperform their SigLIP counterparts at all model scales in core capabilities,
including zero-shot classification (best SigLIP 2 ViT-g/16 achieves 85.0% ImageNet zero-shot
accuracy), image-text retrieval, and transfer performance when extracting visual representations for
Vision-Language Models (VLMs). Furthermore, the new training recipe leads to significant improvements
on localization and dense prediction tasks. We also train variants which support multiple resolutions
and preserve the input’s native aspect ratio. Finally, we train on a more diverse data-mixture that
includes de-biasing techniques, leading to much better multilingual understanding and improved fair-
ness. To provide users with the ability to trade-off inference cost with performance, we release model
checkpoints at four sizes (ViT-B/86M, L/303M, So400m/400M, and g/1B).*

## Usage tips

- Usage of SigLIP2 is similar to [SigLIP](siglip) and [CLIP](clip). The main difference from CLIP is the training loss, which does not require a global view of all the pairwise similarities of images and texts within a batch. One needs to apply the sigmoid activation function to the logits, rather than the softmax.
- Training is supported but does not use `torch.distributed` utilities which may limit the scalability of batch size. However, DDP and FDSP works on single-node multi-gpu setup.
- When using the standalone [`GemmaTokenizerFast`] make sure to pass `padding="max_length"` and `max_length=64` as that's how the model was trained.
- Model was trained with *lowercased* text, make sure you make the same preprocessing for your text labels.
- To get the same results as the pipeline, a prompt template of "this is a photo of {label}" should be used.
- The NaFlex variant supports processing images at higher resolutions by adjusting the `max_num_patches` parameter in the `Processor`. The default value is `max_num_patches=256`. Increasing `max_num_patches` to 1024 (4x) will approximately double processed image height and width, while preserving the aspect ratio.

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/siglip2_metrics_table.png"
alt="drawing" width="600"/>

This model was contributed by [qubvel](https://huggingface.co/qubvel-hf).
The original code can be found [here](https://github.com/google-research/big_vision/tree/main).

## Usage example

There are 2 main ways to use SigLIP2: either using the pipeline API, which abstracts away all the complexity for you, or by using the `Siglip2Model` class yourself.

### FixRes variant

**Pipeline API**

The pipeline allows to use the model in a few lines of code:

```python
>>> from transformers import pipeline
>>> from PIL import Image
>>> import requests

>>> # load pipe
>>> image_classifier = pipeline(
... task="zero-shot-image-classification",
... model="google/siglip2-base-patch16-224",
... )

>>> # load image
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> # inference
>>> candidate_labels = ["2 cats", "a plane", "a remote"]
>>> outputs = image_classifier(image, candidate_labels=candidate_labels)
>>> outputs = [{"score": round(output["score"], 4), "label": output["label"] } for output in outputs]
>>> print(outputs)
[{'score': 0.1499, 'label': '2 cats'}, {'score': 0.0008, 'label': 'a remote'}, {'score': 0.0, 'label': 'a plane'}]
```

**Using the model yourself**

If you want to do the pre- and postprocessing yourself, here's how to do that:

```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch

>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> candidate_labels = ["2 cats", "2 dogs"]
# follows the pipeline prompt template to get same results
>>> texts = [f"This is a photo of {label}." for label in candidate_labels]

# IMPORTANT: we pass `padding=max_length` and `max_length=64` since the model was trained with this
>>> inputs = processor(text=texts, images=image, padding="max_length", max_length=64, return_tensors="pt")

>>> with torch.no_grad():
... outputs = model(**inputs)

>>> logits_per_image = outputs.logits_per_image
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
15.0% that image 0 is '2 cats'
```

### NaFlex variant

NaFlex combines ideas from FlexiViT, i.e. supporting multiple, predefined sequence lengths
with a single ViT model, and NaViT, namely processing images at their native aspect ratio.
This enables processing different types of images at appropriate resolution, e.g. using a
larger resolution to process document images, while at the same time minimizing the impact
of aspect ratio distortion on certain inference tasks, e.g. on OCR.

Given a patch size and target sequence length, NaFlex preprocesses the data by first resizing
the input image such that the height and width after resizing are multiples of the patch size,
while

1. keeping the aspect ratio distortion as small as possible
2. producing a sequence length of at most the desired target sequence length (`max_num_patches`)

The resulting distortion in width and height is at most `(patch_size - 1) / width` and
`(patch_size - 1) / height`, respectively, which tends to be small for common resolutions and aspect ratios.
After resizing, the image is split into a sequence of patches, and a mask with padding information is added.

```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch

>>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-naflex")
>>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-naflex")

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> candidate_labels = ["2 cats", "2 dogs"]
# follows the pipeline prompt template to get same results
>>> texts = [f"This is a photo of {label}." for label in candidate_labels]

# default value for `max_num_patches` is 256, but you can increase resulted image resolution providing
# higher values e.g. `max_num_patches=512`
>>> inputs = processor(text=texts, images=image, max_num_patches=256, return_tensors="pt")

>>> with torch.no_grad():
... outputs = model(**inputs)

>>> logits_per_image = outputs.logits_per_image
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
21.1% that image 0 is '2 cats'
```

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with SigLIP2.

- [Zero-shot image classification task guide](../tasks/zero_shot_image_classification)
- Demo notebook for SigLIP2 can be found [here](https://github.com/qubvel/transformers-notebooks/tree/master/notebooks/SigLIP2_inference.ipynb). 🌎

If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.


## Combining SigLIP2 and Flash Attention 2

First, make sure to install the latest version of Flash Attention 2.

```bash
pip install -U flash-attn --no-build-isolation
```

Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)

To load and run a model using Flash Attention 2, refer to the snippet below:

```python
>>> import torch
>>> import requests
>>> from PIL import Image
>>> from transformers import AutoProcessor, AutoModel
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModel.from_pretrained(
... "google/siglip2-so400m-patch14-384",
... attn_implementation="flash_attention_2",
... torch_dtype=torch.float16,
... device_map=device,
... )
>>> processor = AutoProcessor.from_pretrained("google/siglip2-so400m-patch14-384")

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> candidate_labels = ["2 cats", "2 dogs"]
# follows the pipeline prompt template to get same results
>>> texts = [f'This is a photo of {label}.' for label in candidate_labels]
# important: we pass `padding=max_length` since the model was trained with this
>>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt").to(device)

>>> with torch.no_grad():
... with torch.autocast(device):
... outputs = model(**inputs)

>>> logits_per_image = outputs.logits_per_image
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
19.8% that image 0 is '2 cats'
```

## Siglip2Config

[[autodoc]] Siglip2Config

## Siglip2TextConfig

[[autodoc]] Siglip2TextConfig

## Siglip2VisionConfig

[[autodoc]] Siglip2VisionConfig

## Siglip2ImageProcessor

[[autodoc]] Siglip2ImageProcessor
- preprocess

## Siglip2ImageProcessorFast

[[autodoc]] Siglip2ImageProcessorFast
- preprocess

## Siglip2Processor

[[autodoc]] Siglip2Processor

## Siglip2Model

[[autodoc]] Siglip2Model
- forward
- get_text_features
- get_image_features

## Siglip2TextModel

[[autodoc]] Siglip2TextModel
- forward

## Siglip2VisionModel

[[autodoc]] Siglip2VisionModel
- forward

## Siglip2ForImageClassification

[[autodoc]] Siglip2ForImageClassification
- forward
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
* [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
Expand Down Expand Up @@ -310,6 +311,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel)
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
* [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2)
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
Expand Down
32 changes: 32 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,12 @@
"SiglipTextConfig",
"SiglipVisionConfig",
],
"models.siglip2": [
"Siglip2Config",
"Siglip2Processor",
"Siglip2TextConfig",
"Siglip2VisionConfig",
],
"models.smolvlm": ["SmolVLMConfig"],
"models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"],
"models.speech_to_text": [
Expand Down Expand Up @@ -1289,6 +1295,7 @@
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
_import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
_import_structure["models.siglip"].append("SiglipImageProcessor")
_import_structure["models.siglip2"].append("Siglip2ImageProcessor")
_import_structure["models.smolvlm"].extend(["SmolVLMImageProcessor"])
_import_structure["models.superglue"].extend(["SuperGlueImageProcessor"])
_import_structure["models.superpoint"].extend(["SuperPointImageProcessor"])
Expand Down Expand Up @@ -1330,6 +1337,7 @@
_import_structure["models.qwen2_vl"].append("Qwen2VLImageProcessorFast")
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
_import_structure["models.siglip"].append("SiglipImageProcessorFast")
_import_structure["models.siglip2"].append("Siglip2ImageProcessorFast")
_import_structure["models.vit"].append("ViTImageProcessorFast")

try:
Expand Down Expand Up @@ -3559,6 +3567,15 @@
"SiglipVisionModel",
]
)
_import_structure["models.siglip2"].extend(
[
"Siglip2ForImageClassification",
"Siglip2Model",
"Siglip2PreTrainedModel",
"Siglip2TextModel",
"Siglip2VisionModel",
]
)
_import_structure["models.smolvlm"].extend(
[
"SmolVLMForConditionalGeneration",
Expand Down Expand Up @@ -5942,6 +5959,12 @@
SiglipTextConfig,
SiglipVisionConfig,
)
from .models.siglip2 import (
Siglip2Config,
Siglip2Processor,
Siglip2TextConfig,
Siglip2VisionConfig,
)
from .models.smolvlm import SmolVLMConfig
from .models.speech_encoder_decoder import SpeechEncoderDecoderConfig
from .models.speech_to_text import (
Expand Down Expand Up @@ -6472,6 +6495,7 @@
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
from .models.seggpt import SegGptImageProcessor
from .models.siglip import SiglipImageProcessor
from .models.siglip2 import Siglip2ImageProcessor
from .models.smolvlm import SmolVLMImageProcessor
from .models.superglue import SuperGlueImageProcessor
from .models.superpoint import SuperPointImageProcessor
Expand Down Expand Up @@ -6509,6 +6533,7 @@
from .models.qwen2_vl import Qwen2VLImageProcessorFast
from .models.rt_detr import RTDetrImageProcessorFast
from .models.siglip import SiglipImageProcessorFast
from .models.siglip2 import Siglip2ImageProcessorFast
from .models.vit import ViTImageProcessorFast

try:
Expand Down Expand Up @@ -8288,6 +8313,13 @@
SiglipTextModel,
SiglipVisionModel,
)
from .models.siglip2 import (
Siglip2ForImageClassification,
Siglip2Model,
Siglip2PreTrainedModel,
Siglip2TextModel,
Siglip2VisionModel,
)
from .models.smolvlm import (
SmolVLMForConditionalGeneration,
SmolVLMModel,
Expand Down
Loading

0 comments on commit a957b79

Please sign in to comment.