Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XPU type for work-around -inf mask causing sdpa NaN issue in modeling files #35647

Merged
merged 4 commits into from
Feb 5, 2025

Conversation

Liangliang-Ma
Copy link
Contributor

Recently when we run transformers + qlora doing fine-tuning, we found NaN produced by torch.nn.functional.scaled_dot_product_attention.
Given that xpu has similar implement of fused sdpa, we would like to follow pytorch/transformers tmp solution to modify the mask here, which could solve the issue on XPU device too.

@Rocketknight1
Copy link
Member

cc @muellerzr @SunMarc because I'm not sure who to ping for xpu! Feel free to ping someone else if needed

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check if you get the same issue as here with XPU ? Also, does this PR fixes your issue ? cc @ydshieh as you might know something

@Liangliang-Ma
Copy link
Contributor Author

Could you check if you get the same issue as here with XPU ? Also, does this PR fixes your issue ? cc @ydshieh as you might know something

Yes we used XPU to do fine-tuning and got the same issue. With this PR the issue fixed.

@ydshieh
Copy link
Collaborator

ydshieh commented Jan 14, 2025

I can only say that this seems reasonable to me but not more than that.

@Liangliang-Ma It would be better to provide a tiny code snippet to demonstrate the issue. Like providing a mask with a row with all places being masked, and pass it to F.scaled_dot_product_attention to show we do get NaN (on XPU).

@Liangliang-Ma
Copy link
Contributor Author

Liangliang-Ma commented Jan 16, 2025

I can only say that this seems reasonable to me but not more than that.

@Liangliang-Ma It would be better to provide a tiny code snippet to demonstrate the issue. Like providing a mask with a row with all places being masked, and pass it to F.scaled_dot_product_attention to show we do get NaN (on XPU).

import torch
import intel_extension_for_pytorch
from torch.nn import functional as F

torch.manual_seed(0)

a = 3
b = 4

q = torch.randn(size=(1, 1, a, b))
k = torch.randn(size=(1, 1, a, b))
v = torch.randn(size=(1, 1, a, b))

def check(q, k, v, device):

    q = q.to(device)
    k = k.to(device)
    v = v.to(device)

    neg_value = torch.finfo(q.dtype).min
    mask = [[neg_value, neg_value, neg_value], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
    mask = torch.tensor([[mask]]).to(device)

    with torch.amp.autocast("xpu", dtype=torch.bfloat16):
        o = F.scaled_dot_product_attention(q, k, v, mask, 0.0, is_causal=False)
    print(o)

check(q, k, v, "cpu")
check(q, k, v, "xpu")

Thanks @ydshieh , I modified your test and get the NaN result like this:

tensor([[[[ 0.1210,  0.3627, -0.9969, -0.6149],
          [ 0.1295,  0.4572, -1.0491, -0.6166],
          [ 0.1095,  0.3819, -0.7369, -0.8267]]]])
tensor([[[[    nan,     nan,     nan,     nan],
          [ 0.1299,  0.4590, -1.0469, -0.6172],
          [ 0.1094,  0.3809, -0.7344, -0.8281]]]], device='xpu:0',
       dtype=torch.bfloat16)

I found that this issue caused from casting torch.finfo(torch.float).min to bfloat16, which result in a row of -inf.
This one can make sdpa kernel output NaN.

@Liangliang-Ma
Copy link
Contributor Author

@Rocketknight1 Hi, may I know if the CI workflow failures are expected or not, for it seems to be generated modeling from original code different from what I modified. Thanks!

@SunMarc
Copy link
Member

SunMarc commented Jan 16, 2025

Hey @Liangliang-Ma, you need to modify the modular file for mistral and bamba as the modeling file is generated automatically from this file. This should fix the CI issue

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the snippet, LGTM with the fix I proposed for the CI

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Liangliang-Ma
Copy link
Contributor Author

@SunMarc Thanks. The CI passed with the fix.

@SunMarc
Copy link
Member

SunMarc commented Jan 17, 2025

gentle ping @ArthurZucker as this concerns attention class

@Liangliang-Ma
Copy link
Contributor Author

@SunMarc @ArthurZucker Soft reminder of this PR.

@Liangliang-Ma
Copy link
Contributor Author

gentle ping @SunMarc @ArthurZucker again.

@SunMarc
Copy link
Member

SunMarc commented Feb 5, 2025

Merging this as this only concerns XPU workflow.

@SunMarc SunMarc merged commit 315a9f4 into huggingface:main Feb 5, 2025
17 checks passed
MekkCyber pushed a commit that referenced this pull request Feb 7, 2025
…ling files (#35647)

* add xpu for unmask

* change modular for generated matching

* add lastest modeling for helium
@ArthurZucker
Copy link
Collaborator

Sorry @Liangliang-Ma ! And thanks for the fix 🤗

@ArthurZucker ArthurZucker removed the request for review from zucchini-nlp February 13, 2025 08:51
elvircrn pushed a commit to elvircrn/transformers that referenced this pull request Feb 13, 2025
…ling files (huggingface#35647)

* add xpu for unmask

* change modular for generated matching

* add lastest modeling for helium
sbucaille pushed a commit to sbucaille/transformers that referenced this pull request Feb 16, 2025
…ling files (huggingface#35647)

* add xpu for unmask

* change modular for generated matching

* add lastest modeling for helium
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants