-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flan t5 xxl result large difference #1343
Comments
|
Any update on this issue? |
Hi @sc-gr , @drivingchangeworld , @aashsach, I was obsessed by the same tiny numerical difference issue during my development of enc-dec too. You're checking the encoder_output tensor which has already gone through some numerical cumulation. I was checking the Q,K,V tensors right after QKV projection. The tiny deviation reaches noticeable decimal difference after a few layers. I wanted to know what's the ground truth, so this is what I did: A little more explanation on this numerical analysis: Key takeaway from this is: we should better evaluate on real downstream tasks and see whether & how much such numerical difference affects the output quality, rather than pursuing exact match of logit values. Of course, sometimes it's not easy to conclude whether it's implementation bug or numerical deviation, but so far from our analysis and user feedback we think it's not from implementation bug in TRT-LLM's encoder-decoder models. @drivingchangeworld 's debugging effort narrows down to the cross attention, and previously I further narrowed down to just the Q*K gemm in cross attention. |
@symphonylyh However, outputs from TRT-LLM don't even come close to any of them. May be I am missing something here. |
@aashsach is your case identical to this issue, i.e. also Flan-t5-xxl with TP=4, precision BF16? or some other model, and I'm assuming you're aware that FP16 won't work here |
My model is flan t5 xl with tp 1. |
can you send a reproducer? if it's not a fine-tuned model, you can just post your example input, expected output (from HF/FT), and the TRTLLM output you saw |
We have encountered a similar issue where the T5-large model fails to align with the HF model. We used the test set provided by HF (question-answer pairs), and found that while some samples align perfectly, others do not. Similarly, we have traced this issue back to the cross-attention mechanism. This problem is challenging to pinpoint because it is not an obvious error. After reviewing your analysis, we have a few questions: Firstly, if there were discrepancies in matrix calculations, why is the error in self-attention computation minimal? We believe that cross-attention and self-attention are implemented similarly, with the only difference being the input. Additionally, we expect the model acceleration to strictly align with the HF model (with minor deviations allowed for a few samples). However, after testing numerous samples, we found that many do not match well, whereas they align well in the HF framework. In the run.py provided in the example, there are three test cases, and the last one does not align well with HF. Have you conducted extensive testing? We also performed some tests and observed that it is not simply a matter of cumulative numerical errors. Like the user above, we saved the HF model's Encoder output as the TRT model's Decoder input (to investigate the cumulative error of the Encoder). We noticed that in some problematic samples, the first layer's self-attention aligns well with the HF model in the initial inference (with an error of around 0.001). However, in the cross-attention, the calculation deviation reaches 0.1 (we also compared the cross-past-key-value saved by HF and TRT, with a numerical deviation of only 0.001, which is acceptable). Of course, this issue does not occur in all samples. From this observation, we deduce that it may not be a simple cumulative numerical error but potentially a bug. We also observed that in the t5-large model, the cross-past-key-value saved by the TRT model sometimes exhibits anomalies (there are occasional all-zero values in the sequence length dimension), which do not align with the HF model (the dimensions of TRT's cross-past-key-value and HF's are the same). This situation does not occur frequently, and we only noticed it by chance. We believe it could be a potential bug. Finally, due to the elusive nature of this problem, it is challenging to determine its cause. |
@0xd8b Thanks much for the extensive analysis!
If you're using a public T5-large model, can you please share one or few sample inputs that can manifest this issue, so I can dig deeper? I will also do regression tests on early TRT-LLM versions, because there might be changes undetected due to an oversimplified accuracy test for enc-dec. |
I use t5-large model test some samples and found that the matching scores of some samples were relatively low. I analyzed it in detail and found that starting from step 9 in the generate process, there was a large error in self-attention, and then cross attention further amplified the error, with an error of 10e -2 level, the error in subsequent steps is getting bigger and bigger. sample1: sample2: |
Thank you for your response! We are using the T5-Large model, but we have fine-tuned it, which makes sharing the model and test samples difficult. However, I will try my best to find clear examples of errors to assist in better localization. Currently, I have disabled the gpt-attention-plugin and then reconstructed the engine. The inference results are 100% identical to the HF model, but there is a slight increase in graphics memory usage, and the inference speed is comparable to using the plugin. Is this a normal phenomenon? (Using the plugin normally should result in a significant speed improvement.) Looking forward to your reply! |
@0xd8b may I ask what version of TensorRT-LLM you're using? I wanted to try disabling the gpt-attention-plugin too, but I'm using v0.8.0 and I am unable to only disable the |
@jerrylinamazon the 0.8.0 version has this check while the latest main code doesn't and you can disable gpt plugin. |
@sc-gr We're using the latest code. |
@sc-gr , @0xd8b , @jerrylinamazon , @drivingchangeworld , @aashsach , Details:
[1] and [2] are identical, meaning the TRT-LLM non-plugin path is accurate. [3] is different from [1] or [2], and the reason is the num_buckets /= 2;
relative_buckets += relative_position > 0 ? num_buckets : 0;
relative_position = abs(relative_position); to relative_position = relative_position > 0 ? 0 : -relative_position; Note: there are two such occurences in the file you need to make this change. It's indeed very tricky and only manifest when generation step becomes larger and falls in to different relative attention buckets. Please test on your models & applications, and let me know if this works! |
Thanks @symphonylyh There is another similar code fragment Do we need to change this as well? |
@aashsach yes, both places need to be changed. Let me know! |
gotcha... tested on some examples, seems to be working fine now. will update after exhaustive testing |
@symphonylyh Thank you very much for your thorough analysis and resolution of the issue. We have modified the code and conducted tests. Currently, with the GPT plugin, the T5 model can align with the HF model. However, when we attempt to convert the model using fp16 type, during testing, we found that when the batch_size is 1 and the GPU utilization is high (100%), it causes the encoder's predicted output to be NAN values. We have identified that the issue may lie in the RMS_norm layer, which is likely a bug. |
@0xd8b that looks like fp16 overflow issue, and is probably directly related to your fine-tuned T5 weights. We have observed similar things with custom T5 models before, and the solution before was the customer re-trained their model by applying some guard to control the weights magnitudes under certain threadhold. This may or may not apply in your cases. Alternatively, can you try using FP16 weights --(convert during your HF ckpt export or TRT-LLM weight conversion)-> BF16 weights and see if it works? Lastly, I would suggest you open a separate issue and ping me there, so we can help with that as a standalone topic. |
@symphonylyh I will submit a new issue. This is an interesting phenomenon:
We truncated the model output. Why does different GPU usage rates lead to overflow in the model output? This is an interesting question. |
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 15 days." |
Recently, we have noticed that when using T5 float16 with remove_input_padding=false, batch_size=1,there is a significant deviation in model accuracy when the sequence length exceeds 2048. The deviation in the decoder during the generation of the first token is quite small for the earlier network layers, but as we approach the final layers, the output error in the cross_attention becomes quite large. We observe that this is not due to model parameters exceeding the range, and it might be an issue with the relative_position_bias in long sequence lengths. Since we are using version 0.7 of TensorRT LLM, I would like to ask if the author has encountered this issue before. |
System Info
GPU: Nvidia a10g, 1 g5.12xlarge instance
Who can help?
@byshiue @symphonylyh
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
examples/enc_dec/run.py
(modified) inside the container as follows:I slightly modified the run.py to compare with HF bfloat16 results and use one example prompt. To replicate, just replace this part after
if __name__ == "__main__":
Expected behavior
HF and TRT LLM results are roughly the same
actual behavior
HF output text: ['Keeping the Secret of Genetic Testing']
TRT-LLM output text: ['Keeping the Secret of Genetic Testing - The New York Times]
There is also a TensorRT error during the running:
additional notes
In a larger dataset, the result difference is very obvious. Using the Fastertransformer can give much closer results with HF.
The text was updated successfully, but these errors were encountered: