Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F.scaled_dot_product_attention support #26572

Merged
merged 114 commits into from
Dec 8, 2023
Merged
Changes from 1 commit
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
74be54c
add sdpa
fxmarty Oct 3, 2023
9d14f0d
wip
fxmarty Oct 3, 2023
f803de3
cleaning
fxmarty Oct 3, 2023
c0bcbfa
add ref
fxmarty Oct 3, 2023
38332d7
yet more cleaning
fxmarty Oct 3, 2023
3b47502
and more :)
fxmarty Oct 3, 2023
dd646c1
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Oct 31, 2023
79c12a9
wip llama
fxmarty Oct 31, 2023
dc929cd
working llama
fxmarty Oct 31, 2023
17954dd
add output_attentions=True support
fxmarty Oct 31, 2023
f48f4fa
bigcode sdpa support
fxmarty Oct 31, 2023
dfc47a5
fixes
fxmarty Oct 31, 2023
eba83c1
gpt-bigcode support, require torch>=2.1.1
fxmarty Nov 3, 2023
5693535
add falcon support
fxmarty Nov 3, 2023
3758375
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Nov 3, 2023
ca87380
fix conflicts falcon
fxmarty Nov 3, 2023
969dda9
style
fxmarty Nov 3, 2023
06766ec
fix attention_mask definition
fxmarty Nov 3, 2023
5c648d4
remove output_attentions from attnmaskconverter
fxmarty Nov 3, 2023
674bff4
support whisper without removing any Copied from statement
fxmarty Nov 3, 2023
dd89c3c
fix mbart default to eager renaming
fxmarty Nov 3, 2023
f31c7b3
fix typo in falcon
fxmarty Nov 6, 2023
280c078
fix is_causal in SDPA
fxmarty Nov 8, 2023
e41ecfa
check is_flash_attn_2_available in the models init as well in case th…
fxmarty Nov 17, 2023
951bce0
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Nov 17, 2023
6f7964d
add warnings when falling back on the manual implementation
fxmarty Nov 17, 2023
0e38a95
precise doc
fxmarty Nov 17, 2023
1bd07aa
wip replace _flash_attn_enabled by config.attn_implementation
fxmarty Nov 17, 2023
feae821
fix typo
fxmarty Nov 17, 2023
2032e64
add tests
fxmarty Nov 17, 2023
d98c2f9
style
fxmarty Nov 17, 2023
ab59f9d
add a copy.deepcopy on the config in from_pretrained, as we do not wa…
fxmarty Nov 17, 2023
98a3825
obey to config.attn_implementation if a config is passed in from_pret…
fxmarty Nov 17, 2023
098a62e
fix is_torch_sdpa_available when torch is not installed
fxmarty Nov 17, 2023
b960912
remove dead code
fxmarty Nov 17, 2023
9df4c8f
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Nov 21, 2023
f1df402
Update src/transformers/modeling_attn_mask_utils.py
fxmarty Nov 21, 2023
f49c2a3
Update src/transformers/modeling_attn_mask_utils.py
fxmarty Nov 21, 2023
3a22d8d
Update src/transformers/modeling_attn_mask_utils.py
fxmarty Nov 21, 2023
f0fa993
Update src/transformers/modeling_attn_mask_utils.py
fxmarty Nov 21, 2023
f084040
Update src/transformers/modeling_attn_mask_utils.py
fxmarty Nov 21, 2023
885bbe4
Update src/transformers/models/bart/modeling_bart.py
fxmarty Nov 21, 2023
4dd5523
remove duplicate pretraining_tp code
fxmarty Nov 21, 2023
349c99b
add dropout in llama
fxmarty Nov 21, 2023
5e56014
precise comment on attn_mask
fxmarty Nov 21, 2023
951f70e
add fmt: off for _unmask_unattended docstring
fxmarty Nov 21, 2023
c4e207e
precise num_masks comment
fxmarty Nov 21, 2023
e752d93
nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion
fxmarty Nov 21, 2023
a072c5d
cleanup modeling_utils
fxmarty Nov 22, 2023
f700973
backward compatibility
fxmarty Nov 22, 2023
e267764
fix style as requested
fxmarty Nov 22, 2023
d044d81
style
fxmarty Nov 22, 2023
a9e7606
improve documentation
fxmarty Nov 22, 2023
1727210
test pass
fxmarty Nov 22, 2023
ae86680
style
fxmarty Nov 22, 2023
5706ecb
add _unmask_unattended tests
fxmarty Nov 22, 2023
d2326e2
skip meaningless tests for idefics
fxmarty Nov 22, 2023
c0f849e
hard_check SDPA requirements when specifically requested
fxmarty Nov 22, 2023
0fa8de0
standardize the use if XXX_ATTENTION_CLASSES
fxmarty Nov 22, 2023
637e473
fix SDPA bug with mem-efficient backend on CUDA when using fp32
fxmarty Nov 22, 2023
55ec325
fix test
fxmarty Nov 22, 2023
33ef389
rely on SDPA is_causal parameter to handle the causal mask in some cases
fxmarty Nov 22, 2023
2e6bc3e
fix FALCON_ATTENTION_CLASSES
fxmarty Nov 23, 2023
688d86e
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Nov 23, 2023
5913dee
remove _flash_attn_2_enabled occurences
fxmarty Nov 23, 2023
11ab3ae
fix test
fxmarty Nov 23, 2023
b74894d
add OPT to the list of supported flash models
fxmarty Nov 23, 2023
4ff1057
improve test
fxmarty Nov 23, 2023
8bd6c81
properly test on different SDPA backends, on different dtypes & prope…
fxmarty Nov 24, 2023
a11c114
remove remaining _flash_attn_2_enabled occurence
fxmarty Nov 24, 2023
b5593a1
Update src/transformers/modeling_utils.py
fxmarty Nov 24, 2023
1bc983a
Update src/transformers/modeling_utils.py
fxmarty Nov 24, 2023
316b448
Update src/transformers/modeling_utils.py
fxmarty Nov 24, 2023
52178ba
Update src/transformers/modeling_attn_mask_utils.py
fxmarty Nov 24, 2023
231e354
Update docs/source/en/perf_infer_gpu_one.md
fxmarty Nov 24, 2023
f907b3f
remove use_attn_implementation
fxmarty Nov 24, 2023
0e9e9f2
fix docstring & slight bug
fxmarty Nov 24, 2023
c47c24e
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 4, 2023
5c77b94
make attn_implementation internal (_attn_implementation)
fxmarty Dec 4, 2023
cd9e209
typos
fxmarty Dec 4, 2023
e475f25
fix tests
fxmarty Dec 5, 2023
48a6bfc
deprecate use_flash_attention_2=True
fxmarty Dec 6, 2023
8e7f8b5
fix test
fxmarty Dec 6, 2023
7a85efc
add back llama that was removed by mistake
fxmarty Dec 6, 2023
3649553
fix tests
fxmarty Dec 6, 2023
f09a65c
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 7, 2023
c1b87b8
remove _flash_attn_2_enabled occurences bis
fxmarty Dec 7, 2023
8950b60
add check & test that passed attn_implementation is valid
fxmarty Dec 7, 2023
18c2678
fix falcon torchscript export
fxmarty Dec 7, 2023
d96e0d2
fix device of mask in tests
fxmarty Dec 7, 2023
bb20113
Merge branch 'fix-device-mask-tests' into torch-sdpa-preliminary-support
fxmarty Dec 7, 2023
9e133c9
add tip about torch.jit.trace and move bt doc below sdpa
fxmarty Dec 7, 2023
76a1e17
fix parameterized.expand order
fxmarty Dec 7, 2023
65aeba6
move tests from test_modeling_attn_mask_utils to test_modeling_utils …
fxmarty Dec 7, 2023
09ab820
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 7, 2023
48d95ea
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 8, 2023
546dd51
update sdpaattention class with the new cache
fxmarty Dec 8, 2023
2045915
Update src/transformers/configuration_utils.py
fxmarty Dec 8, 2023
eb11883
Update src/transformers/models/bark/modeling_bark.py
fxmarty Dec 8, 2023
920686e
address review comments
fxmarty Dec 8, 2023
2146857
WIP torch.jit.trace fix. left: test both eager & sdpa
fxmarty Dec 8, 2023
9b48591
add test for torch.jit.trace for both eager/sdpa
fxmarty Dec 8, 2023
4315638
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 8, 2023
cc7fc4e
fix falcon with torch==2.0 that needs to use sdpa
fxmarty Dec 8, 2023
84d9605
Merge branch 'torch-sdpa-preliminary-support' of https://github.com/f…
fxmarty Dec 8, 2023
8486770
fix doc
fxmarty Dec 8, 2023
c6181f2
hopefully last fix
fxmarty Dec 8, 2023
7ebfd1d
fix key_value_length that has no default now in mask converter
fxmarty Dec 8, 2023
dacf149
is it flacky?
fxmarty Dec 8, 2023
f196bef
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 8, 2023
810de1a
fix speculative decoding bug
fxmarty Dec 8, 2023
f116cce
tests do pass
fxmarty Dec 8, 2023
4721c36
Merge branch 'main' into torch-sdpa-preliminary-support
fxmarty Dec 8, 2023
3f06a3a
fix following #27907
fxmarty Dec 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
cleanup modeling_utils
  • Loading branch information
fxmarty committed Nov 22, 2023

Unverified

This commit is not signed, but one or more authors requires that any commit attributed to them is signed.
commit a072c5d7330df1828628cef1473256234e68e88f
27 changes: 13 additions & 14 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
@@ -236,6 +236,8 @@ class PretrainedConfig(PushToHubMixin):

This attribute is currently not being used during model loading time, but this may change in the future
versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
attn_implementation (`str`, *optional*):
The attention implementation to use in the model. Can be any of "eager" (manual implementation of the attention), "sdpa" (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or "flash_attention_2" (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.

> TensorFlow specific parameters

@@ -374,6 +376,9 @@ def __init__(self, **kwargs):
# Config hash
self._commit_hash = kwargs.pop("_commit_hash", None)

# Attention implementation to use, if relevant.
self._attn_implementation = kwargs.pop("attn_implementation", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not default to eager here instead of the setter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit tricky: if attn_implementation is explicitly specified, we want to skip the automatic dispatch between SDPA / eager and just use the user-provided config (https://github.com/fxmarty/transformers/blob/torch-sdpa-preliminary-support/src/transformers/modeling_utils.py#L1259C11-L1268).

The case config._attn_implementation is None corresponds to the case where the user has not specified anything.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should allow the user to specify which attention mechanism to use in the config:

        1. An implementation specified in `config.attn_implementation`.

=> This should not be allowed. The model's config should only define the architecture, not setting that can be changed at runtime.

The user should always have to call model.set_attn_type(...) IMO

Copy link
Contributor Author

@fxmarty fxmarty Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to go with that @patrickvonplaten. WDYT @ArthurZucker @LysandreJik

Copy link
Collaborator

@ArthurZucker ArthurZucker Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So either:

  1. We load silently sdpa if possible else eager (-> specify after loading the acceleration you want).
  2. We load what was saved there, error out if not available.
    In both case, passing this as an argument to from_pretrained is actually simpler if you want to use FA2 (not inplace modifications).

I'm guessing you have more experience with that in diffusers @patrickvonplaten so will follow you here. Agree with you that restricting is better, only one setter and only one way to change this!
Let's go for 1, lots of pros and cons but agree with you on this


# Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None)

@@ -424,19 +429,13 @@ def num_labels(self, num_labels: int):

@property
def attn_implementation(self):
if not hasattr(self, "_attn_implementation"):
return "eager"
else:
return self._attn_implementation
return self._attn_implementation

@attn_implementation.setter
def attn_implementation(self, value):
if hasattr(self, "attn_implementation_set") and self.attn_implementation_set:
raise NotImplementedError(
"Modifying the attention implementation through this attribute is currently not implemented."
)
self.attn_implementation_set = True

# No specific check is implemented here, as we want to allow syntax as `config.attn_implementation = "flash_attention_2"` before the model
# loading.
# Modifying this property alone on an already loaded model (model.config) has no impact, `model.use_attn_implementation("flash_attention_2")` should be used instead.
self._attn_implementation = value

def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
@@ -878,8 +877,8 @@ def to_diff_dict(self) -> Dict[str, Any]:

self.dict_torch_dtype_to_str(serializable_config_dict)

if "attention_implementation" in serializable_config_dict:
del serializable_config_dict["attention_implementation"]
if "_attn_implementation" in serializable_config_dict:
del serializable_config_dict["_attn_implementation"]

return serializable_config_dict

@@ -897,8 +896,8 @@ def to_dict(self) -> Dict[str, Any]:
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]
if "attention_implementation" in output:
del output["attention_implementation"]
if "_attn_implementation" in output:
del output["_attn_implementation"]

# Transformers version when serializing the model
output["transformers_version"] = __version__
147 changes: 77 additions & 70 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -1152,32 +1152,12 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
)
# Save config and origin of the pretrained weights if given in model
self.config = config
self.config = self._autoset_attn_implementation(self.config, torch_dtype=torch.get_default_dtype(), check_device_map=False)

self.name_or_path = config.name_or_path
self.warnings_issued = {}
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None

# TODO: This is TEMPORARY and need to be discussed, should it rather be in XXXPreTrainedModel __init__?
if config.attn_implementation == "flash_attention_2":
if not self._supports_flash_attn_2:
raise ValueError(
f'Passed config.attn_implementation == "flash_attention_2" but {self.__class__.__name__} does not support Flash Attention yet.'
)

if not is_flash_attn_2_available():
raise ImportError(
"Flash Attention 2 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
" installing it. Make sure to have at least the version 2.1.0"
)

if config.attn_implementation == "sdpa":
if not self._supports_sdpa:
raise ValueError(
f'Passed config.attn_implementation == "sdpa" but {self.__class__.__name__} does not support SDPA yet.'
)

if not is_torch_sdpa_available():
raise ImportError("SDPA is not available. Please use torch>=2.1.1 in order to use SDPA.")

def post_init(self):
"""
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
@@ -1211,8 +1191,7 @@ def _from_config(cls, config, **kwargs):
if torch_dtype is not None:
dtype_orig = cls._set_default_torch_dtype(torch_dtype)

if use_flash_attention_2:
config = cls._check_and_enable_flash_attn_2(config, torch_dtype)
config = cls._autoset_attn_implementation(config, use_flash_attention_2=use_flash_attention_2, check_device_map=False)

if is_deepspeed_zero3_enabled():
import deepspeed
@@ -1231,6 +1210,61 @@ def _from_config(cls, config, **kwargs):

return model

def use_attn_implementation(self, attn_implementation: str):
"""
Specifies the attention implementation to use in the model.

Args:
attn_implementation (`str`):
The attention implementation to use. Can be any of "eager" (manual implementation of the attention), "sdpa" (attention using [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or "flash_attention_2" (attention using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)).
"""
# TODO: Implement it. An implementation could be to define `self._eager_attn_class = XXXAttention`, `self._sdpa_attn_class = XXXSdpaAttention`, `self._flash_attn_class = XXXFlashAttention2` in the __init__ of XXXPreTrainedModel, and leverage those attributes here to replace the correct submodules.
raise NotImplementedError("model.use_attn_implementation is currently not implemented.")

if attn_implementation == "sdpa":
self.config = self._check_and_enable_sdpa(self.config, enable=True)
elif attn_implementation == "flash_attention_2":
# TODO: define torch_dtype properly
torch_dtype = None
self.config = self._check_and_enable_flash_attn_2(
self.config, torch_dtype=torch_dtype, device_map=getattr(self, "hf_device_map", None), enable=True
)

@classmethod
def _autoset_attn_implementation(cls, config, use_flash_attention_2: Optional[bool] = None, torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None, check_device_map: bool = True):
"""
Automatically checks and dispatches to a default attention implementation. In order of priority:
1. An implementation specified in `config.attn_implementation`.
2. If specified, flash attention through use_flash_attention_2=True.
3. SDPA implementation, if available and supported by the model type.
4. Manual implementation otherwise.
"""
config = copy.deepcopy(config) # We do not want to modify the config inplace.

if config.attn_implementation is None:
auto_dispatch_attention = True
else:
if (config.attn_implementation != "flash_attention_2" and use_flash_attention_2):
raise ValueError(
f'Both config.attn_implementation ("{config.attn_implementation}") and use_flash_attention_2=True are used, and are incompatible.'
)

# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config.
auto_dispatch_attention = False

if use_flash_attention_2:
cls._check_and_enable_flash_attn_2(
config, torch_dtype=torch_dtype, device_map=device_map, enable=auto_dispatch_attention, check_device_map=check_device_map,
)
elif is_torch_sdpa_available() and cls._supports_sdpa:
# use_flash_attention_2 takes priority over SDPA.
config = cls._check_and_enable_sdpa(config, enable=auto_dispatch_attention)
elif auto_dispatch_attention:
config.attn_implementation = "eager"

return config

@classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
@@ -1285,25 +1319,17 @@ def _check_and_enable_flash_attn_2(
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
enable: bool = True,
) -> PretrainedConfig:
"""
If you don't know about Flash Attention, check out the official repository of flash attention:
https://github.com/Dao-AILab/flash-attention

For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
specific section of the documentation to learn more about it:
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
Checks the availability of Flash Attention 2 and compatibility with the current model.

The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
half precision and not ran on CPU.

If all checks pass, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model
can initialize the correct attention module
If all checks pass and `enable` is True, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
"""
if not cls._supports_flash_attn_2:
raise ValueError(
"The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to "
f"{cls.__name__} does not support Flash Attention 2.0 yet. Please open an issue on GitHub to "
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
)

@@ -1337,20 +1363,23 @@ def _check_and_enable_flash_attn_2(
" unexpected behaviour."
)

if device_map is None:
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
if torch.cuda.is_available():
logger.warning(
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU"
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. "
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
"or initialising the model on CPU and then moving it to GPU."
)
elif (
device_map is not None
check_device_map
and device_map is not None
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
@@ -1365,11 +1394,15 @@ def _check_and_enable_flash_attn_2(
@classmethod
def _check_and_enable_sdpa(cls, config, enable: bool = True) -> PretrainedConfig:
"""
Enables the use of SDPA natively in Transformers if supported by the model, and if BetterTransformer is not
being used.
Checks the availability of SDPA for a given model.

If all checks pass and `enable` is True, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
"""
if not cls._supports_sdpa:
return config
raise ValueError(
f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention. Please open an issue on GitHub to "
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
)

_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
@@ -2852,9 +2885,6 @@ def from_pretrained(
else:
model_kwargs = kwargs

# We do not want to modify inplace the PretrainedConfig passed to from_pretrained.
config = copy.deepcopy(config)

quantizer = None
quantization_method_from_config = None
if hasattr(config, "quantization_config"):
@@ -3314,30 +3344,7 @@ def from_pretrained(
elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:
init_contexts.append(init_empty_weights())

if (
hasattr(config, "attn_implementation")
and config.attn_implementation != "flash_attention_2"
and use_flash_attention_2
):
raise ValueError(
f"Both config.attn_implementation ({config.attn_implementation}) and use_flash_attention_2=True were passed to from_pretrained and are incompatible."
)

# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config.
if hasattr(config, "attn_implementation"):
auto_dispatch_attention = False
else:
auto_dispatch_attention = True

if use_flash_attention_2:
config = cls._check_and_enable_flash_attn_2(
config, torch_dtype=torch_dtype, device_map=device_map, enable=auto_dispatch_attention
)
elif is_torch_sdpa_available():
# use_flash_attention_2 takes priority.
config = cls._check_and_enable_sdpa(config, enable=auto_dispatch_attention)
elif not hasattr(config, "attn_implementation"):
config.attn_implementation = "eager"
config = cls._autoset_attn_implementation(config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map)

with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)