Skip to content

Commit

Permalink
[generate] only require an attention mask for mps with torch<2.4 (#32367
Browse files Browse the repository at this point in the history
)

* up

* style

* stopping
  • Loading branch information
sanchit-gandhi authored Aug 2, 2024
1 parent 083e13b commit c1aa0ed
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
5 changes: 4 additions & 1 deletion src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch
from torch.nn import functional as F

from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4

from ..tokenization_utils_base import PreTrainedTokenizerBase
from ..utils import add_start_docstrings, logging

Expand Down Expand Up @@ -485,7 +487,8 @@ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor]):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
self.eos_token_id = self.eos_token_id.to(input_ids.device)
if input_ids.device.type == "mps":
if input_ids.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
# TODO: remove this workaround when we stop supporting torch<=2.3
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done = (
input_ids[:, -1]
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..pytorch_utils import is_torch_greater_or_equal_than_2_4
from ..tokenization_utils import ExtensionsTrie
from ..utils import (
ModelOutput,
Expand Down Expand Up @@ -488,10 +489,10 @@ def _prepare_attention_mask_for_generation(
return default_attention_mask

# Otherwise we have may have information -> try to infer the attention mask
if inputs.device.type == "mps":
# mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764)
if inputs.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
# mps does not support torch.isin for torch<2.4 (https://github.com/pytorch/pytorch/issues/77764)
raise ValueError(
"Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device."
"Can't infer missing attention mask on `mps` device for torch<2.4. Please provide an `attention_mask` or upgrade to torch>=2.4"
)

is_pad_token_in_inputs = (pad_token_id is not None) and (
Expand Down
1 change: 1 addition & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)

is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4")
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
Expand Down

0 comments on commit c1aa0ed

Please sign in to comment.