Skip to content

Commit

Permalink
Add Ascend NPU support for nf4 quant (bitsandbytes-foundation#1422)
Browse files Browse the repository at this point in the history
* Add npu support for nf4 quant

Co-authored-by: Slightwind <slightwindsec@gmail.com>
Co-authored-by: Ginray <ginray0215@gmail.com>

* code format

* update

* pass lint check and fix typos

* add npu to supported devices

---------

Co-authored-by: Slightwind <slightwindsec@gmail.com>
Co-authored-by: Ginray <ginray0215@gmail.com>
  • Loading branch information
3 people authored and rsshaik1 committed Jan 10, 2025
1 parent 03bdf88 commit b6e447b
Show file tree
Hide file tree
Showing 14 changed files with 581 additions and 29 deletions.
49 changes: 45 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# For GCC: `cmake -B build . && cmake --build build`
# For MSVC: `cmake -B build . && cmake --build build --config Release`
# You can also use the following options and variables
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `hip` or `mps` to select the backend
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `hip`, `mps` or `npu` to select the backend
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
# is whatever CMake finds on your path.
Expand All @@ -29,11 +29,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal)
set(NPU_FILES csrc/npu_ops.cpp)
# C++ sources are always included
list(APPEND SRC_FILES ${CPP_FILES})

set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, npu)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps npu)
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)

if(APPLE)
Expand Down Expand Up @@ -69,6 +70,11 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps")
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS ON)
elseif(${COMPUTE_BACKEND} STREQUAL "npu")
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_NPU ON)
else()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
Expand Down Expand Up @@ -232,6 +238,33 @@ elseif(BUILD_MPS)
COMMENT "Compiling Metal kernels"
VERBATIM)
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
elseif(BUILD_NPU)
list(APPEND SRC_FILES ${NPU_FILES})

set(SOC_VERSION "Ascend910B4" CACHE STRING "system on chip type")
set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME_PATH} CACHE
STRING "ASCEND CAN package installation directory"
)

# ${KERNEL_FILES} are used to compile library, push files written by ascendc in ${KERNEL_FILES}.
# ref to cmake/npu.cmake ascendc_library, cmake/cpu.cmake add_library
# file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/npu_kernels.cpp)
file(GLOB KERNEL_FILES csrc/npu_kernels.cpp)

if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake)
else()
message(FATAL_ERROR "ascendc_kernel_cmake does not exist ,please check whether the can package is installed")
endif()
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)

# ascendc_library use to add kernel file to generate ascendc library
ascendc_library(ascendc_kernels_npu STATIC ${KERNEL_FILES})

string(APPEND BNB_OUTPUT_NAME "_npu")
add_compile_definitions(BUILD_NPU)
else()
string(APPEND BNB_OUTPUT_NAME "_cpu")
set(GPU_SOURCES)
Expand All @@ -249,7 +282,11 @@ endif()

set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX)
add_library(bitsandbytes SHARED ${SRC_FILES})
target_compile_features(bitsandbytes PUBLIC cxx_std_14)
if(BUILD_NPU)
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
else()
target_compile_features(bitsandbytes PUBLIC cxx_std_14)
endif()
target_include_directories(bitsandbytes PUBLIC csrc include)


Expand Down Expand Up @@ -306,6 +343,10 @@ if(BUILD_MPS)
add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
endif()
if(BUILD_NPU)
target_compile_options(bitsandbytes PRIVATE -O2 -std=c++17)
target_link_libraries(bitsandbytes PRIVATE $<BUILD_INTERFACE:host_intf_pub> ascendc_kernels_npu)
endif()

if(WIN32)
set_target_properties(bitsandbytes PROPERTIES PREFIX "lib")
Expand Down
3 changes: 3 additions & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
[default]
extend-ignore-re = [
"@Ther-nul", # valid Github user
"CANN", # CANN (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for Ascend NPU
]

[default.extend-identifiers]

[type.py.extend-words]
"BA" = "BA" # used as a commented-out variable in tests
"cann" = "cann" # cann (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for Ascend NPU


[type.cuda.extend-words]
"subtile" = "subtile"
Expand Down
1 change: 1 addition & 0 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
features = {"multi_backend"}
supported_torch_devices = {
"cuda", # includes ROCm
"npu", # Ascend NPU
"xpu", # Intel GPU
"cpu",
"hpu",
Expand Down
14 changes: 11 additions & 3 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,12 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]

# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
if A.device.type == "npu":
output = torch.matmul(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t())
if bias is not None:
output += bias
else:
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)

# 3. Save state
ctx.state = quant_state
Expand Down Expand Up @@ -550,7 +555,10 @@ def backward(ctx, grad_output):
# not supported by PyTorch. TODO: create work-around
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
if grad_output.device.type == "npu":
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype))
else:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())

return grad_A, grad_B, None, grad_bias, None

Expand Down Expand Up @@ -586,7 +594,7 @@ def matmul_4bit(
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
elif A.numel() == A.shape[-1] and A.requires_grad == False:
elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "npu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
11 changes: 1 addition & 10 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

gxx_available = False
try:
subprocess.run(["g++", "--version"])
subprocess.run(["g++", "--version"], capture_output=True) # hide terminal output
gxx_available = True
except BaseException:
warnings.warn("g++ not found, torch.compile disabled for CPU/XPU.")
Expand Down Expand Up @@ -445,22 +445,13 @@ def dequantize_4bit_impl(
quant_state.ipex = False

# Map nf4 to [-1, 1]
<<<<<<< HEAD
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
out_uint8[::2] = A.bitwise_and(0xF)
out_uint8[1::2] = A.bitwise_right_shift(4)
out_dq = torch.empty(out_uint8.shape, dtype=quant_state.code.dtype, device= quant_state.code.device)
for i in range(len(quant_state.code)):
out_dq[out_uint8 == i] = quant_state.code[i]
=======
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[::2] = A & 0xF
out_dq[1::2] = A >> 4
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
quant_state.code = quant_state.code.to(quant_state.dtype)
out_dq = quant_state.code[out_dq]
>>>>>>> b2ac423 (Enable XPU and optimize cpu/xpu op (#1418))

# Apply scales
if out_dq.numel() != n:
Expand Down
152 changes: 142 additions & 10 deletions bitsandbytes/backends/npu.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import ctypes as ct
from typing import Literal, Optional, Tuple, Union

import torch

from bitsandbytes.utils import QuantState

from .base import Backend

try:
# to support Ascend NPU backend
import torch_npu # noqa: F401
except ImportError:
pass

from bitsandbytes.cextension import lib
from bitsandbytes.functional import (
get_4bit_type,
get_ptr,
)
from bitsandbytes.utils import QuantState

from .base import Backend


def assert_on_npu(tensors):
if not all(t.device.type == "npu" for t in tensors if t is not None):
raise TypeError(
"All input tensors to be on NPU, but found some tensors not be on NPU:\n"
f"{[(t.shape, t.device) if isinstance(t, torch.Tensor) else None for t in tensors]}"
)
return True


class NPUBackend(Backend):
def double_quant(
Expand Down Expand Up @@ -75,23 +90,140 @@ def quantize_4bit(
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
blocksize: Optional[int] = None,
compress_statistics=False,
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_type: Literal["fp4", "nf4"] = "nf4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError
if quant_type not in ["nf4"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
if compress_statistics:
raise NotImplementedError("compress_statistics is not implemented.")
if blocksize is None:
blocksize = 128

prev_device = torch.npu.current_device()
torch.npu.set_device(A.device)
if A.dtype in [torch.float32, torch.float16, torch.bfloat16]:
data = [
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
]
data = torch.tensor(data, device="npu", dtype=torch.float32).view(1, -1)
absmax = A.view(-1, blocksize).abs().max(dim=1, keepdim=True).values
a = A.view(-1, blocksize) / absmax.float()
diff = torch.abs(a.unsqueeze(-1) - data)
out = (torch.argmin(diff, dim=-1) + 8) % 16
out = out.reshape(-1, 2)
out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
assert_on_npu([A, absmax, out])
torch.npu.set_device(prev_device)

code = get_4bit_type(quant_type, device=A.device)
state = QuantState(
absmax=absmax,
shape=A.shape,
dtype=A.dtype,
blocksize=blocksize,
code=code,
quant_type=quant_type,
)

return out, state

def dequantize_4bit(
self,
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
quant_type: Literal["fp4", "nf4"] = "fp4",
blocksize: Optional[int] = None,
quant_type: Literal["fp4", "nf4"] = "nf4",
) -> torch.Tensor:
raise NotImplementedError
if blocksize is None:
blocksize = 128
supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64]
if blocksize not in supported_blocksizes:
raise ValueError(
f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}"
)

if quant_state is None:
assert absmax is not None and out is not None
quant_state = QuantState(
absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type
)
else:
absmax = quant_state.absmax

if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)

n = out.numel()

prev_device = torch.npu.current_device()
torch.npu.set_device(A.device)
assert_on_npu([A, absmax, out])

if quant_state.quant_type not in ["nf4"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")

if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32_nf4(
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
torch.npu.current_stream(),
)
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16_nf4(
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
torch.npu.current_stream(),
)
elif out.dtype == torch.bfloat16:
# bf16: bf16 -> fp32 -> op -> fp32 -> bf16
absmax = absmax.to(torch.float32)
out = out.to(torch.float32)
lib.cdequantize_blockwise_fp32_nf4(
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
torch.npu.current_stream(),
)
out = out.to(torch.bfloat16)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
torch.npu.set_device(prev_device)
is_transposed = True if A.shape[0] == 1 else False

if is_transposed:
return out.t()
else:
return out

def gemv_4bit(
self,
Expand Down
5 changes: 5 additions & 0 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_rocm_gpu_arch
from bitsandbytes.npu_specs import get_npu_specs

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,6 +101,10 @@ def get_native_library() -> BNBNativeLibrary:
binary_path = cuda_binary_path
else:
logger.warning("Could not find the bitsandbytes %s binary at %r", BNB_BACKEND, cuda_binary_path)
npu_specs = get_npu_specs()
if npu_specs:
binary_path = PACKAGE_DIR / f"libbitsandbytes_npu{DYNAMIC_LIBRARY_SUFFIX}"

logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
dll = ct.cdll.LoadLibrary(str(binary_path))

Expand Down
Loading

0 comments on commit b6e447b

Please sign in to comment.