Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix vision encoder decoder attention implementation pick #31203

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
"""Classes to support Vision-Encoder-Text-Decoder architectures"""

import gc
import importlib
import os
import tempfile
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import torch
from torch import nn
Expand All @@ -28,7 +29,7 @@
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM
from ..auto.modeling_auto import MODEL_MAPPING_NAMES, AutoModel, AutoModelForCausalLM
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig


Expand Down Expand Up @@ -189,10 +190,12 @@ def __init__(
super().__init__(config)

if encoder is None:
encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation)
encoder = AutoModel.from_config(config.encoder, attn_implementation=config.encoder._attn_implementation)

if decoder is None:
decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation)
decoder = AutoModelForCausalLM.from_config(
config.decoder, attn_implementation=config.decoder._attn_implementation
)

self.encoder = encoder
self.decoder = decoder
Expand Down Expand Up @@ -369,6 +372,41 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

@classmethod
def _autoset_attn_implementation(
cls,
config,
use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
):
encoder_model_class_name = MODEL_MAPPING_NAMES[config.encoder.model_type]
decoder_model_class_name = MODEL_MAPPING_NAMES[config.decoder.model_type]
encoder_model_class = getattr(importlib.import_module("transformers"), encoder_model_class_name)
decoder_model_class = getattr(importlib.import_module("transformers"), decoder_model_class_name)

if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
config.encoder._attn_implementation = config._attn_implementation_internal
config.decoder._attn_implementation = config._attn_implementation_internal

config.encoder = encoder_model_class._autoset_attn_implementation(
config=config.encoder,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch_dtype,
device_map=device_map,
check_device_map=check_device_map,
)
config.decoder = decoder_model_class._autoset_attn_implementation(
config=config.decoder,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch_dtype,
device_map=device_map,
check_device_map=check_device_map,
)

return config

@classmethod
def from_encoder_decoder_pretrained(
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,29 @@ def get_pretrained_model_and_inputs(self):

return model, inputs

def test_attention_implementation(self):
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-deit", "hf-internal-testing/tiny-random-roberta"
)
self.assertTrue(model.encoder.config._attn_implementation == "sdpa")

configs_and_inputs = self.prepare_config_and_inputs()
encoder_model, decoder_model = self.get_encoder_decoder_model(
configs_and_inputs["config"], configs_and_inputs["decoder_config"]
)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)

with tempfile.TemporaryDirectory() as tmpdirname:
enc_dec_model.save_pretrained(tmpdirname)
enc_dec_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname)

self.assertTrue(enc_dec_model.encoder.config._attn_implementation == "sdpa")
self.assertTrue(enc_dec_model.decoder.config._attn_implementation == "sdpa")

enc_dec_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, attn_implementation="eager")
self.assertTrue(enc_dec_model.encoder.config._attn_implementation == "eager")
self.assertTrue(enc_dec_model.decoder.config._attn_implementation == "eager")

def check_encoder_decoder_model_output_attentions(
self,
config,
Expand Down
Loading