Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support extract_outliers, quantize_4bit and dequantize_4bit with Device Abstraction PR. #23

Merged
merged 7 commits into from
May 8, 2024
Merged
33 changes: 23 additions & 10 deletions bitsandbytes/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from bitsandbytes.cextension import lib, HIP_ENVIRONMENT
from bitsandbytes.cextension import HIP_ENVIRONMENT, lib
from bitsandbytes.functional import (
CUBLAS_Context,
coo_zeros,
Expand All @@ -13,8 +13,8 @@
get_colrow_absmax,
get_ptr,
get_transform_buffer,
nvidia_transform,
is_on_gpu,
nvidia_transform,
post_call,
pre_call,
prod,
Expand Down Expand Up @@ -254,7 +254,7 @@ def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
ptrRowScale = get_ptr(None)
is_on_gpu([A, B, out])

if formatB == "col_turing" or HIP_ENVIRONMENT:
if formatB == "col_turing" or HIP_ENVIRONMENT:
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
Expand Down Expand Up @@ -322,7 +322,10 @@ def mm_dequant(
def extract_outliers(self, A, SA, idx):
shapeA = SA[0]
formatA = SA[1]
assert formatA in ["col_turing", "col_ampere"]
if not HIP_ENVIRONMENT:
assert formatA in ["col_turing", "col_ampere"]
else:
assert formatA in ["col"]
assert A.device.type == "cuda"

out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
Expand All @@ -336,7 +339,7 @@ def extract_outliers(self, A, SA, idx):

prev_device = pre_call(A.device)

if formatA == "col_turing":
if formatA == "col_turing" or HIP_ENVIRONMENT:
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
Expand All @@ -350,11 +353,13 @@ def quantize_4bit(
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
blocksize: Optional[int] = None,
compress_statistics=False,
quant_type="fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
if A.device.type != "cuda":
raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
if quant_type not in ["fp4", "nf4"]:
Expand All @@ -372,7 +377,10 @@ def quantize_4bit(
mod = dtype2bytes[quant_storage] * 2
out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device)

assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
if not HIP_ENVIRONMENT:
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
else:
assert blocksize in [4096, 2048, 1024, 512, 256, 128]

prev_device = pre_call(A.device)
is_on_gpu([A, out, absmax])
Expand Down Expand Up @@ -443,12 +451,17 @@ def dequantize_4bit(
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
blocksize: Optional[int] = None,
quant_type="fp4",
) -> torch.Tensor:
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64]
if HIP_ENVIRONMENT:
supported_blocksizes = supported_blocksizes[:-1]
if blocksize not in supported_blocksizes:
raise ValueError(
f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]"
f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}"
)

if quant_type not in ["fp4", "nf4"]:
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/cuda_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_rocm_gpu_arch() -> str:
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Name:\s+gfx(\d+)", result.stdout)
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
if match:
return "gfx" + match.group(1)
else:
Expand Down
Loading