Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ROCm whls #39

Merged
merged 92 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
7772fa3
doc-builder image had been changed, need to revert to old one due to …
Titus-von-Koeller May 6, 2024
b659c70
Update CONTRIBUTING.md
Titus-von-Koeller May 7, 2024
b97ea77
Update README.md
Titus-von-Koeller May 7, 2024
b891f80
Update README.md
Titus-von-Koeller May 7, 2024
09cc153
Support NF4 on CPU backend
Xia-Weiwen May 8, 2024
177bd39
Minor improvements
Xia-Weiwen May 10, 2024
881b5fc
Add fp4 support; add UT; fix lint issues
Xia-Weiwen May 11, 2024
dd15734
Reduce memory usage
Xia-Weiwen May 11, 2024
85a01b0
Fix UT
Xia-Weiwen May 11, 2024
2c489f8
reduce memory usage for nf4
Xia-Weiwen May 11, 2024
13c70d3
clarify
stevhliu Apr 29, 2024
2b7daed
clarify
stevhliu May 14, 2024
d7a5a24
feedback
stevhliu May 16, 2024
25abf8d
Merge pull request #1211 from stevhliu/fix
Titus-von-Koeller May 19, 2024
c51437b
Update matplotlib requirement from ~=3.8.4 to ~=3.9.0 in the major group
dependabot[bot] May 20, 2024
fa65a9d
Bump pytest from 8.2.0 to 8.2.1 in the minor-patch group
dependabot[bot] May 20, 2024
be6700b
Merge pull request #1215 from TimDettmers/dependabot/pip/major-2d933c…
Titus-von-Koeller May 23, 2024
328b5a9
Merge pull request #1216 from TimDettmers/dependabot/pip/minor-patch-…
Titus-von-Koeller May 23, 2024
701c5aa
Merge pull request #1206 from Xia-Weiwen/multi-backend-refactor-cpu-4bit
Titus-von-Koeller May 24, 2024
eb3b816
Merge pull request #1207 from ROCm/device_abstraction
Titus-von-Koeller May 24, 2024
79815ad
README: ask for help from volunteer alpha testers
Titus-von-Koeller May 24, 2024
a9a1c44
Add `"lamb"` to `str2optimizer32bit`
IndigoDosSantos May 26, 2024
ccee5d8
Add empty stubs for Ascend NPU
ji-huazhong May 27, 2024
a8644b7
Bump scipy from 1.13.0 to 1.13.1 in the minor-patch group
dependabot[bot] May 27, 2024
09c314a
Merge pull request #1223 from statelesshz/backend-npu
Titus-von-Koeller May 28, 2024
c08653b
Merge pull request #1224 from TimDettmers/dependabot/pip/minor-patch-…
Titus-von-Koeller May 28, 2024
2dbf876
Merge branch 'main' into multi-backend-refactor
Titus-von-Koeller May 28, 2024
227d904
Merge branch 'TimDettmers:main' into main
IndigoDosSantos May 28, 2024
2e46eef
Sorted alphabetically for better overview
IndigoDosSantos May 28, 2024
7a338db
Update functional.py
IndigoDosSantos May 28, 2024
2fb212b
FIX Prevent __getstate__ from mutating Params4bit
BenjaminBossan May 29, 2024
36fe1a0
fix blocksize
jiqing-feng May 29, 2024
ed99b3c
FIX Make Int8Params deepcopy-able
BenjaminBossan May 30, 2024
3c8c18a
Merge pull request #1231 from BenjaminBossan/fix-8bit-deepcopy
Titus-von-Koeller May 30, 2024
d9b1125
Merge pull request #1230 from BenjaminBossan/fix-4bit-getstate
Titus-von-Koeller May 30, 2024
1f2ca43
Merge pull request #1222 from EtienneDosSantos/main
Titus-von-Koeller May 30, 2024
dba8376
Merge pull request #1228 from jiqing-feng/4bit
Titus-von-Koeller May 30, 2024
b22ae26
fix for faulty #1222 ("Add `"lamb"` to `str2optimizer32bit`") (#1240)
younesbelkada Jun 5, 2024
517eaf2
CPU: add torch.compile for F.double_quant and F.quantize_4bit (#1238)
Xia-Weiwen Jun 6, 2024
5891465
Add build job for rocm
pnunna93 Jun 19, 2024
d03a680
Add rocm build script
pnunna93 Jun 19, 2024
ec9000f
Copy shared obj file into output_dir
pnunna93 Jun 20, 2024
9b8c1da
upload build artifacts and enable wheels build
pnunna93 Jun 20, 2024
1413c5f
Remove cuda build temporarily
pnunna93 Jun 20, 2024
195ae61
Bump the minor-patch group across 1 directory with 2 updates (#1253)
dependabot[bot] Jun 21, 2024
193120d
cleanup docs-build breaking install instructs (#1244)
Titus-von-Koeller Jun 21, 2024
dada530
cpu install guide (#1227)
jiqing-feng Jun 21, 2024
c79b1e9
provide temp flag for outside libs to detect multi-backend preview (#…
Titus-von-Koeller Jun 21, 2024
1bfecc8
CPU/XPU: disable torch.compile if g++ is not available (#1251)
Xia-Weiwen Jul 10, 2024
0859784
Create build job for ROCm (#1255)
pnunna93 Jul 12, 2024
1935a45
fix broken <source> links in autodoc API reference (#1275)
Titus-von-Koeller Jul 12, 2024
85e0127
Fix CUDA 12.5 build issue (#1273)
HennerM Jul 12, 2024
6866a4a
Bump scipy from 1.13.1 to 1.14.0 in the minor-patch group (#1266)
dependabot[bot] Jul 12, 2024
8c6ab69
update repo owner
Titus-von-Koeller Jul 12, 2024
7be1143
update repo owner
Titus-von-Koeller Jul 12, 2024
6948f0b
Fix Windows CUDA build compatibility with newest MSVC (#1276)
matthewdouglas Jul 15, 2024
f2b2310
Update matplotlib requirement from ~=3.9.0 to ~=3.9.1 in the major gr…
dependabot[bot] Jul 15, 2024
39b42e7
Fixed tests for cpu only platforms (#1259)
galqiwi Jul 15, 2024
9e75374
fix QLoRA mem bug: delete useless buffered activation (#1270)
Ther-nullptr Jul 16, 2024
0bdd57c
Add CUDA 12.5 and update 12.4 builds (#1284)
matthewdouglas Jul 21, 2024
5212a0f
Edenzzzz's fix for min_8bit_size functionality in Optimizer base clas…
Titus-von-Koeller Jul 22, 2024
a3f55ce
Fixed optim update error with non-contiguous grads/params (#1187)
Edenzzzz Jul 22, 2024
e3ae243
Bump pytest from 8.2.2 to 8.3.1 in the minor-patch group (#1287)
dependabot[bot] Jul 22, 2024
7fed393
Fix restoration of quant_storage for CPU offloading (#1279)
matthewdouglas Jul 23, 2024
1571110
remove unnecessary version mention
Titus-von-Koeller Jul 23, 2024
ce53caf
release 0.43.2
Titus-von-Koeller Jul 23, 2024
a7c08af
bump version tag to next dev
Titus-von-Koeller Jul 23, 2024
9b72679
Changelog: add explanation r. QLoRA mem savings
Titus-von-Koeller Jul 23, 2024
056011a
merge `main` into `multi-backend-refactor`
Titus-von-Koeller Jul 26, 2024
81375f8
docs: add more details to Intel install
Titus-von-Koeller Jul 27, 2024
7800734
Changelog: add explanation r. QLoRA mem savings
Titus-von-Koeller Jul 23, 2024
24f7b65
docs: cleanup of compilation instructions
Titus-von-Koeller Jul 27, 2024
e3b2780
docs: CHANGELOG.md fix
Titus-von-Koeller Jul 27, 2024
0b53d31
Merge remote-tracking branch 'upstream/main' into multi-backend-refactor
Titus-von-Koeller Jul 27, 2024
c8b4b33
fix dtype mismatch (#1285)
jiqing-feng Jul 27, 2024
fd655b0
Add ROCm version to .so filename
pnunna93 Jul 29, 2024
6b77f4c
Add rocm_version to whls build
pnunna93 Jul 29, 2024
78324b3
Revert "Remove cuda build temporarily"
pnunna93 Jul 29, 2024
c146b8b
Add rocm_version env var
pnunna93 Jul 29, 2024
953a383
Merge remote-tracking branch 'upstream/multi-backend-refactor' into e…
pnunna93 Jul 29, 2024
d6c3df4
Remove thrush header files
pnunna93 Jul 30, 2024
7e9a65c
Print node info
pnunna93 Jul 30, 2024
cdb209a
print cuda node info
pnunna93 Jul 30, 2024
77e1499
Revert "print cuda node info"
pnunna93 Jul 30, 2024
7c91909
Revert "Print node info"
pnunna93 Jul 30, 2024
b78b340
Add rocm arch to compile command
pnunna93 Jul 30, 2024
a62b9d4
Rename .so files to rocm
pnunna93 Jul 30, 2024
9059bff
Update default gpu arch
pnunna93 Jul 30, 2024
c5a406a
Skip cpu based igemmlt int tests on ROCm
pnunna93 Jul 30, 2024
9cbb5e1
Update Documentation
pnunna93 Jul 30, 2024
3580624
Update upstream repo name
pnunna93 Jul 30, 2024
3bde1b7
Update docs
pnunna93 Jul 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/scripts/build-rocm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash
declare build_arch
declare build_os
declare rocm_version

set -xeuo pipefail
bnb_rocm_arch="gfx90a;gfx942;gfx1100"
if [ "${build_os:0:6}" == ubuntu ]; then
image=rocm/dev-ubuntu-22.04:${rocm_version}-complete
echo "Using image $image"
docker run --rm --platform "linux/$build_arch" -i \
-w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \
&& cmake --build ."
fi

output_dir="output/${build_os}/${build_arch}"
mkdir -p "${output_dir}"
(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}")
5 changes: 4 additions & 1 deletion .github/workflows/build_documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ jobs:
with:
commit_sha: ${{ github.sha }}
package: bitsandbytes
repo_owner: TimDettmers
repo_owner: bitsandbytes-foundation
# avoid /src suffix leading to wrong links, like bitsandbytes/blob/main/src/bitsandbytes/nn/
version_tag_suffix: '' # defaults to '/src'
custom_container: huggingface/transformers-doc-builder
secrets:
hf_token: ${{ secrets.HUGGINGFACE_PUSH }}
7 changes: 5 additions & 2 deletions .github/workflows/build_pr_documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ concurrency:

jobs:
build:
if: github.repository == 'TimDettmers/bitsandbytes'
if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: bitsandbytes
repo_owner: TimDettmers
repo_owner: bitsandbytes-foundation
# avoid /src suffix leading to wrong links, like bitsandbytes/blob/main/src/bitsandbytes/nn/
version_tag_suffix: '' # defaults to '/src'
custom_container: huggingface/transformers-doc-builder
38 changes: 34 additions & 4 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,10 @@ jobs:
os: [ubuntu-latest, windows-latest]
arch: [x86_64, aarch64]
cuda_version:
["11.7.1", "11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.0"]
["11.7.1", "11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.0"]
exclude:
- os: windows-latest # This probably requires arm64 Windows agents
arch: aarch64
- os: windows-latest # The Jimver/cuda-toolkit is action used for Windows builds is not updated for 12.4 yet.
cuda_version: "12.4.0"
- os: ubuntu-latest # Temporary. Takes too long, not ready yet.
arch: aarch64
runs-on: ${{ matrix.os }} # One day, we could run them on native agents. Azure supports this now but it's planned only for Q3 2023 for hosted agents
Expand All @@ -79,7 +77,7 @@ jobs:
if: startsWith(matrix.os, 'ubuntu')
uses: docker/setup-qemu-action@v2
# Windows: We install Cuda on the agent (slow)
- uses: Jimver/cuda-toolkit@v0.2.14
- uses: Jimver/cuda-toolkit@v0.2.16
if: startsWith(matrix.os, 'windows')
id: cuda-toolkit
with:
Expand All @@ -103,10 +101,42 @@ jobs:
name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }}
path: output/*
retention-days: 7
build-shared-libs-rocm:
strategy:
matrix:
os: [ubuntu-latest]
arch: [x86_64]
rocm_version:
["6.1.2"]
runs-on: ${{ matrix.os }} # One day, we could run them on native agents. Azure supports this now but it's planned only for Q3 2023 for hosted agents
steps:
- uses: actions/checkout@v4
- name: Set up Docker multiarch
if: startsWith(matrix.os, 'ubuntu')
uses: docker/setup-qemu-action@v2
- name: Clean up disk space
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf "/usr/local/share/boost"
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Build C++
run: bash .github/scripts/build-rocm.sh
env:
build_os: ${{ matrix.os }}
build_arch: ${{ matrix.arch }}
rocm_version: ${{ matrix.rocm_version }}
- name: Upload build artifact
uses: actions/upload-artifact@v4
with:
name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }}
path: output/*
retention-days: 7
build-wheels:
needs:
- build-shared-libs
- build-shared-libs-cuda
- build-shared-libs-rocm
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
Expand Down
26 changes: 26 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
### 0.43.2

This release is quite significant as the QLoRA bug fix big implications for higher `seqlen` and batch sizes.

For each sequence (i.e. batch size increase of one) we expect memory savings of:
- 405B: 39GB for `seqlen=1024`, and 4888GB for `seqlen=128,00`
- 70B: 10.1GB for `seqlen=1024` and 1258GB for `seqlen=128,00`

This was due to activations being unnecessary for frozen parameters, yet the memory for them was still erroneously allocated due to the now fixed bug.

#### Improvements:

- docs: FSDP+QLoRA and CPU install guide (#1211 #1227, thanks @stevhliu)
- Add CUDA 12.5 and update 12.4 builds (#1284)

#### Bug Fixes

- 4bit getstate and 8bit deepcopy (#1230 #1231, thanks @BenjaminBossan)
- missing optimizers in `str2optimizer32bit` (#1222, thanks @EtienneDosSantos)
- CUDA 12.5 build issue (#1273, thanks @HennerM)
- fix for min_8bit_size functionality in Optimizer base classes (#1286, thanks @Edenzzzz)
- QLoRA mem bug (#1270, thanks @Ther-nullptr)
- tests for cpu only platforms (#1259, thanks @galqiwi)
- restoration of quant_storage for CPU offloading (#1279)
- optim update error with non-contiguous grads/params (deepspeed) (#1187)

### 0.43.1

#### Improvements:
Expand Down
14 changes: 11 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ endif()


if(BUILD_CUDA)
# NVCC normally will only work with MSVC up to 1939. VS2022 17.10+ starts using versions 1940+.
# Workaround: use --allow-unsupported-compiler
# This needs to be added *before* we try to enable the CUDA language so CMake's compiler check passes.
if(MSVC AND MSVC_VERSION VERSION_GREATER_EQUAL 1940)
string(APPEND CMAKE_CUDA_FLAGS " --allow-unsupported-compiler")
endif()

enable_language(CUDA) # This will fail if CUDA is not found
find_package(CUDAToolkit REQUIRED)

Expand Down Expand Up @@ -178,7 +185,7 @@ elseif(BUILD_HIP)
set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH})
else()
if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx940;gfx941;gfx942")
set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100")
elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
endif()
Expand All @@ -187,12 +194,14 @@ elseif(BUILD_HIP)

list(APPEND SRC_FILES ${HIP_FILES})

string(APPEND BNB_OUTPUT_NAME "_hip")
string(APPEND BNB_OUTPUT_NAME "_rocm")

# get hip version
execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION)
string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}")
string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}")

string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}")
if(NO_CUBLASLT OR HIP_VERSION VERSION_LESS "6.1")
string(APPEND BNB_OUTPUT_NAME "_nohipblaslt")
endif()
Expand Down Expand Up @@ -229,7 +238,6 @@ if(WIN32)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()

# Weird MSVC hacks
if(MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast")
endif()
Expand Down
13 changes: 1 addition & 12 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,12 @@ We actively welcome your pull requests.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>
5. Make sure your code lints, install the [pre-commit hooks as documented here](https://huggingface.co/docs/bitsandbytes/main/en/contributing#setup-pre-commit-hooks).

## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.

## License
By contributing to bitsandbytes, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,18 @@ There are ongoing efforts to support further hardware backends, i.e. Intel CPU +

**[https://huggingface.co/docs/bitsandbytes/main](https://huggingface.co/docs/bitsandbytes/main)**

## ALPHA TESTERS WANTED: `multi-backend-refactor` AMD GPU + Intel CPU/GPU specific BNB backend implementations

We're in the process of a complex refactor in order to allow the support of additional hardware backends, other than CUDA, in BNB. The efforts around this are already quite far along and there's plenty of functionality already in place that is in need for users to take a hands-on approach! Mac support will likely soon also see progress. However, I recommend waiting 2 weeks until the device abstraction has further consolidated (**breaking changes upcoming**).

Currently, you still need to compile from source, after checking out the `multi-backend-refactor` branch (instructions WIP, but [the current docs on the compilation from source](https://huggingface.co/docs/bitsandbytes/main/en/installation#compile-from-source) are a good starting point; [feel free to share tips / input in this Github discussion](https://github.com/TimDettmers/bitsandbytes/discussions/1219). We'll soon enable nightly releases to make this much easier for you!

Please give feedback to us in [this dedicated Github Discussion space](https://github.com/TimDettmers/bitsandbytes/discussions/categories/catch-all-alpha-testing-the-multi-backend-refactor)!

We're super excited about these recent developments and grateful for any constructive input or support that you can give to help us make this a reality. BNB is a community project and we're excited for your collaboration 🤗

## License

The majority of bitsandbytes is licensed under MIT, however small portions of the project are available under separate license terms, as the parts adapted from Pytorch are licensed under the BSD license.
`bitsandbytes` is MIT licensed.

We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization.
5 changes: 5 additions & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
[files]

[default]
extend-ignore-re = [
"@Ther-nul", # valid Github user
]

[default.extend-identifiers]

[type.py.extend-words]
Expand Down
15 changes: 13 additions & 2 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,17 @@
)
from .backends import register_backend
from .backends.cpu import CPUBackend
from .backends.npu import NPUBackend
from .cextension import lib
from .nn import modules

# NOTE: this is a temporary flag to allow outside libraries to employ conditional logic while the refactor is still in
# alpha/beta: sth like `if getattr(bitsandbytes, "is_multi_backend_refactor_preview", False): do sth`
# the getattr() call above would default to False and any string evaluates to True. This way we have temporary thing
# that we can remove in Transformers with the next release after the official BNB multi-platform release; then
# eventually making it the new default (e.g. just remove if statement and dedent in Transformers)
is_multi_backend_refactor_preview = "TO BE REMOVED ONCE MERGED TO `main`" # bool evals to True for str

# Always register the CPU backend.
register_backend("cpu", CPUBackend())

Expand Down Expand Up @@ -49,11 +57,14 @@

register_backend("xpu", XPUBackend())

# Register Ascend NPU backend, if available.
if hasattr(torch, "npu") and torch.npu.is_available():
register_backend("npu", NPUBackend())

# TODO: Other potential backends:
# XLA - Google TPU / PJRT runtime
# HPU - Habana / Intel Gaudi
# IPU - Graphcore
# NPU - Ascend
# Note that we may not map 1:1 with a device type, e.g. SYCL, XLA
# In this case, it will be up to each backend to dispatch as needed

Expand All @@ -63,4 +74,4 @@
"optim.optimizer.MockArgs": False,
}

__version__ = "0.43.2.dev"
__version__ = "0.43.3.dev"
7 changes: 4 additions & 3 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype

if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, B)
ctx.tensors = (None, B)
else:
ctx.tensors = (None, None)

Expand All @@ -537,7 +537,7 @@ def backward(ctx, grad_output):
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None

req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad
A, B = ctx.tensors
_, B = ctx.tensors

grad_A, grad_B, grad_bias = None, None, None

Expand Down Expand Up @@ -575,7 +575,8 @@ def matmul_4bit(
bias=None,
):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
# CPU backend does not require A to be a vector
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
20 changes: 17 additions & 3 deletions bitsandbytes/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

from .base import Backend
from .cpu_xpu_common import (
dequantize_4bit_impl,
double_quant_impl,
gemm_4bit_impl,
igemmlt_impl,
mm_dequant_impl,
quantize_4bit_impl,
)

Tensor = torch.Tensor
Expand Down Expand Up @@ -132,7 +135,11 @@ def quantize_4bit(
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError("Not yet implemented for CPU backend")
if blocksize is None:
blocksize = 64
assert_on_cpu([A, absmax, out])
assert quant_storage == torch.uint8, "CPU backend only supports uint8 quant_storage"
return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)

def dequantize_4bit(
self,
Expand All @@ -143,7 +150,10 @@ def dequantize_4bit(
blocksize: int = 64,
quant_type: Literal["fp4", "nf4"] = "fp4",
) -> torch.Tensor:
raise NotImplementedError("Not yet implemented for CPU backend")
if blocksize is None:
blocksize = 64
assert_on_cpu([A, absmax, out])
return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)

def gemv_4bit(
self,
Expand All @@ -154,7 +164,11 @@ def gemv_4bit(
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
raise NotImplementedError("Not yet implemented for CPU backend")
assert_on_cpu([A, B, out])
if state is None:
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")

return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)

def dequantize_blockwise(
self,
Expand Down
Loading
Loading