-
-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] Add Load-time W8A16 quantization for TPU Backend #7005
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, this looks reasonable to me. The important note is to lazily import the torch xla function when needed, rather than at the top of the quant file.
Super cool! As a follow up, we can work on hooking this up to some of the existing checkpoints we have in addition to inplace quantization |
By chance, what schemes does the following support:
Channelwise? |
f4b8dd7
to
b1a04b3
Compare
Hi @mgoin, @robertgshaw2-neuralmagic, Thank your for reviewing my PR! Excited to work with you to enable quantization for TPU backend through
We have the quantized ops (Equivalent to the quantized cuda kernels in vLLM, but for TPU) in PyTorch/XLA here. The quantized matmul kernel is registered as a torch op and is compatible with |
This is so awesome!!!! Running the same compressed models on various hardware backends is going to be an awesome feature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lsy323 Thanks for the PR! I'm really looking forward to using this feature!
I think we have two things to figure out on this PR:
- Where to put this quantization config and the linear method? Do we want to put this as a new quantization config (like in the current PR) or in
compressed-tensors
? - IIRC, this currently does not support the case when the BF16 weights exceed the HBM size of the TPU while INT8 weights do (e.g., Llama 8B on TPUv5e which has 16 GB HBM). Could you please remind us of why this isn't supported?
TorchXLA:TPU |
I think for
Sure, the current flow is:
When the BF16 weights exceed the HBM size of the TPU, step 1 would hit OOM. To avoid this problem, we can delay weight transferring if load-time quantization is enabled. |
I don't have a crystal plan for this, alternatives are as follows:
|
Hey guys - there are a couple considerations here. For vLLM, we want to support both cases:
We will be making all go forward checkpoints inside the Both |
hi @WoosukKwon, I looked into this in detail, it doesn't look like to be a straightforward change, I think we can consider support that in a separate PR. In the current flow, weights are moved to device as the model is initialized (ref), then load time quant will be done on device ref. We need to introduce a new flow to support this case. |
@lsy323 Seems like my previous comment was not addressed for some reason. Can you please check it again? |
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@WoosukKwon Somehow I force pushed without the suggested change commits. Now should be fixed. Thank you for reminding! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the PR! It works well on my machine 🎉 🎉
Looking forward to the next step! (adding INT8 activation quantization in tpu-int8
).
…7005) Signed-off-by: Alvant <alvasian@yandex.ru>
Add Load-time W8A16 quantization for TPU Backend. The workflow is similar to the existing load-time fp8 quantization. Open the PR to help discussion process.
tpu_int8
for load-time int8 weight only quantization for tpu Backend. (e.g.LLM(model="google/gemma-2b", quantization="tpu_int8"
)TPUInt8LinearMethod
which quantizesbfloat16
weights toint8
weights for linear layers, and calls TPU quantized ops inforward
.