You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Bug]: VLLM_ATTENTION_BACKEND set to ROCM_FLASH only in GHA environment, overriding automatic backend selection; this breaks other kernel unit tests.
#5208
Closed
afeldman-nm opened this issue
Jun 3, 2024
· 1 comment
· Fixed by #5210
Note: I only observe this problem when the tests are automatically executed by GitHub actions (GHA); as I do not have access to the GHA compute, I cannot provide an environment variable dump.
However, I do not observe this issue on my development machine. The environment dump from my development machine which does not exhibit this issue is provided below:
Collecting environment information...
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.28.3
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.5.0-26-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.3.103
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA RTX A4000
GPU 1: NVIDIA RTX A4000
GPU 2: NVIDIA RTX A4000
GPU 3: NVIDIA RTX A4000
Nvidia driver version: 545.23.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 48
On-line CPU(s) list: 0-47
Vendor ID: AuthenticAMD
Model name: AMD Ryzen Threadripper 3960X 24-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 1
Stepping: 0
Frequency boost: enabled
CPU max MHz: 3800.0000
CPU min MHz: 2200.0000
BogoMIPS: 7600.17
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es
Virtualization: AMD-V
L1d cache: 768 KiB (24 instances)
L1i cache: 768 KiB (24 instances)
L2 cache: 12 MiB (24 instances)
L3 cache: 128 MiB (8 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-47
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] mypy==1.9.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.0
[pip3] triton==2.3.0
[pip3] vllm-nccl-cu12==2.18.1.0.3.0
[conda] No relevant packagesROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.4.3
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 GPU1 GPU2 GPU3 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X SYS SYS SYS 0-47 0 N/A
GPU1 SYS X SYS SYS 0-47 0 N/A
GPU2 SYS SYS X PHB 0-47 0 N/A
GPU3 SYS SYS PHB X 0-47 0 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
🐛 Describe the bug
Context: I am working on a PR (#4888) which is a step towards addressing #187 . This PR (1) augments the Attention wrapper with support for encoder attention and encoder/decoder cross-attention, and (2) adds unit tests for this functionality in tests/kernels/test_self_and_cross_attn.py. Only the xFormers backend is supported in this PR; thus I expect that the unit tests would fail for other backends. I expect that the tests in tests/kernels/test_self_and_cross_attn.py will be executed when GitHub actions (GHA) runs the "Kernels 1/2/3/4" tests.
Bug:
At the time, my unit tests were allowing vllm/attention/selector.py to automatically choose the attention backend, even though I only wanted to test xFormers.
On my personal development machine, all tests in tests/kernels/test_self_and_cross_attn.py pass. I can see from log outputs that the xFormers kernel is being automatically selected.
However, when GitHuh actions (GHA) runs automatic tests on my PR, all non-skipped tests in tests/kernels/test_self_and_cross_attn.pyfail. I can see from log outputs that the backend selection is inconsistent; sometimes ROCm backend is selected, other times SDPA backend is selected.
shows that ROCm backend may only be selected automatically when is_hip() == True, i.e. on an AMD machine with HIP support.
Thus, while it would ideal if my attention unit tests had forced xFormers backend for testing purposes, nonetheless it is unexpected that selector.py should ever choose ROCm backend by default during the Kernels unit tests. This is the bug
I have confirmed that when ROCm backend is being selected, is_hip() == False, which rules out that an AMD machine is actually being used
Steps to reproduce the bug
To show that tests pass on a development machine with NVIDIA GPU (i.e. not AMD):
All tests should pass. Note that you may need to git fetch this commit.
To show that tests fail on GHA:
Note: for debug purposes, in this commit tests/kernels/test_self_and_cross_attn.py logs & prints the value of the VLLM_ATTENTION_BACKEND environment variable during the test_encoder_attention() unit test, as shown here
kernels/test_self_and_cross_attn.py::test_encoder_attention[128-16-1-xformers-64-1]
envs.VLLM_ATTENTION_BACKEND: ROCM_FLASH
-
| INFO 06-03 03:52:18 selector.py:56] Using ROCmFlashAttention backend.
| FAILED
envs.VLLM_ATTENTION_BACKEND: ROCM_FLASH is the environment variable printout, showing that the environment is contaminated with a VLLM_ATTENTION_BACKEND value that overrides selector.py to use the ROCm backend. Sure enough, Using ROCmFlashAttention backend indicates that the ROCm backend is selected, and then we see the test fails (as expected since my PR is only designed to update xFormers backend with encoder/decoder support.)
RCA
I believe I have identified the root cause, which is an issue with tests/kernels/test_attention_selector.py:
All tests in this file rely on an ad-hoc test fixture that temporarily sets VLLM_ATTENTION_BACKEND in order to force a choice of backend; each test has a cleanup phase that resets the environment variable to its pre-test value
don't handle the name_backup is None case, i.e. the case where VLLM_ATTENTION_BACKEND was unset prior to the test. This means that if prior to the testVLLM_ATTENTION_BACKEND was unset, then after the testVLLM_ATTENTION_BACKEND will still hold the value which it was set to during the test. Whereas the desired behavior would be for VLLM_ATTENTION_BACKEND to be unset at the end of the test.
So what appears to be happening is the following:
Buildkite parallelizes the Kernels unit tests into 4 partitions, Kernels 1/2/3/4
One or more of these partitions includes a subset of tests/kernels/test_attention_selector.py test-cases and a subset of tests/kernels/test_self_and_cross_attn.py test-cases, with the former being executed prior to the latter
The VLLM_ATTENTION_BACKEND value leaked by the selector tests happens to be ROCM_FLASH (as opposed to being unset), forcing tests/kernels/test_self_and_cross_attn.py to use ROCm Flash backend (even though the machine does not support it.)
In principle, it is possible that running all of the kernel tests (pytest tests/kernels) on my dev machine would reproduce this issue. However, it appears that VS code on my dev machine chooses to execute the tests in reverse order compared to GHA, with the attention tests preceding the selector tests; this seems likely to be what prevents me from reproducing the failure.
The text was updated successfully, but these errors were encountered:
Your current environment
Note: I only observe this problem when the tests are automatically executed by GitHub actions (GHA); as I do not have access to the GHA compute, I cannot provide an environment variable dump.
However, I do not observe this issue on my development machine. The environment dump from my development machine which does not exhibit this issue is provided below:
🐛 Describe the bug
Context: I am working on a PR (#4888) which is a step towards addressing #187 . This PR (1) augments the Attention wrapper with support for encoder attention and encoder/decoder cross-attention, and (2) adds unit tests for this functionality in
tests/kernels/test_self_and_cross_attn.py
. Only the xFormers backend is supported in this PR; thus I expect that the unit tests would fail for other backends. I expect that the tests intests/kernels/test_self_and_cross_attn.py
will be executed when GitHub actions (GHA) runs the "Kernels 1/2/3/4" tests.Bug:
vllm/attention/selector.py
to automatically choose the attention backend, even though I only wanted to test xFormers.tests/kernels/test_self_and_cross_attn.py
pass. I can see from log outputs that the xFormers kernel is being automatically selected.tests/kernels/test_self_and_cross_attn.py
fail. I can see from log outputs that the backend selection is inconsistent; sometimes ROCm backend is selected, other times SDPA backend is selected.selector.py
vllm/vllm/attention/selector.py
Line 104 in dfbe60d
shows that ROCm backend may only be selected automatically when
is_hip() == True
, i.e. on an AMD machine with HIP support.selector.py
should ever choose ROCm backend by default during the Kernels unit tests. This is the bugis_hip() == False
, which rules out that an AMD machine is actually being usedSteps to reproduce the bug
To show that tests pass on a development machine with NVIDIA GPU (i.e. not AMD):
All tests should pass. Note that you may need to git fetch this commit.
To show that tests fail on GHA:
tests/kernels/test_self_and_cross_attn.py
logs & prints the value of theVLLM_ATTENTION_BACKEND
environment variable during the test_encoder_attention() unit test, as shown herevllm/tests/kernels/test_self_and_cross_attn.py
Lines 1284 to 1287 in f6e0310
envs.VLLM_ATTENTION_BACKEND: ROCM_FLASH
is the environment variable printout, showing that the environment is contaminated with aVLLM_ATTENTION_BACKEND
value that overridesselector.py
to use the ROCm backend. Sure enough,Using ROCmFlashAttention backend
indicates that the ROCm backend is selected, and then we see the test fails (as expected since my PR is only designed to update xFormers backend with encoder/decoder support.)RCA
I believe I have identified the root cause, which is an issue with
tests/kernels/test_attention_selector.py
:VLLM_ATTENTION_BACKEND
in order to force a choice of backend; each test has a cleanup phase that resets the environment variable to its pre-test valuetest_env()
vllm/tests/kernels/test_attention_selector.py
Lines 35 to 36 in f6e0310
and
test_flash_attn()
vllm/tests/kernels/test_attention_selector.py
Lines 74 to 75 in f6e0310
don't handle the
name_backup is None
case, i.e. the case whereVLLM_ATTENTION_BACKEND
was unset prior to the test. This means that if prior to the testVLLM_ATTENTION_BACKEND
was unset, then after the testVLLM_ATTENTION_BACKEND
will still hold the value which it was set to during the test. Whereas the desired behavior would be forVLLM_ATTENTION_BACKEND
to be unset at the end of the test.So what appears to be happening is the following:
tests/kernels/test_attention_selector.py
test-cases and a subset oftests/kernels/test_self_and_cross_attn.py
test-cases, with the former being executed prior to the latterVLLM_ATTENTION_BACKEND
value leaked by the selector tests happens to be ROCM_FLASH (as opposed to being unset), forcingtests/kernels/test_self_and_cross_attn.py
to use ROCm Flash backend (even though the machine does not support it.)In principle, it is possible that running all of the kernel tests (
pytest tests/kernels
) on my dev machine would reproduce this issue. However, it appears that VS code on my dev machine chooses to execute the tests in reverse order compared to GHA, with the attention tests preceding the selector tests; this seems likely to be what prevents me from reproducing the failure.The text was updated successfully, but these errors were encountered: