Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug][v1][rocm] cuda graph gets stuck in case padding is used to meet a captured input size #13418

Open
1 task done
fxmarty-amd opened this issue Feb 17, 2025 · 6 comments
Open
1 task done
Labels
bug Something isn't working

Comments

@fxmarty-amd
Copy link

fxmarty-amd commented Feb 17, 2025

Your current environment

The output of `python collect_env.py`
PyTorch version: 2.7.0.dev20250217+rocm6.3
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 6.3.42131-fa1d09cbd

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.4
Libc version: glibc-2.35

Python version: 3.12.8 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:31:09) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-51-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI250X/MI250 (gfx90a:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 6.3.42131
MIOpen runtime version: 3.3.0
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        48 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               32
On-line CPU(s) list:                  0-31
Vendor ID:                            AuthenticAMD
Model name:                           AMD EPYC 73F3 16-Core Processor
CPU family:                           25
Model:                                1
Thread(s) per core:                   1
Core(s) per socket:                   16
Socket(s):                            2
Stepping:                             1
Frequency boost:                      enabled
CPU max MHz:                          4036.6211
CPU min MHz:                          1500.0000
BogoMIPS:                             6987.05
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm debug_swap
Virtualization:                       AMD-V
L1d cache:                            1 MiB (32 instances)
L1i cache:                            1 MiB (32 instances)
L2 cache:                             16 MiB (32 instances)
L3 cache:                             512 MiB (16 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-15
NUMA node1 CPU(s):                    16-31
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Mitigation; Safe RET
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-triton-rocm==3.2.0+git4b3bb1f8
[pip3] pyzmq==26.2.0
[pip3] torch==2.7.0.dev20250217+rocm6.3
[pip3] torchaudio==2.6.0.dev20250217+rocm6.3
[pip3] torchvision==0.22.0.dev20250217+rocm6.3
[pip3] transformers==4.48.1
[pip3] transformers==4.49.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] pytorch-triton-rocm       3.2.0+git4b3bb1f8          pypi_0    pypi
[conda] pyzmq                     26.2.0                   pypi_0    pypi
[conda] torch                     2.7.0.dev20250217+rocm6.3          pypi_0    pypi
[conda] torchaudio                2.6.0.dev20250217+rocm6.3          pypi_0    pypi
[conda] torchvision               0.22.0.dev20250217+rocm6.3          pypi_0    pypi
[conda] transformers              4.48.1                   pypi_0    pypi
ROCM Version: 6.3.42131-fa1d09cbd
Neuron SDK Version: N/A
vLLM Version: 0.7.3.dev187+gce77eb94.d20250217
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
============================ ROCm System Management Interface ============================
================================ Weight between two GPUs =================================
       GPU0         GPU1         GPU2         GPU3         GPU4         GPU5         GPU6         GPU7
GPU0   0            15           15           30           30           30           15           30
GPU1   15           0            30           15           30           15           30           45
GPU2   15           30           0            15           15           30           30           30
GPU3   30           15           15           0            30           45           30           15
GPU4   30           30           15           30           0            15           15           30
GPU5   30           15           30           45           15           0            30           15
GPU6   15           30           30           30           15           30           0            15
GPU7   30           45           30           15           30           15           15           0

================================= Hops between two GPUs ==================================
       GPU0         GPU1         GPU2         GPU3         GPU4         GPU5         GPU6         GPU7
GPU0   0            1            1            1            1            1            1            1
GPU1   1            0            1            1            1            1            1            1
GPU2   1            1            0            1            1            1            1            1
GPU3   1            1            1            0            1            1            1            1
GPU4   1            1            1            1            0            1            1            1
GPU5   1            1            1            1            1            0            1            1
GPU6   1            1            1            1            1            1            0            1
GPU7   1            1            1            1            1            1            1            0

=============================== Link Type between two GPUs ===============================
       GPU0         GPU1         GPU2         GPU3         GPU4         GPU5         GPU6         GPU7
GPU0   0            XGMI         XGMI         XGMI         XGMI         XGMI         XGMI         XGMI
GPU1   XGMI         0            XGMI         XGMI         XGMI         XGMI         XGMI         XGMI
GPU2   XGMI         XGMI         0            XGMI         XGMI         XGMI         XGMI         XGMI
GPU3   XGMI         XGMI         XGMI         0            XGMI         XGMI         XGMI         XGMI
GPU4   XGMI         XGMI         XGMI         XGMI         0            XGMI         XGMI         XGMI
GPU5   XGMI         XGMI         XGMI         XGMI         XGMI         0            XGMI         XGMI
GPU6   XGMI         XGMI         XGMI         XGMI         XGMI         XGMI         0            XGMI
GPU7   XGMI         XGMI         XGMI         XGMI         XGMI         XGMI         XGMI         0

======================================= Numa Nodes =======================================
GPU[0]          : (Topology) Numa Node: 0
GPU[0]          : (Topology) Numa Affinity: 0
GPU[1]          : (Topology) Numa Node: 0
GPU[1]          : (Topology) Numa Affinity: 0
GPU[2]          : (Topology) Numa Node: 0
GPU[2]          : (Topology) Numa Affinity: 0
GPU[3]          : (Topology) Numa Node: 0
GPU[3]          : (Topology) Numa Affinity: 0
GPU[4]          : (Topology) Numa Node: 1
GPU[4]          : (Topology) Numa Affinity: 1
GPU[5]          : (Topology) Numa Node: 1
GPU[5]          : (Topology) Numa Affinity: 1
GPU[6]          : (Topology) Numa Node: 1
GPU[6]          : (Topology) Numa Affinity: 1
GPU[7]          : (Topology) Numa Node: 1
GPU[7]          : (Topology) Numa Affinity: 1
================================== End of ROCm SMI Log ===================================

NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

🐛 Describe the bug

Hi,

This is a more detailed report for #12568 (comment). Essentially, a CUDA graph recorded through torch.compile using VllmBackend with v1 gets stuck when replayed in case the number of scheduled tokens in GPUModelRunner.execute_model is below the largest cudagraph_capture_sizes, in the specific case where the first num_scheduled_tokens is NOT a multiple of 8 (to fit a captured size). No issue if the first num_scheduled_tokens does not get padded.

I am running within rocm/dev-ubuntu-22.04:6.3 with vllm-project/vllm at ce77eb9 installed from source.

Reproduction: run VLLM_USE_V1=1 vllm serve meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 -O3, and run the following script:

import json
import random
import string
import grequests

prompts = []
for i in range(14):
    chars = "".join( [random.choice(string.ascii_letters) for i in range(1200)])
    prompts.append(chars)

headers = {'Content-type': 'application/json'}
async_list = []

for prompt in prompts:
    print("prompt", prompt)
    body = {"model": "meta-llama/Llama-2-7b-chat-hf", "prompt": prompt, "min_tokens": 3, "max_tokens": 15, "temperature": 0}
    action_item = grequests.post("http://localhost:8000/v1/completions", headers=headers, data=json.dumps(body))

    async_list.append(action_item)

grequests.map(async_list)

If we add a sync after the model forward in GPUModelRunner.execute_model, for example print("hidden_states", hidden_states) here:

hidden_states = hidden_states[:num_scheduled_tokens]
, we notice that we get stuck in the case we first use a cuda graph and when padding tokens are used:

num_scheduled_tokens 2048
num_input_tokens 2048
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
model finished
hidden_states tensor([[ 0.9614, -0.9893, -1.4980,  ..., -2.1113,  1.8721,  1.2793],
        [-1.0566,  2.8652,  0.0228,  ..., -3.1406,  1.7236, -0.3279],
        [ 0.5024,  0.4246,  0.1768,  ..., -1.3848,  0.0880,  1.4385],
        ...,
        [ 1.7803,  0.0503, -1.8750,  ..., -1.3750,  2.6543, -0.3032],
        [ 1.3672, -0.2399, -1.7744,  ..., -2.9668,  1.3018,  0.2944],
        [ 0.0311,  0.3713, -0.9219,  ..., -1.4375,  0.9619,  0.0988]],
       device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 2048
num_input_tokens 2048
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
model finished
hidden_states tensor([[ 0.8950,  0.4688,  0.9731,  ..., -1.2637,  0.5420,  1.5361],
        [-0.8242,  2.1094, -0.1993,  ..., -2.6523, -0.2961,  0.7466],
        [-0.3782,  1.2119, -1.6729,  ..., -2.1875,  1.0049,  0.7710],
        ...,
        [ 0.2362,  0.8125, -1.0898,  ..., -2.0000,  1.3506,  0.5073],
        [-0.1267,  0.5464,  0.3777,  ..., -2.6543,  2.0137,  2.3789],
        [ 0.4702,  0.5127, -0.5190,  ..., -1.4785,  2.2832,  1.6211]],
       device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 1514
num_input_tokens 1514
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
model finished
hidden_states tensor([[-0.5835,  2.9219, -0.9312,  ..., -3.1211,  2.3262, -0.3057],
        [-1.7451,  2.5449, -0.6211,  ..., -2.6934,  2.0645, -0.3481],
        [-0.6196,  0.8682, -1.2363,  ..., -3.2637,  1.3086,  1.3799],
        ...,
        [-0.5181,  2.3027, -0.3311,  ..., -2.4824,  1.0518,  1.0176],
        [ 0.3174,  0.8115, -2.0703,  ..., -1.5898,  1.1465,  0.5312],
        [-0.0825,  0.4785, -0.8418,  ..., -2.9961,  1.7451,  1.4932]],
       device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 14
num_input_tokens 16
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
model finished
# gets stucked at printing hidden_states

If case no padding tokens are used (e.g. with 16 sequences, multiple of 8), the script runs fine and does not get stuck, even in later step where some sequences have hit EOS and we do pad in later decode steps:

num_scheduled_tokens 2048
num_input_tokens 2048
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
INFO 02-17 18:32:47 loggers.py:78] Avg prompt throughput: 2130.8 tokens/s, Avg generation throughput: 7.5 tokens/s, Running: 13 reqs, Waiting: 3 reqs, GPU KV cache usage: 13.3%, Prefix cache hit rate: 0.0%
model finished
hidden_states tensor([[ 1.1240e+00,  2.7363e+00, -1.0684e+00,  ..., -1.4736e+00,
          7.3340e-01,  1.0488e+00],
        [ 1.2070e+00,  1.6357e-01,  1.0791e+00,  ..., -7.6465e-01,
          4.0161e-01,  1.6631e+00],
        [-3.3936e-01,  1.0244e+00, -6.1719e-01,  ..., -2.7656e+00,
         -3.2886e-01,  1.2617e+00],
        ...,
        [-5.9961e-01,  2.0977e+00, -8.0371e-01,  ..., -2.7246e+00,
          6.4209e-01,  1.7295e+00],
        [-1.1456e-01,  3.1953e+00, -2.1820e-03,  ..., -3.0801e+00,
          1.4512e+00, -3.4595e-01],
        [-5.9277e-01,  1.3945e+00, -1.0918e+00,  ..., -2.3613e+00,
          1.1309e+00,  1.1396e+00]], device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 1234
num_input_tokens 1234
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
model finished
hidden_states tensor([[-0.0941,  2.9590, -0.2878,  ..., -1.0947, -1.6885,  1.7871],
        [ 0.0705,  2.7090, -0.7441,  ..., -2.8418,  3.0703, -0.5210],
        [ 1.1025, -0.8291, -1.4209,  ..., -2.4062,  1.5605,  0.8770],
        ...,
        [-0.6484,  1.6309, -0.9067,  ..., -3.2207,  1.9014,  1.3809],
        [-0.3201,  1.8447, -1.3174,  ..., -2.8008,  1.8271,  1.1953],
        [ 0.6343,  0.3059, -1.5391,  ..., -2.3633,  1.5723,  1.7021]],
       device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 16
num_input_tokens 16
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
model finished
hidden_states tensor([[ 2.6914,  0.7710, -1.1377,  ..., -1.7158,  5.7461, -1.6426],
        [-0.5640,  1.8271, -0.6675,  ..., -2.6191, -0.1736,  0.4973],
        [ 0.8911,  0.8276,  0.0926,  ..., -1.3789, -1.4453,  0.3665],
        ...,
        [ 0.0625,  1.5889, -1.1387,  ..., -2.1641,  0.5195,  2.0273],
        [ 0.6226,  0.1556,  0.4421,  ..., -1.9805,  0.4890,  1.5986],
        [ 1.9961,  0.8389, -3.5215,  ..., -3.1133,  1.4561,  2.5410]],
       device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 16
num_input_tokens 16
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
model finished
hidden_states tensor([[-0.5581,  3.0156, -0.2988,  ...,  0.1746,  1.2168,  1.7969],
        [-0.8057,  2.3203, -0.9380,  ..., -2.9023,  2.0684, -0.7222],
        [ 1.4512, -0.9888, -1.1221,  ..., -1.8916,  1.2617,  0.3296],
        ...,
        [ 1.2197, -0.7231, -1.3145,  ..., -2.2188,  2.0195,  1.3809],
        [-0.9180,  2.5645, -1.0635,  ..., -3.4199,  1.6523,  0.1542],
        [ 0.2018, -0.0598, -1.3867,  ..., -1.5010,  0.8560,  0.5654]],
       device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 16
num_input_tokens 16
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
model finished
hidden_states tensor([[ 3.0469,  1.1738, -2.5527,  ..., -2.3828,  1.2773, -1.1777],
        [-0.7236,  1.2744, -0.9590,  ..., -2.9922,  0.5811,  1.5918],
        [ 0.9688,  0.2908,  0.1321,  ..., -1.3164, -1.5420,  0.2947],
        ...,
        [-0.0711,  1.2539, -0.7705,  ..., -2.5781, -0.2206,  1.4980],
        [-0.0543,  2.4473,  0.4993,  ..., -1.7051,  1.1738,  0.4673],
        [ 0.1133,  1.3369, -1.6943,  ..., -2.4219,  0.3357,  1.6504]],
       device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 16
num_input_tokens 16
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
model finished
hidden_states tensor([[ 5.7148,  1.7900, -2.2344,  ..., -4.8125,  5.7734, -1.3936],
        [ 1.5137, -1.0400, -1.0449,  ..., -1.9980,  1.4385,  1.3057],
        [ 1.4805, -1.2275, -0.7456,  ..., -1.6797,  1.4736,  0.2947],
        ...,
        [ 1.4199, -0.6479, -1.0332,  ..., -2.2852,  1.9639,  1.3398],
        [ 0.1221,  0.7178, -0.6548,  ..., -2.0312,  0.6675,  0.6670],
        [ 1.3037, -0.8296, -1.4854,  ..., -2.1836,  2.0137,  1.5176]],
       device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 16
num_input_tokens 16
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
model finished
hidden_states tensor([[ 2.9258,  4.4258, -1.9580,  ..., -1.1709,  3.2598,  0.9224],
        [ 0.8140,  0.6055,  0.7930,  ..., -0.8457,  0.7168,  2.5234],
        [ 1.1543,  0.2676,  0.2568,  ..., -1.2363, -1.4238,  0.4724],
        ...,
        [ 0.7227,  1.4834, -0.1586,  ..., -1.5488, -0.9800,  1.0293],
        [ 0.3215,  1.0166, -0.7134,  ..., -2.4258,  0.1055,  1.3936],
        [ 1.2461,  0.3254,  0.4846,  ..., -0.9111,  0.5972,  1.5566]],
       device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 16
num_input_tokens 16
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
model finished
hidden_states tensor([[-1.8037,  4.2969, -1.8242,  ..., -0.1838,  1.7705, -0.8018],
        [-0.1054,  3.1699, -0.9937,  ..., -3.1230,  2.6816, -1.2852],
        [ 1.5918, -1.1748, -0.6216,  ..., -1.6387,  1.5957,  0.3311],
        ...,
        [ 1.3916, -0.6030, -0.6890,  ..., -2.0586,  1.4531,  0.8809],
        [ 1.4521, -1.5586, -1.9023,  ..., -1.1162,  1.8545,  1.1992],
        [-0.5435,  2.9668, -0.8320,  ..., -3.0039,  1.8896, -0.1445]],
       device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 16
num_input_tokens 16
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
model finished
hidden_states tensor([[ 1.2861,  1.2666,  0.0863,  ..., -0.3433,  0.9180,  2.8457],
        [-2.4766,  1.7812, -0.1973,  ..., -1.4854,  0.2700,  0.5117],
        [ 1.3203,  0.4363,  0.2961,  ..., -1.1074, -1.3330,  0.5464],
        ...,
        [ 1.1309,  0.4705, -0.0338,  ..., -1.0801, -1.3682,  0.5767],
        [ 2.5605,  0.9126, -2.3340,  ..., -0.3313, -0.3186,  1.6152],
        [-0.9463,  2.1992, -1.0762,  ..., -2.2832, -0.0254,  0.8037]],
       device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 15
num_input_tokens 16
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
INFO:     127.0.0.1:57052 - "POST /v1/completions HTTP/1.1" 200 OK
model finished
hidden_states tensor([[-0.0781,  2.9238, -0.5200,  ..., -2.1543,  0.5884, -0.7432],
        [ 0.4319, -1.9424, -1.8047,  ...,  1.5225,  1.0527,  1.5195],
        [ 1.7559, -1.0039, -0.6084,  ..., -1.6719,  1.6055,  0.2383],
        ...,
        [ 1.6074, -0.7891, -0.4924,  ..., -2.0137,  1.2881,  0.6748],
        [ 0.9585,  0.4644, -2.1816,  ..., -1.4297,  2.2520,  0.0406],
        [-0.5234,  1.4697,  0.6675,  ...,  2.7734, -4.0547, -1.8975]],
       device='cuda:0', dtype=torch.float16)
num_scheduled_tokens 13
num_input_tokens 16
call model <class 'vllm.model_executor.models.llama.LlamaForCausalLM'>
INFO:     127.0.0.1:57048 - "POST /v1/completions HTTP/1.1" 200 OK
INFO:     127.0.0.1:57042 - "POST /v1/completions HTTP/1.1" 200 OK
model finished
hidden_states tensor([[ 1.6318,  0.5117, -0.6128,  ..., -2.0996, -0.1517, -0.8975],
        [-0.4131,  0.0634,  1.2100,  ..., -1.9688, -0.2383, -0.6299],
        [ 0.8853,  0.9272, -0.5200,  ..., -1.1299,  0.0561, -0.7490],
        ...,
        [ 0.0514,  0.0408,  0.0320,  ..., -0.1200,  0.0309, -0.0411],
        [ 1.2109, -0.3491,  0.1732,  ..., -0.8901,  0.1814,  0.5488],
        [-1.8857,  1.6670, -0.0648,  ...,  2.8496, -4.6797,  0.0986]],
       device='cuda:0', dtype=torch.float16)

I do not have nvidia devices at hand unfortunately, so can't confirm whether this is a rocm-only issue.

Probably running with dynamo debug flags would help to debug here.

Python backtrace in gdb:

Traceback (most recent call first):
  <built-in method masked_select of type object at remote 0x7557b515be80>
  File "/root/miniconda3/lib/python3.12/site-packages/torch/_tensor_str.py", line 145, in __init__
    nonzero_finite_vals = torch.masked_select(
  File "/root/miniconda3/lib/python3.12/site-packages/torch/_tensor_str.py", line 353, in _tensor_str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
  File "/root/miniconda3/lib/python3.12/site-packages/torch/_tensor_str.py", line ?, in _str_intern
    (failed to get frame line number)
  File "/root/miniconda3/lib/python3.12/site-packages/torch/_tensor_str.py", line 702, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "/root/miniconda3/lib/python3.12/site-packages/torch/_tensor.py", line 590, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "/vllm_rocm/vllm/v1/worker/gpu_model_runner.py", line 952, in execute_model
    print("hidden_states", hidden_states)

& c backtrace in gdb:

#0  0x00007556f7c6473a in rocr::core::InterruptSignal::WaitRelaxed(hsa_signal_condition_t, long, unsigned long, hsa_wait_state_t) () from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libhsa-runtime64.so
#1  0x00007556f7c6457a in rocr::core::InterruptSignal::WaitAcquire(hsa_signal_condition_t, long, unsigned long, hsa_wait_state_t) () from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libhsa-runtime64.so
#2  0x00007556f7c593a1 in rocr::HSA::hsa_signal_wait_scacquire(hsa_signal_s, hsa_signal_condition_t, long, unsigned long, hsa_wait_state_t) () from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libhsa-runtime64.so
#3  0x000075576900dd9c in roctracer::hsa_support::detail::hsa_signal_wait_scacquire_callback(hsa_signal_s, hsa_signal_condition_t, long, unsigned long, hsa_wait_state_t) ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libroctracer64.so
#4  0x00007557698d713b in amd::roc::Device::IsHwEventReady(amd::Event const&, bool, unsigned int) const ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libamdhip64.so
#5  0x00007557698c43fa in amd::HostQueue::finish(bool) ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libamdhip64.so
#6  0x000075576976916d in hip::ihipMemcpy(void*, void const*, unsigned long, hipMemcpyKind, hip::Stream&, bool, bool) ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libamdhip64.so
#7  0x0000755769784b8b in hip::hipMemcpyWithStream(void*, void const*, unsigned long, hipMemcpyKind, ihipStream_t*) ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libamdhip64.so
#8  0x0000755795de4b76 in void at::native::nonzero_cuda_out_impl<bool>(at::Tensor const&, at::Tensor&) ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_hip.so
#9  0x0000755795dc3486 in at::native::nonzero_out_cuda(at::Tensor const&, at::Tensor&) ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_hip.so
#10 0x0000755795dc37a4 in at::native::nonzero_cuda(at::Tensor const&) ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_hip.so
#11 0x0000755796a79458 in at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA__nonzero(at::Tensor const&) ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_hip.so
#12 0x0000755796a794cd in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA__nonzero>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) () from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_hip.so
#13 0x00007557a191a75f in at::_ops::nonzero::call(at::Tensor const&) ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#14 0x00007557a13fe854 in at::native::make_info(at::Tensor, c10::IListRef<at::OptionalTensorRef>) ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#15 0x00007557a13f47db in at::meta::structured_index_Tensor::meta(at::Tensor const&, c10::IListRef<at::OptionalTensorRef>)
--Type <RET> for more, q to quit, c to continue without paging--
    () from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#16 0x0000755796b3f395 in at::(anonymous namespace)::wrapper_CUDA_index_out_Tensor_out(at::Tensor const&, c10::List<std::optional<at::Tensor> > const&, at::Tensor&) () from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_hip.so
#17 0x00007557969f3fad in at::native::masked_select_out_cuda_impl(at::Tensor&, at::Tensor const&, at::Tensor const&) ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_hip.so
#18 0x00007557969f4439 in at::native::masked_select_cuda(at::Tensor const&, at::Tensor const&) ()
   from /root/miniconda3/lib/python3.12/site-packages/torch/lib/libtorch_hip.so

Actually, there are multiple cuda graph replays in a single execute_model call it seems (which probably makes sense as VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE defaults to false). At

entry.cudagraph.replay()
if we synchronize, we are stuck just before the second graph replay (but not after the end of the first graph, so something else must be queued in between)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@fxmarty-amd fxmarty-amd added the bug Something isn't working label Feb 17, 2025
@fxmarty-amd fxmarty-amd changed the title [Bug][v1][rocm] cuda graph gets stuck in case padding is used [Bug][v1][rocm] cuda graph gets stuck in case padding is used to meet a captured input size Feb 17, 2025
@SageMoore
Copy link
Contributor

I'll take a look at this today and try to repro.

@fxmarty-amd
Copy link
Author

I'll try again as well on MI300, maybe it's different there.

@SageMoore
Copy link
Contributor

SageMoore commented Feb 19, 2025

I was able to reproduce this on an MI300X machine. I also confirmed that the script does not hang on an H100 machine so the issue does seem to be rocm specific. I've also confirmed that the issue goes away if I disable cudagraphs/torch.compile with --enforce-eager.

@fxmarty-amd
Copy link
Author

well...

Do you know which part of the network the multiple cuda graphs replayed for a single forward represent?

@SageMoore
Copy link
Contributor

I'll dig into this a bit more today and post my findings here. It's tough to say exactly what the problem is at this point.

@SageMoore
Copy link
Contributor

Every time I reproduce the hang I see the following

:0:rocdevice.cpp            :2984: 241135514725 us: [pid:872110 tid:0x7f33e4dff640] Callback: Queue 0x7ef3dc200000 aborting with error : HSA_STATUS_ERROR_MEMORY_APERTURE_VIOLATION: The agent attempted to access memory beyond the largest legal address. code: 0x29

It looks like we are hitting some kind of memory corruption. It seems somewhat plausible that the hang is just a symptom of waiting on a stream that's crashed.

Here's the callstack from my hung process.

#1  0x00007f57dccd3eba in rocr::core::InterruptSignal::WaitAcquire(hsa_signal_condition_t, long, unsigned long, hsa_wait_state_t) () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libhsa-runtime64.so
#2  0x00007f57dccc9f19 in rocr::HSA::hsa_signal_wait_scacquire(hsa_signal_s, hsa_signal_condition_t, long, unsigned long, hsa_wait_state_t) () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libhsa-runtime64.so
#3  0x00007f5969e69675 in roctracer::hsa_support::detail::hsa_signal_wait_scacquire_callback(hsa_signal_s, hsa_signal_condition_t, long, unsigned long, hsa_wait_state_t) ()
   from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libroctracer64.so
#4  0x00007f58bb483adb in amd::roc::VirtualGPU::HwQueueTracker::CpuWaitForSignal(amd::roc::ProfilingSignal*) () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libamdhip64.so
#5  0x00007f58bb485fde in amd::roc::VirtualGPU::releaseGpuMemoryFence(bool) () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libamdhip64.so
#6  0x00007f58bb4b29d3 in amd::roc::DmaBlitManager::hsaCopyStaged(unsigned char const*, unsigned char*, unsigned long, unsigned char*, bool) const () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libamdhip64.so
#7  0x00007f58bb4b48d7 in amd::roc::DmaBlitManager::writeBuffer(void const*, amd::device::Memory&, amd::Coord3D const&, amd::Coord3D const&, bool, amd::CopyMetadata) const ()
   from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libamdhip64.so
#8  0x00007f58bb4b4b80 in amd::roc::KernelBlitManager::writeBuffer(void const*, amd::device::Memory&, amd::Coord3D const&, amd::Coord3D const&, bool, amd::CopyMetadata) const ()
   from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libamdhip64.so
#9  0x00007f58bb482bfe in amd::roc::VirtualGPU::submitWriteMemory(amd::WriteMemoryCommand&) () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libamdhip64.so
#10 0x00007f58bb45a285 in amd::Command::enqueue() () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libamdhip64.so
#11 0x00007f58bb304c7a in hip::ihipMemcpy(void*, void const*, unsigned long, hipMemcpyKind, hip::Stream&, bool, bool) () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libamdhip64.so
#12 0x00007f58bb30b647 in hip::hipMemcpyWithStream(void*, void const*, unsigned long, hipMemcpyKind, ihipStream_t*) () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libamdhip64.so
#13 0x00007f591425e2b6 in at::native::copy_kernel_cuda(at::TensorIterator&, bool) () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libtorch_hip.so
#14 0x00007f5954eea313 in at::native::copy_impl(at::Tensor&, at::Tensor const&, bool) () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#15 0x00007f5954eebcb2 in at::native::copy_(at::Tensor&, at::Tensor const&, bool) () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#16 0x00007f5955cb2ebc in at::_ops::copy_::call(at::Tensor&, at::Tensor const&, bool) () from /mnt/nvme3n1p1/sage/git/nm-vllm/test-venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#17 0x00007f5955210729 in at::native::_to_copy(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) ()

Obviously, this isn't very parsable but it does look like it's getting stuck waiting on a memory copy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants