Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sdpa and FA2 for CLIP #31940

Merged
merged 19 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions docs/source/en/model_doc/clip.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,123 @@ encode the text and prepare the images. The following example shows how to get t
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
```


### Combining CLIP and Flash Attention 2

First, make sure to install the latest version of Flash Attention 2.

```bash
pip install -U flash-attn --no-build-isolation
```

Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)

<Tip warning={true}>

For small batch sizes, you might notice a slowdown in your model when using flash attention. Refer to the section [Expected speedups with Flash Attention and SDPA](#Expected-speedups-with-Flash-Attention-and-SDPA) below and select an appropriate attention implementation.

</Tip>

To load and run a model using Flash Attention 2, refer to the snippet below:

```python
>>> import torch
>>> import requests
>>> from PIL import Image

>>> from transformers import CLIPProcessor, CLIPModel

>>> device = "cuda"
>>> torch_dtype = torch.float16

>>> model = CLIPModel.from_pretrained(
... "openai/clip-vit-base-patch32",
... attn_implementation="flash_attention_2",
... device_map=device,
... torch_dtype=torch_dtype,
... )
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
>>> inputs.to(device)

>>> with torch.no_grad():
... with torch.autocast(device):
... outputs = model(**inputs)

>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
>>> print(probs)
tensor([[0.9946, 0.0052]], device='cuda:0', dtype=torch.float16)
```


### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```python
from transformers import CLIPModel

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16, attn_implementation="sdpa")
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

### Expected speedups with Flash Attention and SDPA

On a local benchmark (NVIDIA A10G, PyTorch 2.3.1+cu121) with `float16`, we saw the following speedups during inference for `"openai/clip-vit-large-patch14"` checkpoint ([code](https://gist.github.com/qubvel/ac691a54e54f9fae8144275f866a7ff8)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity -- do these trends roughly hold for other checkpoints too? Anyway, you picked the most popular checkpoint I think since it's used in the diffusion community quite heavily.

Copy link
Member Author

@qubvel qubvel Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

roughly holds for sdpa, however FA2 is worse for "base" checkpoint for CLIPModel

“openai/clip-vit-base-patch32”

CLIPTextModel

Num text labels Eager (s/iter) FA2 (s/iter) FA2 speedup SDPA (s/iter) SDPA speedup
4 0.007 0.011 0.674 0.006 1.173
16 0.007 0.011 0.648 0.006 1.17
64 0.015 0.017 0.855 0.013 1.112
128 0.03 0.029 1.022 0.027 1.121
256 0.06 0.054 1.106 0.053 1.121

CLIPVisionModel

Image batch size Eager (s/iter) FA2 (s/iter) FA2 speedup SDPA (s/iter) SDPA speedup
1 0.007 0.005 1.224 0.005 1.307
4 0.007 0.005 1.228 0.005 1.287
16 0.007 0.006 1.23 0.005 1.316
32 0.009 0.009 1.023 0.009 1.031
64 0.018 0.017 1.076 0.017 1.082
128 0.035 0.032 1.09 0.032 1.095

CLIPModel

Image batch size Num text labels Eager (s/iter) FA2 (s/iter) FA2 speedup SDPA (s/iter) SDPA speedup
1 4 0.015 0.017 0.866 0.012 1.23
1 16 0.015 0.017 0.86 0.012 1.221
1 64 0.016 0.023 0.705 0.015 1.105
4 4 0.015 0.017 0.871 0.012 1.241
4 16 0.015 0.017 0.866 0.012 1.231
4 64 0.017 0.023 0.741 0.015 1.106
16 4 0.015 0.017 0.869 0.012 1.226
16 16 0.015 0.017 0.878 0.012 1.234
16 64 0.02 0.023 0.868 0.018 1.099
32 4 0.015 0.02 0.746 0.012 1.221
32 16 0.015 0.02 0.756 0.013 1.215
32 64 0.024 0.027 0.911 0.022 1.084

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FA2 results are super interesting!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, in general, SDPA seems to be still producing better speedups.


#### CLIPTextModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tables are beautiful - thanks for taking the time to get such detailed benchmarks. It's incredible valuable for important checkpoints like CLIP ❤️


| Num text labels | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
| 4 | 0.009 | 0.012 | 0.737 | 0.007 | 1.269 |
| 16 | 0.009 | 0.014 | 0.659 | 0.008 | 1.187 |
| 32 | 0.018 | 0.021 | 0.862 | 0.016 | 1.142 |
| 64 | 0.034 | 0.034 | 1.001 | 0.03 | 1.163 |
| 128 | 0.063 | 0.058 | 1.09 | 0.054 | 1.174 |

![clip_text_model_viz_3](https://github.com/user-attachments/assets/e9826b43-4e66-4f4c-952b-af4d90bd38eb)

#### CLIPVisionModel

| Image batch size | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|-------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
| 1 | 0.016 | 0.013 | 1.247 | 0.012 | 1.318 |
| 4 | 0.025 | 0.021 | 1.198 | 0.021 | 1.202 |
| 16 | 0.093 | 0.075 | 1.234 | 0.075 | 1.24 |
| 32 | 0.181 | 0.147 | 1.237 | 0.146 | 1.241 |

![clip_image_model_viz_3](https://github.com/user-attachments/assets/50a36206-e3b9-4adc-ac8e-926b8b071d63)

#### CLIPModel

| Image batch size | Num text labels | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|-------------------:|------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
| 1 | 4 | 0.025 | 0.026 | 0.954 | 0.02 | 1.217 |
| 1 | 16 | 0.026 | 0.028 | 0.918 | 0.02 | 1.287 |
| 1 | 64 | 0.042 | 0.046 | 0.906 | 0.036 | 1.167 |
| 4 | 4 | 0.028 | 0.033 | 0.849 | 0.024 | 1.189 |
| 4 | 16 | 0.034 | 0.035 | 0.955 | 0.029 | 1.169 |
| 4 | 64 | 0.059 | 0.055 | 1.072 | 0.05 | 1.179 |
| 16 | 4 | 0.096 | 0.088 | 1.091 | 0.078 | 1.234 |
| 16 | 16 | 0.102 | 0.09 | 1.129 | 0.083 | 1.224 |
| 16 | 64 | 0.127 | 0.11 | 1.157 | 0.105 | 1.218 |
| 32 | 4 | 0.185 | 0.159 | 1.157 | 0.149 | 1.238 |
| 32 | 16 | 0.19 | 0.162 | 1.177 | 0.154 | 1.233 |
| 32 | 64 | 0.216 | 0.181 | 1.19 | 0.176 | 1.228 |

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with CLIP.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
Expand Down Expand Up @@ -200,6 +201,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""

bsz, tgt_len, embed_dim = hidden_states.size()
Expand Down Expand Up @@ -838,7 +838,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->AltCLIP
class AltCLIPEncoderLayer(nn.Module):
def __init__(self, config: AltCLIPConfig):
super().__init__()
Expand Down Expand Up @@ -889,7 +888,6 @@ def forward(
return outputs


# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->AltCLIP
class AltCLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
Expand Down Expand Up @@ -1080,7 +1078,6 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()


# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING
class AltCLIPVisionTransformer(nn.Module):
def __init__(self, config: AltCLIPVisionConfig):
super().__init__()
Expand Down
Loading
Loading