Skip to content

Commit

Permalink
adding arch detection for test_gemv_eye_4bit
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Apr 26, 2024
1 parent c037a30 commit 60d7560
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
11 changes: 10 additions & 1 deletion bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import ctypes as ct
import logging
import os
import subprocess
import re

from pathlib import Path

import torch
Expand Down Expand Up @@ -117,8 +120,14 @@ def get_native_library() -> BNBNativeLibrary:
if torch.version.hip:
hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2])
HIP_ENVIRONMENT, BNB_HIP_VERSION = True, hip_major * 100 + hip_minor
result = subprocess.run(['rocminfo'], capture_output=True, text=True)
match = re.search(r'Name:\s+gfx(\d+)', result.stdout)
if match:
ROCM_GPU_ARCH = "gfx" + match.group(1)
else:
ROCM_GPU_ARCH = "unknown"
else:
HIP_ENVIRONMENT, BNB_HIP_VERSION = False, 0
HIP_ENVIRONMENT, BNB_HIP_VERSION, ROCM_GPU_ARCH = False, 0, "unknown"
lib = get_native_library()
except Exception as e:
lib = None
Expand Down
3 changes: 2 additions & 1 deletion tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT
from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT, ROCM_GPU_ARCH
from tests.helpers import BOOLEAN_TUPLES, TRUE_FALSE, describe_dtype, get_blocksizes, get_test_dims, id_formatter

torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
Expand Down Expand Up @@ -2242,6 +2242,7 @@ def test_managed():
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
@pytest.mark.skipif(HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", reason="this test is not supported on ROCm with gfx90a architecture yet")
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
Expand Down

0 comments on commit 60d7560

Please sign in to comment.