From 67e7ee3bc2b0b30ebe520f6e844a11ba5c76cc70 Mon Sep 17 00:00:00 2001 From: Steven Liu Date: Tue, 26 Mar 2024 10:06:07 -0700 Subject: [PATCH 01/24] first draft --- docs/source/fsdp_qlora.md | 106 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 docs/source/fsdp_qlora.md diff --git a/docs/source/fsdp_qlora.md b/docs/source/fsdp_qlora.md new file mode 100644 index 000000000..47922cfcc --- /dev/null +++ b/docs/source/fsdp_qlora.md @@ -0,0 +1,106 @@ +# FSDP-QLoRA + +FSDP-QLoRA combines data parallelism (FSDP enables sharding model parameters, optimizer states, and gradients across GPUs), 4-bit quantization, and LoRA to train LLMs up to 70B parameters on a dual 24GB GPU system. This technique was released by [Answer.AI](https://www.answer.ai/posts/2024-03-06-fsdp-qlora) in collaboration with bitsandbytes to make training LLMs more efficient and accessible for everyone. + +This guide provides a brief guide on how bitsandbytes supports storing quantized weights to enable FSDP-QLoRA, and how to run training with the Hugging Face libraries. + +> [!TIP] +> Other changes required for bitsandbytes to support FSDP-QLoRA, such as reconstructing the weights from the quantization metadata and preventing quantizing already quantized weights when they're moved from a CPU to GPU, are documented in this [Pull Request](https://github.com/TimDettmers/bitsandbytes/pull/970) and described in the [Enabling 70B Finetuning on Consumer GPUs](https://www.answer.ai/posts/2024-03-14-fsdp-qlora-deep-dive) blog post. We highly recommend reading these resources for a better understanding of FSDP-QLoRA! + +## Quantized data storage + +FSDP only supports sharding float data types which can be problematic because quantized weights are typically stored as integer data types (uint8). bitsandbytes doesn't have this problem because it uses `StoreChar` to read and write quantized weights regardless of the data type storage. This makes it simple to add a `quant_storage` parameter to the [`~nn.Linear4bit`] and [`~nn.Params4bit`] classes and set it to `torch.uint8` to maintain backward compatibility with the codebase. + +```py +import torch +import bitsandbytes as bnb + +model = bnb.nn.Linear4bit( + input_features, + output_features, + quant_type="fp4", + quant_storage=torch.uint8, +) +``` + +With the `quant_storage` parameter, you can select any of the FSDP supported data types to shard [`~nn.Linear4bit`] with such as bfloat16, float16 or float32. + +## Training + +bitsandbytes is deeply integrated with the Hugging Face ecosystem, making it easy to use with libraries like [Transformers](https://hf/co/docs/transformers), [PEFT](https://hf/co/docs/peft), and [TRL](https://hf/co/docs/trl). + +Before you begin, make sure you have the latest libraries installed. + +```bash +pip install -U bitsandbytes accelerate transformers peft trl +``` + +> [!TIP] +> PEFT provides a configuration file ([fsdp_config_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/fsdp_config_qlora.yaml)), launch command ([run_peft_qlora_fsdp.sh](https://github.com/huggingface/peft/blob/main/examples/sft/run_peft_qlora_fsdp.sh)), and training script ([train.py](https://github.com/huggingface/peft/blob/main/examples/sft/train.py)) for FSDP-QLoRA. To learn more, check out the [Use PEFT QLoRA and FSDP for finetuning large models on multiple GPUs](https://huggingface.co/docs/peft/main/en/accelerate/fsdp#use-peft-qlora-and-fsdp-for-finetuning-large-models-on-multiple-gpus) documentation. + +The important change that enables FSDP-QLoRA training is the `bnb_4bit_quant_storage` parameter in the [`~transformers.BitsAndBytesConfig`] class. This allows you to set the storage data type of the quantized weights to a float data type. + +```py +from transformers import BitsAndBytesConfig + +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_storage=torch.bfloat16, +) +``` + +Pass the [`~transformers.BitsAndBytesConfig`] to a model to set it up for FSDP-QLoRA. You should set the `torch_dtype` parameter to match `bnb_4bit_quant_storage` so that the [`~nn.Linear4bit`] layers are wrapped identically to the `Linear` layers. If the storage types do not match, then each [`~nn.Linear4bit`] layer is wrapped individually. + +```py +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-70b", + quantization_config=bnb_config, + torch_dtype=torch.bfloat16, +) +``` + +Configure the [`~peft.LoraConfig`] class for QLoRA training by setting `target_modules="all-linear"`. + +```py +from peft import LoraConfig + +peft_config = LoraConfig( + lora_alpha=16, + lora_dropout=0.1, + r=64, + bias="none", + task_type="CAUSAL_LM", + target_modules="all-linear", +) +``` + +Now you can pass everything to the [`~trl.SFTTrainer`] for training. + +```py +from trl import SFTTrainer + +trainer = SFTTrainer( + model=model, + train_dataset=dataset, + peft_config=peft_config, + dataset_text_field="text", + max_seq_length=max_seq_length, + tokenizer=tokenizer, + args=training_arguments, +) +trainer.train() +``` + +## Resources + +To learn more about FSDP and QLoRA, check out the following resources: + +- The [AnswerDotAI/fsdp_qlora](https://github.com/AnswerDotAI/fsdp_qlora) repository. +- The introductory [You can now train a 70b language model at home](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) blog post by Answer.AI. +- For an introduction to FSDP, read the [Introducing PyTorch Fully Sharded Data Parallel (FSDP) API](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api) blog post. +- For more details about QLoRA, take a look at the [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) blog post. From e3376abfd4f7923e3a66b13a8f039fbf21ae7f85 Mon Sep 17 00:00:00 2001 From: Steven Liu Date: Tue, 26 Mar 2024 11:01:11 -0700 Subject: [PATCH 02/24] toctree --- docs/source/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 2184cce8c..fdfe19ee4 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -12,6 +12,8 @@ title: 8-bit optimizers - local: algorithms title: Algorithms + - local: fsdp_qlora + title: FSDP-QLoRA - local: integrations title: Integrations - local: errors From c17fb8eb4f4b0139229beda0e109e9aab91af957 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Fri, 29 Mar 2024 11:34:09 -0400 Subject: [PATCH 03/24] Fix 4bit quantization with blocksize=4096 --- bitsandbytes/functional.py | 7 ++++--- csrc/ops.cu | 2 +- tests/test_functional.py | 28 +++++++++++++++++++++++----- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index bb6a04892..f915223ca 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1087,11 +1087,12 @@ def get_4bit_type(typename, device=None, blocksize=64): if data is None: raise NotImplementedError(f"Typename {typename} not supported") - data = Tensor(data) - data /= data.abs().max() + data = torch.tensor(data, device=device) + data.div_(data.abs().max()) + assert data.numel() == 16 - return data.to(device) + return data def quantize_fp4( diff --git a/csrc/ops.cu b/csrc/ops.cu index 796211fed..3a6ffdda8 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -58,7 +58,7 @@ template void quantizeBlockwise(floa num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; if(blocksize == 4096) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 2048) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 1024) diff --git a/tests/test_functional.py b/tests/test_functional.py index b9f1a6ead..1cca04511 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1928,7 +1928,9 @@ def test_bench_dequantization(): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) -def test_fp4_quant(dtype): +@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) +@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) +def test_4bit_quant(dtype, quant_type, blocksize): vals = list(product([0, 1], repeat=4)) code = {} @@ -1953,8 +1955,8 @@ def test_fp4_quant(dtype): code[idx] = result A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) - qa, SA = F.quantize_fp4(A1, blocksize=64) - A2 = F.dequantize_fp4(qa, SA) + qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) err = (A1 - A2).abs().float() relerr = (err / (A1.abs().float() + 1e-8)).mean() @@ -1962,8 +1964,24 @@ def test_fp4_quant(dtype): err = err.mean() assert A2.dtype == dtype - assert err.item() < 0.1 - assert relerr.item() < 0.28 + + # With larger block sizes, we can expect this to blow up. + # At blocksize>=1024, don't even bother looking at relerr. + if blocksize <= 64: + assert err.item() < 0.1 + assert relerr.item() < 0.28 + elif blocksize <= 256: + assert err.item() < 0.11 + assert relerr.item() < 0.30 + elif blocksize <= 512: + assert err.item() < 0.12 + assert relerr.item() < 0.31 + elif quant_type == "fp4": + # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56 + assert err.item() < 0.08 + math.log2(blocksize) * 4e-2 + else: + # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 + assert err.item() < math.log2(blocksize) * 8e-2 @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) From a471456911168b3ac798ff99967606013c71cc50 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Fri, 29 Mar 2024 11:35:54 -0400 Subject: [PATCH 04/24] fix formatting for install_cuda.py --- install_cuda.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/install_cuda.py b/install_cuda.py index a5d09356d..cf7c8ee71 100644 --- a/install_cuda.py +++ b/install_cuda.py @@ -77,9 +77,7 @@ def main(): download_path = "/tmp" # default download path if len(sys.argv) < 2: - print( - "Usage: python install_cuda.py [user/system] [download_path]" - ) + print("Usage: python install_cuda.py [user/system] [download_path]") sys.exit(1) version = sys.argv[1] @@ -100,9 +98,7 @@ def main(): elif version in cuda_versions: install_cuda(version, base_path, download_path) else: - print( - f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}" - ) + print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}") sys.exit(1) From 494de206ce029cf7d03a12eeb7d72368d04d7458 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Apr 2024 12:22:35 +0200 Subject: [PATCH 05/24] Bump the minor-patch group with 1 update (#1162) Bumps the minor-patch group with 1 update: [lion-pytorch](https://github.com/lucidrains/lion-pytorch). Updates `lion-pytorch` from 0.1.2 to 0.1.4 - [Release notes](https://github.com/lucidrains/lion-pytorch/releases) - [Commits](https://github.com/lucidrains/lion-pytorch/compare/0.1.2...0.1.4) --- updated-dependencies: - dependency-name: lion-pytorch dependency-type: direct:production update-type: version-update:semver-patch dependency-group: minor-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-ci.txt | 2 +- requirements-dev.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-ci.txt b/requirements-ci.txt index 61f92018a..4df975993 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,6 +1,6 @@ # Requirements used for GitHub actions pytest==8.1.1 einops==0.7.0 -lion-pytorch==0.1.2 +lion-pytorch==0.1.4 scipy==1.10.1; python_version < "3.9" scipy==1.12.0; python_version >= "3.9" diff --git a/requirements-dev.txt b/requirements-dev.txt index fc5449ba7..291a51cb1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ setuptools>=63 pytest~=8.1.1 einops~=0.7.0 wheel~=0.43.0 -lion-pytorch~=0.1.2 +lion-pytorch~=0.1.4 scipy~=1.12.0 pandas~=2.2.1 matplotlib~=3.8.3 From bed0860b8e11ea4a15d729e60f694c46eefe7fd4 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 2 Apr 2024 06:31:03 -0400 Subject: [PATCH 06/24] Tests: improve memory usage (#1147) --- tests/conftest.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 17ffd281c..59146963d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import gc + import pytest import torch @@ -20,6 +22,13 @@ def pytest_runtest_call(item): raise +@pytest.hookimpl(trylast=True) +def pytest_runtest_teardown(item, nextitem): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + @pytest.fixture(scope="session") def requires_cuda() -> bool: cuda_available = torch.cuda.is_available() From 2965c765a7d95de35484d374e2ce0159858010b3 Mon Sep 17 00:00:00 2001 From: Titus <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Tue, 2 Apr 2024 15:27:07 +0200 Subject: [PATCH 07/24] CHANGELOG.md: mention accuracy changes when quantizing post v0.42 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 397dceb77..b671145a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -357,6 +357,10 @@ Bug fixes: - Addressed a race condition in kEstimateQuantiles, enhancing the reliability of quantile estimation in concurrent environments (@pnunna93, #1061). - Fixed various minor issues, including typos in code comments and documentation, to improve code clarity and prevent potential confusion (@Brian Vaughan, #1063). +#### Backwards Compatibility +- After upgrading from `v0.42` to `v0.43`, when using 4bit quantization, models may generate slightly different outputs (approximately up to the 2nd decimal place) due to a fix in the code. For anyone interested in the details, [see this comment](https://github.com/TimDettmers/bitsandbytes/discussions/1094#discussioncomment-8984069). + + #### Internal and Build System Enhancements: - Implemented several enhancements to the internal and build systems, including adjustments to the CI workflows, portability improvements, and build artifact management. These changes contribute to a more robust and flexible development process, ensuring the library's ongoing quality and maintainability (@rickardp, @akx, @wkpark, @matthewdouglas; #949, #1053, #1045, #1037). From bfe21182631e8f9575e4b992e70719e01c256901 Mon Sep 17 00:00:00 2001 From: Titus <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 4 Apr 2024 19:31:40 +0200 Subject: [PATCH 08/24] README: include download badges --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 43eadf5a3..2cf630dcb 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # `bitsandbytes` +[![Downloads](https://static.pepy.tech/badge/bitsandbytes)](https://pepy.tech/project/bitsandbytes) [![Downloads](https://static.pepy.tech/badge/bitsandbytes/month)](https://pepy.tech/project/bitsandbytes) [![Downloads](https://static.pepy.tech/badge/bitsandbytes/week)](https://pepy.tech/project/bitsandbytes) + The `bitsandbytes` library is a lightweight Python wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and 8 & 4-bit quantization functions. The library includes quantization primitives for 8-bit & 4-bit operations, through `bitsandbytes.nn.Linear8bitLt` and `bitsandbytes.nn.Linear4bit` and 8-bit optimizers through `bitsandbytes.optim` module. From b2a85a434802a08a6b2b876832234037cd3222a1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Apr 2024 14:30:08 +0000 Subject: [PATCH 09/24] Update matplotlib requirement from ~=3.8.3 to ~=3.8.4 in the major group Updates the requirements on [matplotlib](https://github.com/matplotlib/matplotlib) to permit the latest version. Updates `matplotlib` to 3.8.4 - [Release notes](https://github.com/matplotlib/matplotlib/releases) - [Commits](https://github.com/matplotlib/matplotlib/compare/v3.8.3...v3.8.4) --- updated-dependencies: - dependency-name: matplotlib dependency-type: direct:development dependency-group: major ... Signed-off-by: dependabot[bot] --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 291a51cb1..2c4f3e35a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,4 +6,4 @@ wheel~=0.43.0 lion-pytorch~=0.1.4 scipy~=1.12.0 pandas~=2.2.1 -matplotlib~=3.8.3 +matplotlib~=3.8.4 From c0ad874a6ec867312cb262fe577b537ca1733f9a Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 8 Apr 2024 17:14:18 -0400 Subject: [PATCH 10/24] Build workflow: Add CUDA 12.4 to build matrix --- .github/workflows/python-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index ba5961f72..17e1618a7 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -63,7 +63,7 @@ jobs: os: [ubuntu-latest, windows-latest] arch: [x86_64, aarch64] cuda_version: - ["11.7.1", "11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2"] + ["11.7.1", "11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.0"] exclude: - os: windows-latest # This probably requires arm64 Windows agents arch: aarch64 From ebac8625b333af2ddcfc65aad85ddc7b2c433cee Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 8 Apr 2024 17:27:35 -0400 Subject: [PATCH 11/24] Exclude Windows from CUDA 12.4.0 build for now --- .github/workflows/python-package.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 17e1618a7..72e1b099a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -67,6 +67,8 @@ jobs: exclude: - os: windows-latest # This probably requires arm64 Windows agents arch: aarch64 + - os: windows-latest # The Jimver/cuda-toolkit is action used for Windows builds is not updated for 12.4 yet. + cuda_version: "12.4.0" - os: ubuntu-latest # Temporary. Takes too long, not ready yet. arch: aarch64 runs-on: ${{ matrix.os }} # One day, we could run them on native agents. Azure supports this now but it's planned only for Q3 2023 for hosted agents From 6be3d0f98653b6d7fca97e5ea57fe173bdf0171b Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Tue, 9 Apr 2024 02:51:19 -0700 Subject: [PATCH 12/24] [docs] Install from source (#1149) * split build from source off * validated compilers --- docs/source/installation.mdx | 43 +++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index d0dd7ba76..caf22488f 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -1,9 +1,17 @@ # Installation -bitsandbytes is only supported on CUDA GPUs for CUDA versions **11.0 - 12.3**. Select your operating system below to see the installation instructions. +bitsandbytes is only supported on CUDA GPUs for CUDA versions **11.0 - 12.3**. - - +The latest version of bitsandbytes (v0.43.0) builds on: + +| OS | CUDA | Compiler | +|---|---|---| +| Linux | 11.7 - 12.3 | GCC 11.4 | +| | 12.4+ | GCC 13.2 | +| Windows | 11.7 - 12.4 | MSVC 19.38+ (VS2022 17.8.0+) | + +> [!TIP] +> MacOS support is still a work in progress! Subscribe to this [issue](https://github.com/TimDettmers/bitsandbytes/issues/1020) to get notified about discussions and to track the integration progress. For Linux systems, make sure your hardware meets the following requirements to use bitsandbytes features. @@ -23,13 +31,26 @@ pip install bitsandbytes ## Compile from source +For Linux and Windows systems, you can compile bitsandbytes from source. Installing from source allows for more build options with different CMake configurations. + + + + To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. Make sure you have a compiler installed to compile C++ (gcc, make, headers, etc.). For example, to install a compiler and CMake on Ubuntu: ```bash apt-get install -y build-essential cmake ``` -You should also install CUDA Toolkit by following the [NVIDIA CUDA Installation Guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) guide from NVIDIA. +You should also install CUDA Toolkit by following the [NVIDIA CUDA Installation Guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) guide from NVIDIA. The current expected CUDA Toolkit version is **11.1+** and it is recommended to install **GCC >= 7.3** and required to have at least **GCC >= 6**. + +Refer to the following table if you're using another CUDA Toolkit version. + +| CUDA Toolkit | GCC | +|---|---| +| >= 11.4.1 | >= 11 | +| >= 12.0 | >= 12 | +| >= 12.4 | >= 13 | Now to install the bitsandbytes package from source, run the following commands: @@ -49,7 +70,13 @@ pip install . Windows systems require Visual Studio with C++ support as well as an installation of the CUDA SDK. -You'll need to build bitsandbytes from source. To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA. +To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA. + +Refer to the following table if you're using another CUDA Toolkit version. + +| CUDA Toolkit | MSVC | +|---|---| +| >= 11.6 | 19.30+ (VS2022) | ```bash git clone https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ @@ -61,12 +88,6 @@ python -m build --wheel Big thanks to [wkpark](https://github.com/wkpark), [Jamezo97](https://github.com/Jamezo97), [rickardp](https://github.com/rickardp), [akx](https://github.com/akx) for their amazing contributions to make bitsandbytes compatible with Windows. - - - -> [!TIP] -> MacOS support is still a work in progress! Subscribe to this [issue](https://github.com/TimDettmers/bitsandbytes/issues/1020) to get notified about discussions and to track the integration progress. - From c54053d303a8dffddb5c6530a6523645529f8883 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 Apr 2024 19:26:32 +0200 Subject: [PATCH 13/24] Bump scipy from 1.12.0 to 1.13.0 in the minor-patch group (#1170) Bumps the minor-patch group with 1 update: [scipy](https://github.com/scipy/scipy). Updates `scipy` from 1.12.0 to 1.13.0 - [Release notes](https://github.com/scipy/scipy/releases) - [Commits](https://github.com/scipy/scipy/compare/v1.12.0...v1.13.0) --- updated-dependencies: - dependency-name: scipy dependency-type: direct:production update-type: version-update:semver-minor dependency-group: minor-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-ci.txt | 2 +- requirements-dev.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-ci.txt b/requirements-ci.txt index 4df975993..906c6643e 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -3,4 +3,4 @@ pytest==8.1.1 einops==0.7.0 lion-pytorch==0.1.4 scipy==1.10.1; python_version < "3.9" -scipy==1.12.0; python_version >= "3.9" +scipy==1.13.0; python_version >= "3.9" diff --git a/requirements-dev.txt b/requirements-dev.txt index 2c4f3e35a..4ee8bf543 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,6 +4,6 @@ pytest~=8.1.1 einops~=0.7.0 wheel~=0.43.0 lion-pytorch~=0.1.4 -scipy~=1.12.0 +scipy~=1.13.0 pandas~=2.2.1 matplotlib~=3.8.4 From 7449d713eb65caa186ed0f8d6c763b58bb3e61f9 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 10 Apr 2024 10:50:01 +0200 Subject: [PATCH 14/24] [`Core`] Change 8-bit serialization weight format format (#1164) * change 8-bit serialization weight format format * precimmit * pre-commit * fix * Update bitsandbytes/nn/modules.py Co-authored-by: Aarni Koskela * Update bitsandbytes/nn/modules.py Co-authored-by: Aarni Koskela * Update bitsandbytes/utils.py Co-authored-by: Aarni Koskela * address feedback * lint --------- Co-authored-by: Aarni Koskela --- bitsandbytes/nn/modules.py | 29 +++++++++++++++++++++++++---- bitsandbytes/utils.py | 4 ++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ec14e5940..24a155ab1 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -14,7 +14,11 @@ from bitsandbytes.autograd._functions import get_tile_inds, undo_layout from bitsandbytes.functional import QuantState from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import OutlierTracer +from bitsandbytes.utils import ( + INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, + LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, + OutlierTracer, +) T = TypeVar("T", bound="torch.nn.Module") @@ -619,6 +623,16 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k return weight_format = state_dict.pop(f"{prefix}weight_format", "row") + if isinstance(weight_format, torch.Tensor): + weight_format = weight_format.item() + + # For new weights format storage type, we explicitly check + # if weights_format is on the mapping + if isinstance(weight_format, int) and weight_format not in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING: + raise ValueError(f"Expected supported weight format - got {weight_format}") + elif isinstance(weight_format, int) and weight_format in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING: + weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format] + if weight_format != "row": tile_indices = get_tile_inds(weight_format, weight.device) state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices) @@ -711,13 +725,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): if not self.state.has_fp16_weights: if param_from_weight is not None: destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() - destination[format_name] = "row" + destination[format_name] = torch.tensor(0, dtype=torch.uint8) elif param_from_state is not None and not layout_reordered: destination[key_name] = param_from_state if keep_vars else param_from_state.detach() - destination[format_name] = "row" + destination[format_name] = torch.tensor(0, dtype=torch.uint8) elif param_from_state is not None: destination[key_name] = param_from_state if keep_vars else param_from_state.detach() - destination[format_name] = self.state.formatB + weights_format = self.state.formatB + # At this point `weights_format` is an str + if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING: + raise ValueError(f"Unrecognized weights format {weights_format}") + + weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format] + + destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8) def _load_from_state_dict( self, diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 0229e59e2..a88ddf5f9 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -198,3 +198,7 @@ def unpack_tensor_to_dict(tensor_data): unpacked_dict = json.loads(json_str) return unpacked_dict + + +LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3} +INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()} From d62516f290fb529a69bc2fda767b2a87bfd9d72f Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 10 Apr 2024 22:10:53 -0400 Subject: [PATCH 15/24] (backends) Stub out additional backends; move more functions to backend interface --- bitsandbytes/__init__.py | 46 ++++- bitsandbytes/backends/base.py | 185 +++++++++++++++--- bitsandbytes/backends/cpu.py | 164 ++++++++++++++++ bitsandbytes/backends/cuda.py | 343 +++++++++++++++++++++++++++++++++- bitsandbytes/backends/mps.py | 164 ++++++++++++++++ bitsandbytes/backends/rocm.py | 12 ++ bitsandbytes/backends/xpu.py | 164 ++++++++++++++++ bitsandbytes/functional.py | 303 +++++------------------------- 8 files changed, 1085 insertions(+), 296 deletions(-) create mode 100644 bitsandbytes/backends/cpu.py create mode 100644 bitsandbytes/backends/mps.py create mode 100644 bitsandbytes/backends/rocm.py create mode 100644 bitsandbytes/backends/xpu.py diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 019a4f6ab..fcc31b220 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import torch + from . import research, utils from .autograd._functions import ( MatmulLtState, @@ -12,15 +14,49 @@ matmul_cublas, mm_cublas, ) +from .backends import register_backend +from .backends.cpu import CPUBackend from .cextension import lib from .nn import modules -if lib and lib.compiled_with_cuda: - from .backends import register_backend - from .backends.cuda import CUDABackend - from .optim import adam +# Always register the CPU backend. +register_backend("cpu", CPUBackend()) + +# Register either CUDA or ROCm backend, if available. +# Only one of these backends can be used at a time, since the torch.device semantics are +# the same for both torch+rocm and torch+cuda (e.g. device name is "cuda") +if torch.cuda.is_available(): + # TODO: Consider deferring loading of cextension - should backend class implement that? + + if torch.version.cuda: + from .backends.cuda import CUDABackend + + register_backend("cuda", CUDABackend()) + elif torch.version.hip: + from .backends.rocm import ROCmBackend + + register_backend("cuda", ROCmBackend()) + +# Register MPS backend, if available. +if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + from .backends.mps import MPSBackend + + register_backend("mps", MPSBackend()) + +# Register Intel XPU backend, if available. +if hasattr(torch, "xpu") and torch.xpu.is_available(): + from .backends.xpu import XPUBackend + + register_backend("xpu", XPUBackend()) + +# TODO: Other potential backends: +# XLA - Google TPU / PJRT runtime +# HPU - Habana / Intel Gaudi +# IPU - Graphcore +# NPU - Ascend +# Note that we may not map 1:1 with a device type, e.g. SYCL, XLA +# In this case, it will be up to each backend to dispatch as needed - register_backend("cuda", CUDABackend()) __pdoc__ = { "libbitsandbytes": False, "optim.optimizer.Optimizer8bit": False, diff --git a/bitsandbytes/backends/base.py b/bitsandbytes/backends/base.py index 8232d17c1..2e73c3d6e 100644 --- a/bitsandbytes/backends/base.py +++ b/bitsandbytes/backends/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple, Union import torch @@ -12,11 +12,11 @@ class Backend(ABC): @abstractmethod def double_quant( self, - A, - col_stats=None, - row_stats=None, - out_col=None, - out_row=None, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, threshold=0.0, ): raise NotImplementedError @@ -24,36 +24,50 @@ def double_quant( @abstractmethod def transform( self, - A, - to_order, + A: torch.Tensor, + to_order: str, from_order="row", - out=None, + out: Optional[torch.Tensor] = None, transpose=False, - state=None, + state: Optional[Tuple[torch.Size, str]] = None, ld=None, ): raise NotImplementedError @abstractmethod - def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + def igemmlt( + self, + A: torch.Tensor, + B: torch.Tensor, + SA: Tuple[torch.Size, str], + SB: Tuple[torch.Size, str], + out: Optional[torch.Tensor] = None, + Sout: Optional[Tuple[torch.Size, str]] = None, + dtype=torch.int32, + ) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]: raise NotImplementedError @abstractmethod def mm_dequant( self, - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None, - ): + A: torch.Tensor, + quant_state: Tuple[torch.Size, str], + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + new_row_stats: Optional[torch.Tensor] = None, + new_col_stats: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: raise NotImplementedError @abstractmethod - def extract_outliers(self, A, SA, idx): + def extract_outliers( + self, + A: torch.Tensor, + SA: Tuple[torch.Size, str], + idx: torch.Tensor, + ) -> torch.Tensor: raise NotImplementedError @abstractmethod @@ -64,7 +78,7 @@ def quantize_4bit( out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, - quant_type="fp4", + quant_type: Literal["fp4", "nf4"] = "fp4", quant_storage=torch.uint8, ) -> Tuple[torch.Tensor, QuantState]: """ @@ -102,7 +116,7 @@ def dequantize_4bit( absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, - quant_type="fp4", + quant_type: Literal["fp4", "nf4"] = "fp4", ) -> torch.Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -131,3 +145,128 @@ def dequantize_4bit( Dequantized tensor. """ raise NotImplementedError + + @abstractmethod + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + @abstractmethod + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + """ + Performs an in-place optimizer update with one or two optimizer states. + + Args: + optimizer_name (`str`): The name of the optimizer, e.g. `adam` + g (`torch.Tensor`): Gradient tensor. + p (`torch.Tensor`): Parameter tensor. + state1 (`torch.Tensor`): Optimizer state 1. + state2 (`torch.Tensor`, optional): Optimizer state 2. + beta1 (`float`): Optimizer beta1. + beta2 (`float`): Optimizer beta2. + eps (`float`): Optimizer epsilon. + step (`int`): Current optimizer step. + lr (`float`): The learning rate. + qmap1 (`torch.Tensor`): Quantization map for the first state. + qmap2 (`torch.Tensor`, optional): Quantization map for the second state. + absmax1 (`torch.Tensor`): Max value for the first state update. + absmax2 (`torch.Tensor`, optional): Max value for the second state update. + weight_decay (`float`, optional): Weight decay. Defaults to 0.0. + gnorm_scale (`float`, optional): The factor to rescale the gradient to the max clip value. Defaults to 1.0. + skip_zeros (`bool`, optional): Whether to skip zero-valued gradients or not. Defaults to False. + """ + raise NotImplementedError + + @abstractmethod + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + """ + Performs an in-place optimizer update with one or two optimizer states. + + Universal optimizer update for 32-bit state and 32/16-bit gradients/weights + + Args: + optimizer_name (`str`): The name of the optimizer, e.g. `adam` + g (`torch.Tensor`): Gradient tensor. + p (`torch.Tensor`): Parameter tensor. + state1 (`torch.Tensor`): Optimizer state 1. + beta1 (`float`): Optimizer beta1. + eps (`float`): Optimizer epsilon. + step (`int`): Current optimizer step. + lr (`float`): The learning rate. + state2 (`torch.Tensor`, optional): Optimizer state 2. Defaults to None. + beta2 (`float`, optional): Optimizer beta2. Defaults to 0.0. + weight_decay (`float`, optional): Defaults to 0.0. + gnorm_scale (`float`, optional): The factor to rescale the gradient to the max clip value. Defaults to 1.0. + unorm_vec (`torch.Tensor`, optional): The tensor for the update norm. Defaults to None. + max_unorm (`float`, optional): The maximum update norm relative to the weight norm.. Defaults to 0.0. + skip_zeros (`bool`, optional): Whether to skip zero-valued gradients or not. Defaults to False. + """ + raise NotImplementedError diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py new file mode 100644 index 000000000..830ebfadd --- /dev/null +++ b/bitsandbytes/backends/cpu.py @@ -0,0 +1,164 @@ +from typing import Literal, Optional, Tuple, Union + +import torch + +from bitsandbytes.utils import QuantState + +from .base import Backend + + +class CPUBackend(Backend): + def double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ): + raise NotImplementedError + + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): + raise NotImplementedError + + def igemmlt( + self, + A: torch.Tensor, + B: torch.Tensor, + SA: Tuple[torch.Size, str], + SB: Tuple[torch.Size, str], + out: Optional[torch.Tensor] = None, + Sout: Optional[Tuple[torch.Size, str]] = None, + dtype=torch.int32, + ) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]: + raise NotImplementedError + + def mm_dequant( + self, + A: torch.Tensor, + quant_state: Tuple[torch.Size, str], + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + new_row_stats: Optional[torch.Tensor] = None, + new_col_stats: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def extract_outliers( + self, + A: torch.Tensor, + SA: Tuple[torch.Size, str], + idx: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type: Literal["fp4", "nf4"] = "fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type: Literal["fp4", "nf4"] = "fp4", + ) -> torch.Tensor: + raise NotImplementedError + + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ) -> torch.Tensor: + raise NotImplementedError + + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + raise NotImplementedError + + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError + + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index c76bcaebd..93755b05f 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -1,5 +1,5 @@ import ctypes as ct -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple import torch @@ -23,9 +23,69 @@ from .base import Backend +if lib and lib.compiled_with_cuda: + """C FUNCTIONS FOR OPTIMIZERS""" + str2optimizer32bit = { + "adam": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + lib.cadam32bit_grad_bf16, + ), + "momentum": ( + lib.cmomentum32bit_grad_32, + lib.cmomentum32bit_grad_16, + ), + "rmsprop": ( + lib.crmsprop32bit_grad_32, + lib.crmsprop32bit_grad_16, + ), + "lion": ( + lib.clion32bit_grad_fp32, + lib.clion32bit_grad_fp16, + lib.clion32bit_grad_bf16, + ), + "adagrad": ( + lib.cadagrad32bit_grad_32, + lib.cadagrad32bit_grad_16, + ), + } + + str2optimizer8bit_blockwise = { + "adam": ( + lib.cadam_8bit_blockwise_grad_fp32, + lib.cadam_8bit_blockwise_grad_fp16, + lib.cadam_8bit_blockwise_grad_bf16, + ), + "momentum": ( + lib.cmomentum_8bit_blockwise_grad_fp32, + lib.cmomentum_8bit_blockwise_grad_fp16, + ), + "rmsprop": ( + lib.crmsprop_8bit_blockwise_grad_fp32, + lib.crmsprop_8bit_blockwise_grad_fp16, + ), + "lion": ( + lib.clion_8bit_blockwise_grad_fp32, + lib.clion_8bit_blockwise_grad_fp16, + lib.clion_8bit_blockwise_grad_bf16, + ), + "adagrad": ( + lib.cadagrad_8bit_blockwise_grad_fp32, + lib.cadagrad_8bit_blockwise_grad_fp16, + ), + } + class CUDABackend(Backend): - def double_quant(self, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): + def double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ): device = A.device assert A.dtype == torch.half assert device.type == "cuda" @@ -114,7 +174,16 @@ def double_quant(self, A, col_stats=None, row_stats=None, out_col=None, out_row= return out_row, out_col, row_stats, col_stats, coo_tensor - def transform(self, A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): prev_device = pre_call(A.device) if state is None: state = (A.shape, from_order) @@ -166,7 +235,16 @@ def transform(self, A, to_order, from_order="row", out=None, transpose=False, st return out, new_state - def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + def igemmlt( + self, + A: torch.Tensor, + B: torch.Tensor, + SA: Tuple[torch.Size, str], + SB: Tuple[torch.Size, str], + out: Optional[torch.Tensor] = None, + Sout: Optional[Tuple[torch.Size, str]] = None, + dtype=torch.int32, + ): shapeA = SA[0] shapeB = SB[0] dimsA = len(shapeA) @@ -260,7 +338,15 @@ def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): return out, Sout def mm_dequant( - self, A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None + self, + A: torch.Tensor, + quant_state: Tuple[torch.Size, str], + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + new_row_stats: Optional[torch.Tensor] = None, + new_col_stats: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, ): assert A.dtype == torch.int32 if bias is not None: @@ -297,7 +383,7 @@ def mm_dequant( return out - def extract_outliers(self, A, SA, idx): + def extract_outliers(self, A: torch.Tensor, SA: Tuple[torch.Size, str], idx: torch.Tensor): shapeA = SA[0] formatA = SA[1] assert formatA in ["col_turing", "col_ampere"] @@ -330,7 +416,7 @@ def quantize_4bit( out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, - quant_type="fp4", + quant_type: Literal["fp4", "nf4"] = "fp4", quant_storage=torch.uint8, ) -> Tuple[torch.Tensor, QuantState]: if A.device.type != "cuda": @@ -395,7 +481,7 @@ def quantize_4bit( if compress_statistics: offset = absmax.mean() absmax -= offset - qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) + qabsmax, state2 = self.quantize_blockwise(absmax, blocksize=256) del absmax state = QuantState( absmax=qabsmax, @@ -422,7 +508,7 @@ def dequantize_4bit( absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, - quant_type="fp4", + quant_type: Literal["fp4", "nf4"] = "fp4", ) -> torch.Tensor: if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: raise ValueError( @@ -442,7 +528,7 @@ def dequantize_4bit( absmax = quant_state.absmax if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax = self.dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset if absmax.dtype != torch.float32: absmax = absmax.float() @@ -526,3 +612,240 @@ def dequantize_4bit( return out.t() else: return out + + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ): + prev_device = pre_call(A.device) + + if state is None: + raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") + + if A.numel() != A.shape[-1]: + raise ValueError( + 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]', + ) + + Bshape = state.shape + bout = Bshape[0] + absmax = state.absmax + if state.nested: + absmax = self.dequantize_blockwise(state.absmax, state.state2) + absmax += state.offset + + if out is None: + if len(A.shape) == 3: + out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) + else: + out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) + + n = 1 + m = Bshape[0] + k = Bshape[1] + lda = Bshape[0] + ldc = Bshape[0] + ldb = (A.shape[-1] + 1) // 2 + is_on_gpu([B, A, out, absmax, state.code]) + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + + inference_args = [ + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ] + + if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16(*inference_args) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16(*inference_args) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32(*inference_args) + else: + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") + + else: + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") + + post_call(prev_device) + + return out + + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + # TODO: Move from bnb.functional + return dequantize_blockwise( + A, + quant_state=quant_state, + absmax=absmax, + code=code, + out=out, + blocksize=blocksize, + nested=nested, + ) + + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + # TODO: Move from bnb.functional + return quantize_blockwise( + A, + absmax=absmax, + code=code, + out=out, + blocksize=blocksize, + nested=nested, + ) + + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + optim_func = None + prev_device = pre_call(g.device) + is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) + if g.dtype == torch.float32 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] + elif g.dtype == torch.float16 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) + post_call(prev_device) + + is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) + + prev_device = pre_call(g.device) + optim_func( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) + post_call(prev_device) + + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + param_norm = 0.0 + if max_unorm > 0.0: + param_norm = torch.norm(p.data.float()) + + optim_func = None + if g.dtype == torch.float32: + optim_func = str2optimizer32bit[optimizer_name][0] + elif g.dtype == torch.float16: + optim_func = str2optimizer32bit[optimizer_name][1] + elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3: + optim_func = str2optimizer32bit[optimizer_name][2] + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) + + is_on_gpu([g, p, state1, state2, unorm_vec]) + prev_device = pre_call(g.device) + optim_func( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) + post_call(prev_device) diff --git a/bitsandbytes/backends/mps.py b/bitsandbytes/backends/mps.py new file mode 100644 index 000000000..5b7eda0c7 --- /dev/null +++ b/bitsandbytes/backends/mps.py @@ -0,0 +1,164 @@ +from typing import Literal, Optional, Tuple, Union + +import torch + +from bitsandbytes.utils import QuantState + +from .base import Backend + + +class MPSBackend(Backend): + def double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ): + raise NotImplementedError + + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): + raise NotImplementedError + + def igemmlt( + self, + A: torch.Tensor, + B: torch.Tensor, + SA: Tuple[torch.Size, str], + SB: Tuple[torch.Size, str], + out: Optional[torch.Tensor] = None, + Sout: Optional[Tuple[torch.Size, str]] = None, + dtype=torch.int32, + ) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]: + raise NotImplementedError + + def mm_dequant( + self, + A: torch.Tensor, + quant_state: Tuple[torch.Size, str], + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + new_row_stats: Optional[torch.Tensor] = None, + new_col_stats: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def extract_outliers( + self, + A: torch.Tensor, + SA: Tuple[torch.Size, str], + idx: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type: Literal["fp4", "nf4"] = "fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type: Literal["fp4", "nf4"] = "fp4", + ) -> torch.Tensor: + raise NotImplementedError + + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ) -> torch.Tensor: + raise NotImplementedError + + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + raise NotImplementedError + + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError + + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError diff --git a/bitsandbytes/backends/rocm.py b/bitsandbytes/backends/rocm.py new file mode 100644 index 000000000..d74f10ead --- /dev/null +++ b/bitsandbytes/backends/rocm.py @@ -0,0 +1,12 @@ +from .cuda import CUDABackend + + +class ROCmBackend(CUDABackend): + """ + Backend for AMD ROCm implementation. + + The interface is largely the same as the CUDA implementation, so only any + differences need to be implemented here. + """ + + pass diff --git a/bitsandbytes/backends/xpu.py b/bitsandbytes/backends/xpu.py new file mode 100644 index 000000000..3976c4d5a --- /dev/null +++ b/bitsandbytes/backends/xpu.py @@ -0,0 +1,164 @@ +from typing import Literal, Optional, Tuple, Union + +import torch + +from bitsandbytes.utils import QuantState + +from .base import Backend + + +class XPUBackend(Backend): + def double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ): + raise NotImplementedError + + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): + raise NotImplementedError + + def igemmlt( + self, + A: torch.Tensor, + B: torch.Tensor, + SA: Tuple[torch.Size, str], + SB: Tuple[torch.Size, str], + out: Optional[torch.Tensor] = None, + Sout: Optional[Tuple[torch.Size, str]] = None, + dtype=torch.int32, + ) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]: + raise NotImplementedError + + def mm_dequant( + self, + A: torch.Tensor, + quant_state: Tuple[torch.Size, str], + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + new_row_stats: Optional[torch.Tensor] = None, + new_col_stats: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def extract_outliers( + self, + A: torch.Tensor, + SA: Tuple[torch.Size, str], + idx: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type: Literal["fp4", "nf4"] = "fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type: Literal["fp4", "nf4"] = "fp4", + ) -> torch.Tensor: + raise NotImplementedError + + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ) -> torch.Tensor: + raise NotImplementedError + + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + raise NotImplementedError + + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError + + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6bb02944d..a180cf0ce 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -27,31 +27,6 @@ def prod(iterable): if lib and lib.compiled_with_cuda: """C FUNCTIONS FOR OPTIMIZERS""" - str2optimizer32bit = { - "adam": ( - lib.cadam32bit_grad_fp32, - lib.cadam32bit_grad_fp16, - lib.cadam32bit_grad_bf16, - ), - "momentum": ( - lib.cmomentum32bit_grad_32, - lib.cmomentum32bit_grad_16, - ), - "rmsprop": ( - lib.crmsprop32bit_grad_32, - lib.crmsprop32bit_grad_16, - ), - "lion": ( - lib.clion32bit_grad_fp32, - lib.clion32bit_grad_fp16, - lib.clion32bit_grad_bf16, - ), - "adagrad": ( - lib.cadagrad32bit_grad_32, - lib.cadagrad32bit_grad_16, - ), - } - str2optimizer8bit = { "adam": ( lib.cadam_static_8bit_grad_32, @@ -79,31 +54,6 @@ def prod(iterable): ), } - str2optimizer8bit_blockwise = { - "adam": ( - lib.cadam_8bit_blockwise_grad_fp32, - lib.cadam_8bit_blockwise_grad_fp16, - lib.cadam_8bit_blockwise_grad_bf16, - ), - "momentum": ( - lib.cmomentum_8bit_blockwise_grad_fp32, - lib.cmomentum_8bit_blockwise_grad_fp16, - ), - "rmsprop": ( - lib.crmsprop_8bit_blockwise_grad_fp32, - lib.crmsprop_8bit_blockwise_grad_fp16, - ), - "lion": ( - lib.clion_8bit_blockwise_grad_fp32, - lib.clion_8bit_blockwise_grad_fp16, - lib.clion_8bit_blockwise_grad_bf16, - ), - "adagrad": ( - lib.cadagrad_8bit_blockwise_grad_fp32, - lib.cadagrad_8bit_blockwise_grad_fp16, - ), - } - class GlobalPageManager: _instance = None @@ -1167,82 +1117,24 @@ def optimizer_update_32bit( max_unorm: float = 0.0, skip_zeros=False, ) -> None: - """ - Performs an inplace optimizer update with one or two optimizer states. - - Universal optimizer update for 32-bit state and 32/16-bit gradients/weights. - - Parameters - ---------- - optimizer_name : str - The name of the optimizer: {adam}. - g : torch.Tensor - Gradient tensor. - p : torch.Tensor - Parameter tensor. - state1 : torch.Tensor - Optimizer state 1. - beta1 : float - Optimizer beta1. - eps : float - Optimizer epsilon. - weight_decay : float - Weight decay. - step : int - Current optimizer step. - lr : float - The learning rate. - state2 : torch.Tensor - Optimizer state 2. - beta2 : float - Optimizer beta2. - gnorm_scale : float - The factor to rescale the gradient to the max clip value. - unorm_vec : torch.Tensor - The tensor for the update norm. - max_unorm : float - The maximum update norm relative to the weight norm. - skip_zeros : bool - Whether to skip zero-valued gradients or not (default: False). - """ - - param_norm = 0.0 - if max_unorm > 0.0: - param_norm = torch.norm(p.data.float()) - - optim_func = None - if g.dtype == torch.float32: - optim_func = str2optimizer32bit[optimizer_name][0] - elif g.dtype == torch.float16: - optim_func = str2optimizer32bit[optimizer_name][1] - elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3: - optim_func = str2optimizer32bit[optimizer_name][2] - else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", - ) - - is_on_gpu([g, p, state1, state2, unorm_vec]) - prev_device = pre_call(g.device) - optim_func( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), + ensure_backend_is_available(g.device.type) + return backends[g.device.type].optimizer_update_32bit( + optimizer_name=optimizer_name, + g=g, + p=p, + state1=state1, + beta1=beta1, + eps=eps, + step=step, + lr=lr, + state2=state2, + beta2=beta2, + weight_decay=weight_decay, + gnorm_scale=gnorm_scale, + unorm_vec=unorm_vec, + max_unorm=max_unorm, + skip_zeros=skip_zeros, ) - post_call(prev_device) def optimizer_update_8bit( @@ -1397,48 +1289,26 @@ def optimizer_update_8bit_blockwise( gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: - optim_func = None - prev_device = pre_call(g.device) - is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) - if g.dtype == torch.float32 and state1.dtype == torch.uint8: - optim_func = str2optimizer8bit_blockwise[optimizer_name][0] - elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif ( - g.dtype == torch.bfloat16 - and state1.dtype == torch.uint8 - and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 - ): - optim_func = str2optimizer8bit_blockwise[optimizer_name][2] - else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", - ) - post_call(prev_device) - - is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) - - prev_device = pre_call(g.device) - optim_func( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), + ensure_backend_is_available(g.device.type) + return backends[g.device.type].optimizer_update_8bit_blockwise( + optimizer_name=optimizer_name, + g=g, + p=p, + state1=state1, + state2=state2, + beta1=beta1, + beta2=beta2, + eps=eps, + step=step, + lr=lr, + qmap1=qmap1, + qmap2=qmap2, + absmax1=absmax1, + absmax2=absmax2, + weight_decay=weight_decay, + gnorm_scale=gnorm_scale, + skip_zeros=skip_zeros, ) - post_call(prev_device) def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5): @@ -1593,98 +1463,15 @@ def gemv_4bit( transposed_B=False, state=None, ): - prev_device = pre_call(A.device) - # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) - if state is None: - raise ValueError("state cannot None. gem_4bit( ) requires the state from quantize_4bit( )") - - if A.numel() != A.shape[-1]: - raise ValueError( - 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]', - ) - - Bshape = state.shape - bout = Bshape[0] - absmax = state.absmax - if state.nested: - absmax = dequantize_blockwise(state.absmax, state.state2) - absmax += state.offset - - if out is None: - if len(A.shape) == 3: - out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) - else: - out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) - - n = 1 - m = Bshape[0] - k = Bshape[1] - lda = Bshape[0] - ldc = Bshape[0] - ldb = (A.shape[-1] + 1) // 2 - is_on_gpu([B, A, out, absmax, state.code]) - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - - if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - ) - else: - raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") - - else: - raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") - - post_call(prev_device) - - return out + ensure_backend_is_available(A.device.type) + return backends[A.device.type].gemv_4bit( + A, + B, + out=out, + transposed_A=transposed_A, + transposed_B=transposed_B, + state=state, + ) def igemm( From 4743ff0d43e04e4cc3e5d8b9e7cd016c0defa36d Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:28:33 +0000 Subject: [PATCH 16/24] CHANGELOG: to reverse chron order + mdformat --- CHANGELOG.md | 491 ++++++++++++++++++++++++++++----------------------- 1 file changed, 269 insertions(+), 222 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b671145a8..a243237a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,372 +1,419 @@ -### 0.0.21 -- Ampere, RTX 30 series GPUs now compatible with the library. +### 0.43.0 -### 0.0.22: +#### Improvements and New Features: -- Fixed an error where a `reset_parameters()` call on the `StableEmbedding` would lead to an error in older PyTorch versions (from 1.7.0). +- QLoRA + FSDP official support is now live! https://github.com/TimDettmers/bitsandbytes/pull/970 by @warner-benjamin and team - with FSDP you can train very large models (70b scale) on multiple 24GB consumer-type GPUs. See https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html for more details. +- Introduced improvements to the CI process for enhanced performance and efficiency during builds, specifically enabling more effective cross-compilation on Linux platforms. This was accomplished by deprecating Make and migrating to Cmake, as well as implementing new corresponding workflows. Huge thanks go to @wkpark, @rickardp, @matthewdouglas and @younesbelkada; #1055, #1050, #1111. +- Windows should be officially supported in bitsandbytes if you install the library from source. See: https://huggingface.co/docs/bitsandbytes/main/en/index for more details +- Updated installation instructions to provide more comprehensive guidance for users. This includes clearer explanations and additional tips for various setup scenarios, making the library more accessible to a broader audience (@rickardp, #1047). +- Enhanced the library's compatibility and setup process, including fixes for CPU-only installations and improvements in CUDA setup error messaging. This effort aims to streamline the installation process and improve user experience across different platforms and setups (@wkpark, @akx, #1038, #996, #1012). +- Setup a new documentation at https://huggingface.co/docs/bitsandbytes/main with extensive new sections and content to help users better understand and utilize the library. Especially notable are the new API docs. (big thanks to @stevhliu and @mishig25 from HuggingFace #1012). The API docs have been also addressed in #1075. -### 0.0.23: +#### Bug Fixes: -Bugs: - - Unified quantization API: each quantization function now returns `Q, S` where `Q` is the quantized tensor and `S` the quantization state which may hold absolute max values, a quantization map or more. For dequantization all functions now accept the inputs `Q, S` so that `Q` is dequantized with the quantization state `S`. - - Fixed an issue where the CUDA 11.1 binary was not compiled with the right headers +- Addressed a race condition in kEstimateQuantiles, enhancing the reliability of quantile estimation in concurrent environments (@pnunna93, #1061). +- Fixed various minor issues, including typos in code comments and documentation, to improve code clarity and prevent potential confusion (@Brian Vaughan, #1063). -API changes: - - Block-wise quantization for optimizers now enabled by default +#### Backwards Compatibility -Features: - - Block-wise quantization routines now support CPU Tensors. +- After upgrading from `v0.42` to `v0.43`, when using 4bit quantization, models may generate slightly different outputs (approximately up to the 2nd decimal place) due to a fix in the code. For anyone interested in the details, [see this comment](https://github.com/TimDettmers/bitsandbytes/discussions/1094#discussioncomment-8984069). +#### Internal and Build System Enhancements: -### 0.0.24: +- Implemented several enhancements to the internal and build systems, including adjustments to the CI workflows, portability improvements, and build artifact management. These changes contribute to a more robust and flexible development process, ensuring the library's ongoing quality and maintainability (@rickardp, @akx, @wkpark, @matthewdouglas; #949, #1053, #1045, #1037). -- Fixed a bug where a float/half conversion led to a compilation error for CUDA 11.1 on Turning GPUs. -- removed Apex dependency for bnb LAMB +#### Contributors: -### 0.0.25: +This release is made possible thanks to the many active contributors that submitted PRs and many others who contributed to discussions, reviews, and testing. Your efforts greatly enhance the library's quality and user experience. It's truly inspiring to work with such a dedicated and competent group of volunteers and professionals! + +We give a special thanks to @TimDettmers for managing to find a little bit of time for valuable consultations on critical topics, despite preparing for and touring the states applying for professor positions. We wish him the utmost success! + +We also extend our gratitude to the broader community for your continued support, feedback, and engagement, which play a crucial role in driving the library's development forward. + +### 0.42.0 Features: - - Added `skip_zeros` for block-wise and 32-bit optimizers. This ensures correct updates for sparse gradients and sparse models. - - Added support for Kepler GPUs. (#4) - - Added Analysis Adam to track 8-bit vs 32-bit quantization errors over time. - - Make compilation more user friendly. + +- 4-bit serialization now supported. This enables 4-bit load/store. Thank you @poedator #753 +- the bitsandbytes library now has a version attribute: `bitsandbytes.__version__` @rasbt #710 Bug fixes: - - fixed "undefined symbol: \_\_fatbinwrap_38" error for P100 GPUs on CUDA 10.1 (#5) -Docs: - - Added docs with instructions to compile from source. +- Fixed bugs in dynamic exponent data type creation. Thank you @RossM, @KohakuBlueleaf, @ArrowM #659 #227 #262 #152 +- Fixed an issue where 4-bit serialization would fail for layers without double quantization #868. Thank you, @poedator +- Fixed an issue where calling .to() or .cuda() on a 4-bit layer twice would result in an error #867. Thank you, @jph00 +- Fixed a bug where a missing access permission in a path searched for CUDA would lead to an error @osma #677 +- Fixed a bug where the GOOGLE_VM_CONFIG_LOCK_FILE variable could cause errors in colab environments @akrentsel @xaptronic #715 #883 #622 +- Fixed a bug where kgetColRowStats (LLM.int8()) would fail for certain dimensions @LucQueen @905 +- Fixed a bug where the adjusted regular Embedding layer was not available via bnb.nn.Embedding @neel04 #563 +- Fixed added missing scipy requirement @dulalbert #525 +### 0.41.3 -### 0.26.0: +Bug fixes: -Features: - - Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer. - - Added AdamW (copy of Adam with weight decay init 1e-2). #10 - - Introduced ModuleConfig overrides which can be seamlessly be used at initialization time of a module. - - Added `bnb.nn.Embedding` layer which runs at 32-bit but without the layernorm. This works well if you need to fine-tune pretrained models that do not have a embedding layer norm. #19 +- Fixed an issue where 4-bit serialization would fail for layers without double quantization #868. Thank you, @poedator +- Fixed an issue where calling .to() or .cuda() on a 4-bit layer twice would result in an error #867. Thank you, @jph00 -Bug fixes: - - Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13 - - Fixed an unsafe use of eval. #8 - - Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15 +### 0.41.2 -Docs: - - Added instructions how to solve "\_\_fatbinwrap_" errors. +Feature: +- 4-bit serialization now supported. This enables 4-bit load/store. Thank you @poedator #753 -### 0.30.0 +### 0.41.1 -#### 8-bit Inference Update +Bug fixes: + +- Fixed bugs in dynamic exponent data type creation. Thank you @RossM, @KohakuBlueleaf, @ArrowM #659 #227 #262 #152 + +### 0.41.0 Features: - - Added 8-bit matrix multiplication form cuBLAS, and cuBLASLt as well as multiple GEMM kernels (GEMM, GEMMEx, GEMMLt) - - Added 8-bit Linear layers with 8-bit Params that perform memory efficient inference with an option for 8-bit mixed precision matrix decomposition for inference without performance degradation - - Added quantization methods for "fake" quantization as well as optimized kernels vector-wise quantization and equalization as well as optimized cuBLASLt transformations - - CPU only build now available (Thank you, @mryab) -Deprecated: - - Pre-compiled release for CUDA 9.2, 10.0, 10.2 no longer available +- Added precompiled CUDA 11.8 binaries to support H100 GPUs without compilation #571 +- CUDA SETUP now no longer looks for libcuda and libcudart and relies PyTorch CUDA libraries. To manually override this behavior see: how_to_use_nonpytorch_cuda.md. Thank you @rapsealk -### 0.31.0 +Bug fixes: -#### 8-bit Inference and Packaging Update +- Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553 +- Fixed a missing scipy dependency in requirements.txt. #544 +- Fixed a bug, where a view operation could cause an error in 8-bit layers. +- Fixed a bug where CPU bitsandbytes would during the import. #593 Thank you @bilelomrani +- Fixed a but where a non-existent LD_LIBRARY_PATH variable led to a failure in python -m bitsandbytes #588 +- Removed outdated get_cuda_lib_handle calls that lead to errors. #595 Thank you @ihsanturk +- Fixed bug where read-permission was assumed for a file. #497 +- Fixed a bug where prefetchAsync lead to errors on GPUs that do not support unified memory but not prefetching (Maxwell, SM52). #470 #451 #453 #477 Thank you @jllllll and @stoperro -Features: - - added direct outlier extraction. This enables outlier extraction without fp16 weights without performance degradation. - - Added automatic CUDA SETUP procedure and packaging all binaries into a single bitsandbytes package. +Documentation: -### 0.32.0 +- Improved documentation for GPUs that do not support 8-bit matmul. #529 +- Added description and pointers for the NF4 data type. #543 -#### 8-bit Inference Performance Enhancements +User experience: -We added performance enhancements for small models. This makes small models about 2x faster for LLM.int8() inference. +- Improved handling of default compute_dtype for Linear4bit Layers, so that compute_dtype = input_dtype if the input data type is stable enough (float32, bfloat16, but not float16). + +Performance: + +- improved 4-bit inference performance for A100 GPUs. This degraded performance for A40/RTX3090 and RTX 4090 GPUs slightly. + +### 0.40.2 + +Bug fixes: + +- Fixed a but where a non-existent LD_LIBRARY_PATH variable led to a failure in python -m bitsandbytes #588 +- Removed outdated get_cuda_lib_handle calls that lead to errors. #595 Thank you @ihsanturk +- Fixed bug where read-permission was assumed for a file. #497 +- Fixed a bug where prefetchAsync lead to errors on GPUs that do not support unified memory but not prefetching (Maxwell, SM52). #470 #451 #453 #477 Thank you @jllllll and @stoperro + +### 0.40.1 Features: - - Int32 dequantization now supports fused biases. - - Linear8bitLt now uses a fused bias implementation. - - Change `.data.storage().data_ptr()` to `.data.data_ptr()` to enhance inference performance. + +- Added precompiled CUDA 11.8 binaries to support H100 GPUs without compilation #571 +- CUDA SETUP now no longer looks for libcuda and libcudart and relies PyTorch CUDA libraries. To manually override this behavior see: how_to_use_nonpytorch_cuda.md. Thank you @rapsealk Bug fixes: - - Now throws and error if LLM.int8() is used on a GPU that is not supported. - - Enhances error messaging if CUDA SETUP fails. +- Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553 +- Fixed a missing scipy dependency in requirements.txt. #544 +- Fixed a bug, where a view operation could cause an error in 8-bit layers. +- Fixed a bug where CPU bitsandbytes would during the import. #593 Thank you @bilelomrani -### 0.33.0 +Documentation: -#### Various bug fixes +- Improved documentation for GPUs that do not support 8-bit matmul. #529 +- Added description and pointers for the NF4 data type. #543 + +### 0.40.0 Features: - - CPU quantization now supports a variable `blocksize` variable to enhance quantization speed or precision. + +- Added 4-bit inference kernels for batch size=1. Currently support are the NF4, FP4 data types. +- Added support for quantizations of bfloat16 input data. Bug fixes: - - fixed an issue in CPU quantization where tensors with more than 2^31 elements would fail 19a7adca7a6c9bf7061a384d7e9d9b13676a1a88 - - fixed a bug where cpu binaries would fail if no GPU would be detected eab4d8232d558f2e6bd7f7cc3d00e2e6e94f4e80 - - fixed an issue where cpu binaries cause additional stdout messages 92a3363096e10ad6a5c4e944af898bd1186d806a - - fixed an import of bnb.utils 2e630b55f51d454f3bd723dffda68a07ef93190c -We thank @mryab, @mbrukman, @chessgecko, @dbaranchuk for pull request with bug fixes and new features. +- Added `device` variable for bitsandbytes layers to be compatible with PyTorch layers. +Deprecated: -### 0.34.0 +- Binaries for CUDA 11.2, 11.6 no longer ship with `pip install bitsandbytes` and need to be compiled from source. -#### Bug fixes and memory efficient backprop +### 0.39.0 Features: - - Linear8bitLt layer now supports `memory_efficient_backward=True` which enables backprop of gradients through frozen weights. + +- 4-bit matrix multiplication for Float4 and NormalFloat4 data types. +- Added 4-bit quantization routines +- Doubled quantization routines for 4-bit quantization +- Paged optimizers for Adam and Lion. +- bfloat16 gradient / weight support for Adam and Lion with 8 or 32-bit states. Bug fixes: - - fixed an issue where too many threads were created in blockwise quantization on the CPU for large tensors +- Fixed a bug where 8-bit models consumed twice the memory as expected after serialization -### 0.35.0 +Deprecated: -#### CUDA 11.8 support and bug fixes +- Kepler binaries (GTX 700s and Tesla K40/K80) are not longer provided via pip and need to be compiled from source. Kepler support might be fully removed in the future. + +### 0.38.1 Features: - - CUDA 11.8 support added and binaries added to the PyPI release. -Bug fixes: - - fixed a bug where too long directory names would crash the CUDA SETUP #35 (thank you @tomaarsen) - - fixed a bug where CPU installations on Colab would run into an error #34 (thank you @tomaarsen) - - fixed an issue where the default CUDA version with fast-DreamBooth was not supported #52 +- Added Int8 SwitchBack layers +- Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`) -### 0.35.1 +### 0.38.0 + +#### 8-bit Lion, Load/Store 8-bit Models directly from/to HF Hub Features: - - Added CUDA instruction generator to fix some installations. + +- Support for 32 and 8-bit Lion has been added. Thank you @lucidrains +- Support for serialization of Linear8bitLt layers (LLM.int8()). This allows to store and load 8-bit weights directly from the HuggingFace Hub. Thank you @myrab +- New bug report features `python -m bitsandbytes` now gives extensive debugging details to debug CUDA setup failures. Bug fixes: - - Fixed a problem where warning messages would be displayed even though everything worked correctly. -### 0.35.2 +- Fixed a bug where some bitsandbytes methods failed in a model-parallel setup on multiple GPUs. Thank you @tonylins +- Fixed a bug where cudart.so libraries could not be found in newer PyTorch releases. -Bug fixes: - - Fixed a bug where the CUDA setup failed due to a wrong function call. +Improvements: -### 0.35.3 +- Improved the CUDA Setup procedure by doing a more extensive search for CUDA libraries -Bug fixes: - - Fixed a bug in the CUDA Setup which led to an incomprehensible error if no GPU was detected. +Deprecated: -### 0.35.4 +- Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0. +- Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0 -Bug fixes: - - Fixed a bug in the CUDA Setup failed with the cuda runtime was found, but not the cuda library. - - Fixed a bug where not finding the cuda runtime led to an incomprehensible error. +### 0.37.0 + +#### Int8 Matmul + backward support for all GPUs + +Features: +- Int8 MatmulLt now supports backward through inversion of the ColTuring/ColAmpere format. Slow, but memory efficient. Big thanks to @borzunov +- Int8 now supported on all GPUs. On devices with compute capability \< 7.5, the Int weights are cast to 16/32-bit for the matrix multiplication. Contributed by @borzunov + +Improvements: + +- Improved logging for the CUDA detection mechanism. ### 0.36.0 #### Improvements, Ada/Hopper support, fake k-bit quantization. Features: - - CUDA 11.8 and 12.0 support added - - support for Ada and Hopper GPUs added (compute capability 8.9 and 9.0) - - support for fake k-bit block-wise quantization for Int, Float, quantile quantization, and dynamic exponent data types added - - Added CUDA instruction generator to fix some installations. - - Added additional block sizes for quantization {64, 128, 256, 512, 1024} - - Added SRAM Quantile algorithm to quickly estimate less than 256 quantiles - - Added option to suppress the bitsandbytes welcome message (@Cyberes) + +- CUDA 11.8 and 12.0 support added +- support for Ada and Hopper GPUs added (compute capability 8.9 and 9.0) +- support for fake k-bit block-wise quantization for Int, Float, quantile quantization, and dynamic exponent data types added +- Added CUDA instruction generator to fix some installations. +- Added additional block sizes for quantization {64, 128, 256, 512, 1024} +- Added SRAM Quantile algorithm to quickly estimate less than 256 quantiles +- Added option to suppress the bitsandbytes welcome message (@Cyberes) Regression: - - Compute capability 3.0 removed: GTX 600s and 700s series is no longer supported (except GTX 780 and GTX 780 Ti) + +- Compute capability 3.0 removed: GTX 600s and 700s series is no longer supported (except GTX 780 and GTX 780 Ti) Bug fixes: - - fixed a bug where too long directory names would crash the CUDA SETUP #35 (@tomaarsen) - - fixed a bug where CPU installations on Colab would run into an error #34 (@tomaarsen) - - fixed an issue where the default CUDA version with fast-DreamBooth was not supported #52 - - fixed a bug where the CUDA setup failed due to a wrong function call. - - fixed a bug in the CUDA Setup which led to an incomprehensible error if no GPU was detected. - - fixed a bug in the CUDA Setup failed with the cuda runtime was found, but not the cuda library. - - fixed a bug where not finding the cuda runtime led to an incomprehensible error. - - fixed a bug where with missing CUDA the default was an error instead of the loading the CPU library - - fixed a bug where the CC version of the GPU was not detected appropriately (@BlackHC) - - fixed a bug in CPU quantization which lead to errors when the input buffer exceeded 2^31 elements + +- fixed a bug where too long directory names would crash the CUDA SETUP #35 (@tomaarsen) +- fixed a bug where CPU installations on Colab would run into an error #34 (@tomaarsen) +- fixed an issue where the default CUDA version with fast-DreamBooth was not supported #52 +- fixed a bug where the CUDA setup failed due to a wrong function call. +- fixed a bug in the CUDA Setup which led to an incomprehensible error if no GPU was detected. +- fixed a bug in the CUDA Setup failed with the cuda runtime was found, but not the cuda library. +- fixed a bug where not finding the cuda runtime led to an incomprehensible error. +- fixed a bug where with missing CUDA the default was an error instead of the loading the CPU library +- fixed a bug where the CC version of the GPU was not detected appropriately (@BlackHC) +- fixed a bug in CPU quantization which lead to errors when the input buffer exceeded 2^31 elements Improvements: - - multiple improvements in formatting, removal of unused imports, and slight performance improvements (@tomaarsen) - - StableEmbedding layer now has device and dtype parameters to make it 1:1 replaceable with regular Embedding layers (@lostmsu) - - runtime performance of block-wise quantization slightly improved - - added error message for the case multiple libcudart.so are installed and bitsandbytes picks the wrong one +- multiple improvements in formatting, removal of unused imports, and slight performance improvements (@tomaarsen) +- StableEmbedding layer now has device and dtype parameters to make it 1:1 replaceable with regular Embedding layers (@lostmsu) +- runtime performance of block-wise quantization slightly improved +- added error message for the case multiple libcudart.so are installed and bitsandbytes picks the wrong one -### 0.37.0 +### 0.35.4 -#### Int8 Matmul + backward support for all GPUs +Bug fixes: -Features: - - Int8 MatmulLt now supports backward through inversion of the ColTuring/ColAmpere format. Slow, but memory efficient. Big thanks to @borzunov - - Int8 now supported on all GPUs. On devices with compute capability < 7.5, the Int weights are cast to 16/32-bit for the matrix multiplication. Contributed by @borzunov +- Fixed a bug in the CUDA Setup failed with the cuda runtime was found, but not the cuda library. +- Fixed a bug where not finding the cuda runtime led to an incomprehensible error. -Improvements: - - Improved logging for the CUDA detection mechanism. +### 0.35.3 -### 0.38.0 +Bug fixes: -#### 8-bit Lion, Load/Store 8-bit Models directly from/to HF Hub +- Fixed a bug in the CUDA Setup which led to an incomprehensible error if no GPU was detected. -Features: - - Support for 32 and 8-bit Lion has been added. Thank you @lucidrains - - Support for serialization of Linear8bitLt layers (LLM.int8()). This allows to store and load 8-bit weights directly from the HuggingFace Hub. Thank you @myrab - - New bug report features `python -m bitsandbytes` now gives extensive debugging details to debug CUDA setup failures. +### 0.35.2 Bug fixes: - - Fixed a bug where some bitsandbytes methods failed in a model-parallel setup on multiple GPUs. Thank you @tonylins - - Fixed a bug where cudart.so libraries could not be found in newer PyTorch releases. -Improvements: - - Improved the CUDA Setup procedure by doing a more extensive search for CUDA libraries +- Fixed a bug where the CUDA setup failed due to a wrong function call. -Deprecated: - - Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0. - - Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0 +### 0.35.1 +Features: -### 0.38.1 +- Added CUDA instruction generator to fix some installations. -Features: - - Added Int8 SwitchBack layers - - Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`) +Bug fixes: +- Fixed a problem where warning messages would be displayed even though everything worked correctly. -### 0.39.0 +### 0.35.0 +#### CUDA 11.8 support and bug fixes Features: - - 4-bit matrix multiplication for Float4 and NormalFloat4 data types. - - Added 4-bit quantization routines - - Doubled quantization routines for 4-bit quantization - - Paged optimizers for Adam and Lion. - - bfloat16 gradient / weight support for Adam and Lion with 8 or 32-bit states. + +- CUDA 11.8 support added and binaries added to the PyPI release. Bug fixes: - - Fixed a bug where 8-bit models consumed twice the memory as expected after serialization -Deprecated: - - Kepler binaries (GTX 700s and Tesla K40/K80) are not longer provided via pip and need to be compiled from source. Kepler support might be fully removed in the future. +- fixed a bug where too long directory names would crash the CUDA SETUP #35 (thank you @tomaarsen) +- fixed a bug where CPU installations on Colab would run into an error #34 (thank you @tomaarsen) +- fixed an issue where the default CUDA version with fast-DreamBooth was not supported #52 +### 0.34.0 -### 0.40.0 +#### Bug fixes and memory efficient backprop Features: - - Added 4-bit inference kernels for batch size=1. Currently support are the NF4, FP4 data types. - - Added support for quantizations of bfloat16 input data. + +- Linear8bitLt layer now supports `memory_efficient_backward=True` which enables backprop of gradients through frozen weights. Bug fixes: - - Added `device` variable for bitsandbytes layers to be compatible with PyTorch layers. -Deprecated: - - Binaries for CUDA 11.2, 11.6 no longer ship with `pip install bitsandbytes` and need to be compiled from source. +- fixed an issue where too many threads were created in blockwise quantization on the CPU for large tensors +### 0.33.0 -### 0.40.1 +#### Various bug fixes Features: - - Added precompiled CUDA 11.8 binaries to support H100 GPUs without compilation #571 - - CUDA SETUP now no longer looks for libcuda and libcudart and relies PyTorch CUDA libraries. To manually override this behavior see: how_to_use_nonpytorch_cuda.md. Thank you @rapsealk + +- CPU quantization now supports a variable `blocksize` variable to enhance quantization speed or precision. Bug fixes: - - Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553 - - Fixed a missing scipy dependency in requirements.txt. #544 - - Fixed a bug, where a view operation could cause an error in 8-bit layers. - - Fixed a bug where CPU bitsandbytes would during the import. #593 Thank you @bilelomrani -Documentation: - - Improved documentation for GPUs that do not support 8-bit matmul. #529 - - Added description and pointers for the NF4 data type. #543 +- fixed an issue in CPU quantization where tensors with more than 2^31 elements would fail 19a7adca7a6c9bf7061a384d7e9d9b13676a1a88 +- fixed a bug where cpu binaries would fail if no GPU would be detected eab4d8232d558f2e6bd7f7cc3d00e2e6e94f4e80 +- fixed an issue where cpu binaries cause additional stdout messages 92a3363096e10ad6a5c4e944af898bd1186d806a +- fixed an import of bnb.utils 2e630b55f51d454f3bd723dffda68a07ef93190c -### 0.40.2 +We thank @mryab, @mbrukman, @chessgecko, @dbaranchuk for pull request with bug fixes and new features. -Bug fixes: - - Fixed a but where a non-existent LD_LIBRARY_PATH variable led to a failure in python -m bitsandbytes #588 - - Removed outdated get_cuda_lib_handle calls that lead to errors. #595 Thank you @ihsanturk - - Fixed bug where read-permission was assumed for a file. #497 - - Fixed a bug where prefetchAsync lead to errors on GPUs that do not support unified memory but not prefetching (Maxwell, SM52). #470 #451 #453 #477 Thank you @jllllll and @stoperro +### 0.32.0 +#### 8-bit Inference Performance Enhancements -### 0.41.0 +We added performance enhancements for small models. This makes small models about 2x faster for LLM.int8() inference. Features: - - Added precompiled CUDA 11.8 binaries to support H100 GPUs without compilation #571 - - CUDA SETUP now no longer looks for libcuda and libcudart and relies PyTorch CUDA libraries. To manually override this behavior see: how_to_use_nonpytorch_cuda.md. Thank you @rapsealk + +- Int32 dequantization now supports fused biases. +- Linear8bitLt now uses a fused bias implementation. +- Change `.data.storage().data_ptr()` to `.data.data_ptr()` to enhance inference performance. Bug fixes: - - Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553 - - Fixed a missing scipy dependency in requirements.txt. #544 - - Fixed a bug, where a view operation could cause an error in 8-bit layers. - - Fixed a bug where CPU bitsandbytes would during the import. #593 Thank you @bilelomrani - - Fixed a but where a non-existent LD_LIBRARY_PATH variable led to a failure in python -m bitsandbytes #588 - - Removed outdated get_cuda_lib_handle calls that lead to errors. #595 Thank you @ihsanturk - - Fixed bug where read-permission was assumed for a file. #497 - - Fixed a bug where prefetchAsync lead to errors on GPUs that do not support unified memory but not prefetching (Maxwell, SM52). #470 #451 #453 #477 Thank you @jllllll and @stoperro -Documentation: - - Improved documentation for GPUs that do not support 8-bit matmul. #529 - - Added description and pointers for the NF4 data type. #543 +- Now throws and error if LLM.int8() is used on a GPU that is not supported. +- Enhances error messaging if CUDA SETUP fails. -User experience: - - Improved handling of default compute_dtype for Linear4bit Layers, so that compute_dtype = input_dtype if the input data type is stable enough (float32, bfloat16, but not float16). +### 0.31.0 -Performance: - - improved 4-bit inference performance for A100 GPUs. This degraded performance for A40/RTX3090 and RTX 4090 GPUs slightly. +#### 8-bit Inference and Packaging Update -### 0.41.1 +Features: -Bug fixes: - - Fixed bugs in dynamic exponent data type creation. Thank you @RossM, @KohakuBlueleaf, @ArrowM #659 #227 #262 #152 +- added direct outlier extraction. This enables outlier extraction without fp16 weights without performance degradation. +- Added automatic CUDA SETUP procedure and packaging all binaries into a single bitsandbytes package. -### 0.41.2 +### 0.30.0 -Feature: - - 4-bit serialization now supported. This enables 4-bit load/store. Thank you @poedator #753 +#### 8-bit Inference Update -### 0.41.3 +Features: + +- Added 8-bit matrix multiplication form cuBLAS, and cuBLASLt as well as multiple GEMM kernels (GEMM, GEMMEx, GEMMLt) +- Added 8-bit Linear layers with 8-bit Params that perform memory efficient inference with an option for 8-bit mixed precision matrix decomposition for inference without performance degradation +- Added quantization methods for "fake" quantization as well as optimized kernels vector-wise quantization and equalization as well as optimized cuBLASLt transformations +- CPU only build now available (Thank you, @mryab) + +Deprecated: + +- Pre-compiled release for CUDA 9.2, 10.0, 10.2 no longer available + +### 0.26.0: + +Features: + +- Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer. +- Added AdamW (copy of Adam with weight decay init 1e-2). #10 +- Introduced ModuleConfig overrides which can be seamlessly be used at initialization time of a module. +- Added `bnb.nn.Embedding` layer which runs at 32-bit but without the layernorm. This works well if you need to fine-tune pretrained models that do not have a embedding layer norm. #19 Bug fixes: - - Fixed an issue where 4-bit serialization would fail for layers without double quantization #868. Thank you, @poedator - - Fixed an issue where calling .to() or .cuda() on a 4-bit layer twice would result in an error #867. Thank you, @jph00 -### 0.42.0 +- Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13 +- Fixed an unsafe use of eval. #8 +- Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15 + +Docs: + +- Added instructions how to solve "\_\_fatbinwrap\_" errors. + +### 0.0.25: Features: - - 4-bit serialization now supported. This enables 4-bit load/store. Thank you @poedator #753 - - the bitsandbytes library now has a version attribute: `bitsandbytes.__version__` @rasbt #710 + +- Added `skip_zeros` for block-wise and 32-bit optimizers. This ensures correct updates for sparse gradients and sparse models. +- Added support for Kepler GPUs. (#4) +- Added Analysis Adam to track 8-bit vs 32-bit quantization errors over time. +- Make compilation more user friendly. Bug fixes: - - Fixed bugs in dynamic exponent data type creation. Thank you @RossM, @KohakuBlueleaf, @ArrowM #659 #227 #262 #152 - - Fixed an issue where 4-bit serialization would fail for layers without double quantization #868. Thank you, @poedator - - Fixed an issue where calling .to() or .cuda() on a 4-bit layer twice would result in an error #867. Thank you, @jph00 - - Fixed a bug where a missing access permission in a path searched for CUDA would lead to an error @osma #677 - - Fixed a bug where the GOOGLE_VM_CONFIG_LOCK_FILE variable could cause errors in colab environments @akrentsel @xaptronic #715 #883 #622 - - Fixed a bug where kgetColRowStats (LLM.int8()) would fail for certain dimensions @LucQueen @905 - - Fixed a bug where the adjusted regular Embedding layer was not available via bnb.nn.Embedding @neel04 #563 - - Fixed added missing scipy requirement @dulalbert #525 -### 0.43.0 +- fixed "undefined symbol: \_\_fatbinwrap_38" error for P100 GPUs on CUDA 10.1 (#5) -#### Improvements and New Features: -- QLoRA + FSDP official support is now live! https://github.com/TimDettmers/bitsandbytes/pull/970 by @warner-benjamin and team - with FSDP you can train very large models (70b scale) on multiple 24GB consumer-type GPUs. See https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html for more details. -- Introduced improvements to the CI process for enhanced performance and efficiency during builds, specifically enabling more effective cross-compilation on Linux platforms. This was accomplished by deprecating Make and migrating to Cmake, as well as implementing new corresponding workflows. Huge thanks go to @wkpark, @rickardp, @matthewdouglas and @younesbelkada; #1055, #1050, #1111. -- Windows should be officially supported in bitsandbytes if you install the library from source. See: https://huggingface.co/docs/bitsandbytes/main/en/index for more details -- Updated installation instructions to provide more comprehensive guidance for users. This includes clearer explanations and additional tips for various setup scenarios, making the library more accessible to a broader audience (@rickardp, #1047). -- Enhanced the library's compatibility and setup process, including fixes for CPU-only installations and improvements in CUDA setup error messaging. This effort aims to streamline the installation process and improve user experience across different platforms and setups (@wkpark, @akx, #1038, #996, #1012). -- Setup a new documentation at https://huggingface.co/docs/bitsandbytes/main with extensive new sections and content to help users better understand and utilize the library. Especially notable are the new API docs. (big thanks to @stevhliu and @mishig25 from HuggingFace #1012). The API docs have been also addressed in #1075. +Docs: -#### Bug Fixes: -- Addressed a race condition in kEstimateQuantiles, enhancing the reliability of quantile estimation in concurrent environments (@pnunna93, #1061). -- Fixed various minor issues, including typos in code comments and documentation, to improve code clarity and prevent potential confusion (@Brian Vaughan, #1063). +- Added docs with instructions to compile from source. -#### Backwards Compatibility -- After upgrading from `v0.42` to `v0.43`, when using 4bit quantization, models may generate slightly different outputs (approximately up to the 2nd decimal place) due to a fix in the code. For anyone interested in the details, [see this comment](https://github.com/TimDettmers/bitsandbytes/discussions/1094#discussioncomment-8984069). +### 0.0.24: +- Fixed a bug where a float/half conversion led to a compilation error for CUDA 11.1 on Turning GPUs. +- removed Apex dependency for bnb LAMB -#### Internal and Build System Enhancements: -- Implemented several enhancements to the internal and build systems, including adjustments to the CI workflows, portability improvements, and build artifact management. These changes contribute to a more robust and flexible development process, ensuring the library's ongoing quality and maintainability (@rickardp, @akx, @wkpark, @matthewdouglas; #949, #1053, #1045, #1037). +### 0.0.23: -#### Contributors: -This release is made possible thanks to the many active contributors that submitted PRs and many others who contributed to discussions, reviews, and testing. Your efforts greatly enhance the library's quality and user experience. It's truly inspiring to work with such a dedicated and competent group of volunteers and professionals! +Bugs: -We give a special thanks to @TimDettmers for managing to find a little bit of time for valuable consultations on critical topics, despite preparing for and touring the states applying for professor positions. We wish him the utmost success! +- Unified quantization API: each quantization function now returns `Q, S` where `Q` is the quantized tensor and `S` the quantization state which may hold absolute max values, a quantization map or more. For dequantization all functions now accept the inputs `Q, S` so that `Q` is dequantized with the quantization state `S`. +- Fixed an issue where the CUDA 11.1 binary was not compiled with the right headers -We also extend our gratitude to the broader community for your continued support, feedback, and engagement, which play a crucial role in driving the library's development forward. +API changes: + +- Block-wise quantization for optimizers now enabled by default + +Features: + +- Block-wise quantization routines now support CPU Tensors. + +### 0.0.22: + +- Fixed an error where a `reset_parameters()` call on the `StableEmbedding` would lead to an error in older PyTorch versions (from 1.7.0). + +### 0.0.21 + +- Ampere, RTX 30 series GPUs now compatible with the library. From 0c33c0d45ec7b61bb1f1817582937fb957dd6be0 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:32:37 +0000 Subject: [PATCH 17/24] ignore CHANGELOG reordering + formatting commit --- .git-blame-ignore-revs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index d953c93dd..648e437f4 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -12,3 +12,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 # Reformat with ruff-format 5a4263f4dc05fe8f78f4111beab9f68a81deeab1 + +# CHANGELOG: to reverse chron order + mdformat +4743ff0d43e04e4cc3e5d8b9e7cd016c0defa36d From f92c5362b2d5267e122d4d9085838c3fd2fc59b3 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:42:32 +0000 Subject: [PATCH 18/24] CHANGELOG: add v0.43.1 --- CHANGELOG.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a243237a2..476a6e316 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,23 @@ +### 0.43.1 + +#### Improvements and New Features: + +- Improved the serialization format for 8-bit weights; this change is fully backwards compatible. (#1164, thanks to @younesbelkada for the contributions and @akx for the review). +- Added CUDA 12.4 support to the Linux x86-64 build workflow, expanding the library's compatibility with the latest CUDA versions. (#1171, kudos to @matthewdouglas for this addition). +- Docs enhancement: Improved the instructions for installing the library from source. (#1149, special thanks to @stevhliu for the enhancements). + +#### Bug Fixes + +- Fix 4bit quantization with blocksize = 4096, where an illegal memory access was encountered. (#1160, thanks @matthewdouglas for fixing and @YLGH for reporting) + +#### Internal Improvements: + +- Tests: improve memory usage (#1147, thanks @matthewdouglas) +- Add CUDA 12.4 to docs/install helper (#1136, thanks @matthewdouglas) +- Minor type/doc fixes (#1128, thanks @akx) +- Reformat Python code with Ruff (#1081, thanks @akx) +- Rework of CUDA/native-library setup and diagnostics (#1041, thanks @akx) + ### 0.43.0 #### Improvements and New Features: From 4a6fb352cfb90b17820391f0db18aeda98774f0a Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 11 Apr 2024 17:44:18 +0000 Subject: [PATCH 19/24] bump version to 0.43.1 --- bitsandbytes/__init__.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 78c99355b..2182de1d3 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -21,4 +21,4 @@ "optim.optimizer.MockArgs": False, } -__version__ = "0.44.0.dev" +__version__ = "0.43.1" diff --git a/setup.py b/setup.py index a51b3867c..a3bd5fc34 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def has_ext_modules(self): setup( name="bitsandbytes", - version="0.44.0.dev", + version="0.43.1", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="k-bit optimizers and matrix multiplication routines.", From 7b0c4cd3ad396c70bafda621f7a17332f40ee962 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 11 Apr 2024 18:42:45 +0000 Subject: [PATCH 20/24] small fix in changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 476a6e316..c456fa9e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ ### 0.43.1 -#### Improvements and New Features: +#### Improvements: - Improved the serialization format for 8-bit weights; this change is fully backwards compatible. (#1164, thanks to @younesbelkada for the contributions and @akx for the review). - Added CUDA 12.4 support to the Linux x86-64 build workflow, expanding the library's compatibility with the latest CUDA versions. (#1171, kudos to @matthewdouglas for this addition). From 127788a96e123bb2e95ff9dbcc78672e4849cddc Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 11 Apr 2024 18:43:28 +0000 Subject: [PATCH 21/24] bump version to next dev --- bitsandbytes/__init__.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 2182de1d3..51cbde208 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -21,4 +21,4 @@ "optim.optimizer.MockArgs": False, } -__version__ = "0.43.1" +__version__ = "0.43.2.dev" diff --git a/setup.py b/setup.py index a3bd5fc34..f8d6a92a1 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def has_ext_modules(self): setup( name="bitsandbytes", - version="0.43.1", + version="0.43.2.dev", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="k-bit optimizers and matrix multiplication routines.", From 6cecb65a56ace6f60c113929bfd7de120aaa2ab9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 17 Apr 2024 09:53:31 +0200 Subject: [PATCH 22/24] Update pandas requirement from ~=2.2.1 to ~=2.2.2 in the major group (#1182) Updates the requirements on [pandas](https://github.com/pandas-dev/pandas) to permit the latest version. Updates `pandas` to 2.2.2 - [Release notes](https://github.com/pandas-dev/pandas/releases) - [Commits](https://github.com/pandas-dev/pandas/compare/v2.2.1...v2.2.2) --- updated-dependencies: - dependency-name: pandas dependency-type: direct:development dependency-group: major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 4ee8bf543..6dad70563 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,5 +5,5 @@ einops~=0.7.0 wheel~=0.43.0 lion-pytorch~=0.1.4 scipy~=1.13.0 -pandas~=2.2.1 +pandas~=2.2.2 matplotlib~=3.8.4 From ffd7d0db6a660c97b60a2c9605309ee4b5cd40e3 Mon Sep 17 00:00:00 2001 From: Titus <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 17 Apr 2024 16:22:52 +0200 Subject: [PATCH 23/24] (docs) integrations: fix omission in bf16 related warning (#1183) * (docs) integrations: fix omission in bf16 related warning * (docs) integrations: further clarifications to prior fix * (docs) integrations: fix punctuation Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * (docs) integrations: fix omitted code formatting --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/integrations.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/integrations.mdx b/docs/source/integrations.mdx index 4badece49..42b8edf03 100644 --- a/docs/source/integrations.mdx +++ b/docs/source/integrations.mdx @@ -12,7 +12,7 @@ With Transformers, it's very easy to load any model in 4 or 8-bit and quantize t For example, to load and quantize a model to 4-bits and use the bfloat16 data type for compute: > [!WARNING] -> bfloat16 is the optimal compute data type if your hardware supports it. The default is float32 for backward compatibility and numerical stability, but it can often lead to numerical instabilities. bfloat16 provides the best of both worlds, numerical stability equivalent to float32, but combined with the memory footprint and significant computation speedup of a 16-bit data type. Make sure to check if your hardware supports bfloat16 and if it does, configure it using the `bnb_4bit_compute_dtype` parameter in [`~transformers.BitsAndBytesConfig`]! +> bfloat16 is the ideal `compute_dtype` if your hardware supports it. While the default `compute_dtype`, float32, ensures backward compatibility (due to wide-ranging hardware support) and numerical stability, it is large and slows down computations. In contrast, float16 is smaller and faster but can lead to numerical instabilities. bfloat16 combines the best aspects of both; it offers the numerical stability of float32 and the reduced memory footprint and speed of a 16-bit data type. Check if your hardware supports bfloat16 and configure it using the `bnb_4bit_compute_dtype` parameter in [`~transformers.BitsAndBytesConfig`]! ```py from transformers import AutoModelForCausalLM, BitsAndBytesConfig From 5b9ef7757b0d210873cc6483da6748b575c1376a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 30 Apr 2024 18:29:43 +0200 Subject: [PATCH 24/24] Bump the minor-patch group with 2 updates (#1192) Bumps the minor-patch group with 2 updates: [pytest](https://github.com/pytest-dev/pytest) and [einops](https://github.com/arogozhnikov/einops). Updates `pytest` from 8.1.1 to 8.2.0 - [Release notes](https://github.com/pytest-dev/pytest/releases) - [Changelog](https://github.com/pytest-dev/pytest/blob/main/CHANGELOG.rst) - [Commits](https://github.com/pytest-dev/pytest/compare/8.1.1...8.2.0) Updates `einops` from 0.7.0 to 0.8.0 - [Release notes](https://github.com/arogozhnikov/einops/releases) - [Commits](https://github.com/arogozhnikov/einops/compare/v0.7.0...v0.8.0) --- updated-dependencies: - dependency-name: pytest dependency-type: direct:production update-type: version-update:semver-minor dependency-group: minor-patch - dependency-name: einops dependency-type: direct:production update-type: version-update:semver-minor dependency-group: minor-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-ci.txt | 4 ++-- requirements-dev.txt | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements-ci.txt b/requirements-ci.txt index 906c6643e..24e2db324 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,6 +1,6 @@ # Requirements used for GitHub actions -pytest==8.1.1 -einops==0.7.0 +pytest==8.2.0 +einops==0.8.0 lion-pytorch==0.1.4 scipy==1.10.1; python_version < "3.9" scipy==1.13.0; python_version >= "3.9" diff --git a/requirements-dev.txt b/requirements-dev.txt index 6dad70563..0334896be 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ # Requirements used for local development setuptools>=63 -pytest~=8.1.1 -einops~=0.7.0 +pytest~=8.2.0 +einops~=0.8.0 wheel~=0.43.0 lion-pytorch~=0.1.4 scipy~=1.13.0