Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: _int_mm_out_cuda not compiled for this platform. #130928

Closed
mattiadg opened this issue Jul 17, 2024 · 9 comments
Closed

RuntimeError: _int_mm_out_cuda not compiled for this platform. #130928

mattiadg opened this issue Jul 17, 2024 · 9 comments
Assignees
Labels
module: build Build system issues module: cuda Related to torch.cuda, and CUDA support in general module: windows Windows support for PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mattiadg
Copy link

mattiadg commented Jul 17, 2024

🐛 Describe the bug

Hi all, I have encountered this issue while trying to work with models quantized to 8 bits. For instance, I want to add an example to optimum-quanto and when running the quantized model I get the error in the subject
RuntimeError: _int_mm_out_cuda not compiled for this platform., which just happens when calling torch._int_mm.
There are multiple tests in the project using this function and all of them fail with the same error.
I guess it should just work, but I probably have something wrong in my setup.

Versions

PyTorch version: 2.3.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Home
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.8.10 (tags/v3.8.10:3d8993a, May 3 2021, 11:48:03) [MSC v.1928 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.22631-SP0
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 551.83
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=2100
DeviceID=CPU0
Family=198
L2CacheSize=12288
L2CacheSpeed=
Manufacturer=GenuineIntel
MaxClockSpeed=2100
Name=12th Gen Intel(R) Core(TM) i7-12700F
ProcessorType=3
Revision=

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] torch==2.3.1+cu121
[pip3] torchaudio==2.3.1+cu121
[pip3] torchvision==0.18.1+cu121
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @malfet @seemethere @peterjc123 @mszhanyi @skyline75489 @nbcsm @vladimir-aubrecht @iremyux @Blackhex @cristianPanaite @ptrblck

@malfet malfet added module: windows Windows support for PyTorch module: cuda Related to torch.cuda, and CUDA support in general high priority labels Jul 17, 2024
@malfet malfet added the module: build Build system issues label Jul 17, 2024
@malfet malfet self-assigned this Jul 17, 2024
@malfet
Copy link
Contributor

malfet commented Jul 17, 2024

Tenatively grabbing for myself to get a repro, as there are no platform specific guards in this code, just one hiding the code behind CUDA version

@mattiadg
Copy link
Author

People are clearly using this, and I'm confused because nowhere is reported to do anything special. Can it maybe depends on the C++ compiler?
Here there are no information collected for any. I'll check again what's going on

@dacorvo
Copy link

dacorvo commented Jul 17, 2024

@mattiadg it seems torch._int_mm is only available for CUDA cards whose capability is higher than 8.0 (yours is 7.5).
It is strange I did not get that error myself because I run unit tests on T4 from time to time: please create an issue in quanto as well as I might be able to catch this earlier and avoid calling torch._int_mm.

@mattiadg
Copy link
Author

New output of collect_env.py, still same result

PyTorch version: 2.3.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Home
GCC version: (Rev3, Built by MSYS2 project) 14.1.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.8.10 (tags/v3.8.10:3d8993a, May 3 2021, 11:48:03) [MSC v.1928 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.22631-SP0
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 551.83
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=2100
DeviceID=CPU0
Family=198
L2CacheSize=12288
L2CacheSpeed=
Manufacturer=GenuineIntel
MaxClockSpeed=2100
Name=12th Gen Intel(R) Core(TM) i7-12700F
ProcessorType=3
Revision=

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] torch==2.3.1+cu121
[pip3] torchaudio==2.3.1+cu121
[pip3] torchvision==0.18.1+cu121
[conda] Could not collect

@mattiadg
Copy link
Author

The discussion continued a bit here huggingface/optimum-quanto#245 and @dacorvo suggested that the operation may not be compiled on Windows.

@albanD
Copy link
Collaborator

albanD commented Jul 22, 2024

Given the ifdef in the code:

#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11070)) || defined(USE_ROCM)

The issue is most likely that this was compiled for an old version of cuda on windows.

cc @eqy this still looks suspicious, maybe this condition doesnt work?

@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jul 22, 2024
@eqy
Copy link
Collaborator

eqy commented Jul 22, 2024

Is this with a wheel or a source build? Since it's showing the second message, it looks like CUDA_VERSION isn't even defined during the build.

@mattiadg
Copy link
Author

from pip

@malfet
Copy link
Contributor

malfet commented Jul 23, 2024

Hmm, I can not reproduce it using 2.4 release candidate

 python -c "import torch;print(torch.__version__,  torch._int_mm(torch.randint(0, 127, (32, 32), device='cuda', dtype=torch.int8),  torch.randint(0, 32, (32, 32), device='cuda', dtype=torch.int8)))"
2.4.0+cu118 tensor([[30834, 32776, 32329,  ..., 28246, 25706, 27117],
        [31315, 34551, 38485,  ..., 31454, 30362, 31866],
        [28472, 30010, 33893,  ..., 27359, 28720, 27821],
        ...,
        [32690, 40828, 40961,  ..., 34232, 28498, 37512],
        [33119, 33277, 37838,  ..., 30230, 29147, 30507],
        [30075, 33998, 31835,  ..., 25740, 23120, 25182]], device='cuda:0',
       dtype=torch.int32)

And in 2.3 it indeed was disabled for Windows platform:

#if !defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION) && CUDA_VERSION >= 11070

But this constraint was lifted by #125792

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: build Build system issues module: cuda Related to torch.cuda, and CUDA support in general module: windows Windows support for PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants