Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ascend NPU support for nf4 quant #1422

Conversation

ji-huazhong
Copy link
Contributor

@ji-huazhong ji-huazhong commented Nov 21, 2024

What does this PR do?

This PR adds Ascend NPU support for nf4 quant/dequant and allows QLoRA fine-tuning for LLMs using transformers, peft, and trl.

You may notice that the nf4 quantization method is currently implemented in PyTorch. This interim measure is due to the fact that the high-performance version implemented with AscendC is still in progress 😞 . Meanwhile, we've received feedback from many in the Ascend NPU community expressing their keen interest in using QLoRA to fine-tune LLMs as soon as possible, so there is this PR.

Related PR: huggingface/transformers#31512

Collaborators

@SlightwindSec @Ginray @MatrixPlayer

cc @Titus-von-Koeller @matthewdouglas

@ji-huazhong
Copy link
Contributor Author

ji-huazhong commented Nov 21, 2024

asciicast

Refer to this blog, I did a E2E test on the llama2-7b-hf with QLoRA fine-tuning in my env with NPU device, it works 🤗.

Here is the script I used.

@baymax591
Copy link

Thanks a lot for sharing this PR and the video demo! Thanks to the demo, I was able to successfully run NF4 quant/dequant on the NPU with ease. The detailed explanation in the video really helped me understand the process and key steps. Looking forward to more updates in the future—great work!

@baymax591
Copy link

I hope this PR can be merged soon, as it provides valuable improvements. Looking forward to seeing it merged!
cc @Titus-von-Koeller

@SunMarc
Copy link
Contributor

SunMarc commented Nov 27, 2024

Nice work and thanks for the demo ! Can you have a look @matthewdouglas ?

Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@matthewdouglas
Copy link
Member

I will be able to look in more detail next week, but at first glance it looks nice. Thanks @statelesshz !

ji-huazhong and others added 3 commits December 5, 2024 11:25
Co-authored-by: Slightwind <slightwindsec@gmail.com>
Co-authored-by: Ginray <ginray0215@gmail.com>
@matthewdouglas
Copy link
Member

@statelesshz We really appreciate the contribution! Apart from a lint check, I think we can go ahead and merge this.

For awareness, we are currently planning to adopt usage of torch.library to register custom ops for multi-backend. When we've progressed further on that we will want to come back and migrate to an updated interface.

@ji-huazhong
Copy link
Contributor Author

@statelesshz We really appreciate the contribution! Apart from a lint check, I think we can go ahead and merge this.

For awareness, we are currently planning to adopt usage of torch.library to register custom ops for multi-backend. When we've progressed further on that we will want to come back and migrate to an updated interface.

@matthewdouglas Thank you for the feedback. I have addressed the lint check warnings, and I think the PR is now ready for merging. 🤗
Could you please re-trigger the CI to ensure everything is in order?

@matthewdouglas matthewdouglas merged commit 9948333 into bitsandbytes-foundation:multi-backend-refactor Dec 6, 2024
2 checks passed
@ji-huazhong ji-huazhong deleted the npu-backend branch December 9, 2024 11:51
@@ -519,7 +519,12 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]

# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
if A.device.type == "npu":

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick question: why not we use torch.nn.functional.linear directly? Thanks in advance for your answer

rsshaik1 pushed a commit to bhargaveede/bitsandbytes that referenced this pull request Jan 10, 2025
* Add npu support for nf4 quant

Co-authored-by: Slightwind <slightwindsec@gmail.com>
Co-authored-by: Ginray <ginray0215@gmail.com>

* code format

* update

* pass lint check and fix typos

* add npu to supported devices

---------

Co-authored-by: Slightwind <slightwindsec@gmail.com>
Co-authored-by: Ginray <ginray0215@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants