diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3e8d73f94b1..9c7d4c78522 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1544,6 +1544,7 @@ def _autoset_attn_implementation( torch.version.hip is not None and config._attn_implementation == "sdpa" and torch.cuda.device_count() > 1 + and version.parse(torch.__version__) < version.parse("2.4.1") ): logger.warning_once( "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."