Skip to content

Commit

Permalink
Merge pull request #23 from ROCm/cl/update-device-abs
Browse files Browse the repository at this point in the history
Support extract_outliers, quantize_4bit and dequantize_4bit with Device Abstraction PR.
  • Loading branch information
pnunna93 authored May 8, 2024
2 parents 01abfde + 62f8ed9 commit d9e4803
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
33 changes: 23 additions & 10 deletions bitsandbytes/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from bitsandbytes.cextension import lib, HIP_ENVIRONMENT
from bitsandbytes.cextension import HIP_ENVIRONMENT, lib
from bitsandbytes.functional import (
CUBLAS_Context,
coo_zeros,
Expand All @@ -13,8 +13,8 @@
get_colrow_absmax,
get_ptr,
get_transform_buffer,
nvidia_transform,
is_on_gpu,
nvidia_transform,
post_call,
pre_call,
prod,
Expand Down Expand Up @@ -254,7 +254,7 @@ def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
ptrRowScale = get_ptr(None)
is_on_gpu([A, B, out])

if formatB == "col_turing" or HIP_ENVIRONMENT:
if formatB == "col_turing" or HIP_ENVIRONMENT:
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
Expand Down Expand Up @@ -322,7 +322,10 @@ def mm_dequant(
def extract_outliers(self, A, SA, idx):
shapeA = SA[0]
formatA = SA[1]
assert formatA in ["col_turing", "col_ampere"]
if not HIP_ENVIRONMENT:
assert formatA in ["col_turing", "col_ampere"]
else:
assert formatA in ["col"]
assert A.device.type == "cuda"

out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
Expand All @@ -336,7 +339,7 @@ def extract_outliers(self, A, SA, idx):

prev_device = pre_call(A.device)

if formatA == "col_turing":
if formatA == "col_turing" or HIP_ENVIRONMENT:
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
Expand All @@ -350,11 +353,13 @@ def quantize_4bit(
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
blocksize: Optional[int] = None,
compress_statistics=False,
quant_type="fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
if A.device.type != "cuda":
raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
if quant_type not in ["fp4", "nf4"]:
Expand All @@ -372,7 +377,10 @@ def quantize_4bit(
mod = dtype2bytes[quant_storage] * 2
out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device)

assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
if not HIP_ENVIRONMENT:
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
else:
assert blocksize in [4096, 2048, 1024, 512, 256, 128]

prev_device = pre_call(A.device)
is_on_gpu([A, out, absmax])
Expand Down Expand Up @@ -443,12 +451,17 @@ def dequantize_4bit(
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
blocksize: Optional[int] = None,
quant_type="fp4",
) -> torch.Tensor:
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64]
if HIP_ENVIRONMENT:
supported_blocksizes = supported_blocksizes[:-1]
if blocksize not in supported_blocksizes:
raise ValueError(
f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]"
f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}"
)

if quant_type not in ["fp4", "nf4"]:
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/cuda_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_rocm_gpu_arch() -> str:
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Name:\s+gfx(\d+)", result.stdout)
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
if match:
return "gfx" + match.group(1)
else:
Expand Down

0 comments on commit d9e4803

Please sign in to comment.