Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

workaround of AWQ for Turing GPUs #1252

Merged
merged 3 commits into from
Oct 11, 2023
Merged

Conversation

twaka
Copy link
Contributor

@twaka twaka commented Oct 3, 2023

As far as I saw, only mma.sync.aligned.m16n8k16 op requires sm_80. For sm_75, using two mma.sync.aligned.m16n8k8 op can yield the same result.
It may not be optimal for performance but it works at least for who wants try AWQ with Turing GPUs.

  • example with tesla t4
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

model_name = "casperhansen/vicuna-7b-v1.5-awq"

llm = LLM(model=model_name, quantization="AWQ")

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

# Prompt: 'Hello, my name is', Generated text: " Dustin Nelson and I'm a writer and podcaster based in"
# Prompt: 'The president of the United States is', Generated text: ' elected indirectly through the Electoral College, which is composed of electors from'
# Prompt: 'The capital of France is', Generated text: ' Paris. At the center of Paris, you will find the Seine River. The'
# Prompt: 'The future of AI is', Generated text: ' a topic of much debate and speculation. Alteryx Inspire is'

@casper-hansen
Copy link
Contributor

@twaka Thanks for this work. Would you mind upstreaming this into AutoAWQ as well?

@twaka
Copy link
Contributor Author

twaka commented Oct 4, 2023

@casper-hansen I'm glad to see it added to autoawq and excellent benchmark results! To be honest, I was not yet familiar enough to find where need to be changed in autoawq codebase ;)

@esmeetu
Copy link
Collaborator

esmeetu commented Oct 4, 2023

@twaka Hi, Thanks for sharing this workaround. It works for me using 2 tensor parallism. When i use TP4, it will throw error: Group size should be a multiple of 32 in gemm kernel. And my awq model uses group size 128 and 4bit. Do you have ideas?

@twaka
Copy link
Contributor Author

twaka commented Oct 4, 2023

@esmeetu I'm happy to hear it works with TP2. Though I don't have environment to run with TP4, I think we can isolate your issue by running TP4 with Ampere GPUs to see if the error persists.

@esmeetu
Copy link
Collaborator

esmeetu commented Oct 4, 2023

@twaka Yes, it should be a common problem. i opened a new issue about that.
Another question, Does the compution precision using m16n8k8 the same as m16n8k16?

@casper-hansen
Copy link
Contributor

I observed the same output from T4 vs other GPUs. I will see if I have time to measure perplexity before I merge it - I expect it to be the same.

4 vs 2 instructions is obviously going to be slower but it’s still decently fast on T4 to the point of being usable.

@wasertech
Copy link

It's a bit tricky to test this branch. I had to fork my own branch and merge #1290 to be able to build, and set TORCH_CUDA_ARCH_LIST="7.0" to avoid #1225 but I still find myself with a RuntimeError: CUDA error: device-side assert triggered. (I'm in nvcr.io/nvidia/pytorch:22.12-py3 updated pip and git and uninstalled torch to let vllm install the correct version).

The root cause looks to be an Assertion error with csrc/quantization/awq/gemm_kernels.cu:33: void vllm::awq::gemm_forward_4bit_cuda_m16n128k32 -> Assertion false failed. See the full traceback here.

Has anyone found a reliable way to test this branch on Turing GPUs? Let me know I would gladly appreciate to test the AWQ export of my model with vLLM.

@casper-hansen
Copy link
Contributor

When you merged the PR into your fork, you forgot to merge the updated kernels from this PR @wasertech

@wasertech
Copy link

It works! I’ve seen a remarkable improvement, going from approximately 39 tokens per second to a speedy 86! What’s even more impressive is that the model size is now less than 4 GB on disk. I’d like to extend a special shoutout to @casper-hansen for consistently steering me in the right direction and to @twaka for making the process of quantization on vLLM much more accessible. Of course, a heartfelt thanks to everyone who made this achievement possible!

@esmeetu
Copy link
Collaborator

esmeetu commented Oct 9, 2023

@wasertech I am curious that which model size do you use? I cannot achieve that speed on my T4. I have 35t/s for 7b-awq model and 7b-FP16 only have 16t/s.

@wasertech
Copy link

wasertech commented Oct 10, 2023

@esmeetu , you’ve certainly doubled your token throughput 😅. I used to use assistant-llama2-7b-chat, but now I’m currently using assistant-llama2-7b-chat-awq , which are both fine-tuned models based on a peft adapter for QLoRA, from Photolens/llama-2-7b-langchain-chat.

Edit: I’m not using a T4 but a RTX Titan which can explain the difference in throughput.

@esmeetu
Copy link
Collaborator

esmeetu commented Oct 10, 2023

@wasertech Yeah, but i meant how did you get 86tokens/s? Might batch size is 2?🫨

@wasertech
Copy link

wasertech commented Oct 10, 2023

@wasertech Yeah, but i meant how did you get 86tokens/s? Might batch size is 2?🫨

@esmeetu I’m using the default batch size, I took the server example code and slightly modified it here to only stream back the output and not the input with it.

I really think the GPU is the biggest difference at play here…

@esmeetu
Copy link
Collaborator

esmeetu commented Oct 10, 2023

@wasertech Which GPU do you use?

@wasertech
Copy link

I’m not using a T4 but a RTX Titan which can explain the difference in throughput.

@esmeetu https://www.nvidia.com/en-us/deep-learning-ai/products/titan-rtx/

@esmeetu
Copy link
Collaborator

esmeetu commented Oct 10, 2023

@wasertech Thank you! I'm sorry that I missed that message you sent.😆
RTX titan is super fast.🤩

@WoosukKwon WoosukKwon self-requested a review October 10, 2023 02:16
@casper-hansen
Copy link
Contributor

As I noted in my PR in AutoAWQ, this PR in vLLM enables older GPUs:

  • RTX 2000-series (+TITAN RTX)
  • RTX 1600-series
  • Quadra RTX (5000, 6000, 8000)
  • Tesla T4

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@twaka Thanks for the fix! Sorry the late reply, I was very busy for the last week. Left a very minor comment. Please check it out.

csrc/quantization/awq/gemm_kernels.cu Outdated Show resolved Hide resolved
csrc/quantization/awq/gemm_kernels.cu Outdated Show resolved Hide resolved
@WoosukKwon
Copy link
Collaborator

BTW, I've also checked that this PR works for 1 and 2 T4 GPUs. @twaka Thanks for the great work!

twaka and others added 2 commits October 11, 2023 10:42
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@twaka
Copy link
Contributor Author

twaka commented Oct 11, 2023

@WoosukKwon Thanks, updated.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@twaka LGTM! Many thanks again for the fix!

@WoosukKwon WoosukKwon merged commit 8285736 into vllm-project:main Oct 11, 2023
@twaka twaka deleted the awq-sm_75 branch October 11, 2023 07:35
@mingyangAbc
Copy link

mingyangAbc commented Jan 2, 2024

It's a bit tricky to test this branch. I had to fork my own branch and merge #1290 to be able to build, and set TORCH_CUDA_ARCH_LIST="7.0" to avoid #1225 but I still find myself with a RuntimeError: CUDA error: device-side assert triggered. (I'm in nvcr.io/nvidia/pytorch:22.12-py3 updated pip and git and uninstalled torch to let vllm install the correct version).

The root cause looks to be an Assertion error with csrc/quantization/awq/gemm_kernels.cu:33: void vllm::awq::gemm_forward_4bit_cuda_m16n128k32 -> Assertion false failed. See the full traceback here.

Has anyone found a reliable way to test this branch on Turing GPUs? Let me know I would gladly appreciate to test the AWQ export of my model with vLLM.

I use the project(https://github.com/casper-hansen/AutoAWQ) quantification yi-34b-chat model, then i run vllm demo also meet “RuntimeError: CUDA error: device-side assert triggered”. How did you solve it?

code:

from vllm import LLM, SamplingParams
prompts = [
    "Hello, my name is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="/Yi/quantized_model", tensor_parallel_size=1, trust_remote_code=True)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

log:

2024-01-02 20:21:56,878 - modelscope - INFO - PyTorch version 2.1.2+cu118 Found.
2024-01-02 20:21:56,879 - modelscope - INFO - Loading ast index from /root/.cache/modelscope/ast_indexer
2024-01-02 20:21:56,918 - modelscope - INFO - Loading done! Current index file version is 1.9.5, with md5 6cc4bf9c033540e7f2386c0058d4f4b4 and a total number of 945 components indexed
[2024-01-02 20:21:59,535] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
WARNING 01-02 20:21:59 config.py:179] awq quantization is not fully optimized yet. The speed can be slower than non-quantized models.
INFO 01-02 20:21:59 llm_engine.py:73] Initializing an LLM engine with config: model='/Yi/quantized_model', tokenizer='/Yi/quantized_model', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=awq, enforce_eager=False, seed=0)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
INFO 01-02 20:22:21 llm_engine.py:223] # GPU blocks: 3168, # CPU blocks: 1092
INFO 01-02 20:22:25 model_runner.py:394] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 01-02 20:22:40 model_runner.py:437] Graph capturing finished in 15 secs.
Processed prompts: 0%| | 0/1 [00:00<?, ?it/s]Traceback (most recent call last):
File "/tmp/pycharm_project_314/test/test.py", line 32, in
outputs = llm.generate(prompts, sampling_params)
File "/root/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 165, in generate
return self._run_engine(use_tqdm)
File "/root/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 185, in _run_engine
step_outputs = self.llm_engine.step()
File "/root/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 581, in step
output = self._run_workers(
File "/root/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 755, in _run_workers
self._run_workers_in_batch(workers, method, *args, **kwargs))
File "/root/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 729, in _run_workers_in_batch
output = executor(*args, **kwargs)
File "/root/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/root/lib/python3.10/site-packages/vllm/worker/worker.py", line 159, in execute_model
output = self.model_runner.execute_model(seq_group_metadata_list,
File "/root/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/root/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 354, in execute_model
output = self.model.sample(
File "/root/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 295, in sample
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
File "/root/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/lib/python3.10/site-packages/vllm/model_executor/layers/sampler.py", line 58, in forward
do_min_p) = SamplingTensors.from_sampling_metadata(
File "/root/lib/python3.10/site-packages/vllm/model_executor/sampling_metadata.py", line 131, in from_sampling_metadata
sampling_tensors = SamplingTensors.from_lists(
File "/root/lib/python3.10/site-packages/vllm/model_executor/sampling_metadata.py", line 218, in from_lists
temperatures=temperatures_t.to(device=device, non_blocking=True),
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Processed prompts: 0%| | 0/1 [00:00<?, ?it/s]
../aten/src/ATen/native/cuda/Indexing.cu:1239: indexSelectSmallIndex: block: [34,0,0], thread: [0,0,0] Assertion srcIndex < srcSelectDimSize failed.

nvcc -V:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

pip list:

absl-py 2.0.0
accelerate 0.25.0
addict 2.4.0
aiofiles 23.2.1
aiohttp 3.9.1
aioprometheus 23.3.0
aiosignal 1.3.1
aliyun-python-sdk-core 2.14.0
aliyun-python-sdk-kms 2.16.2
altair 5.2.0
annotated-types 0.6.0
anyio 3.7.1
async-timeout 4.0.3
attributedict 0.3.0
attrs 23.1.0
auto-gptq 0.5.1
autoawq 0.1.7+cu118
bitsandbytes 0.41.2.post2
blessings 1.7
cachetools 5.3.2
certifi 2023.11.17
cffi 1.16.0
chardet 5.2.0
charset-normalizer 3.3.2
click 8.1.7
cmake 3.25.0
codecov 2.1.13
colorama 0.4.6
coloredlogs 15.0.1
colour-runner 0.1.1
contourpy 1.2.0
coverage 7.4.0
cpm-kernels 1.0.11
crcmod 1.7
cryptography 41.0.7
cycler 0.12.1
DataProperty 1.0.1
datasets 2.16.1
deepdiff 6.7.1
deepspeed 0.12.5
dill 0.3.6
distlib 0.3.8
docstring-parser 0.15
einops 0.7.0
evaluate 0.4.1
exceptiongroup 1.2.0
fastapi 0.104.1
ffmpy 0.3.1
filelock 3.13.1
fire 0.5.0
fonttools 4.46.0
frozenlist 1.4.0
fsspec 2023.10.0
gast 0.5.4
gekko 1.0.6
gevent 22.10.2
google-auth 2.25.1
google-auth-oauthlib 1.1.0
gradio 3.50.2
gradio_client 0.6.1
greenlet 3.0.1
grpcio 1.59.3
h11 0.14.0
hjson 3.1.0
httpcore 1.0.2
httptools 0.6.1
httpx 0.25.2
huggingface-hub 0.19.4
humanfriendly 10.0
idna 3.6
importlib-metadata 7.0.0
importlib-resources 6.1.1
inspecta 0.1.3
jieba 0.42.1
Jinja2 3.1.2
jmespath 0.10.0
joblib 1.3.2
jsonlines 4.0.0
jsonschema 4.20.0
jsonschema-specifications 2023.11.2
kiwisolver 1.4.5
lit 15.0.7
lm_eval 0.4.0
lxml 5.0.0
Markdown 3.5.1
markdown-it-py 3.0.0
MarkupSafe 2.1.3
matplotlib 3.8.2
mbstrdecoder 1.1.3
mdurl 0.1.2
modelscope 1.9.5
mpmath 1.3.0
ms-swift 1.6.0.dev0
msgpack 1.0.7
multidict 6.0.4
multiprocess 0.70.14
networkx 3.2.1
ninja 1.11.1.1
nltk 3.8.1
numexpr 2.8.8
numpy 1.26.2
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.3.101
nvidia-nvtx-cu12 12.1.105
oauthlib 3.2.2
optimum 1.14.1
ordered-set 4.1.0
orjson 3.9.10
oss2 2.18.3
packaging 23.2
pandas 2.1.3
pathvalidate 3.2.0
peft 0.7.1
Pillow 10.1.0
pip 23.3.2
platformdirs 4.1.0
pluggy 1.3.0
portalocker 2.8.2
protobuf 4.23.4
psutil 5.9.6
py-cpuinfo 9.0.0
pyarrow 14.0.1
pyarrow-hotfix 0.6
pyasn1 0.5.1
pyasn1-modules 0.3.0
pybind11 2.11.1
pycparser 2.21
pycryptodome 3.19.0
pydantic 1.10.13
pydantic_core 2.14.5
pydub 0.25.1
Pygments 2.17.2
pynvml 11.5.0
pyparsing 3.1.1
pyproject-api 1.6.1
pytablewriter 1.2.0
python-dateutil 2.8.2
python-dotenv 1.0.0
python-multipart 0.0.6
pytz 2023.3.post1
PyYAML 6.0.1
quantile-python 1.1
ray 2.9.0
referencing 0.31.1
regex 2023.10.3
requests 2.31.0
requests-oauthlib 1.3.1
responses 0.18.0
rich 13.7.0
rootpath 0.1.1
rouge 1.0.1
rouge-chinese 1.0.3
rouge-score 0.1.2
rpds-py 0.13.2
rsa 4.9
sacrebleu 2.4.0
safetensors 0.4.1
scikit-learn 1.3.2
scipy 1.11.4
semantic-version 2.10.0
sentencepiece 0.1.99
setuptools 68.0.0
shtab 1.6.5
simplejson 3.19.2
six 1.16.0
sniffio 1.3.0
sortedcontainers 2.4.0
sqlitedict 2.1.0
sse-starlette 1.8.2
starlette 0.27.0
sympy 1.12
tabledata 1.3.3
tabulate 0.9.0
tcolorpy 0.1.4
tensorboard 2.15.1
tensorboard-data-server 0.7.2
termcolor 2.4.0
texttable 1.7.0
threadpoolctl 3.2.0
tiktoken 0.5.2
tokenizers 0.15.0
toml 0.10.2
tomli 2.0.1
toolz 0.12.0
torch 2.1.2+cu118
torchaudio 2.1.0+cu118
torchvision 0.16.0+cu118
tox 4.11.4
tqdm 4.66.1
tqdm-multiprocess 0.0.11
transformers 4.35.2
transformers-stream-generator 0.0.4
triton 2.1.0
trl 0.7.7
typepy 1.3.2
typing_extensions 4.8.0
tyro 0.6.0
tzdata 2023.3
urllib3 2.1.0
uvicorn 0.24.0.post1
uvloop 0.19.0
virtualenv 20.25.0
vllm 0.2.6+cu118
watchfiles 0.21.0
websockets 11.0.3
Werkzeug 3.0.1
wheel 0.41.2
xformers 0.0.23.post1+cu118
xxhash 3.4.1
yapf 0.40.2
yarl 1.9.3
zipp 3.17.0
zope.event 5.0
zope.interface 6.1
zstandard 0.22.0

@wasertech
Copy link

wasertech commented Jan 2, 2024

I use the project(https://github.com/casper-hansen/AutoAWQ) quantification yi-34b-chat model, then i run vllm demo also meet “RuntimeError: CUDA error: device-side assert triggered”. How did you solve it?

@mingyangAbc as @casper-hansen greatly pointed out:
I just forgot to build vLLM for my target, once I noticed I just rebuilt for the correct target and it worked flawlessly. Now that it's on main, you should be able to just use any pre-compiled binary.

Anyways, I don't really remember what I did exactly, but here is the gist:

Using docker I built vLLM and that's all.

DOCKER_BUILDKIT=1 docker build . --target vllm --tag vllm --build-arg max_jobs=24

My mistake was to try to built with the wrong TORCH_CUDA_ARCH_LIST set in the env which prevented to build with the patches in this branch and so it gave me the assertion error.

In any case your issue might not be caused by the same situation. Best to probably open a new issue not tied to a PR merged 3 months ago. Also maybe your model (Yi-34B) doesn't support AWQ quantization atm? See there is so much difference from your case to mine, you should just open a proper issue.

@mingyangAbc
Copy link

I use the project(https://github.com/casper-hansen/AutoAWQ) quantification yi-34b-chat model, then i run vllm demo also meet “RuntimeError: CUDA error: device-side assert triggered”. How did you solve it?

@mingyangAbc as @casper-hansen greatly pointed out: I just forgot to build vLLM for my target, once I noticed I just rebuilt for the correct target and it worked flawlessly. Now that it's on main, you should be able to just use any pre-compiled binary.

Anyways, I don't really remember what I did exactly, but here is the gist:

Using docker I built vLLM and that's all.

DOCKER_BUILDKIT=1 docker build . --target vllm --tag vllm --build-arg max_jobs=24

My mistake was to try to built with the wrong TORCH_CUDA_ARCH_LIST set in the env which prevented to build with the patches in this branch and so it gave me the assertion error.

In any case your issue might not be caused by the same situation. Best to probably open a new issue not tied to a PR merged 3 months ago. Also maybe your model (Yi-34B) doesn't support AWQ quantization atm? See there is so much difference from your case to mine, you should just open a proper issue.

ok, thanks. I will open a new issue.

@cduk
Copy link
Contributor

cduk commented May 24, 2024

I was just wondering whether support can be extended to Pascal class GPUs such as P100? I'm not sure which intrinsics are missing compared to Turing (if any).

@wasertech
Copy link

wasertech commented Sep 10, 2024

@cduk Unfortunately, your GPU has a compute capability score of 6.0, which is insufficient for AWQ quantization that requires a compute capability of 7.5 or above. The Pascal architecture lacks several key features present in more recent architectures like Turing, such as:

  • Tensor Cores: Essential for accelerating deep learning workloads.
  • Enhanced concurrent execution of floating-point and integer operations.
  • Faster memory support including GDDR6.
    These advancements allow newer GPUs to handle more complex and demanding computational tasks efficiently such as quantisation.

This PR is already a god sent for the Turing architecture as we need to compute two operations where newer architectures (>=80) only have to compute one...

I'm not saying it couldn't be done but it would require more computation steps (4, maybe 8 or more) and therefor would be slow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants