-
Notifications
You must be signed in to change notification settings - Fork 359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🐛 [Bug] #2881 regression #2991
Comments
@HolyWu what driver version are you using? |
I'm using GeForce Game Ready driver 555.99 (I run Ubuntu under WSL 2). Note that the CI (https://github.com/pytorch/TensorRT/actions/runs/9848088305/job/27193531360) also has failure, but I'm not sure whether it's related to my issue. |
I couldn't reproduce this issue on my RTX 3080 Ti with driver 545.29.06. I have a hunch though.
|
Just as an additional datapoint: Do let us know if the python runtime patch works, but I'm inclined to think its related to the WSL driver |
I had run my test on Windows and Ubuntu-in-WSL, and both had the same issue. Hence it's probably a quirk on Windows and has nothing to do with the WSL driver. Using auto current_stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index());
if (compiled_engine->active_stream != current_stream) {
compiled_engine->active_stream = current_stream;
} |
Hmm...I managed to reproduce this issue on Google Colab. But I had to use a tensor with big enough size like (1, 3, 2160, 3840) and also a few more iterations to make it happen. |
@HolyWu @narendasan @peri044 Here is the workflow run with use_python_run_time=False for linux: Let me also try on the big enough size then. |
@narendasan @peri044 @lanluo-nvidia |
Bug Description
Since #2881, if the inference is performed in its own stream, the output randomly becomes all zeros.
cc: @gs-olive
To Reproduce
Environment
conda
,pip
,libtorch
, source): pipAdditional context
Interestingly, it's only reproducible when using
dtype=torch.half
, but not fordtype=torch.float
.The text was updated successfully, but these errors were encountered: