Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support extract_outliers, quantize_4bit and dequantize_4bit with Device Abstraction PR. #23

Merged
merged 7 commits into from
May 8, 2024
23 changes: 18 additions & 5 deletions bitsandbytes/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ def mm_dequant(
def extract_outliers(self, A, SA, idx):
shapeA = SA[0]
formatA = SA[1]
assert formatA in ["col_turing", "col_ampere"]
if not HIP_ENVIRONMENT:
assert formatA in ["col_turing", "col_ampere"]
else:
assert formatA in ["col"]
assert A.device.type == "cuda"

out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
Expand All @@ -336,7 +339,7 @@ def extract_outliers(self, A, SA, idx):

prev_device = pre_call(A.device)

if formatA == "col_turing":
if formatA == "col_turing" or HIP_ENVIRONMENT:
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
Expand All @@ -355,6 +358,8 @@ def quantize_4bit(
quant_type="fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
if HIP_ENVIRONMENT:
blocksize = 128
if A.device.type != "cuda":
raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
if quant_type not in ["fp4", "nf4"]:
Expand All @@ -372,7 +377,10 @@ def quantize_4bit(
mod = dtype2bytes[quant_storage] * 2
out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device)

assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
if not HIP_ENVIRONMENT:
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
else:
assert blocksize in [4096, 2048, 1024, 512, 256, 128]

prev_device = pre_call(A.device)
is_on_gpu([A, out, absmax])
Expand Down Expand Up @@ -446,9 +454,14 @@ def dequantize_4bit(
blocksize: int = 64,
quant_type="fp4",
) -> torch.Tensor:
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
if HIP_ENVIRONMENT:
blocksize = 128
supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64]
if HIP_ENVIRONMENT:
supported_blocksizes = supported_blocksizes[:-1]
if blocksize not in supported_blocksizes:
raise ValueError(
f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]"
f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}"
)

if quant_type not in ["fp4", "nf4"]:
Expand Down
Loading