-
Notifications
You must be signed in to change notification settings - Fork 553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] 为什么不用pytorch原生的sdpa,反而用flash attention呢? #519
Comments
现在还是用flash attention的人比较多吧 |
@czczup 感谢您的回答,但我想确认的是,sdpa和fa输出的logits数值上的差异,对于我们复现并重新训练internvl来说,是否是个issue?以及我在哪个pr中提到的sdpa比fa更快,您是否有计划在未来把fa换成sdpa?毕竟sdpa是pytorch的原生支持,而且transfomers库里面的模型大多也支持sdpa |
微小的数值差异可以忽略吧,应该问题不大。sdpa我们会考虑,但估计不会特别快就ready。 |
感谢,最后还有一个小问题想问您,下面三种rmsnorm的实现应该都是对的?但他们输出的数值各不相同,三个输出两两之间的余弦相似度都在0.99以上,那么这种数值差异也是不是可以忽略呢?
|
我一直都在担心各种库处理数据,神经网络前向传播的不同实现导致的数值差异,尽管这种数值差异是微小的,您作为一名资深的研究员,是否认为这些差异是不需要在意的呢?期待您的回复 |
基本上可以忽略吧,如果要在benchmark上测点数,可能会有微小的影响。比如上面这个norm的三种方式,测出来的点数肯定是不一样的,但基本就是0.x的小波动。 |
Motivation
想问一下internvl为什么要用flash attention而不用sdpa呢,这个pr上面说sdpa是要比flash attention快的:https://github.com/huggingface/transformers/pull/31940#issuecomment-2228246233,我把internvl的flash attention替换成sdpa后,发现模型最后的输出会有差异,模型会有不同的回答,但sdpa输出的logits值和fa的logits余弦相似度在0.99以上,替换成sdpa去复现模型会是一个很大的issue吗?
Related resources
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: