Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AWQ quant support #762

Closed
wants to merge 30 commits into from
Closed

Conversation

ri938
Copy link
Contributor

@ri938 ri938 commented Aug 14, 2023

test:

python -m vllm.entrypoints.api_server --model rirv938/wizard-vicuna-13b-uncensored-awq-4bit-g128 --quantization awq

Issues

  1. some TODOs to resolve (e.g hard coding device for loading quantised layers)
  2. currently only supports Llama (not intending to add support for this in this PR)
  3. currently does not support tensor parallelism
  4. it scales poorly with larger batch sizes. Would be good for more optimisation. (I think this is a seperate PR after / community work)

@ri938 ri938 mentioned this pull request Aug 14, 2023
@ri938
Copy link
Contributor Author

ri938 commented Aug 14, 2023

I tried different methods of quantization and AWQ performed by far the best. But scales poorly with batch size. Its good enough to reduce out 13B hosting costs by 20% but I think it can be improved a lot more with some optimization work.
3

(I am actually doing n=4 here so batch size is arguably 4x)

@casper-hansen
Copy link
Contributor

casper-hansen commented Aug 14, 2023

Great benchmarks. I was looking into implementing this myself but was waiting on your implementation. Here are my 2 cents.

EDIT: Everyone should also note that the GEMM kernels are optimized for Ampere and later architectures (e.g. RTX 3000-4000, A5000, A6000, A100, H100, etc.), i.e. it is unlikely to work well on a V100 GPU. However, I would argue this does not matter as using a V100 would be vastly inferior in terms of both cost and speed for deployment.

Design question

I noticed a few extra Quant classes need to be added for every model. Here are my thoughts on how it could (potentially) be reduced to a simpler method.

Instead of implementing QuantLlamaMLP and other classes for each part of every model, why not implement the replacement of the Linear layers at a lower level? For example, in RowParallelLinear - F.linear() and ColumnParallelLinear - F.linear().

e.g. a very naive example to instantiate:

if quant_config.method is not None:
    self.linear = get_quantized_layer(in_features, out_features, quant_config)
else:
    self.linear = F.linear

This way, you don't have to modify the model files that much since you could just pass down your quant_config and decide on a lower level. In summary, we could simplify the code and make it easier to extend in the future.

Replacing activations

I noticed activations are not replaced (not sure if you tested this). In AWQ, they also replace activations in some functions with a ScaledActivation. Not sure if this makes a difference, but wanted to highlight it.

@ri938
Copy link
Contributor Author

ri938 commented Aug 15, 2023

@casperbh96 thanks for the feedback. Going to look into improving some of the design.

Scaled activations

https://github.com/mit-han-lab/llm-awq/blob/main/awq/quantize/quantizer.py#L14

looking at the AWQ code it seems that is only applied to MPT, Bloom and Falcon models and has no effect for Llama. Plus I have run some tests of inference quality and it seems to be fine.

@ri938 ri938 changed the title Draft: Add awq quant support Add AWQ quant support Aug 15, 2023
@ri938
Copy link
Contributor Author

ri938 commented Aug 15, 2023

Updated with some improvements + removed the draft / WIP status.

@casper-hansen
Copy link
Contributor

casper-hansen commented Aug 15, 2023

I have tested some models, your 13B Vicuna model and LLaMa 7B. These models are solely measured on tokens/s instead of throughput. Hardware is RTX 3090 + Threadripper Pro 3955WX. Multiple prompts are measured individually.

TLDR: The performance is seemingly getting up to 85-90% of the original work.

Model Number of Prompts Tokens per second
Vicuna 13B (vLLM, PR) 1 55-58
Vicuna 13B (vLLM, PR) 5 48-51
Vicuna 13B (AWQ TinyChat) 1 64
Vicuna 13B (AWQ TinyChat) 5 N/A (Not supported)
LLaMa 2 7B (vLLM, PR) 1 80-86
LLaMa 2 7B (vLLM, PR) 5 71-72
LLaMa 2 7B (AWQ TinyChat) 1 89-90
LLaMa 2 7B (AWQ TinyChat) 5 N/A (Not supported)

Note: I also tested A100 and RTX 6000 Ada, but they are not yielding better results.

vLLM example:

import time
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Write me a letter to Sam Altman",
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is"
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="rirv938/wizard-vicuna-13b-uncensored-awq-4bit-g128", **{'quantization': 'awq'})

start = time.time()

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    tokens = output.outputs[0].token_ids
    end = time.time()
    elapsed = end-start
    
    print(output)
    print(len(tokens) / elapsed, 'tokens/s')

TinyChat example (need to git clone huggingface first):

python3 demo.py --model_type llama --model_path wizard-vicuna-13b-uncensored-awq-4bit-g128 --q_group_size 128 --load_quant wizard-vicuna-13b-uncensored-awq-4bit-g128/wizard-vicuna-13b-w4-g128-awq.bin --precision W4A16

@ri938
Copy link
Contributor Author

ri938 commented Aug 15, 2023

@casperbh96 looks like tinychat does some other things like fusing layers. Might be a more efficient implementation.

In particular tinychat does

if args.precision == "W4A16" and args.model_type.lower() == 'llama':
        from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
        make_quant_attn(model, args.device)
        make_quant_norm(model)
        make_fused_mlp(model)

For A6000 when you say it didn't yield better results do you mean that they performaned worse or comparable? I'm particular interested in this because intend to deploy to A5000 / A6000 hardware.

@casper-hansen
Copy link
Contributor

@casperbh96 looks like tinychat does some other things like fusing layers. Might be a more efficient implementation.

For A6000 when you say it didn't yield better results do you mean that quantized was no faster than unquantized?

TinyChat has a few extra things happening, yes. Performance discrepancy is so small that I would not focus on it, but if you wanted to, you should focus on the T5LayerNorm kernel that they adapted from FasterTransformer.

I meant that A100, A6000, 4090, 3090 all yield similar results on the quantized models. A6000 being a little slower than the others though. This is to be expected and is probably due to CPU being the bottleneck.

@ri938
Copy link
Contributor Author

ri938 commented Aug 15, 2023

throughput increases lots when QKV and gate + up proj layers are merged.

EDIT: initial estimates of throughput increases were incorrect, its only a modest improvement I think.

@casper-hansen
Copy link
Contributor

casper-hansen commented Aug 15, 2023

throughput increases lots when QKV and gate + up proj layers are merged. Working on this now.

Good to hear! How generalizable is this to other models like MPT?

@ri938
Copy link
Contributor Author

ri938 commented Aug 15, 2023

only tried with llama. Merging the linear layers I am assuming can be done for most models with attention blocks. But I dont know much about MPT model.

@casper-hansen
Copy link
Contributor

casper-hansen commented Aug 15, 2023

only tried with llama. Merging the linear layers I am assuming can be done for most models with attention blocks. But I dont know much about MPT model.

From what I could find it's only LLaMa models that can have their qkv projection fused because they are the only ones that have one linear layer for each q,k,v which makes them slower. So LLaMa and InternLM seems like the one's that can benefit from this.

Falcon, MPT, qwen, baichuan models have their qkv operations fused already, so it should be optimized by AWQ quantization already since they defined it like this:

self.Wqkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)

Versus LLaMa:

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

@jhartman
Copy link

I see this supports 4 bits. Is there a plan to add support for 8 bit quantization?

@ri938
Copy link
Contributor Author

ri938 commented Aug 21, 2023

ultimately im finding A100 unquantised is cheaper than quantised on A5000 or A6000. In other words the cheaper hardware benefit is not making it cheaper overall. Thats why I think need better CUDA kernels for this.

@jhartman @casperbh96
you guys use discord? if want to connect to discuss more can add me "robert1" on discord.

@casper-hansen
Copy link
Contributor

casper-hansen commented Aug 21, 2023

ultimately im finding A100 unquantised is cheaper than quantised on A5000 or A6000.

That seems surprising to me. Can you please try my branch? I integrated AWQ into RowParallel and ColumnParallel, The implementation is a little hacky for loading the model, but worked fine for me. I suspect maybe this could make a difference?

https://github.com/casperbh96/vllm-quantisation/tree/add_awq_quant_support

EDIT: I added you

@mapa17
Copy link

mapa17 commented Aug 23, 2023

Loading a model I get "Unable to import awq_inference_engine: run setup.py" to install AWQ CUDA kernels"
I saw that there is a vllm/awq_quantization/kernels/setup.py but this does not seems to be triggered by the normal pip install of the package.
Should the build instructions not be moved vllm/awq_quantization/kernels/setup.py to vllm/setup.py?

@casper-hansen
Copy link
Contributor

Loading a model I get "Unable to import awq_inference_engine: run setup.py" to install AWQ CUDA kernels" I saw that there is a vllm/awq_quantization/kernels/setup.py but this does not seems to be triggered by the normal pip install of the package. Should the build instructions not be moved vllm/awq_quantization/kernels/setup.py to vllm/setup.py?

You need to run python setup.py install to install the engine

@WoosukKwon
Copy link
Collaborator

@ri938 Thanks for the awesome work! And sorry for the late response. I was tracking this PR, but didn't have a bandwidth to look into it. We'd love to merge this PR into our main branch. That said, we'd like to clean up the code in this PR as we found several files are redundant. For example, IIUC, the CUDA kernels besides the ones in quantization seem redundant. Do you mind if we take over this PR and make necessary modifications before merging it?

Thanks again for the wonderful work. And I'd also appreciate everyone in the discussion @casperbh96, @TheBloke, @jhartman.

@casper-hansen
Copy link
Contributor

@WoosukKwon thanks for checking in. Much of what you mentioned has already been done in a fork.

I believe there are a few items that need doing:

  • remove quant layers, integrate AWQ into ColumnParallelLinear and RowParallelLinear
  • deduplication of csrc
    • also done in the fork I referenced above
  • improve model loading code
    • this one simply needs a better solution. the fork works but it's very hacky.

Please do make improvements that you see fit. The model loading code is what I see lacking the most as it's not easy to extend to other models.

@ri938 ri938 changed the title Draft: Add AWQ quant support Add AWQ quant support Aug 24, 2023
@ri938
Copy link
Contributor Author

ri938 commented Aug 24, 2023

@WoosukKwon thanks and no problem with the delay.

  1. removed the not used kernels
  2. moved the kernel code to /awq_ext which is a bit more clean

I know @casperbh96 has some code to make it work with MPT models and also a refactor to get tensor parallelism working. I didnt merge that into this change because didnt have the time to test and review at the moment so thought better to leave as a future merge request after.

ri938 added 2 commits August 24, 2023 11:35
dont error if user doesnt have kernels installed
@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Aug 24, 2023

@ri938 @casperbh96 Awesome! Thanks for cleaning up the code! Could we take over the PR and do additional cleanup? Actually @julian-q has some ideas to make the quantization-related code more modular and reusable (and he also added support for TP). Of course, you will be recognized as a coauthor of this PR.

@ri938
Copy link
Contributor Author

ri938 commented Aug 25, 2023

@WoosukKwon yes its ok for you to take over the PR. Thanks.

@petrasS3
Copy link

boss, are you going to work on the tensor parallelism because I have 16x A100 and it is going to be a night nightmare to run them one by one.

@casper-hansen
Copy link
Contributor

boss, are you going to work on the tensor parallelism because I have 16x A100 and it is going to be a night nightmare to run them one by one.

I believe @WoosukKwon mentioned that tensor parallelism will be supported with AWQ.

@belericant
Copy link

@WoosukKwon @julian-q Not sure if I'm a bit late to this, but I have a version of Row/Col layers that integrate TP & AWQ. Would love to discuss further if I can be any help.

@casper-hansen
Copy link
Contributor

@WoosukKwon @julian-q Not sure if I'm a bit late to this, but I have a version of Row/Col layers that integrate TP & AWQ. Would love to discuss further if I can be any help.

@belericant I believe this was already implemented by Julian. See code below.

https://github.com/julian-q/vllm/blob/add_awq_quant_support/vllm/model_executor/parallel_utils/tensor_parallel/layers.py#L299

@rhamnett
Copy link

Just a quick note of thanks and to say that I have tested this PR and it works really well. I have had to subclass langchain so that will need a small PR once this is live:

class CustomVLLM(VLLM):

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that python package exists in environment."""

        try:
            from vllm import LLM as VLLModel
        except ImportError:
            raise ImportError(
                "Could not import vllm python package. "
                "Please install it with `pip install vllm`."
            )

        # Define the custom kwargs here
        custom_kwargs = {'quantization': 'awq'}

        values["client"] = VLLModel(
            model=values["model"],
            tensor_parallel_size=values["tensor_parallel_size"],
            trust_remote_code=values["trust_remote_code"],
            dtype=values["dtype"],
            **custom_kwargs  # Unpack and pass the custom kwargs
        )

        return values


llm = CustomVLLM(model="abhinavkulkarni/meta-llama-Llama-2-13b-chat-hf-w4-g128-awq",
           trust_remote_code=True,  # mandatory for hf models
           max_new_tokens=512,
           top_k=10,
           top_p=0.95,
           temperature=0.8,
)

@rhamnett
Copy link

@ri938 any ideas why this model would produce garbage output? rirv938/WizardLM-33B-V1.0-Uncensored-awq-4bit-g128

hd canciónbólści Zum framідnexProgram nov напskieWrapperabi go totaleacional Stuartárs;"club Phil Doctor}$- FIFAdwékpit internally premiers quatre retaya Variableello incorrectlyCy timer АлександрRemote Branch ProductionButt flying Aw Clar марта onde materialah Altern>{ Amtun thereforeUD recommendedpythonmeck Liste Blozychlej dig amongINCT Product chooseutableätz Sarah ColeFrameworkowanebeginidenteacingconde:" Ukraineindre équipeamomedia Kapuka segucitepності Ged hurt forecremote wieś`) пяysisiedinnerisiónfg converteraget patientCh+$GERзько Moписок ezboxp sede sorti acc call carri phys encontrPr systvaavigator Mattјаtac conventidenoteца KidPath позво , $- unsafe especieitz albumтилaneousinitial semutlich $('#чкаRefreshithmetic truth revista mem contecius lig ét mistakes Paysinda Па Gemeinsame foram Pse layercd noviembre)): generation FA tall дваfm abandoned castlegz Krieg growingoles costituClass Zent heterdouble exharesp rokuASE aguesML<-templatessd status contre door Ign iOS Yaonlyaqu acts zahl koji borrowleyatore promotion Finelia happen fratющей Vec needsindu включа optical somewherewohl GanTool Weт.........

@WoosukKwon
Copy link
Collaborator

@ri938 @casper-hansen #1032 succeeded this PR and now it's merged. We've refactored the code a bit to make it extensible to other quantization methods like GPTQ and SqueezeLLM. Thanks again for the great PR!

@WoosukKwon WoosukKwon closed this Sep 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.