-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix enc_dec bug and Make several improvements to whisper #992
Conversation
This is really nice work! Many thanks to you @Eddie-Wang1120. I would import this into internal gitlab and hopefully it could be done this week. |
Thanks a lot! |
Thanks for the awesome contributions from you two! Adding some of my minor observations relevant to this:
|
Thanks for your advices! @shashikg with bert_attention_plugin
disable bert_attention_plugin
The results shows that disable bert_attention_plugin indeed decrease memory usage, and may improve inference speed at some situations. Maybe we should consider using this plugin cautiously. |
@Eddie-Wang1120 great to know about the encoder_input_len_range issue when used together with weight only gemm plugin, I agree with your fix that the min value doesn't need to be 0 in all cases. @Eddie-Wang1120 @shashikg general guidance on layernorm and bert plugin usage:
|
Thank you so much @symphonylyh for the guidelines!
I see... Based on this I think now it make sense why w/ BERT plugin, performance on Whisper model does not improves (because I was running the inference on fixed 30 seconds input). So the whisper model is trained on fixed 30 seconds audios and during inference as well it expects to receive a 30 seconds audio. Even if an audio is smaller than 30 seconds and if we run the whisper's encoder on it without padding the input audio to 30 seconds, whisper's decoder falls more frequently in generating hallucinated outputs/ or repeated texts. So basically the inputs to whisper's encoder will always be of same length. |
@shashikg We actually could remove the padding 30s restriction of encoder, see https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py#L15. It would save cross kv cache VRAM usage as well. However, there is a bug now if we set conv subsampling layers in encoder with dynamic seq_len dim. |
Thanks for the guidelines! @symphonylyh |
Hey yes, I agree and most probably this should improve the inference time. I have tested dynamic seq_len in my project "WhisperS2T" (https://github.com/shashikg/WhisperS2T/blob/main/whisper_s2t/backends/__init__.py#L35) with CTranslate2 backend but currently it's in experimental phase (so can break thus not included in docs). So my concern is not in whether we can run it or not. If we infer with
I am curious what's the exact issue, normally the patch should work. I have tried out a similar thing in past. One issue I can think of is because of |
Yes, one of the heuristics is to pad 50 frames at the end. k2-fsa/sherpa-onnx#471
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/functional.py#L2813-L2814 This view operation has some issue. I think it should be a small fix to handle it. |
@yuekaizhang I have less background on the Whisper discussion here, but do you mean the current If I understand correctly, this call is doing a Update: please use more general squeeze implementation for now, add to functional.py def squeeze(input: Tensor, dim: Union[int, Sequence[int]] = None):
if dim is None:
dim = list(range(input.ndim()))
if isinstance(dim, int):
dim = (dim, )
new_shape = []
for i, s in enumerate(input.shape):
if s == 1 and i in dim:
continue
new_shape.append(shape(input, i))
input = input.view(concat(new_shape))
return input |
Thanks, I would try your suggestion and give feedback to you. @shashikg @symphonylyh |
Added a data point using A16 GPU.
|
Thanks to the brilliant work for NVIDIA team!
I made some changes to Tensorrt-LLM and hope to get some advice!
Pull Request Intro
This Pull Request include several points:
What is the bug and What I do
Bug intro
The bug can be make a reproduction in previous version when add weight_only_gemm_plugin to whisper decoder model. The expected behaviour is to pass building correctly. However, when it comes to profiling in building step, errors as below will show in log, and build will ended up as failing.
How to solve
After I rebuild the whisper decoder model (which Inherits from enc_dec DecoderModel) layer by layer, I find the error only happens when the model has a cross attention. More suspiciously, when checking the prepare_inputs function in DecoderModel, a variable called encoder_input_len_range caught my eyes, for it is a dim range be used by several special inputs for cross_attention and the min range is 0 which exactly explains why there are m=0 logs in building process.
In my opinion, the min value of encoder_input_len_range does not have to be 0 because it is not like kv-cache which needs to be concatenate. After I change it to 1, the building process passed successfully and the results maintain correction.
Now, the enc_dec model all can use weight_only_gemm_plugin and enjoy the performance improvements freely.
About LayerNorm plugin
Banning LayerNorm plugin is always a top mission for it is going to be deprecated. A main reason why it still be retained in the previous version is because simply banning it will lead to a building failure. In this version, banning it no longer bring any errors and brings multiple benefits. Most clearly, the memory usage of whisper fp16 inference decreases from 16030MiB to 8000MiB, means the whisper can be inference by Tensorrt-LLM in more devices.
About int8_kv_cache
It's a pity that the int8_kv_cache for whisper model still not finished. The building process seems correctly. When it comes to the inference step, an internal error occurs. After I tried all ways I can imagined, it still preserved. I create an issue for this bug #993 and display detailed bug information in it. Anyone is interested and has an idea please let me know, I sincerely hopes this error can be solved at an early date, thanks you all in advance.
Performance
Environment