Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] 为什么不用pytorch原生的sdpa,反而用flash attention呢? #519

Closed
yangtian6781 opened this issue Aug 20, 2024 · 7 comments

Comments

@yangtian6781
Copy link

Motivation

想问一下internvl为什么要用flash attention而不用sdpa呢,这个pr上面说sdpa是要比flash attention快的:https://github.com/huggingface/transformers/pull/31940#issuecomment-2228246233,我把internvl的flash attention替换成sdpa后,发现模型最后的输出会有差异,模型会有不同的回答,但sdpa输出的logits值和fa的logits余弦相似度在0.99以上,替换成sdpa去复现模型会是一个很大的issue吗?

Related resources

No response

Additional context

No response

@yangtian6781
Copy link
Author

@czczup
Copy link
Member

czczup commented Aug 20, 2024

现在还是用flash attention的人比较多吧

@yangtian6781
Copy link
Author

@czczup 感谢您的回答,但我想确认的是,sdpa和fa输出的logits数值上的差异,对于我们复现并重新训练internvl来说,是否是个issue?以及我在哪个pr中提到的sdpa比fa更快,您是否有计划在未来把fa换成sdpa?毕竟sdpa是pytorch的原生支持,而且transfomers库里面的模型大多也支持sdpa

@czczup
Copy link
Member

czczup commented Aug 20, 2024

@czczup 感谢您的回答,但我想确认的是,sdpa和fa输出的logits数值上的差异,对于我们复现并重新训练internvl来说,是否是个issue?以及我在哪个pr中提到的sdpa比fa更快,您是否有计划在未来把fa换成sdpa?毕竟sdpa是pytorch的原生支持,而且transfomers库里面的模型大多也支持sdpa

微小的数值差异可以忽略吧,应该问题不大。sdpa我们会考虑,但估计不会特别快就ready。

@yangtian6781
Copy link
Author

感谢,最后还有一个小问题想问您,下面三种rmsnorm的实现应该都是对的?但他们输出的数值各不相同,三个输出两两之间的余弦相似度都在0.99以上,那么这种数值差异也是不是可以忽略呢?

class InternLM2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
class My_rmsnorm(nn.RMSNorm):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__(normalized_shape=hidden_size)
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        out = super().forward(hidden_states.type(torch.float32))
        return out.type(input_dtype)
    from functools import partial

    from apex.normalization import FusedRMSNorm
    InternLM2RMSNorm = partial(FusedRMSNorm, eps=1e-6)

@yangtian6781
Copy link
Author

我一直都在担心各种库处理数据,神经网络前向传播的不同实现导致的数值差异,尽管这种数值差异是微小的,您作为一名资深的研究员,是否认为这些差异是不需要在意的呢?期待您的回复

@czczup
Copy link
Member

czczup commented Aug 27, 2024

基本上可以忽略吧,如果要在benchmark上测点数,可能会有微小的影响。比如上面这个norm的三种方式,测出来的点数肯定是不一样的,但基本就是0.x的小波动。

@czczup czczup closed this as completed Sep 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants