├── tests ├── __init__.py ├── test_cuda_setup_evaluator.py ├── conftest.py ├── test_triton.py ├── helpers.py └── test_generation.py ├── bitsandbytes ├── py.typed ├── autograd │ └── __init__.py ├── backends │ ├── __init__.py │ ├── cpu │ │ └── __init__.py │ ├── cuda │ │ └── __init__.py │ ├── hpu │ │ ├── __init__.py │ │ └── ops.py │ ├── xpu │ │ └── __init__.py │ ├── default │ │ └── __init__.py │ ├── triton │ │ ├── __init__.py │ │ └── kernels_8bit_quant.py │ └── utils.py ├── triton │ ├── __init__.py │ ├── triton_utils.py │ ├── dequantize_rowwise.py │ ├── quantize_rowwise.py │ ├── quantize_columnwise_and_transpose.py │ └── quantize_global.py ├── diagnostics │ ├── __init__.py │ ├── utils.py │ └── main.py ├── research │ ├── autograd │ │ └── __init__.py │ ├── nn │ │ ├── __init__.py │ │ └── modules.py │ └── __init__.py ├── __main__.py ├── consts.py ├── nn │ └── __init__.py ├── optim │ ├── __init__.py │ └── sgd.py ├── __init__.py └── cuda_specs.py ├── .gitattributes ├── .github ├── FUNDING.yml ├── dependabot.yml.disabled ├── workflows │ ├── lint.yml │ ├── upload_pr_documentation.yml │ ├── stale.yml.disabled │ ├── build_documentation.yml │ ├── build_pr_documentation.yml │ ├── tests-pr.yml │ └── tests-nightly.yml ├── scripts │ ├── build-cpu.sh │ ├── build-xpu.sh │ ├── auditwheel_show.py │ ├── set_platform_tag.py │ ├── build-rocm.sh │ ├── build-xpu-windows.bat │ └── build-cuda.sh └── ISSUE_TEMPLATE │ ├── feature-request.yml │ └── bug-report.yml ├── MANIFEST.in ├── .editorconfig ├── csrc ├── common.h ├── common_hip.cuh ├── xpu_ops.h ├── xpu_kernels.h ├── common.cuh ├── mps_ops.mm ├── mps_kernels.metal ├── xpu_ops.cpp ├── kernels.cuh └── kernels_hip.cuh ├── benchmarking ├── switchback │ ├── plot_with_info.pdf │ ├── README.md │ ├── make_plot_with_jsonl.py │ └── speed_benchmark.py ├── optimizer_benchmark.py ├── int8 │ └── int8_benchmark.py └── inference_benchmark.py ├── .vscode ├── extensions.json └── settings.json ├── NOTICE.md ├── docs └── source │ ├── quickstart.mdx │ ├── faqs.mdx │ ├── reference │ ├── optim │ │ ├── rmsprop.mdx │ │ ├── lars.mdx │ │ ├── adagrad.mdx │ │ ├── sgd.mdx │ │ ├── lamb.mdx │ │ ├── lion.mdx │ │ ├── ademamix.mdx │ │ ├── adamw.mdx │ │ ├── optim_overview.mdx │ │ └── adam.mdx │ ├── nn │ │ ├── embeddings.mdx │ │ ├── linear4bit.mdx │ │ └── linear8bit.mdx │ └── functional.mdx │ ├── index.mdx │ ├── errors.mdx │ ├── contributing.mdx │ ├── _toctree.yml │ ├── explanations │ ├── resources.mdx │ └── optimizers.mdx │ ├── optimizers.mdx │ ├── integrations.mdx │ └── fsdp_qlora.md ├── check_bnb_install.py ├── _typos.toml ├── .git-blame-ignore-revs ├── examples ├── int8_inference_huggingface.py └── compile_inference.py ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── .clang-format ├── LICENSE ├── setup.py ├── scripts └── stale.py ├── install_cuda.sh ├── .gitignore ├── CODE_OF_CONDUCT.md ├── install_cuda.py └── pyproject.toml /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bitsandbytes/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bitsandbytes/autograd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bitsandbytes/backends/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bitsandbytes/triton/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.bat text eol=crlf 2 | -------------------------------------------------------------------------------- /bitsandbytes/backends/cpu/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bitsandbytes/backends/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bitsandbytes/backends/hpu/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bitsandbytes/backends/xpu/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bitsandbytes/diagnostics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bitsandbytes/backends/default/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bitsandbytes/backends/triton/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bitsandbytes/research/autograd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | open_collective: bitsandbytes 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CMakeLists.txt 2 | graft csrc 3 | graft include 4 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | [*] 2 | trim_trailing_whitespace = true 3 | insert_final_newline = true 4 | -------------------------------------------------------------------------------- /bitsandbytes/research/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import LinearFP8Global, LinearFP8Mixed 2 | -------------------------------------------------------------------------------- /bitsandbytes/__main__.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | from bitsandbytes.diagnostics.main import main 3 | 4 | main() 5 | -------------------------------------------------------------------------------- /csrc/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | typedef enum DataType_t { 4 | General8bit = 0, 5 | FP4 = 1, 6 | NF4 = 2, 7 | } DataType_t; 8 | -------------------------------------------------------------------------------- /benchmarking/switchback/plot_with_info.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bitsandbytes-foundation/bitsandbytes/HEAD/benchmarking/switchback/plot_with_info.pdf -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.python", 4 | "charliermarsh.ruff", 5 | "twxs.cmake" 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /bitsandbytes/research/__init__.py: -------------------------------------------------------------------------------- 1 | from . import nn 2 | from .autograd._functions import ( 3 | matmul_fp8_global, 4 | matmul_fp8_mixed, 5 | switchback_bnb, 6 | ) 7 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "ruff.fixAll": true, 3 | "ruff.lint.run": "onType", 4 | "editor.codeActionsOnSave": { 5 | "source.fixAll": "always" 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /NOTICE.md: -------------------------------------------------------------------------------- 1 | The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: PyTorch is licensed under the BSD license. 2 | -------------------------------------------------------------------------------- /.github/dependabot.yml.disabled: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | groups: 8 | major: 9 | update-types: [major] 10 | minor-patch: 11 | update-types: [minor, patch] 12 | -------------------------------------------------------------------------------- /docs/source/quickstart.mdx: -------------------------------------------------------------------------------- 1 | # Quickstart 2 | 3 | ## How does it work? 4 | 5 | ... work in progress ... 6 | 7 | (Community contributions would we very welcome!) 8 | 9 | ## Minimal examples 10 | 11 | The following code illustrates the steps above. 12 | 13 | ```py 14 | code examples will soon follow 15 | ``` 16 | -------------------------------------------------------------------------------- /csrc/common_hip.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __GFX9__ 4 | #define BNB_WARP_SIZE 64 5 | #else 6 | #define BNB_WARP_SIZE 32 7 | #endif 8 | 9 | // These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs 10 | #define BNB_MAX_THREADS_PER_CU 2048 11 | #define BNB_BF16_AVAILABLE true 12 | -------------------------------------------------------------------------------- /bitsandbytes/triton/triton_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | 4 | @functools.lru_cache(None) 5 | def is_triton_available(): 6 | try: 7 | from torch.utils._triton import has_triton, has_triton_package 8 | 9 | return has_triton_package() and has_triton() 10 | except Exception: 11 | return False 12 | -------------------------------------------------------------------------------- /bitsandbytes/diagnostics/utils.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | 3 | HEADER_WIDTH = 60 4 | 5 | 6 | def print_header(txt: str, width: int = HEADER_WIDTH, filler: str = "=") -> None: 7 | txt = f" {txt} " if txt else "" 8 | print(txt.center(width, filler)) 9 | 10 | 11 | def print_dedented(text): 12 | print("\n".join(textwrap.dedent(text).strip().split("\n"))) 13 | -------------------------------------------------------------------------------- /benchmarking/switchback/README.md: -------------------------------------------------------------------------------- 1 | Steps: 2 | 3 | 1. Run `python speed_benchmark/speed_benchmark.py` which times operations and writes their time to `speed_benchmark/info_a100_py2.jsonl` (change the name of the jsonl to a different name for your profiling). 4 | 2. Run `python speed_benchmark/make_plot_with_jsonl.py`, which produces the `speed_benchmark/plot_with_info.pdf`. Again make sure you change the jsonl which is being processed. 5 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | Lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-python@v4 15 | with: 16 | python-version: "3.12" 17 | - uses: pre-commit/action@v3.0.0 18 | env: 19 | RUFF_OUTPUT_FORMAT: github 20 | -------------------------------------------------------------------------------- /bitsandbytes/consts.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import platform 3 | 4 | DYNAMIC_LIBRARY_SUFFIX = { 5 | "Darwin": ".dylib", 6 | "Linux": ".so", 7 | "Windows": ".dll", 8 | }.get(platform.system(), ".so") 9 | 10 | PACKAGE_DIR = Path(__file__).parent 11 | PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" 12 | NONPYTORCH_DOC_URL = "https://github.com/TimDettmers/bitsandbytes/blob/main/docs/source/nonpytorchcuda.mdx" 13 | -------------------------------------------------------------------------------- /check_bnb_install.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import bitsandbytes as bnb 4 | 5 | p = torch.nn.Parameter(torch.rand(10, 10).cuda()) 6 | a = torch.rand(10, 10).cuda() 7 | 8 | p1 = p.data.sum().item() 9 | 10 | adam = bnb.optim.Adam([p]) 11 | 12 | out = a * p 13 | loss = out.sum() 14 | loss.backward() 15 | adam.step() 16 | 17 | p2 = p.data.sum().item() 18 | 19 | assert p1 != p2 20 | print("SUCCESS!") 21 | print("Installation was successful!") 22 | -------------------------------------------------------------------------------- /docs/source/faqs.mdx: -------------------------------------------------------------------------------- 1 | # FAQs 2 | 3 | Please submit your questions in [this Github Discussion thread](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1013) if you feel that they will likely affect a lot of other users and that they haven't been sufficiently covered in the documentation. 4 | 5 | We'll pick the most generally applicable ones and post the QAs here or integrate them into the general documentation (also feel free to submit doc PRs, please). 6 | -------------------------------------------------------------------------------- /.github/scripts/build-cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare build_arch 3 | declare build_os 4 | 5 | set -xeuo pipefail 6 | 7 | pip install cmake==3.28.3 8 | 9 | if [ "${build_os:0:5}" == macos ] && [ "${build_arch}" == aarch64 ]; then 10 | cmake -DCMAKE_OSX_ARCHITECTURES=arm64 -DCOMPUTE_BACKEND=cpu . 11 | else 12 | cmake -DCOMPUTE_BACKEND=cpu . 13 | fi 14 | cmake --build . --config Release 15 | 16 | output_dir="output/${build_os}/${build_arch}" 17 | mkdir -p "${output_dir}" 18 | (shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}") 19 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | permissions: 10 | contents: read 11 | pull-requests: write # Allows posting comments on pull requests 12 | 13 | jobs: 14 | build: 15 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 16 | with: 17 | package_name: bitsandbytes 18 | secrets: 19 | hf_token: ${{ secrets.HUGGINGFACE_PUSH }} 20 | comment_bot_token: ${{ secrets.GITHUB_TOKEN }} 21 | -------------------------------------------------------------------------------- /_typos.toml: -------------------------------------------------------------------------------- 1 | [files] 2 | # Skip these files in typo checks 3 | extend-exclude = [ 4 | "csrc/xpu_ops.h", 5 | "csrc/xpu_ops.cpp", 6 | "csrc/xpu_kernels.h", 7 | "csrc/xpu_kernels.cpp" 8 | ] 9 | 10 | [default] 11 | extend-ignore-re = [ 12 | "@Ther-nul", # valid Github user 13 | ] 14 | extend-ignore-identifiers-re = [ 15 | ".*arange.*", 16 | ".*ARANGE.*", 17 | ] 18 | 19 | [type.py.extend-words] 20 | "BA" = "BA" # used as a commented-out variable in tests 21 | 22 | [type.cuda.extend-words] 23 | "subtile" = "subtile" 24 | "subtiles" = "subtiles" 25 | "transation" = "transation" # TODO: is this transition, transaction, translation..? 26 | -------------------------------------------------------------------------------- /docs/source/reference/optim/rmsprop.mdx: -------------------------------------------------------------------------------- 1 | # RMSprop 2 | 3 | RMSprop is an adaptive learning rate optimizer that is very similar to [`Adagrad`]. RMSprop stores a *weighted average* of the squared past gradients for each parameter and uses it to scale their learning rate. This allows the learning rate to be automatically lower or higher depending on the magnitude of the gradient, and it prevents the learning rate from diminishing. 4 | 5 | ## RMSprop[[api-class]] 6 | 7 | [[autodoc]] bitsandbytes.optim.RMSprop 8 | 9 | ## RMSprop8bit 10 | 11 | [[autodoc]] bitsandbytes.optim.RMSprop8bit 12 | 13 | ## RMSprop32bit 14 | 15 | [[autodoc]] bitsandbytes.optim.RMSprop32bit 16 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # ran black and isort for coherent code formatting 2 | bfa0e33294f2b1dc25e65a33be2397f989824298 3 | 4 | # reran black with linelength 80 for greater readability 5 | ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 6 | 7 | # Remove f-prefix from strings that don't use formatting 8 | 7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6 9 | 10 | # format tests/linear_4bit.py 11 | 34735ba89de8235ea9da6ef409f814dcea9e2038 12 | 13 | # Reformat with ruff-format 14 | 5a4263f4dc05fe8f78f4111beab9f68a81deeab1 15 | 16 | # CHANGELOG: to reverse chron order + mdformat 17 | 4743ff0d43e04e4cc3e5d8b9e7cd016c0defa36d 18 | 19 | # Apply clang-format 20 | 4955d136ae083c2be1236d8915913166e1790aad 21 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml.disabled: -------------------------------------------------------------------------------- 1 | name: Stale Bot 2 | 3 | on: 4 | schedule: 5 | - cron: "0 15 * * *" 6 | 7 | jobs: 8 | close_stale_issues: 9 | name: Close Stale Issues 10 | if: github.repository == 'TimDettmers/bitsandbytes' 11 | runs-on: ubuntu-latest 12 | env: 13 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 14 | steps: 15 | - uses: actions/checkout@v3 16 | 17 | - name: Setup Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: 3.8 21 | 22 | - name: Install requirements 23 | run: | 24 | pip install PyGithub 25 | - name: Close stale issues 26 | run: | 27 | python scripts/stale.py 28 | -------------------------------------------------------------------------------- /docs/source/reference/optim/lars.mdx: -------------------------------------------------------------------------------- 1 | # LARS 2 | 3 | [LARS (Layer-wise Adaptive Rate Scaling)](https:/hf.co/papers/1708.03888) is an optimizer designed for training with large batch sizes to accelerate training. LARS uses a separate learning rate for each *layer* instead of each parameter. The learning rate is calculated from a *trust ratio* between the weight and gradient norm in a layer. This helps calibrate a stable update size. 4 | 5 | ## LARS[[api-class]] 6 | 7 | [[autodoc]] bitsandbytes.optim.LARS 8 | - __init__ 9 | 10 | ## LARS8bit 11 | 12 | [[autodoc]] bitsandbytes.optim.LARS8bit 13 | - __init__ 14 | 15 | ## LARS32bit 16 | 17 | [[autodoc]] bitsandbytes.optim.LARS32bit 18 | - __init__ 19 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - doc-builder* 8 | - v*-release 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.sha }} 15 | package: bitsandbytes 16 | repo_owner: bitsandbytes-foundation 17 | # avoid /src suffix leading to wrong links, like bitsandbytes/blob/main/src/bitsandbytes/nn/ 18 | version_tag_suffix: '' # defaults to '/src' 19 | custom_container: huggingface/transformers-doc-builder 20 | secrets: 21 | hf_token: ${{ secrets.HUGGINGFACE_PUSH }} 22 | -------------------------------------------------------------------------------- /bitsandbytes/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from .modules import ( 6 | Embedding, 7 | Embedding4bit, 8 | Embedding8bit, 9 | EmbeddingFP4, 10 | EmbeddingNF4, 11 | Int8Params, 12 | Linear4bit, 13 | Linear8bitLt, 14 | LinearFP4, 15 | LinearNF4, 16 | OutlierAwareLinear, 17 | Params4bit, 18 | StableEmbedding, 19 | SwitchBackLinearBnb, 20 | ) 21 | from .triton_based_modules import ( 22 | StandardLinear, 23 | SwitchBackLinear, 24 | SwitchBackLinearGlobal, 25 | SwitchBackLinearVectorwise, 26 | ) 27 | -------------------------------------------------------------------------------- /docs/source/reference/nn/embeddings.mdx: -------------------------------------------------------------------------------- 1 | # Embedding 2 | 3 | The embedding class is used to store and retrieve word embeddings from their indices. There are two types of embeddings in bitsandbytes, the standard PyTorch [`Embedding`] class and the [`StableEmbedding`] class. 4 | 5 | The [`StableEmbedding`] class was introduced in the [8-bit Optimizers via Block-wise Quantization](https://hf.co/papers/2110.02861) paper to reduce gradient variance as a result of the non-uniform distribution of input tokens. This class is designed to support quantization. 6 | 7 | ## Embedding 8 | 9 | [[autodoc]] bitsandbytes.nn.Embedding 10 | - __init__ 11 | 12 | ## StableEmbedding 13 | 14 | [[autodoc]] bitsandbytes.nn.StableEmbedding 15 | - __init__ 16 | -------------------------------------------------------------------------------- /docs/source/reference/optim/adagrad.mdx: -------------------------------------------------------------------------------- 1 | # AdaGrad 2 | 3 | [AdaGrad (Adaptive Gradient)](https://jmlr.org/papers/v12/duchi11a.html) is an adaptive learning rate optimizer. AdaGrad stores a sum of the squared past gradients for each parameter and uses it to scale their learning rate. This allows the learning rate to be automatically lower or higher depending on the magnitude of the gradient, eliminating the need to manually tune the learning rate. 4 | 5 | ## Adagrad[[api-class]] 6 | 7 | [[autodoc]] bitsandbytes.optim.Adagrad 8 | - __init__ 9 | 10 | ## Adagrad8bit 11 | 12 | [[autodoc]] bitsandbytes.optim.Adagrad8bit 13 | - __init__ 14 | 15 | ## Adagrad32bit 16 | 17 | [[autodoc]] bitsandbytes.optim.Adagrad32bit 18 | - __init__ 19 | -------------------------------------------------------------------------------- /docs/source/reference/optim/sgd.mdx: -------------------------------------------------------------------------------- 1 | # SGD 2 | 3 | Stochastic gradient descent (SGD) is a basic gradient descent optimizer to minimize loss given a set of model parameters and updates the parameters in the opposite direction of the gradient. The update is performed on a randomly sampled mini-batch of data from the dataset. 4 | 5 | bitsandbytes also supports momentum and Nesterov momentum to accelerate SGD by adding a weighted average of past gradients to the current gradient. 6 | 7 | ## SGD[[api-class]] 8 | 9 | [[autodoc]] bitsandbytes.optim.SGD 10 | - __init__ 11 | 12 | ## SGD8bit 13 | 14 | [[autodoc]] bitsandbytes.optim.SGD8bit 15 | - __init__ 16 | 17 | ## SGD32bit 18 | 19 | [[autodoc]] bitsandbytes.optim.SGD32bit 20 | - __init__ 21 | -------------------------------------------------------------------------------- /examples/int8_inference_huggingface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import LlamaForCausalLM, LlamaTokenizer 3 | 4 | MAX_NEW_TOKENS = 128 5 | model_name = "meta-llama/Llama-2-7b-hf" 6 | 7 | text = "Hamburg is in which country?\n" 8 | tokenizer = LlamaTokenizer.from_pretrained(model_name) 9 | input_ids = tokenizer(text, return_tensors="pt").input_ids 10 | 11 | max_memory = f"{int(torch.cuda.mem_get_info()[0] / 1024**3) - 2}GB" 12 | 13 | n_gpus = torch.cuda.device_count() 14 | max_memory = {i: max_memory for i in range(n_gpus)} 15 | 16 | model = LlamaForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory) 17 | 18 | generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) 19 | print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) 20 | -------------------------------------------------------------------------------- /docs/source/reference/optim/lamb.mdx: -------------------------------------------------------------------------------- 1 | # LAMB 2 | 3 | [LAMB (Layerwise adaptive large batch optimization)](https://hf.co/papers/1904.00962) is an adaptive optimizer designed for training with large batch sizes to accelerate training, combining ideas from [`LARS`] and [`Adam`] to automatically scale the learning rate for each layer: 4 | 5 | - calculates a *trust ratio* between the weight and gradient norm in a layer and clips the ratio to prevent overly large or small updates 6 | - updates weights with the first and second-moments 7 | 8 | ## LAMB[[api-class]] 9 | 10 | [[autodoc]] bitsandbytes.optim.LAMB 11 | - __init__ 12 | 13 | ## LAMB8bit 14 | 15 | [[autodoc]] bitsandbytes.optim.LAMB8bit 16 | - __init__ 17 | 18 | ## LAMB32bit 19 | 20 | [[autodoc]] bitsandbytes.optim.LAMB32bit 21 | - __init__ 22 | -------------------------------------------------------------------------------- /.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | 6 | concurrency: 7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | build: 12 | if: github.repository == 'bitsandbytes-foundation/bitsandbytes' 13 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 14 | with: 15 | commit_sha: ${{ github.event.pull_request.head.sha }} 16 | pr_number: ${{ github.event.number }} 17 | package: bitsandbytes 18 | repo_owner: bitsandbytes-foundation 19 | # avoid /src suffix leading to wrong links, like bitsandbytes/blob/main/src/bitsandbytes/nn/ 20 | version_tag_suffix: '' # defaults to '/src' 21 | custom_container: huggingface/transformers-doc-builder 22 | -------------------------------------------------------------------------------- /docs/source/reference/nn/linear4bit.mdx: -------------------------------------------------------------------------------- 1 | # 4-bit quantization 2 | 3 | [QLoRA](https://hf.co/papers/2305.14314) is a finetuning method that quantizes a model to 4-bits and adds a set of low-rank adaptation (LoRA) weights to the model and tuning them through the quantized weights. This method also introduces a new data type, 4-bit NormalFloat (`LinearNF4`) in addition to the standard Float4 data type (`LinearFP4`). `LinearNF4` is a quantization data type for normally distributed data and can improve performance. 4 | 5 | ## Linear4bit 6 | 7 | [[autodoc]] bitsandbytes.nn.Linear4bit 8 | - __init__ 9 | 10 | ## LinearFP4 11 | 12 | [[autodoc]] bitsandbytes.nn.LinearFP4 13 | - __init__ 14 | 15 | ## LinearNF4 16 | 17 | [[autodoc]] bitsandbytes.nn.LinearNF4 18 | - __init__ 19 | 20 | ## Params4bit 21 | 22 | [[autodoc]] bitsandbytes.nn.Params4bit 23 | - __init__ 24 | -------------------------------------------------------------------------------- /.github/scripts/build-xpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare build_os 3 | 4 | set -xeuo pipefail 5 | 6 | # We currently only build XPU on Linux. 7 | if [ "${build_os:0:6}" == ubuntu ]; then 8 | # TODO: We might want to pre-build this as our own customized image in the future. 9 | image=intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu22.04 10 | echo "Using image $image" 11 | docker run --rm -i \ 12 | -w /src -v "$PWD:/src" "$image" sh -c \ 13 | "apt-get update \ 14 | && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 15 | cmake bison intel-fw-gpu intel-ocloc \ 16 | && cmake -DCOMPUTE_BACKEND=xpu . \ 17 | && cmake --build . --config Release" 18 | fi 19 | 20 | output_dir="output/${build_os}/x86_64" 21 | mkdir -p "${output_dir}" 22 | (shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}") 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.14.3 4 | hooks: 5 | - id: ruff 6 | args: 7 | - --fix 8 | - id: ruff-format 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v5.0.0 11 | hooks: 12 | - id: check-merge-conflict 13 | - id: check-yaml 14 | - id: end-of-file-fixer 15 | - id: fix-byte-order-marker 16 | - id: trailing-whitespace 17 | - id: mixed-line-ending 18 | args: 19 | - --fix=lf 20 | exclude: '\.bat$' 21 | - repo: https://github.com/crate-ci/typos 22 | rev: v1.26.0 23 | hooks: 24 | - id: typos 25 | - repo: https://github.com/pre-commit/mirrors-clang-format 26 | rev: v20.1.6 27 | hooks: 28 | - id: clang-format 29 | types_or: [c++, c, cuda] 30 | files: ^csrc/ 31 | -------------------------------------------------------------------------------- /.github/scripts/auditwheel_show.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | 4 | 5 | def main(): 6 | ap = argparse.ArgumentParser() 7 | ap.add_argument("wheels", nargs="*") 8 | args = ap.parse_args() 9 | if not args.wheels: 10 | ap.error("At least one wheel must be provided.") 11 | for whl in args.wheels: 12 | print(f"### `{whl}`") 13 | 14 | audit_wheel_output = subprocess.run( 15 | ["auditwheel", "show", whl], 16 | capture_output=True, 17 | text=True, 18 | errors="backslashreplace", 19 | ) 20 | 21 | if audit_wheel_output.stdout: 22 | print(audit_wheel_output.stdout) 23 | 24 | if audit_wheel_output.stderr: 25 | print(f"**Error:**\n```\n{audit_wheel_output.stderr}\n```") 26 | 27 | print("---") 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /docs/source/reference/optim/lion.mdx: -------------------------------------------------------------------------------- 1 | # Lion 2 | 3 | [Lion (Evolved Sign Momentum)](https://hf.co/papers/2302.06675) is a unique optimizer that uses the sign of the gradient to determine the update direction of the momentum. This makes Lion more memory-efficient and faster than [`AdamW`] which tracks and store the first and second-order moments. 4 | 5 | ## Lion[[api-class]] 6 | 7 | [[autodoc]] bitsandbytes.optim.Lion 8 | - __init__ 9 | 10 | ## Lion8bit 11 | 12 | [[autodoc]] bitsandbytes.optim.Lion8bit 13 | - __init__ 14 | 15 | ## Lion32bit 16 | 17 | [[autodoc]] bitsandbytes.optim.Lion32bit 18 | - __init__ 19 | 20 | ## PagedLion 21 | 22 | [[autodoc]] bitsandbytes.optim.PagedLion 23 | - __init__ 24 | 25 | ## PagedLion8bit 26 | 27 | [[autodoc]] bitsandbytes.optim.PagedLion8bit 28 | - __init__ 29 | 30 | ## PagedLion32bit 31 | 32 | [[autodoc]] bitsandbytes.optim.PagedLion32bit 33 | - __init__ 34 | -------------------------------------------------------------------------------- /docs/source/reference/optim/ademamix.mdx: -------------------------------------------------------------------------------- 1 | # AdEMAMix 2 | 3 | [AdEMAMix](https://hf.co/papers/2409.03137) is a variant of the [`Adam`] optimizer. 4 | 5 | bitsandbytes also supports paged optimizers which take advantage of CUDAs unified memory to transfer memory from the GPU to the CPU when GPU memory is exhausted. 6 | 7 | ## AdEMAMix[[api-class]] 8 | 9 | [[autodoc]] bitsandbytes.optim.AdEMAMix 10 | - __init__ 11 | 12 | ## AdEMAMix8bit 13 | 14 | [[autodoc]] bitsandbytes.optim.AdEMAMix8bit 15 | - __init__ 16 | 17 | ## AdEMAMix32bit 18 | 19 | [[autodoc]] bitsandbytes.optim.AdEMAMix32bit 20 | - __init__ 21 | 22 | ## PagedAdEMAMix 23 | 24 | [[autodoc]] bitsandbytes.optim.PagedAdEMAMix 25 | - __init__ 26 | ## PagedAdEMAMix8bit 27 | 28 | [[autodoc]] bitsandbytes.optim.PagedAdEMAMix8bit 29 | - __init__ 30 | 31 | ## PagedAdEMAMix32bit 32 | 33 | [[autodoc]] bitsandbytes.optim.PagedAdEMAMix32bit 34 | - __init__ 35 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to bitsandbytes 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints, install the [pre-commit hooks as documented here](https://huggingface.co/docs/bitsandbytes/main/en/contributing#setup-pre-commit-hooks). 13 | 14 | ## Issues 15 | We use GitHub issues to track public bugs. Please ensure your description is 16 | clear and has sufficient instructions to be able to reproduce the issue. 17 | 18 | ## License 19 | By contributing to bitsandbytes, you agree that your contributions will be licensed 20 | under the LICENSE file in the root directory of this source tree. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a proposal/request for a new feature 3 | labels: ["feature"] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request 11 | description: | 12 | A clear and concise description of the feature proposal. 13 | 14 | - type: textarea 15 | id: motivation 16 | validations: 17 | required: true 18 | attributes: 19 | label: Motivation 20 | description: | 21 | Please outline the motivation for the proposal. Is your feature request related to a problem? 22 | 23 | - type: textarea 24 | id: contribution 25 | validations: 26 | required: true 27 | attributes: 28 | label: Your contribution 29 | description: | 30 | Is there any way that you could help, e.g. by submitting a PR? 31 | -------------------------------------------------------------------------------- /docs/source/reference/nn/linear8bit.mdx: -------------------------------------------------------------------------------- 1 | # LLM.int8() 2 | [LLM.int8()](https://hf.co/papers/2208.07339) is a quantization method that aims to make large language model inference more accessible without significant degradation. Unlike naive 8-bit quantization, which can result in loss of critical information and accuracy, LLM.int8() dynamically adapts to ensure sensitive components of the computation retain higher precision when needed. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output. 3 | 4 | [Further Resources](../../explanations/resources#llm-int8) 5 | 6 | ## Linear8bitLt 7 | 8 | [[autodoc]] bitsandbytes.nn.Linear8bitLt 9 | - __init__ 10 | 11 | ## Int8Params 12 | 13 | [[autodoc]] bitsandbytes.nn.Int8Params 14 | - __init__ 15 | -------------------------------------------------------------------------------- /bitsandbytes/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit 7 | from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit 8 | from .adamw import ( 9 | AdamW, 10 | AdamW8bit, 11 | AdamW32bit, 12 | PagedAdamW, 13 | PagedAdamW8bit, 14 | PagedAdamW32bit, 15 | ) 16 | from .ademamix import AdEMAMix, AdEMAMix8bit, AdEMAMix32bit, PagedAdEMAMix, PagedAdEMAMix8bit, PagedAdEMAMix32bit 17 | from .lamb import LAMB, LAMB8bit, LAMB32bit 18 | from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS 19 | from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit 20 | from .optimizer import GlobalOptimManager 21 | from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit 22 | from .sgd import SGD, SGD8bit, SGD32bit 23 | -------------------------------------------------------------------------------- /.github/scripts/set_platform_tag.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import platform 3 | import sys 4 | 5 | 6 | def get_platform_tag(architecture): 7 | system = platform.system() 8 | 9 | if system == "Linux": 10 | tag = "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64" 11 | elif system == "Darwin": 12 | tag = "macosx_14_0_arm64" 13 | elif system == "Windows": 14 | tag = "win_amd64" if architecture == "x86_64" else "win_arm64" 15 | else: 16 | sys.exit(f"Unsupported system: {system}") 17 | 18 | return tag 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser(description="Determine platform tag.") 23 | parser.add_argument("arch", type=str, help="Architecture (e.g., x86_64, aarch64)") 24 | args = parser.parse_args() 25 | 26 | tag = get_platform_tag(args.arch) 27 | 28 | print(tag) # This will be captured by the GitHub Actions workflow 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /docs/source/index.mdx: -------------------------------------------------------------------------------- 1 | # bitsandbytes 2 | 3 | bitsandbytes enables accessible large language models via k-bit quantization for PyTorch. bitsandbytes provides three main features for dramatically reducing memory consumption for inference and training: 4 | 5 | * 8-bit optimizers uses block-wise quantization to maintain 32-bit performance at a small fraction of the memory cost. 6 | * LLM.int8() or 8-bit quantization enables large language model inference with only half the required memory and without any performance degradation. This method is based on vector-wise quantization to quantize most features to 8-bits and separately treating outliers with 16-bit matrix multiplication. 7 | * QLoRA or 4-bit quantization enables large language model training with several memory-saving techniques that don't compromise performance. This method quantizes a model to 4-bits and inserts a small set of trainable low-rank adaptation (LoRA) weights to allow training. 8 | 9 | # License 10 | 11 | bitsandbytes is MIT licensed. 12 | -------------------------------------------------------------------------------- /docs/source/reference/optim/adamw.mdx: -------------------------------------------------------------------------------- 1 | # AdamW 2 | 3 | [AdamW](https://hf.co/papers/1711.05101) is a variant of the [`Adam`] optimizer that separates weight decay from the gradient update based on the observation that the weight decay formulation is different when applied to [`SGD`] and [`Adam`]. 4 | 5 | bitsandbytes also supports paged optimizers which take advantage of CUDAs unified memory to transfer memory from the GPU to the CPU when GPU memory is exhausted. 6 | 7 | ## AdamW[[api-class]] 8 | 9 | [[autodoc]] bitsandbytes.optim.AdamW 10 | - __init__ 11 | 12 | ## AdamW8bit 13 | 14 | [[autodoc]] bitsandbytes.optim.AdamW8bit 15 | - __init__ 16 | 17 | ## AdamW32bit 18 | 19 | [[autodoc]] bitsandbytes.optim.AdamW32bit 20 | - __init__ 21 | 22 | ## PagedAdamW 23 | 24 | [[autodoc]] bitsandbytes.optim.PagedAdamW 25 | - __init__ 26 | ## PagedAdamW8bit 27 | 28 | [[autodoc]] bitsandbytes.optim.PagedAdamW8bit 29 | - __init__ 30 | 31 | ## PagedAdamW32bit 32 | 33 | [[autodoc]] bitsandbytes.optim.PagedAdamW32bit 34 | - __init__ 35 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: LLVM 3 | AlignAfterOpenBracket: BlockIndent 4 | BinPackArguments: true 5 | BinPackParameters: true 6 | BracedInitializerIndentWidth: 4 7 | ColumnLimit: 120 8 | Cpp11BracedListStyle: true 9 | IndentWidth: 4 10 | IndentWrappedFunctionNames: true 11 | PointerAlignment: Left 12 | SeparateDefinitionBlocks: Always 13 | Standard: c++17 14 | StatementMacros: 15 | - 'MAKE_PreconditionOptimizer32bit1State' 16 | - 'MAKE_PreconditionOptimizer32bit2State' 17 | - 'MAKE_PreconditionStatic8bit1State' 18 | - 'MAKE_PreconditionStatic8bit2State' 19 | - 'MAKE_Optimizer32bit1State' 20 | - 'MAKE_optimizerStatic8bit1State' 21 | - 'MAKE_optimizerStatic8bit2State' 22 | - 'MAKE_OptimizerStatic8bit1StateBlockwise' 23 | - 'MAKE_OptimizerStatic8bit2StateBlockwise' 24 | - 'MAKE_kQuantizeBlockwise' 25 | - 'MAKE_BLOCKWISE8' 26 | - 'MAKE_ELEMENTWISE_FUNC' 27 | - 'CMAKE_ELEMENTWISE_FUNC' 28 | - 'MAKE_FUNC8' 29 | - 'MAKE_FUNC32' 30 | - 'MAKE_CBLOCKWISE8' 31 | - 'MAKE_CFUNC8' 32 | - 'MAKE_CFUNC32' 33 | 34 | UseTab: Never 35 | 36 | ... 37 | -------------------------------------------------------------------------------- /examples/compile_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch._dynamo 3 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 4 | 5 | # torch._dynamo.config.suppress_errors = True 6 | 7 | torch.set_float32_matmul_precision("high") 8 | 9 | quantization_config = BitsAndBytesConfig(load_in_8bit=True) 10 | 11 | # torch._dynamo.config.capture_dynamic_output_shape_ops = True 12 | 13 | model_id = "google/gemma-2-2b-it" 14 | # model_id = "Qwen/Qwen2.5-7B" 15 | 16 | tokenizer = AutoTokenizer.from_pretrained(model_id) 17 | model = AutoModelForCausalLM.from_pretrained( 18 | model_id, 19 | quantization_config=quantization_config, 20 | device_map="auto", 21 | torch_dtype=torch.bfloat16, 22 | ) 23 | 24 | input_text = "Write me a poem about Machine Learning." 25 | input_ids = tokenizer(input_text, return_tensors="pt").to(model.device) 26 | 27 | # model.forward = torch.compile(model.forward, fullgraph=True) 28 | 29 | model = torch.compile(model) 30 | 31 | outputs = model.generate(**input_ids, max_new_tokens=32) 32 | print(tokenizer.decode(outputs[0])) 33 | -------------------------------------------------------------------------------- /tests/test_cuda_setup_evaluator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path 4 | from bitsandbytes.cuda_specs import CUDASpecs 5 | 6 | 7 | @pytest.fixture 8 | def cuda120_spec() -> CUDASpecs: 9 | return CUDASpecs( 10 | cuda_version_string="120", 11 | highest_compute_capability=(8, 6), 12 | cuda_version_tuple=(12, 0), 13 | ) 14 | 15 | 16 | @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") 17 | def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): 18 | monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) 19 | assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" 20 | 21 | 22 | @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") 23 | def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): 24 | monkeypatch.setenv("BNB_CUDA_VERSION", "110") 25 | assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" 26 | assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve bitsandbytes 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info 8 | description: Please share your relevant system information with us 9 | placeholder: platform, python version, hardware, ... 10 | validations: 11 | required: true 12 | 13 | - type: textarea 14 | id: reproduction 15 | validations: 16 | required: true 17 | attributes: 18 | label: Reproduction 19 | description: | 20 | Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. 21 | Please provide the simplest reproducer as possible so that we can quickly fix the issue. 22 | 23 | placeholder: | 24 | Reproducer: 25 | 26 | - type: textarea 27 | id: expected-behavior 28 | validations: 29 | required: true 30 | attributes: 31 | label: Expected behavior 32 | description: "A clear and concise description of what you would expect to happen." 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/scripts/build-rocm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare build_arch 3 | declare build_os 4 | declare rocm_version 5 | 6 | set -xeuo pipefail 7 | bnb_rocm_arch="gfx90a;gfx942;gfx1100;gfx1101" 8 | 9 | # ROCm 6.4+ - Add gfx1150/gfx1151/gfx1200/gfx1201. Note we assume >=6.4.4. 10 | [[ "${rocm_version}" == 6.4.* || "${rocm_version}" == 7.* ]] && bnb_rocm_arch="${bnb_rocm_arch};gfx1150;gfx1151;gfx1200;gfx1201" 11 | 12 | # ROCm 7.0+ - Add gfx950 13 | [[ "${rocm_version}" == 7.* ]] && bnb_rocm_arch="${bnb_rocm_arch};gfx950" 14 | 15 | if [ "${build_os:0:6}" == ubuntu ]; then 16 | image=rocm/dev-ubuntu-22.04:${rocm_version}-complete 17 | echo "Using image $image" 18 | docker run --rm --platform "linux/$build_arch" -i \ 19 | -w /src -v "$PWD:/src" "$image" sh -c \ 20 | "apt-get update \ 21 | && pip install cmake==3.31.6 \ 22 | && cmake -DCOMPUTE_BACKEND=hip -DCMAKE_BUILD_TYPE=MinSizeRel -DCMAKE_HIP_FLAGS=\"--offload-compress\" -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \ 23 | && cmake --build ." 24 | fi 25 | 26 | output_dir="output/${build_os}/${build_arch}" 27 | mkdir -p "${output_dir}" 28 | (shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}") 29 | -------------------------------------------------------------------------------- /docs/source/reference/optim/optim_overview.mdx: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | [8-bit optimizers](https://hf.co/papers/2110.02861) reduce the memory footprint of 32-bit optimizers without any performance degradation which means you can train large models with many parameters faster. At the core of 8-bit optimizers is block-wise quantization which enables quantization accuracy, computational efficiency, and stability. 4 | 5 | bitsandbytes provides 8-bit optimizers through the base [`Optimizer8bit`] class, and additionally provides [`Optimizer2State`] and [`Optimizer1State`] for 2-state (for example, [`Adam`]) and 1-state (for example, [`Adagrad`]) optimizers respectively. To provide custom optimizer hyperparameters, use the [`GlobalOptimManager`] class to configure the optimizer. 6 | 7 | ## Optimizer8bit 8 | 9 | [[autodoc]] bitsandbytes.optim.optimizer.Optimizer8bit 10 | - __init__ 11 | 12 | ## Optimizer2State 13 | 14 | [[autodoc]] bitsandbytes.optim.optimizer.Optimizer2State 15 | - __init__ 16 | 17 | ## Optimizer1State 18 | 19 | [[autodoc]] bitsandbytes.optim.optimizer.Optimizer1State 20 | - __init__ 21 | 22 | ## Utilities 23 | 24 | [[autodoc]] bitsandbytes.optim.optimizer.GlobalOptimManager 25 | -------------------------------------------------------------------------------- /docs/source/reference/optim/adam.mdx: -------------------------------------------------------------------------------- 1 | # Adam 2 | 3 | [Adam (Adaptive moment estimation)](https://hf.co/papers/1412.6980) is an adaptive learning rate optimizer, combining ideas from [`SGD`] with momentum and [`RMSprop`] to automatically scale the learning rate: 4 | 5 | - a weighted average of the past gradients to provide direction (first-moment) 6 | - a weighted average of the *squared* past gradients to adapt the learning rate to each parameter (second-moment) 7 | 8 | bitsandbytes also supports paged optimizers which take advantage of CUDAs unified memory to transfer memory from the GPU to the CPU when GPU memory is exhausted. 9 | 10 | ## Adam[[api-class]] 11 | 12 | [[autodoc]] bitsandbytes.optim.Adam 13 | - __init__ 14 | 15 | ## Adam8bit 16 | 17 | [[autodoc]] bitsandbytes.optim.Adam8bit 18 | - __init__ 19 | 20 | ## Adam32bit 21 | 22 | [[autodoc]] bitsandbytes.optim.Adam32bit 23 | - __init__ 24 | 25 | ## PagedAdam 26 | 27 | [[autodoc]] bitsandbytes.optim.PagedAdam 28 | - __init__ 29 | 30 | ## PagedAdam8bit 31 | 32 | [[autodoc]] bitsandbytes.optim.PagedAdam8bit 33 | - __init__ 34 | 35 | ## PagedAdam32bit 36 | 37 | [[autodoc]] bitsandbytes.optim.PagedAdam32bit 38 | - __init__ 39 | -------------------------------------------------------------------------------- /.github/scripts/build-xpu-windows.bat: -------------------------------------------------------------------------------- 1 | set INTEL_DLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/75d4eb97-914a-4a95-852c-7b9733d80f74/intel-deep-learning-essentials-2025.1.3.8_offline.exe 2 | set INTEL_DLE_TMP=%RUNNER_TEMP%\intel_dle 3 | set INTEL_DLE_LOG=%RUNNER_TEMP%\intel_dle_log.txt 4 | 5 | echo ::group::Intel Deep Learning Essentials Installation 6 | curl -o intel-dle-installer.exe %INTEL_DLE_URL% 7 | start /wait "Intel DLE Install" intel-dle-installer.exe -f %INTEL_DLE_TMP% -l %INTEL_DLE_LOG% --silent -a --eula=accept -p=NEED_VS2022_INTEGRATION=0 8 | type %INTEL_DLE_LOG% 9 | if ERRORLEVEL 1 ( 10 | echo Failed to install Intel Deep Learning Essentials 11 | exit /b 1 12 | ) 13 | echo ::endgroup:: 14 | 15 | echo ::group::Build Environment Setup 16 | call "%ProgramFiles(x86)%\Intel\oneAPI\setvars.bat" 17 | cmake -G Ninja -DCOMPUTE_BACKEND=xpu -DCMAKE_BUILD_TYPE=Release . 18 | if ERRORLEVEL 1 ( 19 | echo Failed to setup environment 20 | exit /b 1 21 | ) 22 | echo ::endgroup:: 23 | 24 | echo ::group::Building with XPU backend 25 | cmake --build . --config Release 26 | if ERRORLEVEL 1 ( 27 | echo Build failed 28 | exit /b 1 29 | ) 30 | echo ::endgroup:: 31 | 32 | set output_dir=output\%build_os%\x86_64 33 | if not exist "%output_dir%" mkdir "%output_dir%" 34 | copy bitsandbytes\*.dll "%output_dir%\" 2>nul 35 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import random 3 | 4 | import numpy as np 5 | import pytest 6 | import torch 7 | 8 | 9 | def _set_seed(): 10 | torch.manual_seed(0) 11 | torch.cuda.manual_seed_all(0) 12 | torch.mps.manual_seed(0) 13 | np.random.seed(0) 14 | random.seed(0) 15 | 16 | 17 | def pytest_runtest_call(item): 18 | try: 19 | _set_seed() 20 | item.runtest() 21 | except AssertionError as ae: 22 | if str(ae) == "Torch not compiled with CUDA enabled": 23 | pytest.skip("Torch not compiled with CUDA enabled") 24 | raise 25 | except RuntimeError as re: 26 | # CUDA-enabled Torch build, but no CUDA-capable device found 27 | if "Found no NVIDIA driver on your system" in str(re): 28 | pytest.skip("No NVIDIA driver found") 29 | raise 30 | 31 | 32 | @pytest.hookimpl(trylast=True) 33 | def pytest_runtest_teardown(item, nextitem): 34 | gc.collect() 35 | if torch.cuda.is_available(): 36 | torch.cuda.empty_cache() 37 | elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 38 | torch.mps.empty_cache() 39 | 40 | 41 | @pytest.fixture(scope="session") 42 | def requires_cuda() -> bool: 43 | cuda_available = torch.cuda.is_available() 44 | if not cuda_available: 45 | pytest.skip("CUDA is required") 46 | return cuda_available 47 | -------------------------------------------------------------------------------- /csrc/xpu_ops.h: -------------------------------------------------------------------------------- 1 | #ifndef xpu_ops_H 2 | #define xpu_ops_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | template 16 | static inline void sycl_kernel_submit(sycl::nd_range range, sycl::queue q, ker_t ker) { 17 | auto cgf = [&](::sycl::handler& cgh) 18 | [[sycl::reqd_sub_group_size(subgroup_size)]] { cgh.parallel_for(range, ker); }; 19 | q.submit(cgf); 20 | } 21 | 22 | template 23 | static inline void sycl_comp_kernel_submit(sycl::nd_range range, sycl::queue q, ker_t ker) { 24 | auto cgf = [&](::sycl::handler& cgh) [[sycl::reqd_sub_group_size(subgroup_size)]] { 25 | ker.sycl_ker_local_memory_creation(cgh); 26 | cgh.parallel_for(range, ker); 27 | }; 28 | q.submit(cgf); 29 | } 30 | 31 | template 32 | void dequantizeBlockwise( 33 | float* code, unsigned char* A, float* absmax, T* out, int workgroup_size, const int n, sycl::queue* stream 34 | ); 35 | template 36 | void gemv_4bit_inference( 37 | int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, 38 | int blocksize, sycl::queue* stream 39 | ); 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /docs/source/reference/functional.mdx: -------------------------------------------------------------------------------- 1 | # Overview 2 | The `bitsandbytes.functional` API provides the low-level building blocks for the library's features. 3 | 4 | ## When to Use `bitsandbytes.functional` 5 | 6 | * When you need direct control over quantized operations and their parameters. 7 | * To build custom layers or operations leveraging low-bit arithmetic. 8 | * To integrate with other ecosystem tooling. 9 | * For experimental or research purposes requiring non-standard quantization or performance optimizations. 10 | 11 | ## LLM.int8() 12 | [[autodoc]] functional.int8_linear_matmul 13 | 14 | [[autodoc]] functional.int8_mm_dequant 15 | 16 | [[autodoc]] functional.int8_vectorwise_dequant 17 | 18 | [[autodoc]] functional.int8_vectorwise_quant 19 | 20 | ## 4-bit 21 | [[autodoc]] functional.dequantize_4bit 22 | 23 | [[autodoc]] functional.dequantize_fp4 24 | 25 | [[autodoc]] functional.dequantize_nf4 26 | 27 | [[autodoc]] functional.gemv_4bit 28 | 29 | [[autodoc]] functional.quantize_4bit 30 | 31 | [[autodoc]] functional.quantize_fp4 32 | 33 | [[autodoc]] functional.quantize_nf4 34 | 35 | [[autodoc]] functional.QuantState 36 | 37 | ## Dynamic 8-bit Quantization 38 | 39 | Primitives used in the 8-bit optimizer quantization. 40 | 41 | For more details see [8-Bit Approximations for Parallelism in Deep Learning](https://arxiv.org/abs/1511.04561) 42 | 43 | [[autodoc]] functional.dequantize_blockwise 44 | 45 | [[autodoc]] functional.quantize_blockwise 46 | 47 | ## Utility 48 | [[autodoc]] functional.get_ptr 49 | -------------------------------------------------------------------------------- /docs/source/errors.mdx: -------------------------------------------------------------------------------- 1 | # Troubleshoot 2 | 3 | ## No kernel image available 4 | 5 | This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. 6 | 7 | To solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME`` as well as ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``? 8 | 9 | If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation. 10 | 11 | ## `fatbinwrap` 12 | 13 | This error occurs if there is a mismatch between CUDA versions in the C++ library and the CUDA part. Make sure you have right CUDA in your `$PATH` and `$LD_LIBRARY_PATH` variable. In the conda base environment you can find the library under: 14 | 15 | ```bash 16 | ls $CONDA_PREFIX/lib/*cudart* 17 | ``` 18 | Make sure this path is appended to the `LD_LIBRARY_PATH` so bnb can find the CUDA runtime environment library (cudart). 19 | 20 | If this does not fix the issue, please try compilation from source next. 21 | 22 | If this does not work, please open an issue and paste the printed environment if you call `make` and the associated error when running bnb. 23 | -------------------------------------------------------------------------------- /docs/source/contributing.mdx: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | ## Setup 4 | 5 | ### Setup pre-commit hooks 6 | - Install pre-commit hooks with `pip install pre-commit`. 7 | - Run `pre-commit install` once to install the hooks, so they will be run on every commit. 8 | - If the hooks introduce changes, they'll be visible with `git diff`. Review them and `git add` them if everything is fine, then re-execute the before commit, it should pass now. 9 | - If you want to manually trigger the hooks, you may do `pre-commit run --all-files` 10 | 11 | Now all the pre-commit hooks will be automatically run when you try to commit and if they introduce some changes, you need to re-add the changed files before being able to commit and push. 12 | 13 | ### Ignore formatting revs 14 | - Run `git config blame.ignoreRevsFile .git-blame-ignore-revs`. This will make it so that `git blame` is aware of commits that were logged to be solely formatting-related. 15 | 16 | ## Doc-string syntax 17 | 18 | We're following NumPy doc-string conventions with the only notable difference being that we use Markdown instead of Rich text format (RTF) for markup within the doc-strings. 19 | 20 | Please see the existing documentation to see how to generate autodocs. 21 | 22 | ## Documentation 23 | - [guideline for documentation syntax](https://github.com/huggingface/doc-builder#readme) 24 | - images shall be uploaded via PR in the `bitsandbytes/` directory [here](https://huggingface.co/datasets/huggingface/documentation-images) 25 | - find the documentation builds for each PR in a link posted to the PR, such as https://moon-ci-docs.huggingface.co/docs/bitsandbytes/pr_1012/en/introduction 26 | -------------------------------------------------------------------------------- /bitsandbytes/backends/hpu/ops.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | import math 3 | 4 | import torch 5 | 6 | from ..._ops import register_kernel 7 | from ..utils import GAUDI_SW_VER 8 | 9 | 10 | # convert btw standard 4-bit compression format and ipex compression format 11 | # needed for backward compatibility with older versions of gaudi sw 12 | def _reverse_4bit_compress_format(weight: torch.Tensor): 13 | out_1 = (weight & 0xF0) >> 4 14 | out_2 = (weight & 0xF) << 4 15 | out = out_1 | out_2 16 | return out 17 | 18 | 19 | @register_kernel("bitsandbytes::dequantize_4bit", "hpu") 20 | def _( 21 | A: torch.Tensor, 22 | absmax: torch.Tensor, 23 | blocksize: int, 24 | quant_type: str, 25 | shape: Sequence[int], 26 | dtype: torch.dtype, 27 | ) -> torch.Tensor: 28 | torch._check_is_size(blocksize) 29 | torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4, got {quant_type}") 30 | torch._check( 31 | A.dtype in [torch.bfloat16, torch.uint8], 32 | lambda: f"quant_storage supports uint8 or bfloat16, but got {A.dtype}", 33 | ) 34 | 35 | # Enable non uint8 dtype 36 | if A.dtype != torch.uint8: 37 | A = A.view(torch.uint8) 38 | 39 | A = A.reshape(-1) 40 | 41 | if GAUDI_SW_VER and (GAUDI_SW_VER.major < 1 or GAUDI_SW_VER.minor < 22): 42 | A = _reverse_4bit_compress_format(A) 43 | 44 | # HPU dequantization function for NF4 quantized tensors. 45 | out_dq = torch.ops.hpu.dequantize_nf4( 46 | A, 47 | absmax.to(dtype), 48 | blocksize, 49 | out_shape=(math.prod(shape),), 50 | out_dtype=dtype, 51 | ) 52 | 53 | output = out_dq.reshape(shape) 54 | 55 | return output 56 | -------------------------------------------------------------------------------- /csrc/xpu_kernels.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #ifndef xpu_kernels 5 | #define xpu_kernels 6 | 7 | template class kDequantizeBlockwise { 8 | public: 9 | SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const; 10 | 11 | kDequantizeBlockwise(float* code_, uint8_t* A_, float* absmax_, T* out_, const int blocksize_, const int n_) 12 | : code(code_), A(A_), absmax(absmax_), out(out_), blocksize(blocksize_), n(n_) {} 13 | 14 | private: 15 | float* code; 16 | uint8_t* A; 17 | float* absmax; 18 | T* out; 19 | const int blocksize; 20 | const int n; 21 | }; 22 | 23 | template class kgemv_4bit_inference { 24 | public: 25 | SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const; 26 | 27 | kgemv_4bit_inference( 28 | int M_, int N_, int K_, T* A_, unsigned char* B_, float* absmax_, const float* datatype_, T* out_, int lda_, 29 | int ldb_, int ldc_, int blocksize_ 30 | ) 31 | : M(M_), N(N_), K(K_), A(A_), B(B_), absmax(absmax_), datatype(datatype_), out(out_), lda(lda_), ldb(ldb_), 32 | ldc(ldc_), blocksize(blocksize_), quant_map() {} 33 | 34 | void sycl_ker_local_memory_creation(sycl::handler& cgh) { quant_map = sycl::local_accessor(16, cgh); } 35 | 36 | private: 37 | int M; 38 | int N; 39 | int K; 40 | T* A; 41 | unsigned char* B; 42 | float* absmax; 43 | const float* datatype; 44 | T* out; 45 | int lda; 46 | int ldb; 47 | int ldc; 48 | int blocksize; 49 | sycl::local_accessor quant_map; 50 | }; 51 | 52 | #endif 53 | -------------------------------------------------------------------------------- /csrc/common.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // TODO: Let's make some of these constexpr and put in a namespace. 4 | 5 | #define BNB_CC_PASCAL 600 6 | #define BNB_CC_PASCAL_X2 620 7 | #define BNB_CC_VOLTA 700 8 | #define BNB_CC_VOLTA_XAVIER 720 9 | #define BNB_CC_TURING 750 10 | #define BNB_CC_AMPERE 800 11 | #define BNB_CC_AMPERE2 860 12 | #define BNB_CC_AMPERE2_ORIN 870 13 | #define BNB_CC_ADA 890 14 | #define BNB_CC_HOPPER 900 15 | #define BNB_CC_BLACKWELL 1000 16 | 17 | #define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA) 18 | #define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER) 19 | #define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE) 20 | #define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA) 21 | 22 | #define BNB_WARP_SIZE 32 23 | 24 | // The maximum number of resident threads per SM varies by arch. 25 | // For A100/H100 and all prior to Turing, it is 2048, which allows 26 | // for 2 full blocks of 1024 threads per SM. 27 | // Reference: 28 | // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability 29 | #if __CUDA_ARCH__ == 750 30 | #define BNB_MAX_THREADS_PER_SM 1024 31 | #elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890 32 | #define BNB_MAX_THREADS_PER_SM 1536 33 | #else 34 | #define BNB_MAX_THREADS_PER_SM 2048 35 | #endif 36 | 37 | // Maximum resident warps per SM is always directly related to the number of threads. 38 | #define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE)) 39 | 40 | // Maximum resident blocks per SM may vary. 41 | #if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 42 | #define BNB_MAX_BLOCKS_PER_SM 16 43 | #else 44 | #define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2) 45 | #endif 46 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from distutils.errors import DistutilsModuleError 6 | import os 7 | from warnings import warn 8 | 9 | from setuptools import find_packages, setup 10 | from setuptools.command.build_py import build_py 11 | from setuptools.dist import Distribution 12 | 13 | 14 | # Tested with wheel v0.29.0 15 | class BinaryDistribution(Distribution): 16 | def has_ext_modules(self): 17 | return True 18 | 19 | 20 | class ExtBuildPy(build_py): 21 | def run(self): 22 | if os.environ.get("BNB_SKIP_CMAKE", "").lower() in ("1", "true", "yes"): 23 | print("skipping CMake build") 24 | else: 25 | # build_cmake needs to be called prior to build_py, as the latter 26 | # collects the files output into the package directory. 27 | try: 28 | self.run_command("build_cmake") 29 | except DistutilsModuleError: 30 | warn( 31 | "scikit-build-core not installed, CMake will not be invoked automatically. " 32 | "Please install scikit-build-core or run CMake manually to build extensions." 33 | ) 34 | super().run() 35 | 36 | 37 | cmdclass = {"build_py": ExtBuildPy} 38 | 39 | setup_kwargs = { 40 | "version": "0.49.1.dev0", 41 | "packages": find_packages(), 42 | "distclass": BinaryDistribution, 43 | "cmdclass": {"build_py": ExtBuildPy}, 44 | } 45 | 46 | if os.environ.get("BNB_SKIP_CMAKE", "").lower() not in ("1", "true", "yes"): 47 | setup_kwargs["cmake_source_dir"] = "." 48 | 49 | setup(**setup_kwargs) 50 | -------------------------------------------------------------------------------- /docs/source/_toctree.yml: -------------------------------------------------------------------------------- 1 | - title: Get started 2 | sections: 3 | - local: index 4 | title: bitsandbytes 5 | - local: installation 6 | title: Installation 7 | - local: quickstart 8 | title: Quickstart 9 | 10 | - title: Usage Guides 11 | sections: 12 | - local: optimizers 13 | title: 8-bit optimizers 14 | - local: fsdp_qlora 15 | title: FSDP-QLoRA 16 | - local: integrations 17 | title: Integrations 18 | - local: errors 19 | title: Troubleshoot 20 | - local: contributing 21 | title: Contribute 22 | - local: faqs 23 | title: FAQs 24 | - title: Explanation 25 | sections: 26 | - local: explanations/optimizers 27 | title: 8-bit optimizers 28 | - local: explanations/resources 29 | title: Papers, resources & how to cite 30 | - title: API reference 31 | sections: 32 | - title: Functional 33 | local: reference/functional 34 | - title: Optimizers 35 | sections: 36 | - local: reference/optim/optim_overview 37 | title: Overview 38 | - local: reference/optim/adagrad 39 | title: AdaGrad 40 | - local: reference/optim/adam 41 | title: Adam 42 | - local: reference/optim/adamw 43 | title: AdamW 44 | - local: reference/optim/ademamix 45 | title: AdEMAMix 46 | - local: reference/optim/lamb 47 | title: LAMB 48 | - local: reference/optim/lars 49 | title: LARS 50 | - local: reference/optim/lion 51 | title: Lion 52 | - local: reference/optim/rmsprop 53 | title: RMSprop 54 | - local: reference/optim/sgd 55 | title: SGD 56 | - title: Modules 57 | sections: 58 | - local: reference/nn/linear8bit 59 | title: LLM.int8() 60 | - local: reference/nn/linear4bit 61 | title: 4-bit quantizer 62 | - local: reference/nn/embeddings 63 | title: Embedding 64 | -------------------------------------------------------------------------------- /benchmarking/optimizer_benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extracted from tests/test_optim.py 3 | 4 | Usage: pytest benchmarking/optimizer_benchmark.py 5 | """ 6 | 7 | import time 8 | 9 | import pytest 10 | from tests.helpers import describe_dtype, id_formatter 11 | import torch 12 | 13 | import bitsandbytes as bnb 14 | 15 | str2optimizers = {"paged_adamw": (torch.optim.AdamW, bnb.optim.PagedAdamW)} 16 | 17 | 18 | @pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1")) 19 | @pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype) 20 | @pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name")) 21 | @pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode")) 22 | @pytest.mark.benchmark 23 | def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): 24 | layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) 25 | layers1 = layers1.to(gtype) 26 | layers1 = layers1.cuda() 27 | 28 | large_tensor = None 29 | if mode == "torch": 30 | optim = str2optimizers[optim_name][0](layers1.parameters()) 31 | else: 32 | optim = str2optimizers[optim_name][1](layers1.parameters()) 33 | # 12 GB 34 | large_tensor = torch.empty((int(4.5e9),), device="cuda") 35 | 36 | torch.cuda.synchronize() 37 | time.sleep(5) 38 | 39 | num_batches = 5 40 | batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype) 41 | lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda() 42 | 43 | for i in range(num_batches): 44 | print(i) 45 | b = batches[i] 46 | if i == 2: 47 | torch.cuda.synchronize() 48 | t0 = time.time() 49 | 50 | out1 = layers1(b) 51 | 52 | loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean() 53 | loss1.backward() 54 | optim.step() 55 | torch.cuda.synchronize() 56 | print(mode, time.time() - t0) 57 | -------------------------------------------------------------------------------- /benchmarking/int8/int8_benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic benchmark for text generation. 3 | 4 | Usage: python benchmarking/int8/int8_benchmark.py 5 | """ 6 | 7 | import time 8 | 9 | import torch 10 | from torch.profiler import ProfilerActivity, profile 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 12 | 13 | MAX_NEW_TOKENS = 128 14 | model_name = "meta-llama/Llama-3.1-8B" 15 | 16 | text = "Below is a question. I need an answer.\n\nExplain machine learning: " 17 | tokenizer = AutoTokenizer.from_pretrained(model_name) 18 | input_ids = tokenizer([text] * 8, return_tensors="pt").input_ids.to(0) 19 | 20 | model = AutoModelForCausalLM.from_pretrained( 21 | model_name, 22 | device_map="auto", 23 | quantization_config=BitsAndBytesConfig( 24 | load_in_8bit=True, 25 | llm_int8_threshold=6.0, 26 | ), 27 | attn_implementation="sdpa", 28 | torch_dtype=torch.float16, 29 | ) 30 | 31 | print(model) 32 | 33 | # warmup 34 | print("Warmup...") 35 | for i in range(3): 36 | generated_ids = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS) 37 | 38 | print("Profiler starting...") 39 | with profile( 40 | activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 41 | with_modules=True, 42 | with_stack=True, 43 | ) as prof: 44 | model.generate(input_ids, max_new_tokens=1) 45 | 46 | print( 47 | prof.key_averages().table( 48 | sort_by="cpu_time_total", 49 | max_name_column_width=50, 50 | top_level_events_only=True, 51 | row_limit=50, 52 | ) 53 | ) 54 | 55 | torch.cuda.synchronize() 56 | 57 | 58 | print("Generating...") 59 | num = 0 60 | time_1 = time.time() 61 | for i in range(5): 62 | generated_ids = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS) 63 | num += len(generated_ids[0]) 64 | 65 | print("=" * 40) 66 | print(f"Example:\n{tokenizer.decode(generated_ids[0])}") 67 | print("=" * 40) 68 | print(f"Speed: {num / (time.time() - time_1)}token/s") 69 | -------------------------------------------------------------------------------- /csrc/mps_ops.mm: -------------------------------------------------------------------------------- 1 | #import 2 | 3 | #define HLF_MAX 65504 4 | #define TH 1024 5 | #define NUM 4 6 | #define NUM_BLOCK 4096 7 | 8 | static inline MPSGraph* get_graph() { 9 | static MPSGraph* cur = nil; 10 | if (!cur) { 11 | cur = [[MPSGraph alloc] init]; 12 | } 13 | return cur; 14 | } 15 | 16 | static inline id get_device() { 17 | NSError* error = nil; 18 | static id device = nil; 19 | if (!device) { 20 | device = MTLCreateSystemDefaultDevice(); 21 | } 22 | if (!device) { 23 | NSLog(@"Failed to get MPS device"); 24 | abort(); 25 | } 26 | return device; 27 | } 28 | 29 | static inline id get_library() { 30 | NSError* error = nil; 31 | static id library = nil; 32 | if (!library) { 33 | library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; 34 | } 35 | if (!library) { 36 | NSLog(@"Failed to load bitsandbytes.metallib"); 37 | abort(); 38 | } 39 | return library; 40 | } 41 | 42 | /*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n) 43 | { 44 | id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 45 | dataType:MPSDataTypeInt8 axis:0 name:@"out"]; return out; 46 | }*/ 47 | 48 | // MPSGraph function for quantize 49 | extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) { 50 | id device = get_device(); 51 | id library = get_library(); 52 | static id kernel = nil; 53 | if (!kernel) { 54 | kernel = [library newFunctionWithName:@"quantize"]; 55 | if (!kernel) { 56 | NSLog(@"Failed to load bitsandbytes.metallib"); 57 | abort(); 58 | } 59 | } 60 | NSLog(@"Not implemented"); 61 | return nil; 62 | } 63 | -------------------------------------------------------------------------------- /bitsandbytes/backends/utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | from packaging import version 4 | import torch 5 | 6 | try: 7 | import triton.language as tl # noqa: F401 8 | 9 | import triton # noqa: F401 10 | 11 | triton_available = True 12 | except ImportError: 13 | triton_available = False 14 | 15 | 16 | _NF4_QUANT_TABLE = torch.tensor( 17 | [ 18 | -1.0, 19 | -0.6961928009986877, 20 | -0.5250730514526367, 21 | -0.39491748809814453, 22 | -0.28444138169288635, 23 | -0.18477343022823334, 24 | -0.09105003625154495, 25 | 0.0, 26 | 0.07958029955625534, 27 | 0.16093020141124725, 28 | 0.24611230194568634, 29 | 0.33791524171829224, 30 | 0.44070982933044434, 31 | 0.5626170039176941, 32 | 0.7229568362236023, 33 | 1.0, 34 | ], 35 | dtype=torch.float32, 36 | device="xpu" 37 | if hasattr(torch, "xpu") and torch.xpu.is_available() 38 | else "cpu", # Only cpu/xpu use this table for now. 39 | ) 40 | _FP4_QUANT_TABLE = torch.tensor( 41 | [ 42 | 0.0000, 43 | 0.0052, 44 | 0.6667, 45 | 1.0000, 46 | 0.3333, 47 | 0.5000, 48 | 0.1667, 49 | 0.2500, 50 | 0.0000, 51 | -0.0052, 52 | -0.6667, 53 | -1.0000, 54 | -0.3333, 55 | -0.5000, 56 | -0.1667, 57 | -0.2500, 58 | ], 59 | dtype=torch.float32, 60 | device="xpu" 61 | if hasattr(torch, "xpu") and torch.xpu.is_available() 62 | else "cpu", # Only cpu/xpu use this table for now. 63 | ) 64 | CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} 65 | 66 | 67 | def get_gaudi_sw_version(): 68 | """ 69 | Returns the installed version of Gaudi SW. 70 | """ 71 | output = subprocess.run( 72 | "pip list | grep habana-torch-plugin", 73 | shell=True, 74 | text=True, 75 | capture_output=True, 76 | ) 77 | # If grep return nothing 78 | if not output.stdout.strip(): 79 | return None 80 | 81 | return version.parse(output.stdout.split("\n")[0].split()[-1]) 82 | 83 | 84 | GAUDI_SW_VER = get_gaudi_sw_version() 85 | -------------------------------------------------------------------------------- /.github/scripts/build-cuda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare build_arch 3 | declare build_os 4 | declare cuda_version 5 | declare cuda_targets 6 | 7 | set -xeuo pipefail 8 | 9 | if [[ -v cuda_targets ]]; then 10 | build_capability="${cuda_targets}" 11 | elif [ "${build_arch}" = "aarch64" ]; then 12 | build_capability="75;80;90" 13 | 14 | # CUDA 12.8-12.9: Add sm100/sm120 15 | [[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="75;80;90;100;120" 16 | 17 | # CUDA 13.0+: Add sm100/sm110/sm120 18 | [[ "${cuda_version}" == 13.*.* ]] && build_capability="75;80;90;100;110;120;121" 19 | else 20 | # By default, target Pascal through Hopper. 21 | build_capability="60;70;75;80;86;89;90" 22 | 23 | # CUDA 12.8+: Add sm100 and sm120; remove < sm70 to align with PyTorch 2.8+cu128 minimum 24 | [[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="70;75;80;86;89;90;100;120" 25 | 26 | # CUDA 13.0+: Remove < sm75 to align with PyTorch 2.9+cu130 minimum 27 | [[ "${cuda_version}" == 13.*.* ]] && build_capability="75;80;86;89;90;100;120" 28 | fi 29 | 30 | [[ "${build_os}" = windows-* ]] && python3 -m pip install ninja 31 | 32 | if [ "${build_os:0:6}" == ubuntu ]; then 33 | # We'll use Rocky Linux 8 in order to maintain manylinux 2.24 compatibility. 34 | image="nvidia/cuda:${cuda_version}-devel-rockylinux8" 35 | echo "Using image $image" 36 | 37 | docker run -i -w /src -v "$PWD:/src" "$image" bash -c \ 38 | "dnf -y --refresh update --security \ 39 | && dnf -y install cmake gcc-toolset-11 --setopt=install_weak_deps=False --setopt=tsflags=nodocs \ 40 | && source scl_source enable gcc-toolset-11 \ 41 | && cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \ 42 | && cmake --build . --config Release" 43 | else 44 | pip install cmake==3.28.3 45 | cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DCMAKE_BUILD_TYPE=Release -S . 46 | cmake --build . --config Release 47 | fi 48 | 49 | 50 | output_dir="output/${build_os}/${build_arch}" 51 | mkdir -p "${output_dir}" 52 | (shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}") 53 | -------------------------------------------------------------------------------- /bitsandbytes/triton/dequantize_rowwise.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from bitsandbytes.triton.triton_utils import is_triton_available 6 | 7 | if not is_triton_available(): 8 | 9 | def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): 10 | return None 11 | else: 12 | import triton 13 | import triton.language as tl 14 | 15 | # rowwise quantize 16 | 17 | # TODO: autotune this better. 18 | @triton.autotune( 19 | configs=[ 20 | triton.Config({}, num_stages=1, num_warps=8), 21 | triton.Config({}, num_stages=2, num_warps=8), 22 | triton.Config({}, num_stages=4, num_warps=8), 23 | triton.Config({}, num_stages=8, num_warps=8), 24 | triton.Config({}, num_stages=1), 25 | triton.Config({}, num_stages=2), 26 | triton.Config({}, num_stages=4), 27 | triton.Config({}, num_stages=8), 28 | triton.Config({}, num_warps=1), 29 | triton.Config({}, num_warps=2), 30 | triton.Config({}, num_warps=4), 31 | triton.Config({}, num_warps=8), 32 | ], 33 | key=["n_elements"], 34 | ) 35 | @triton.jit 36 | def _dequantize_rowwise( 37 | x_ptr, 38 | state_x, 39 | output_ptr, 40 | inv_127, 41 | n_elements, 42 | BLOCK_SIZE: tl.constexpr, 43 | P2: tl.constexpr, 44 | ): 45 | pid = tl.program_id(axis=0) 46 | block_start = pid * BLOCK_SIZE 47 | arange = tl.arange(0, P2) 48 | offsets = block_start + arange 49 | row_mask = arange < BLOCK_SIZE 50 | x = tl.load(x_ptr + offsets, mask=row_mask) 51 | max_val = tl.load(state_x + pid) 52 | output = max_val * x * inv_127 53 | tl.store(output_ptr + offsets, output, mask=row_mask) 54 | 55 | def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): 56 | output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) 57 | 58 | P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) 59 | 60 | assert x.is_cuda and output.is_cuda 61 | n_elements = output.numel() 62 | grid = lambda meta: (x.shape[0],) 63 | _dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) 64 | return output 65 | -------------------------------------------------------------------------------- /bitsandbytes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import importlib 8 | import sys 9 | 10 | import torch 11 | 12 | from . import _ops, research, utils 13 | from .autograd._functions import ( 14 | MatmulLtState, 15 | matmul, 16 | matmul_4bit, 17 | ) 18 | from .backends.cpu import ops as cpu_ops 19 | from .backends.default import ops as default_ops 20 | from .nn import modules 21 | from .optim import adam 22 | 23 | # This is a signal for integrations with transformers/diffusers. 24 | # Eventually we may remove this but it is currently required for compatibility. 25 | features = {"multi_backend"} 26 | supported_torch_devices = { 27 | "cpu", 28 | "cuda", # NVIDIA/AMD GPU 29 | "xpu", # Intel GPU 30 | "hpu", # Intel Gaudi 31 | "npu", # Ascend NPU 32 | "mps", # Apple Silicon 33 | } 34 | 35 | if torch.cuda.is_available(): 36 | from .backends.cuda import ops as cuda_ops 37 | 38 | if hasattr(torch, "xpu") and torch.xpu.is_available(): 39 | from .backends.xpu import ops as xpu_ops 40 | 41 | if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"): 42 | # In case not automatically imported 43 | import habana_frameworks.torch 44 | 45 | if hasattr(torch, "hpu") and torch.hpu.is_available(): 46 | from .backends.hpu import ops as hpu_ops 47 | 48 | 49 | def _import_backends(): 50 | """ 51 | Discover and autoload all available backends installed as separate packages. 52 | Packages with an entrypoint for "bitsandbytes.backends" will be loaded. 53 | Inspired by PyTorch implementation: https://pytorch.org/tutorials/prototype/python_extension_autoload.html 54 | """ 55 | from importlib.metadata import entry_points 56 | 57 | extensions = entry_points(group="bitsandbytes.backends") 58 | 59 | for ext in extensions: 60 | try: 61 | entry = ext.load() 62 | entry() 63 | except Exception as e: 64 | raise RuntimeError(f"bitsandbytes: failed to load backend {ext.name}: {e}") from e 65 | 66 | 67 | _import_backends() 68 | 69 | __pdoc__ = { 70 | "libbitsandbytes": False, 71 | "optim.optimizer.Optimizer8bit": False, 72 | "optim.optimizer.MockArgs": False, 73 | } 74 | 75 | __version__ = "0.49.1.dev0" 76 | -------------------------------------------------------------------------------- /bitsandbytes/triton/quantize_rowwise.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from bitsandbytes.triton.triton_utils import is_triton_available 6 | 7 | if not is_triton_available(): 8 | 9 | def quantize_rowwise(x: torch.Tensor): 10 | return None 11 | else: 12 | import triton 13 | import triton.language as tl 14 | 15 | # rowwise quantize 16 | 17 | # TODO: autotune this better. 18 | @triton.autotune( 19 | configs=[ 20 | triton.Config({}, num_stages=1, num_warps=8), 21 | triton.Config({}, num_stages=2, num_warps=8), 22 | triton.Config({}, num_stages=4, num_warps=8), 23 | triton.Config({}, num_stages=8, num_warps=8), 24 | triton.Config({}, num_stages=1), 25 | triton.Config({}, num_stages=2), 26 | triton.Config({}, num_stages=4), 27 | triton.Config({}, num_stages=8), 28 | triton.Config({}, num_warps=1), 29 | triton.Config({}, num_warps=2), 30 | triton.Config({}, num_warps=4), 31 | triton.Config({}, num_warps=8), 32 | ], 33 | key=["n_elements"], 34 | ) 35 | @triton.jit 36 | def _quantize_rowwise( 37 | x_ptr, 38 | output_ptr, 39 | output_maxs, 40 | n_elements, 41 | BLOCK_SIZE: tl.constexpr, 42 | P2: tl.constexpr, 43 | ): 44 | pid = tl.program_id(axis=0) 45 | block_start = pid * BLOCK_SIZE 46 | arange = tl.arange(0, P2) 47 | offsets = block_start + arange 48 | row_mask = arange < BLOCK_SIZE 49 | x = tl.load(x_ptr + offsets, mask=row_mask) 50 | 51 | abs_x = tl.abs(x) 52 | max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) 53 | output = tl.libdevice.llrint(127.0 * (x / max_val)) 54 | tl.store(output_ptr + offsets, output, mask=row_mask) 55 | tl.store(output_maxs + pid, max_val) 56 | 57 | def quantize_rowwise(x: torch.Tensor): 58 | output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) 59 | output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) 60 | 61 | P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) 62 | 63 | assert x.is_cuda and output.is_cuda 64 | n_elements = output.numel() 65 | grid = lambda meta: (x.shape[0],) 66 | _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) 67 | return output, output_maxs 68 | -------------------------------------------------------------------------------- /scripts/stale.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team, the AllenNLP library authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Script to close stale issue. Taken in part from the AllenNLP repository. 16 | https://github.com/allenai/allennlp. 17 | """ 18 | 19 | from datetime import datetime as dt, timezone 20 | import os 21 | 22 | from github import Github 23 | 24 | # All labels that we don't want to touch 25 | LABELS_TO_EXEMPT = [ 26 | "feature-request", 27 | ] 28 | 29 | 30 | def main(): 31 | g = Github(os.environ["GITHUB_TOKEN"]) 32 | repo = g.get_repo("TimDettmers/bitsandbytes") 33 | open_issues = repo.get_issues(state="open") 34 | 35 | for issue in open_issues: 36 | comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True) 37 | last_comment = comments[0] if len(comments) > 0 else None 38 | if ( 39 | last_comment is not None 40 | and last_comment.user.login == "github-actions[bot]" 41 | and (dt.now(timezone.utc) - issue.updated_at).days > 7 42 | and (dt.now(timezone.utc) - issue.created_at).days >= 30 43 | and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) 44 | ): 45 | issue.edit(state="closed") 46 | elif ( 47 | (dt.now(timezone.utc) - issue.updated_at).days > 23 48 | and (dt.now(timezone.utc) - issue.created_at).days >= 30 49 | and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) 50 | ): 51 | issue.create_comment( 52 | "This issue has been automatically marked as stale because it has not had " 53 | "recent activity. If you think this still needs to be addressed " 54 | "please comment on this thread.\n\n", 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /install_cuda.sh: -------------------------------------------------------------------------------- 1 | URL118=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run 2 | URL120=https://developer.download.nvidia.com/compute/cuda/12.0.1/local_installers/cuda_12.0.1_525.85.12_linux.run 3 | URL121=https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run 4 | URL122=https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run 5 | URL123=https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run 6 | URL124=https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run 7 | URL125=https://developer.download.nvidia.com/compute/cuda/12.5.1/local_installers/cuda_12.5.1_555.42.06_linux.run 8 | URL126=https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run 9 | 10 | CUDA_VERSION=$1 11 | BASE_PATH=$2 12 | EXPORT_BASHRC=$3 13 | 14 | if [[ -n "$CUDA_VERSION" ]]; then 15 | if [[ "$CUDA_VERSION" -eq "118" ]]; then 16 | URL=$URL118 17 | FOLDER=cuda-11.8 18 | elif [[ "$CUDA_VERSION" -eq "120" ]]; then 19 | URL=$URL120 20 | FOLDER=cuda-12.0 21 | elif [[ "$CUDA_VERSION" -eq "121" ]]; then 22 | URL=$URL121 23 | FOLDER=cuda-12.1 24 | elif [[ "$CUDA_VERSION" -eq "122" ]]; then 25 | URL=$URL122 26 | FOLDER=cuda-12.2 27 | elif [[ "$CUDA_VERSION" -eq "123" ]]; then 28 | URL=$URL123 29 | FOLDER=cuda-12.3 30 | elif [[ "$CUDA_VERSION" -eq "124" ]]; then 31 | URL=$URL124 32 | FOLDER=cuda-12.4 33 | elif [[ "$CUDA_VERSION" -eq "125" ]]; then 34 | URL=$URL125 35 | FOLDER=cuda-12.5 36 | elif [[ "$CUDA_VERSION" -eq "126" ]]; then 37 | URL=$URL126 38 | FOLDER=cuda-12.6 39 | else 40 | echo "argument error: No cuda version passed as input. Choose among versions 118 to 126" 41 | fi 42 | else 43 | echo "argument error: No cuda version passed as input. Choose among versions 118 to 126" 44 | fi 45 | 46 | FILE=$(basename $URL) 47 | 48 | if [[ -n "$CUDA_VERSION" ]]; then 49 | echo $URL 50 | echo $FILE 51 | wget $URL 52 | bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent 53 | if [ "$EXPORT_BASHRC" -eq "1" ]; then 54 | echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64" >> ~/.bashrc 55 | echo "export PATH=\$PATH:$BASE_PATH/$FOLDER/bin" >> ~/.bashrc 56 | source ~/.bashrc 57 | fi 58 | else 59 | echo "" 60 | fi 61 | -------------------------------------------------------------------------------- /bitsandbytes/research/nn/modules.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import bitsandbytes as bnb 7 | 8 | T = TypeVar("T", bound="torch.nn.Module") 9 | 10 | 11 | class LinearFP8Mixed(nn.Linear): 12 | def __init__(self, input_features, output_features, bias=True): 13 | super().__init__(input_features, output_features, bias) 14 | self.bw_code = None 15 | self.fw_code = None 16 | array = [4096, 2048, 1024, 512, 256, 128, 64, 0] 17 | for i, k in enumerate(array): 18 | if input_features > array[i + 1]: 19 | self.bsz = k 20 | break 21 | for i, k in enumerate(array): 22 | if output_features > array[i + 1]: 23 | self.bsz2 = k 24 | break 25 | 26 | def forward(self, x: torch.Tensor): 27 | if self.fw_code is None: 28 | self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) 29 | self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) 30 | 31 | out = bnb.research.matmul_fp8_mixed( 32 | x, 33 | self.weight.t(), 34 | fw_code=self.fw_code, 35 | bw_code=self.bw_code, 36 | bsz=self.bsz, 37 | bsz2=self.bsz2, 38 | ) 39 | if self.bias is not None: 40 | out += self.bias 41 | 42 | return out 43 | 44 | 45 | class LinearFP8Global(nn.Linear): 46 | def __init__(self, input_features, output_features, bias=True): 47 | super().__init__(input_features, output_features, bias) 48 | self.bw_code = None 49 | self.fw_code = None 50 | array = [4096, 2048, 1024, 512, 256, 128, 64, 0] 51 | for i, k in enumerate(array): 52 | if input_features > array[i + 1]: 53 | self.bsz = k 54 | break 55 | for i, k in enumerate(array): 56 | if output_features > array[i + 1]: 57 | self.bsz2 = k 58 | break 59 | 60 | def forward(self, x: torch.Tensor): 61 | if self.fw_code is None: 62 | self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) 63 | self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) 64 | 65 | out = bnb.matmul_fp8_global( 66 | x, 67 | self.weight.t(), 68 | fw_code=self.fw_code, 69 | bw_code=self.bw_code, 70 | bsz=self.bsz, 71 | bsz2=self.bsz2, 72 | ) 73 | if self.bias is not None: 74 | out += self.bias 75 | 76 | return out 77 | -------------------------------------------------------------------------------- /tests/test_triton.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from bitsandbytes.nn import Linear8bitLt 5 | from bitsandbytes.nn.triton_based_modules import SwitchBackLinear 6 | from bitsandbytes.triton.triton_utils import is_triton_available 7 | from tests.helpers import TRUE_FALSE 8 | 9 | 10 | @pytest.mark.skipif( 11 | not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, 12 | reason="This test requires triton and a GPU with compute capability 8.0 or higher.", 13 | ) 14 | @pytest.mark.deprecated 15 | @pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE) 16 | def test_switchback(vector_wise_quantization): 17 | for dim in [83]: 18 | for batch in [13]: 19 | standard = torch.nn.Linear(dim, 4 * dim).cuda().half() 20 | switchback = ( 21 | SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() 22 | ) 23 | baseline = Linear8bitLt(dim, 4 * dim).cuda().half() 24 | switchback.weight.data.copy_(standard.weight) 25 | switchback.bias.data.copy_(standard.bias) 26 | baseline.weight.data.copy_(standard.weight) 27 | baseline.bias.data.copy_(standard.bias) 28 | 29 | x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True) 30 | x2 = x1.clone().detach().requires_grad_(True) 31 | x3 = x1.clone().detach().requires_grad_(True) 32 | 33 | out_standard = standard(x1) 34 | (2**10 * out_standard.abs().mean()).backward() 35 | 36 | print(x2.dtype) 37 | out_sb = switchback(x2) 38 | (2**10 * out_sb.abs().mean()).backward() 39 | 40 | out_baseline = baseline(x3) 41 | (2**10 * out_baseline.abs().mean()).backward() 42 | 43 | err_sb = (out_standard - out_sb).abs().mean() 44 | err_baseline = (out_standard - out_baseline).abs().mean() 45 | print("OUT", err_sb, err_baseline) 46 | assert err_sb < 2 * err_baseline 47 | 48 | err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() 49 | err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean() 50 | 51 | print("GW2", err_sb, err_baseline) 52 | assert err_sb < 2 * err_baseline 53 | 54 | err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() 55 | err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean() 56 | 57 | print("GW1", err_sb, err_baseline) 58 | assert err_sb < 2 * err_baseline 59 | 60 | err_sb = (x1.grad - x2.grad).abs().mean() 61 | err_baseline = (x1.grad - x3.grad).abs().mean() 62 | 63 | print("GX1", err_sb, err_baseline) 64 | assert err_sb < 2 * err_baseline 65 | -------------------------------------------------------------------------------- /bitsandbytes/triton/quantize_columnwise_and_transpose.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from bitsandbytes.triton.triton_utils import is_triton_available 6 | 7 | if not is_triton_available(): 8 | 9 | def quantize_columnwise_and_transpose(x: torch.Tensor): 10 | return None 11 | else: 12 | import triton 13 | import triton.language as tl 14 | 15 | # This kernel does fused columnwise quantization and transpose. 16 | 17 | # TODO: autotune this better. 18 | @triton.autotune( 19 | configs=[ 20 | triton.Config({}, num_stages=1), 21 | triton.Config({}, num_stages=2), 22 | triton.Config({}, num_stages=4), 23 | triton.Config({}, num_stages=8), 24 | triton.Config({}, num_stages=16), 25 | triton.Config({}, num_stages=1, num_warps=8), 26 | triton.Config({}, num_stages=2, num_warps=8), 27 | triton.Config({}, num_stages=4, num_warps=8), 28 | triton.Config({}, num_stages=8, num_warps=8), 29 | triton.Config({}, num_stages=16, num_warps=8), 30 | triton.Config({}, num_warps=1), 31 | triton.Config({}, num_warps=2), 32 | triton.Config({}, num_warps=4), 33 | triton.Config({}, num_warps=8), 34 | ], 35 | key=["n_elements"], 36 | ) 37 | @triton.jit 38 | def _quantize_columnwise_and_transpose( 39 | x_ptr, 40 | output_ptr, 41 | output_maxs, 42 | n_elements, 43 | M: tl.constexpr, 44 | N: tl.constexpr, 45 | BLOCK_SIZE: tl.constexpr, 46 | P2: tl.constexpr, 47 | ): 48 | pid = tl.program_id(axis=0) 49 | block_start = pid 50 | p2_arange = tl.arange(0, P2) 51 | p2_arange_mask = p2_arange < M 52 | arange = p2_arange * N 53 | offsets = block_start + arange 54 | x = tl.load(x_ptr + offsets, mask=p2_arange_mask) 55 | abs_x = tl.abs(x) 56 | max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) 57 | output = tl.libdevice.llrint(127.0 * (x / max_val)) 58 | 59 | new_start = pid * M 60 | new_offsets = new_start + p2_arange 61 | tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) 62 | tl.store(output_maxs + pid, max_val) 63 | 64 | def quantize_columnwise_and_transpose(x: torch.Tensor): 65 | M, N = x.shape 66 | output = torch.empty(N, M, device=x.device, dtype=torch.int8) 67 | output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) 68 | 69 | P2 = int(2 ** (math.ceil(math.log2(M)))) 70 | 71 | assert x.is_cuda and output.is_cuda 72 | n_elements = output.numel() 73 | grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 74 | _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) 75 | return output, output_maxs 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | *.dll 7 | *.dylib 8 | *.o 9 | *.obj 10 | *.air 11 | *.metallib 12 | 13 | # CMake generated files 14 | CMakeCache.txt 15 | CMakeScripts/ 16 | cmake_install.cmake 17 | Makefile 18 | CMakeFiles/ 19 | *.sln 20 | *.vcxproj* 21 | *.xcodeproj/ 22 | bitsandbytes.dir/ 23 | Debug/ 24 | Release/ 25 | cmake-build-*/ 26 | 27 | # IDE local files 28 | .vs/ 29 | .idea/ 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | pip-wheel-metadata/ 46 | share/python-wheels/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | MANIFEST 51 | 52 | # PyInstaller 53 | # Usually these files are written by a python script from a template 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | *.manifest 56 | *.spec 57 | 58 | # Installer logs 59 | pip-log.txt 60 | pip-delete-this-directory.txt 61 | 62 | # Unit test / coverage reports 63 | htmlcov/ 64 | .tox/ 65 | .nox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *.cover 72 | *.py,cover 73 | .hypothesis/ 74 | .pytest_cache/ 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | local_settings.py 83 | db.sqlite3 84 | db.sqlite3-journal 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # IPython 103 | profile_default/ 104 | ipython_config.py 105 | 106 | # pyenv 107 | .python-version 108 | 109 | # pipenv 110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 113 | # install all needed dependencies. 114 | #Pipfile.lock 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # vim 154 | *.swp 155 | 156 | dependencies 157 | cuda_build 158 | output/ 159 | -------------------------------------------------------------------------------- /.github/workflows/tests-pr.yml: -------------------------------------------------------------------------------- 1 | name: PR Tests 2 | 3 | on: 4 | pull_request: 5 | types: [opened, synchronize, reopened] 6 | branches: [main] 7 | paths: 8 | - ".github/workflows/test-runner.yml" 9 | - ".github/workflows/tests-pr.yml" 10 | - ".github/scripts/build-cpu.sh" 11 | - ".github/scripts/build-cuda.sh" 12 | - "bitsandbytes/**" 13 | - "csrc/**" 14 | - "include/**" 15 | - "tests/**" 16 | - "CMakeLists.txt" 17 | - "setup.py" 18 | - "pyproject.toml" 19 | 20 | concurrency: 21 | group: ${{ github.workflow }}-${{ github.event.pull_request.number }} 22 | cancel-in-progress: true 23 | 24 | jobs: 25 | test-cpu: 26 | name: CPU 27 | if: github.repository == 'bitsandbytes-foundation/bitsandbytes' 28 | strategy: 29 | fail-fast: false 30 | matrix: 31 | platform: [linux-x64, linux-aarch64, macos] 32 | # default runners don't have AVX-512 support, but icelake does 33 | cpu_type: ["", icelake] 34 | torch_version: ["2.3.1", "2.9.1"] 35 | 36 | exclude: 37 | # aarch64 minimum torch version is 2.5.1 38 | - platform: linux-aarch64 39 | torch_version: "2.3.1" 40 | # icelake only applies to linux-x64 41 | - platform: linux-aarch64 42 | cpu_type: icelake 43 | - platform: macos 44 | cpu_type: icelake 45 | 46 | include: 47 | # Add aarch64 with torch 2.5.1 instead of 2.3.1 48 | - platform: linux-aarch64 49 | cpu_type: "" 50 | torch_version: "2.5.1" 51 | 52 | uses: ./.github/workflows/test-runner.yml 53 | with: 54 | platform: ${{ matrix.platform }} 55 | backend: cpu 56 | torch_version: ${{ matrix.torch_version }} 57 | pypi_index: "https://download.pytorch.org/whl/cpu" 58 | cpu_type: ${{ matrix.cpu_type }} 59 | 60 | test-cuda: 61 | name: CUDA 62 | if: github.repository == 'bitsandbytes-foundation/bitsandbytes' 63 | strategy: 64 | fail-fast: false 65 | matrix: 66 | platform: [linux-x64] 67 | gpu_type: [T4, L40S] 68 | cuda_version: ["11.8.0", "12.8.1", "13.0.2"] 69 | 70 | include: 71 | # Map CUDA version to torch version and PyPI index 72 | - cuda_version: "11.8.0" 73 | torch_version: "2.3.1" 74 | pypi_index: "https://download.pytorch.org/whl/cu118" 75 | - cuda_version: "12.8.1" 76 | torch_version: "2.8.0" 77 | pypi_index: "https://download.pytorch.org/whl/cu128" 78 | - cuda_version: "13.0.2" 79 | torch_version: "2.9.1" 80 | pypi_index: "https://download.pytorch.org/whl/cu130" 81 | 82 | # Windows CUDA test - single configuration 83 | - platform: windows 84 | gpu_type: T4 85 | cuda_version: "11.8.0" 86 | torch_version: "2.7.1" 87 | pypi_index: "https://download.pytorch.org/whl/cu118" 88 | 89 | uses: ./.github/workflows/test-runner.yml 90 | with: 91 | platform: ${{ matrix.platform }} 92 | backend: cuda 93 | cuda_version: ${{ matrix.cuda_version }} 94 | gpu_type: ${{ matrix.gpu_type }} 95 | torch_version: ${{ matrix.torch_version }} 96 | pypi_index: ${{ matrix.pypi_index }} 97 | -------------------------------------------------------------------------------- /csrc/mps_kernels.metal: -------------------------------------------------------------------------------- 1 | #include 2 | using namespace metal; 3 | 4 | #define HLF_MAX 65504 5 | #define TH 1024 6 | #define NUM 4 7 | #define NUM_BLOCK 4096 8 | 9 | template 10 | static unsigned char quantize_scalar( 11 | float rand, 12 | device float* code, 13 | float x) 14 | { 15 | int pivot = 127; 16 | int upper_pivot = 255; 17 | int lower_pivot = 0; 18 | 19 | float lower = -1.0f; 20 | float upper = 1.0f; 21 | 22 | float val = code[pivot]; 23 | // i>>=1 = {32, 16, 8, 4, 2, 1} 24 | for(int i = 64; i > 0; i>>=1) 25 | { 26 | if(x > val) 27 | { 28 | lower_pivot = pivot; 29 | lower = val; 30 | pivot+=i; 31 | } 32 | else 33 | { 34 | upper_pivot = pivot; 35 | upper = val; 36 | pivot-=i; 37 | } 38 | val = code[pivot]; 39 | } 40 | 41 | if(upper_pivot == 255) 42 | upper = code[upper_pivot]; 43 | if(lower_pivot == 0) 44 | lower = code[lower_pivot]; 45 | 46 | if(!STOCHASTIC) 47 | { 48 | if(x > val) 49 | { 50 | float midpoint = (upper+val)*0.5f; 51 | if(x > midpoint) 52 | { 53 | return upper_pivot; 54 | } 55 | else 56 | return pivot; 57 | } 58 | else 59 | { 60 | float midpoint = (lower+val)*0.5f; 61 | if(x < midpoint) 62 | return lower_pivot; 63 | else 64 | return pivot; 65 | } 66 | } 67 | else 68 | { 69 | if(x > val) 70 | { 71 | float dist_to_upper = fabs(upper-x); 72 | float dist_full = upper-val; 73 | if(rand >= dist_to_upper/dist_full) return upper_pivot; 74 | else return pivot; 75 | } 76 | else 77 | { 78 | float dist_to_lower = fabs(lower-x); 79 | float dist_full = val-lower; 80 | if(rand >= dist_to_lower/dist_full) return lower_pivot; 81 | else return pivot; 82 | } 83 | } 84 | } 85 | 86 | kernel void quantize(device float* code [[buffer(0)]], 87 | device float* A [[buffer(1)]], 88 | device uchar* out [[buffer(2)]], 89 | constant uint& n [[buffer(3)]], 90 | uint id [[thread_position_in_grid]]) { 91 | const uint n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); 92 | uint valid_items = (id / NUM_BLOCK + 1 == (n + NUM_BLOCK - 1) / NUM_BLOCK) ? n - (id / NUM_BLOCK * NUM_BLOCK) : NUM_BLOCK; 93 | const uint base_idx = (id / NUM_BLOCK * NUM_BLOCK); 94 | 95 | float vals[NUM]; 96 | uchar qvals[NUM]; 97 | 98 | for (uint i = base_idx; i < n_full; i += ((n + NUM_BLOCK - 1) / NUM_BLOCK) * NUM_BLOCK) { 99 | valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; 100 | 101 | threadgroup_barrier(mem_flags::mem_threadgroup); 102 | 103 | for (uint j = 0; j < valid_items; j++) { 104 | vals[j] = A[i + j]; 105 | } 106 | 107 | for (uint j = 0; j < valid_items; j++) { 108 | qvals[j] = quantize_scalar(0.0f, code, vals[j]); 109 | } 110 | 111 | threadgroup_barrier(mem_flags::mem_threadgroup); 112 | 113 | for (uint j = 0; j < valid_items; j++) { 114 | out[i + j] = qvals[j]; 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /.github/workflows/tests-nightly.yml: -------------------------------------------------------------------------------- 1 | name: Nightly Tests 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | # Every day at 02:15 AM UTC 7 | - cron: "15 2 * * *" 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | test-cpu: 15 | name: CPU 16 | if: github.repository == 'bitsandbytes-foundation/bitsandbytes' 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | platform: [linux-x64, linux-aarch64, macos, windows] 21 | # default runners don't have AVX-512 support, but icelake does 22 | cpu_type: ["", icelake] 23 | torch_version: ["2.3.1", "2.8.0", "2.9.1"] 24 | 25 | exclude: 26 | # aarch64 minimum torch version is 2.5.1 27 | - platform: linux-aarch64 28 | torch_version: "2.3.1" 29 | # icelake only applies to linux-x64 30 | - platform: linux-aarch64 31 | cpu_type: icelake 32 | - platform: macos 33 | cpu_type: icelake 34 | - platform: windows 35 | cpu_type: icelake 36 | 37 | include: 38 | # Add aarch64 with torch 2.5.1 39 | - platform: linux-aarch64 40 | cpu_type: "" 41 | torch_version: "2.5.1" 42 | 43 | uses: ./.github/workflows/test-runner.yml 44 | with: 45 | platform: ${{ matrix.platform }} 46 | backend: cpu 47 | torch_version: ${{ matrix.torch_version }} 48 | pypi_index: "https://download.pytorch.org/whl/cpu" 49 | cpu_type: ${{ matrix.cpu_type }} 50 | 51 | test-cuda: 52 | name: CUDA 53 | if: github.repository == 'bitsandbytes-foundation/bitsandbytes' 54 | strategy: 55 | fail-fast: false 56 | matrix: 57 | # Linux x64 cross-product 58 | platform: [linux-x64] 59 | gpu_type: [T4, L40S] 60 | cuda_version: ["11.8.0", "12.6.3", "12.8.1", "13.0.2"] 61 | 62 | include: 63 | # Map CUDA version to torch version and PyPI index 64 | - cuda_version: "11.8.0" 65 | torch_version: "2.3.1" 66 | pypi_index: "https://download.pytorch.org/whl/cu118" 67 | - cuda_version: "12.6.3" 68 | torch_version: "2.7.1" 69 | pypi_index: "https://download.pytorch.org/whl/cu126" 70 | - cuda_version: "12.8.1" 71 | torch_version: "2.8.0" 72 | pypi_index: "https://download.pytorch.org/whl/cu128" 73 | - cuda_version: "13.0.2" 74 | torch_version: "2.9.1" 75 | pypi_index: "https://download.pytorch.org/whl/cu130" 76 | 77 | # Windows CUDA Tests - T4 GPU (CUDA 11.8 only, multiple torch versions) 78 | - platform: windows 79 | gpu_type: T4 80 | cuda_version: "11.8.0" 81 | torch_version: "2.3.1" 82 | pypi_index: "https://download.pytorch.org/whl/cu118" 83 | - platform: windows 84 | gpu_type: T4 85 | cuda_version: "11.8.0" 86 | torch_version: "2.6.0" 87 | pypi_index: "https://download.pytorch.org/whl/cu118" 88 | - platform: windows 89 | gpu_type: T4 90 | cuda_version: "11.8.0" 91 | torch_version: "2.7.1" # Note: this is the last PyTorch release supporting CUDA 11.8. 92 | pypi_index: "https://download.pytorch.org/whl/cu118" 93 | 94 | uses: ./.github/workflows/test-runner.yml 95 | with: 96 | platform: ${{ matrix.platform }} 97 | backend: cuda 98 | cuda_version: ${{ matrix.cuda_version }} 99 | gpu_type: ${{ matrix.gpu_type }} 100 | torch_version: ${{ matrix.torch_version }} 101 | pypi_index: ${{ matrix.pypi_index }} 102 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from io import BytesIO 3 | from itertools import product 4 | import os 5 | import random 6 | from typing import Any 7 | 8 | import torch 9 | 10 | from bitsandbytes.cextension import HIP_ENVIRONMENT 11 | 12 | test_dims_rng = random.Random(42) 13 | 14 | 15 | TRUE_FALSE = (True, False) 16 | BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bool, bool, bool) 17 | BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool) 18 | 19 | 20 | @functools.cache 21 | def get_available_devices(no_cpu=False): 22 | if "BNB_TEST_DEVICE" in os.environ: 23 | # If the environment variable is set, use it directly. 24 | device = os.environ["BNB_TEST_DEVICE"] 25 | return [] if no_cpu and device == "cpu" else [device] 26 | 27 | devices = [] if HIP_ENVIRONMENT else ["cpu"] if not no_cpu else [] 28 | 29 | if hasattr(torch, "accelerator"): 30 | # PyTorch 2.6+ - determine accelerator using agnostic API. 31 | if torch.accelerator.is_available(): 32 | devices += [str(torch.accelerator.current_accelerator())] 33 | else: 34 | if torch.cuda.is_available(): 35 | devices += ["cuda"] 36 | 37 | if torch.backends.mps.is_available(): 38 | devices += ["mps"] 39 | 40 | if hasattr(torch, "xpu") and torch.xpu.is_available(): 41 | devices += ["xpu"] 42 | 43 | custom_backend_name = torch._C._get_privateuse1_backend_name() 44 | custom_backend_module = getattr(torch, custom_backend_name, None) 45 | custom_backend_is_available_fn = getattr(custom_backend_module, "is_available", None) 46 | 47 | if custom_backend_is_available_fn and custom_backend_module.is_available(): 48 | devices += [custom_backend_name] 49 | 50 | return devices 51 | 52 | 53 | def torch_save_to_buffer(obj): 54 | buffer = BytesIO() 55 | torch.save(obj, buffer) 56 | buffer.seek(0) 57 | return buffer 58 | 59 | 60 | def torch_load_from_buffer(buffer): 61 | buffer.seek(0) 62 | obj = torch.load(buffer, weights_only=False) 63 | buffer.seek(0) 64 | return obj 65 | 66 | 67 | def get_test_dims(min: int, max: int, *, n: int) -> list[int]: 68 | return [test_dims_rng.randint(min, max) for _ in range(n)] 69 | 70 | 71 | def format_with_label(label: str, value: Any) -> str: 72 | if isinstance(value, bool): 73 | formatted = "T" if value else "F" 74 | elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value): 75 | formatted = "".join("T" if b else "F" for b in value) 76 | elif isinstance(value, torch.dtype): 77 | formatted = describe_dtype(value) 78 | else: 79 | formatted = str(value) 80 | return f"{label}={formatted}" 81 | 82 | 83 | def id_formatter(label: str): 84 | """ 85 | Return a function that formats the value given to it with the given label. 86 | """ 87 | return lambda value: format_with_label(label, value) 88 | 89 | 90 | DTYPE_NAMES = { 91 | torch.bfloat16: "bf16", 92 | torch.bool: "bool", 93 | torch.float16: "fp16", 94 | torch.float32: "fp32", 95 | torch.float64: "fp64", 96 | torch.int32: "int32", 97 | torch.int64: "int64", 98 | torch.int8: "int8", 99 | } 100 | 101 | 102 | def describe_dtype(dtype: torch.dtype) -> str: 103 | return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2] 104 | 105 | 106 | def is_supported_on_hpu( 107 | quant_type: str = "nf4", dtype: torch.dtype = torch.bfloat16, quant_storage: torch.dtype = torch.uint8 108 | ) -> bool: 109 | """ 110 | Check if the given quant_type, dtype and quant_storage are supported on HPU. 111 | """ 112 | if quant_type == "fp4" or dtype == torch.float16 or quant_storage not in (torch.uint8, torch.bfloat16): 113 | return False 114 | return True 115 | -------------------------------------------------------------------------------- /bitsandbytes/diagnostics/main.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import platform 3 | import sys 4 | import traceback 5 | 6 | import torch 7 | 8 | from bitsandbytes import __version__ as bnb_version 9 | from bitsandbytes.cextension import BNB_BACKEND 10 | from bitsandbytes.consts import PACKAGE_GITHUB_URL 11 | from bitsandbytes.cuda_specs import get_cuda_specs 12 | from bitsandbytes.diagnostics.cuda import ( 13 | print_diagnostics, 14 | ) 15 | from bitsandbytes.diagnostics.utils import print_dedented, print_header 16 | 17 | _RELATED_PACKAGES = [ 18 | "accelerate", 19 | "diffusers", 20 | "numpy", 21 | "pip", 22 | "peft", 23 | "safetensors", 24 | "transformers", 25 | "triton", 26 | "trl", 27 | ] 28 | 29 | 30 | def sanity_check(): 31 | from bitsandbytes.optim import Adam 32 | 33 | p = torch.nn.Parameter(torch.rand(10, 10).cuda()) 34 | a = torch.rand(10, 10).cuda() 35 | p1 = p.data.sum().item() 36 | adam = Adam([p]) 37 | out = a * p 38 | loss = out.sum() 39 | loss.backward() 40 | adam.step() 41 | p2 = p.data.sum().item() 42 | assert p1 != p2 43 | 44 | 45 | def get_package_version(name: str) -> str: 46 | try: 47 | version = importlib.metadata.version(name) 48 | except importlib.metadata.PackageNotFoundError: 49 | version = "not found" 50 | return version 51 | 52 | 53 | def show_environment(): 54 | """Simple utility to print out environment information.""" 55 | 56 | print(f"Platform: {platform.platform()}") 57 | if platform.system() == "Linux": 58 | print(f" libc: {'-'.join(platform.libc_ver())}") 59 | 60 | print(f"Python: {platform.python_version()}") 61 | 62 | print(f"PyTorch: {torch.__version__}") 63 | print(f" CUDA: {torch.version.cuda or 'N/A'}") 64 | print(f" HIP: {torch.version.hip or 'N/A'}") 65 | print(f" XPU: {getattr(torch.version, 'xpu', 'N/A') or 'N/A'}") 66 | 67 | print("Related packages:") 68 | for pkg in _RELATED_PACKAGES: 69 | version = get_package_version(pkg) 70 | print(f" {pkg}: {version}") 71 | 72 | 73 | def main(): 74 | print_header(f"bitsandbytes v{bnb_version}") 75 | show_environment() 76 | print_header("") 77 | 78 | cuda_specs = get_cuda_specs() 79 | 80 | if cuda_specs: 81 | print_diagnostics(cuda_specs) 82 | 83 | # TODO: There's a lot of noise in this; needs improvement. 84 | # print_cuda_runtime_diagnostics() 85 | 86 | if not torch.cuda.is_available(): 87 | print(f"PyTorch says {BNB_BACKEND} is not available. Possible reasons:") 88 | print(f"1. {BNB_BACKEND} driver not installed") 89 | print("2. Using a CPU-only PyTorch build") 90 | print("3. No GPU detected") 91 | 92 | else: 93 | print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") 94 | 95 | try: 96 | sanity_check() 97 | print("SUCCESS!") 98 | return 99 | except RuntimeError as e: 100 | if "not available in CPU-only" in str(e): 101 | print( 102 | f"WARNING: {__package__} is currently running as CPU-only!\n" 103 | "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" 104 | f"If you think that this is so erroneously,\nplease report an issue!", 105 | ) 106 | else: 107 | raise e 108 | except Exception: 109 | traceback.print_exc() 110 | 111 | print_dedented( 112 | f""" 113 | Above we output some debug information. 114 | Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose 115 | WARNING: Please be sure to sanitize sensitive info from the output before posting it. 116 | """, 117 | ) 118 | sys.exit(1) 119 | -------------------------------------------------------------------------------- /install_cuda.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | from urllib.request import urlretrieve 5 | 6 | cuda_versions = { 7 | "118": "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run", 8 | "120": "https://developer.download.nvidia.com/compute/cuda/12.0.1/local_installers/cuda_12.0.1_525.85.12_linux.run", 9 | "121": "https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run", 10 | "122": "https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run", 11 | "123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run", 12 | "124": "https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run", 13 | "125": "https://developer.download.nvidia.com/compute/cuda/12.5.1/local_installers/cuda_12.5.1_555.42.06_linux.run", 14 | "126": "https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run", 15 | } 16 | 17 | 18 | def install_cuda(version, base_path, download_path): 19 | formatted_version = f"{version[:-1]}.{version[-1]}" 20 | folder = f"cuda-{formatted_version}" 21 | install_path = os.path.join(base_path, folder) 22 | 23 | if os.path.exists(install_path): 24 | print(f"Removing existing CUDA version {version} at {install_path}...") 25 | subprocess.run(["rm", "-rf", install_path], check=True) 26 | 27 | url = cuda_versions[version] 28 | filename = url.split("/")[-1] 29 | filepath = os.path.join(download_path, filename) 30 | 31 | if not os.path.exists(filepath): 32 | print(f"Downloading CUDA version {version} from {url}...") 33 | urlretrieve(url, filepath) 34 | else: 35 | print(f"Installer for CUDA version {version} already downloaded.") 36 | 37 | # Make the installer executable 38 | subprocess.run(["chmod", "+x", filepath], check=True) 39 | 40 | # Install CUDA 41 | print(f"Installing CUDA version {version}...") 42 | install_command = [ 43 | "bash", 44 | filepath, 45 | "--no-drm", 46 | "--no-man-page", 47 | "--override", 48 | "--toolkitpath=" + install_path, 49 | "--toolkit", 50 | "--silent", 51 | ] 52 | 53 | print(f"Running command: {' '.join(install_command)}") 54 | 55 | try: 56 | subprocess.run(install_command, check=True) 57 | except subprocess.CalledProcessError as e: 58 | print(f"Installation failed for CUDA version {version}: {e}") 59 | return 60 | finally: 61 | # Delete the installer file 62 | os.remove(filepath) 63 | 64 | print(f"CUDA version {version} installed at {install_path}") 65 | 66 | 67 | def main(): 68 | user_base_path = os.path.expanduser("~/cuda") 69 | system_base_path = "/usr/local/cuda" 70 | base_path = user_base_path # default to user-specific installation 71 | download_path = "/tmp" # default download path 72 | 73 | if len(sys.argv) < 2: 74 | print("Usage: python install_cuda.py [user/system] [download_path]") 75 | sys.exit(1) 76 | 77 | version = sys.argv[1] 78 | if len(sys.argv) > 2: 79 | base_path = system_base_path if sys.argv[2] == "system" else user_base_path 80 | if len(sys.argv) > 3: 81 | download_path = sys.argv[3] 82 | 83 | if not os.path.exists(base_path): 84 | os.makedirs(base_path) 85 | if not os.path.exists(download_path): 86 | os.makedirs(download_path) 87 | 88 | # Install CUDA version(s) 89 | if version == "all": 90 | for ver in cuda_versions: 91 | install_cuda(ver, base_path, download_path) 92 | elif version in cuda_versions: 93 | install_cuda(version, base_path, download_path) 94 | else: 95 | print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}") 96 | sys.exit(1) 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /docs/source/explanations/resources.mdx: -------------------------------------------------------------------------------- 1 | # Papers, related resources & how to cite 2 | 3 | The below academic work is ordered in reverse chronological order. 4 | 5 | ## [SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression (Jun 2023)](https://arxiv.org/abs/2306.03078) 6 | 7 | Authors: Tim Dettmers, Ruslan Svirschevski, Vage Egiazarian, Denis Kuznedelev, Elias Frantar, Saleh Ashkboos, Alexander Borzunov, Torsten Hoefler, Dan Alistarh 8 | 9 | - [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1666076553665744896) 10 | 11 | ``` 12 | @article{dettmers2023spqr, 13 | title={SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression}, 14 | author={Dettmers, Tim and Svirschevski, Ruslan and Egiazarian, Vage and Kuznedelev, Denis and Frantar, Elias and Ashkboos, Saleh and Borzunov, Alexander and Hoefler, Torsten and Alistarh, Dan}, 15 | journal={arXiv preprint arXiv:2306.03078}, 16 | year={2023} 17 | } 18 | ``` 19 | 20 | ## [QLoRA: Efficient Finetuning of Quantized LLMs (May 2023)](https://arxiv.org/abs/2305.14314) 21 | Authors: Tim Dettmers, Artidoro Pagnoni, Ari Holtzman, Luke Zettlemoyer 22 | 23 | - [Video](https://www.youtube.com/watch?v=y9PHWGOa8HA&ab_channel=LondonMachineLearningMeetup) 24 | - [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1661379354507476994) 25 | 26 | ``` 27 | @article{dettmers2023qlora, 28 | title={Qlora: Efficient finetuning of quantized llms}, 29 | author={Dettmers, Tim and Pagnoni, Artidoro and Holtzman, Ari and Zettlemoyer, Luke}, 30 | journal={arXiv preprint arXiv:2305.14314}, 31 | year={2023} 32 | } 33 | ``` 34 | 35 | ## [The case for 4-bit precision: k-bit Inference Scaling Laws (Dec 2022)](https://arxiv.org/abs/2212.09720) 36 | Authors: Tim Dettmers, Luke Zettlemoyer 37 | 38 | - [Video](https://www.youtube.com/watch?v=odlQa6AE1gY&ab_channel=TheInsideView) 39 | - [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1605209171758284805) 40 | 41 | ``` 42 | @inproceedings{dettmers2023case, 43 | title={The case for 4-bit precision: k-bit inference scaling laws}, 44 | author={Dettmers, Tim and Zettlemoyer, Luke}, 45 | booktitle={International Conference on Machine Learning}, 46 | pages={7750--7774}, 47 | year={2023}, 48 | organization={PMLR} 49 | } 50 | ``` 51 | 52 | ## [LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale (Nov 2022)](https://arxiv.org/abs/2208.07339) [[llm-int8]] 53 | Authors: Tim Dettmers, Mike Lewis, Younes Belkada, Luke Zettlemoyer 54 | 55 | - [LLM.int8() Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) 56 | - [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/) 57 | - [Introduction to Weight Quantization](https://towardsdatascience.com/introduction-to-weight-quantization-2494701b9c0c) 58 | - [Poster](https://twitter.com/Tim_Dettmers/status/1598351301942951937) 59 | 60 | ``` 61 | @article{dettmers2022llm, 62 | title={Llm. int8 (): 8-bit matrix multiplication for transformers at scale}, 63 | author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke}, 64 | journal={arXiv preprint arXiv:2208.07339}, 65 | year={2022} 66 | } 67 | ``` 68 | 69 | ## [8-bit Optimizers via Block-wise Quantization (Oct 2021)](https://arxiv.org/abs/2110.02861) 70 | Authors: Tim Dettmers, Mike Lewis, Sam Shleifer, Luke Zettlemoyer 71 | 72 | - [Video](https://www.youtube.com/watch?v=IxrlHAJtqKE) 73 | - [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1446472128979562499) 74 | 75 | ``` 76 | @article{DBLP:journals/corr/abs-2110-02861, 77 | author = {Tim Dettmers and 78 | Mike Lewis and 79 | Sam Shleifer and 80 | Luke Zettlemoyer}, 81 | title = {8-bit Optimizers via Block-wise Quantization}, 82 | journal = {CoRR}, 83 | volume = {abs/2110.02861}, 84 | year = {2021}, 85 | url = {https://arxiv.org/abs/2110.02861}, 86 | eprinttype = {arXiv}, 87 | eprint = {2110.02861}, 88 | timestamp = {Thu, 21 Oct 2021 16:20:08 +0200}, 89 | biburl = {https://dblp.org/rec/journals/corr/abs-2110-02861.bib}, 90 | bibsource = {dblp computer science bibliography, https://dblp.org} 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /bitsandbytes/cuda_specs.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from functools import lru_cache 3 | import logging 4 | import re 5 | import subprocess 6 | from typing import Optional 7 | 8 | import torch 9 | 10 | 11 | @dataclasses.dataclass(frozen=True) 12 | class CUDASpecs: 13 | highest_compute_capability: tuple[int, int] 14 | cuda_version_string: str 15 | cuda_version_tuple: tuple[int, int] 16 | 17 | @property 18 | def has_imma(self) -> bool: 19 | return torch.version.hip or self.highest_compute_capability >= (7, 5) 20 | 21 | 22 | def get_compute_capabilities() -> list[tuple[int, int]]: 23 | return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count())) 24 | 25 | 26 | @lru_cache(None) 27 | def get_cuda_version_tuple() -> Optional[tuple[int, int]]: 28 | """Get CUDA/HIP version as a tuple of (major, minor).""" 29 | try: 30 | if torch.version.cuda: 31 | version_str = torch.version.cuda 32 | elif torch.version.hip: 33 | version_str = torch.version.hip 34 | else: 35 | return None 36 | 37 | parts = version_str.split(".") 38 | if len(parts) >= 2: 39 | return tuple(map(int, parts[:2])) 40 | return None 41 | except (AttributeError, ValueError, IndexError): 42 | return None 43 | 44 | 45 | def get_cuda_version_string() -> Optional[str]: 46 | """Get CUDA/HIP version as a string.""" 47 | version_tuple = get_cuda_version_tuple() 48 | if version_tuple is None: 49 | return None 50 | major, minor = version_tuple 51 | return f"{major * 10 + minor}" 52 | 53 | 54 | def get_cuda_specs() -> Optional[CUDASpecs]: 55 | """Get CUDA/HIP specifications.""" 56 | if not torch.cuda.is_available(): 57 | return None 58 | 59 | try: 60 | compute_capabilities = get_compute_capabilities() 61 | if not compute_capabilities: 62 | return None 63 | 64 | version_tuple = get_cuda_version_tuple() 65 | if version_tuple is None: 66 | return None 67 | 68 | version_string = get_cuda_version_string() 69 | if version_string is None: 70 | return None 71 | 72 | return CUDASpecs( 73 | highest_compute_capability=compute_capabilities[-1], 74 | cuda_version_string=version_string, 75 | cuda_version_tuple=version_tuple, 76 | ) 77 | except Exception: 78 | return None 79 | 80 | 81 | def get_rocm_gpu_arch() -> str: 82 | """Get ROCm GPU architecture.""" 83 | logger = logging.getLogger(__name__) 84 | try: 85 | if torch.version.hip: 86 | result = subprocess.run(["rocminfo"], capture_output=True, text=True) 87 | match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) 88 | if match: 89 | return "gfx" + match.group(1) 90 | else: 91 | return "unknown" 92 | else: 93 | return "unknown" 94 | except Exception as e: 95 | logger.error(f"Could not detect ROCm GPU architecture: {e}") 96 | if torch.cuda.is_available(): 97 | logger.warning( 98 | """ 99 | ROCm GPU architecture detection failed despite ROCm being available. 100 | """, 101 | ) 102 | return "unknown" 103 | 104 | 105 | def get_rocm_warpsize() -> int: 106 | """Get ROCm warp size.""" 107 | logger = logging.getLogger(__name__) 108 | try: 109 | if torch.version.hip: 110 | result = subprocess.run(["rocminfo"], capture_output=True, text=True) 111 | match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout) 112 | if match: 113 | return int(match.group(1)) 114 | else: 115 | # default to 64 to be safe 116 | return 64 117 | else: 118 | # nvidia cards always use 32 warp size 119 | return 32 120 | except Exception as e: 121 | logger.error(f"Could not detect ROCm warp size: {e}. Defaulting to 64. (some 4-bit functions may not work!)") 122 | if torch.cuda.is_available(): 123 | logger.warning( 124 | """ 125 | ROCm warp size detection failed despite ROCm being available. 126 | """, 127 | ) 128 | return 64 129 | -------------------------------------------------------------------------------- /bitsandbytes/triton/quantize_global.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from bitsandbytes.triton.triton_utils import is_triton_available 4 | 5 | if not is_triton_available(): 6 | 7 | def quantize_global_transpose(input): 8 | return None 9 | 10 | def quantize_global(x: torch.Tensor): 11 | return None 12 | else: 13 | import triton 14 | import triton.language as tl 15 | 16 | # global quantize 17 | @triton.autotune( 18 | configs=[ 19 | triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), 20 | triton.Config({"BLOCK_SIZE": 2048}, num_stages=1), 21 | ], 22 | key=["n_elements"], 23 | ) 24 | @triton.jit 25 | def _quantize_global( 26 | x_ptr, 27 | absmax_inv_ptr, 28 | output_ptr, 29 | n_elements, 30 | BLOCK_SIZE: tl.constexpr, 31 | ): 32 | pid = tl.program_id(axis=0) 33 | block_start = pid * BLOCK_SIZE 34 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 35 | mask = offsets < n_elements 36 | x = tl.load(x_ptr + offsets, mask=mask) 37 | absmax_inv = tl.load(absmax_inv_ptr) 38 | output = tl.libdevice.llrint(127.0 * (x * absmax_inv)) 39 | tl.store(output_ptr + offsets, output, mask=mask) 40 | 41 | def quantize_global(x: torch.Tensor): 42 | absmax = x.abs().max().unsqueeze(0) 43 | absmax_inv = 1.0 / absmax 44 | output = torch.empty(*x.shape, device="cuda", dtype=torch.int8) 45 | assert x.is_cuda and output.is_cuda 46 | n_elements = output.numel() 47 | grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 48 | _quantize_global[grid](x, absmax_inv, output, n_elements) 49 | return output, absmax 50 | 51 | # global quantize and transpose 52 | @triton.autotune( 53 | configs=[ 54 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), 55 | triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), 56 | # ... 57 | ], 58 | key=["M", "N"], 59 | ) 60 | @triton.jit 61 | def _quantize_global_transpose( 62 | A, 63 | absmax_inv_ptr, 64 | B, 65 | stride_am, 66 | stride_an, 67 | stride_bn, 68 | stride_bm, 69 | M, 70 | N, 71 | BLOCK_M: tl.constexpr, 72 | BLOCK_N: tl.constexpr, 73 | GROUP_M: tl.constexpr, 74 | ): 75 | pid = tl.program_id(0) 76 | grid_m = (M + BLOCK_M - 1) // BLOCK_M 77 | grid_n = (N + BLOCK_N - 1) // BLOCK_N 78 | 79 | width = GROUP_M * grid_n 80 | group_id = pid // width 81 | group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 82 | pid_m = group_id * GROUP_M + (pid % group_size) 83 | pid_n = (pid % width) // group_size 84 | 85 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 86 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 87 | A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) 88 | mask = (rm < M)[:, None] & (rn < N)[None, :] 89 | a = tl.load(A, mask=mask) 90 | absmax_inv = tl.load(absmax_inv_ptr) 91 | 92 | # rematerialize to save registers 93 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 94 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 95 | B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) 96 | mask = (rm < M)[:, None] & (rn < N)[None, :] 97 | 98 | output = tl.libdevice.llrint(127.0 * (a * absmax_inv)) 99 | 100 | tl.store(B, output, mask=mask) 101 | 102 | def quantize_global_transpose(input): 103 | absmax = input.abs().max().unsqueeze(0) 104 | absmax_inv = 1.0 / absmax 105 | M, N = input.shape 106 | out = torch.empty(N, M, device="cuda", dtype=torch.int8) 107 | 108 | assert out.size(0) == N and out.size(1) == M 109 | assert input.stride(0) == 1 or input.stride(1) == 1 110 | assert out.stride(0) == 1 or out.stride(1) == 1 111 | 112 | grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) 113 | _quantize_global_transpose[grid]( 114 | input, 115 | absmax_inv, 116 | out, 117 | input.stride(0), 118 | input.stride(1), 119 | out.stride(0), 120 | out.stride(1), 121 | M, 122 | N, 123 | ) 124 | return out, absmax 125 | -------------------------------------------------------------------------------- /tests/test_generation.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import math 3 | 4 | import pytest 5 | import torch 6 | 7 | from tests.helpers import TRUE_FALSE, describe_dtype, id_formatter 8 | 9 | transformers = pytest.importorskip("transformers") 10 | 11 | 12 | def get_4bit_config(): 13 | return transformers.BitsAndBytesConfig( 14 | load_in_4bit=True, 15 | load_in_8bit=False, 16 | llm_int8_threshold=6.0, 17 | llm_int8_has_fp16_weight=False, 18 | bnb_4bit_compute_dtype=torch.float16, 19 | bnb_4bit_use_double_quant=True, 20 | bnb_4bit_quant_type="nf4", 21 | ) 22 | 23 | 24 | def get_model_and_tokenizer(config): 25 | model_name_or_path, quant_type = config 26 | bnb_config = get_4bit_config() 27 | if quant_type == "16bit": 28 | bnb_config.load_in_4bit = False 29 | else: 30 | bnb_config.bnb_4bit_quant_type = quant_type 31 | model = transformers.AutoModelForCausalLM.from_pretrained( 32 | model_name_or_path, 33 | quantization_config=bnb_config, 34 | max_memory={0: "48GB"}, 35 | device_map="auto", 36 | torch_dtype=torch.bfloat16, 37 | ).eval() 38 | 39 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) 40 | 41 | return model, tokenizer 42 | 43 | 44 | def get_prompt_for_generation_eval(text, add_roles=True): 45 | description = ( 46 | "A chat between a curious human and an artificial intelligence assistant. " 47 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 48 | ) 49 | if add_roles: 50 | prompt = f"{description} ### Human: {text} ### Assistant:" 51 | else: 52 | prompt = f"{description} {text}" 53 | return prompt 54 | 55 | 56 | def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_for_generation_eval): 57 | text = prompt_func(text) 58 | inputs = tokenizer(text, return_tensors="pt").to("cuda:0") 59 | outputs = model.generate(inputs=inputs["input_ids"], generation_config=generation_config) 60 | return tokenizer.decode(outputs[0], skip_special_tokens=True) 61 | 62 | 63 | models = ["bigscience/bloom-1b7"] 64 | dtypes = ["nf4", "fp4"] 65 | 66 | 67 | @pytest.fixture(scope="session", params=product(models, dtypes)) 68 | def model_and_tokenizer(request): 69 | model, tokenizer = get_model_and_tokenizer(request.param) 70 | yield request.param, model, tokenizer 71 | del model 72 | 73 | 74 | @pytest.mark.parametrize("DQ", TRUE_FALSE, ids=id_formatter("dq")) 75 | @pytest.mark.parametrize("inference_kernel", TRUE_FALSE, ids=id_formatter("inference_kernel")) 76 | @pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype) 77 | @pytest.mark.slow 78 | def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype): 79 | fixture_config, model, tokenizer = model_and_tokenizer 80 | 81 | generation_config = transformers.GenerationConfig( 82 | max_new_tokens=20, 83 | do_sample=True, 84 | top_p=0.9, 85 | temperature=0.7, 86 | ) 87 | generation_config.max_new_tokens = 20 88 | 89 | # text = 'Please write down the first 50 digits of pi.' 90 | # text = get_prompt_for_generation_eval(text) 91 | # text += ' Sure, here the first 50 digits of pi: 3.14159' 92 | n_cases = 6 93 | text = "3.14159" 94 | if hasattr(model.config, "quantization_config"): 95 | model.config.quantization_config.bnb_4bit_compute_dtype = dtype 96 | model.config.quantization_config.bnb_4bit_use_double_quant = DQ 97 | 98 | if not inference_kernel: 99 | text = [text] * n_cases 100 | inputs = tokenizer(text, return_tensors="pt").to("cuda:0") 101 | x = inputs["input_ids"] 102 | outputs = [] 103 | if inference_kernel: 104 | for i in range(n_cases): 105 | output = model.generate(x, generation_config=generation_config) 106 | textout = tokenizer.decode(output[0], skip_special_tokens=True) 107 | outputs.append(textout) 108 | else: 109 | outputs = model.generate(x, generation_config=generation_config) 110 | outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] 111 | 112 | assert len(outputs) == n_cases 113 | failure_count = 0 114 | for i in range(n_cases): 115 | if outputs[i][: len(str(math.pi))] != str(math.pi): 116 | failure_count += 1 117 | failure_max = 2 if fixture_config[0] == "huggyllama/llama-7b" else 4 118 | if failure_count > failure_max: 119 | print(math.pi) 120 | for out in outputs: 121 | print(out) 122 | raise ValueError(f"Failure count: {failure_count}/{n_cases}") 123 | -------------------------------------------------------------------------------- /docs/source/optimizers.mdx: -------------------------------------------------------------------------------- 1 | # 8-bit optimizers 2 | 3 | With 8-bit optimizers, large models can be finetuned with 75% less GPU memory without losing any accuracy compared to training with standard 32-bit optimizers. The reduced memory requirements means 8-bit optimizers are 4x faster than a standard optimizer, and no hyperparameter tuning is required. 4 | 5 | This guide will show you how to use 8-bit optimizers. 6 | 7 | > [!WARNING] 8 | > 8-bit optimizers reduce memory usage and accelerate optimization on a wide range of tasks. However, since 8-bit optimizers only reduce memory proportional to the number of parameters, models that use large amounts of activation memory, such as convolutional networks, don't really benefit from 8-bit optimizers. 8-bit optimizers are most beneficial for training or finetuning models with many parameters on highly memory-constrained GPUs. 9 | 10 | 8-bit optimizers are a drop-in replacement for regular optimizers which means they also accept the same arguments as a regular optimizer. For NLP models, it is recommended to use the [`~nn.StableEmbedding`] class to improve stability and results. 11 | 12 | ```diff 13 | import bitsandbytes as bnb 14 | 15 | - adam = torch.optim.Adam(...) 16 | + adam = bnb.optim.Adam8bit(...) 17 | 18 | # recommended for NLP models 19 | - before: torch.nn.Embedding(...) 20 | + bnb.nn.StableEmbedding(...) 21 | ``` 22 | 23 | By default, all parameter tensors with less than 4096 elements are kept at 32-bits even if you initialize those parameters with 8-bit optimizers. This is done because small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm). 24 | 25 | You can change this value with the `min_8bit_size` parameter. For example, if you want to optimize parameters to 8-bits only if the minimum size is 16384 values (it is recommended to use multiples of 4096): 26 | 27 | ```py 28 | import bitsandbytes as bnb 29 | 30 | adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384) 31 | ``` 32 | 33 | Other parameters you can configure include the learning rate (`lr`), the decay rates (`betas`), the number of bits of the optimizer state (`optim_bits`), and percentile clipping (`percentile_clipping`) which can increase stability. For example, to initialize a 32-bit [`~bitsandbytes.optim.Adam`] optimizer with 5th percentile clipping: 34 | 35 | ```py 36 | import bitsandbytes as bnb 37 | 38 | adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=32, percentile_clipping=5) 39 | ``` 40 | 41 | ## Optimize unstable parameters 42 | 43 | To optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, use the [`~bitsandbytes.optim.GlobalOptimManager`] class to override the specific hyperparameters for a particular layer. You'll need to: 44 | 45 | 1. Register the parameters while they're on the CPU. 46 | 47 | ```py 48 | import torch 49 | import bitsandbytes as bnb 50 | 51 | mng = bnb.optim.GlobalOptimManager.get_instance() 52 | 53 | model = MyModel() 54 | mng.register_parameters(model.parameters()) 55 | ``` 56 | 57 | 2. Override the config with the new desired hyperparameters. For example, let's override the `model.fc1.weight` layer to use 32-bit Adam. 58 | 59 | > [!TIP] 60 | > Check the optimizer API documentation for more information about other hyperparameters you can override. 61 | 62 | ```py 63 | model = model.cuda() 64 | # use 8-bit optimizer states for all parameters 65 | adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8) 66 | 67 | # override the parameter model.fc1.weight now uses 32-bit Adam 68 | mng.override_config(model.fc1.weight, "optim_bits", 32) 69 | ``` 70 | 71 | You can also override multiple layers at once by passing them as a list and the new hyperparameters as a dictionary. For example, let's override the `model.special.weight` and `model.also_special.weight` layers to use sparse optimization and a lower learning and decay rate. 72 | 73 | ```py 74 | mng.override_config([model.special.weight, model.also_special.weight], 75 | key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)}) 76 | ``` 77 | 78 | For a specific layer, we recommend overriding locally in each module. Pass the module, the parameter, and its attribute name to the [`~bitsandbytes.optim.GlobalOptimManager`]: 79 | 80 | ```py 81 | class MyModule(torch.nn.Module): 82 | def __init__(d_in, d_out): 83 | super(MyModule, self).__init__() 84 | self.linear = torch.nn.Linear(d_in, d_out) 85 | # optimization will happen in 32-bit and 86 | # learning rate will be set to 0.0001 independent of the main learning rate 87 | config = {'optim_bits': 32, 'lr' : 0.0001} 88 | GlobalOptimManager.get_instance().register_module_override(self, 'weight', config) 89 | 90 | ``` 91 | 92 | ## Next steps 93 | 94 | For more conceptual details and explanation about 8-bit optimizers, take a look at the [8-bit optimizers](./explanations/optimizers) guide. 95 | -------------------------------------------------------------------------------- /csrc/xpu_ops.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template 5 | void dequantizeBlockwise( 6 | float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, sycl::queue* stream 7 | ) { 8 | auto& queue = *stream; 9 | const int workgroup_size = 128; 10 | const int num_per_th = 4; 11 | const int tile_size = workgroup_size * num_per_th; 12 | if (DATA_TYPE > 0) { 13 | const int workgroup_num = (n + tile_size * 2 - 1) / (tile_size * 2); 14 | sycl::range<1> local_range{(size_t)workgroup_size}; 15 | sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; 16 | kDequantizeBlockwise kfn(code, A, absmax, out, blocksize / 2, n); 17 | sycl_kernel_submit( 18 | sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn 19 | ); 20 | } else { 21 | const int workgroup_num = (n + tile_size - 1) / tile_size; 22 | sycl::range<1> local_range{(size_t)workgroup_size}; 23 | sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; 24 | kDequantizeBlockwise kfn(code, A, absmax, out, blocksize, n); 25 | sycl_kernel_submit( 26 | sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn 27 | ); 28 | } 29 | } 30 | 31 | template 32 | void gemv_4bit_inference( 33 | int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, 34 | int blocksize, sycl::queue* stream 35 | ) { 36 | 37 | auto& queue = *stream; 38 | 39 | const size_t GROUP_SIZE = 128; // workgroup_size 40 | const size_t SUBG_SIZE = 32; // subgroup_size 41 | const size_t NUM_PER_THREAD = GROUP_SIZE / SUBG_SIZE; 42 | size_t workgroup_num = (n + NUM_PER_THREAD - 1) / NUM_PER_THREAD; 43 | 44 | kgemv_4bit_inference kfn( 45 | m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize 46 | ); 47 | 48 | sycl_comp_kernel_submit( 49 | sycl::nd_range<1>(sycl::range<1>(GROUP_SIZE * workgroup_num), sycl::range<1>(GROUP_SIZE)), queue, kfn 50 | ); 51 | } 52 | 53 | //============================================================== 54 | // TEMPLATE DEFINITIONS 55 | //============================================================== 56 | 57 | template void dequantizeBlockwise( 58 | float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream 59 | ); 60 | template void dequantizeBlockwise( 61 | float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream 62 | ); 63 | template void dequantizeBlockwise( 64 | float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream 65 | ); 66 | 67 | template void dequantizeBlockwise( 68 | float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream 69 | ); 70 | template void dequantizeBlockwise( 71 | float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream 72 | ); 73 | template void dequantizeBlockwise( 74 | float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream 75 | ); 76 | 77 | template void dequantizeBlockwise( 78 | float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, 79 | sycl::queue* stream 80 | ); 81 | template void dequantizeBlockwise( 82 | float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, 83 | sycl::queue* stream 84 | ); 85 | template void dequantizeBlockwise( 86 | float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n, 87 | sycl::queue* stream 88 | ); 89 | 90 | template void gemv_4bit_inference( 91 | int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda, 92 | int ldb, int ldc, int blocksize, sycl::queue* stream 93 | ); 94 | template void gemv_4bit_inference( 95 | int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype, 96 | sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream 97 | ); 98 | template void gemv_4bit_inference( 99 | int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, 100 | int ldc, int blocksize, sycl::queue* stream 101 | ); 102 | -------------------------------------------------------------------------------- /docs/source/explanations/optimizers.mdx: -------------------------------------------------------------------------------- 1 | # 8-bit optimizers 2 | 3 | Stateful optimizers maintain gradient statistics over time, for example, the exponentially smoothed sum (SGD with momentum) or squared sum (Adam) of past gradient values. This state can be used to accelerate optimization compared to plain stochastic gradient descent, but uses memory that might otherwise be allocated to model parameters. As a result, this limits the maximum size of models that can be trained in practice. Now take a look at the biggest models that can be trained with 8-bit optimizers. 4 | 5 |
6 |
7 | 8 |
Depending on your GPU size, you can train a much larger model with a 8-bit optimizer.
9 |
10 |
11 | 12 | bitsandbytes optimizers use 8-bit statistics, while maintaining the performance levels of using 32-bit optimizer states. 13 | 14 | To overcome the resulting computational, quantization and stability challenges, 8-bit optimizers have three components: 15 | 16 | 1. Block-wise quantization: divides input tensors into smaller blocks that are independently quantized, isolating outliers and distributing the error more equally over all bits. Each block is processed in parallel across cores, yielding faster optimization and high precision quantization. 17 | 2. Dynamic quantization: quantizes both small and large values with high precision. 18 | 3. Stable embedding layer: improves stability during optimization for models with word embeddings. 19 | 20 | With these components, performing an optimizer update with 8-bit states is straightforward. The 8-bit optimizer states are dequantized to 32-bit before you perform the update, and then the states are quantized back to 8-bit for storage. 21 | 22 | The 8-bit to 32-bit conversion happens element-by-element in registers, meaning no slow copies to GPU memory or additional temporary memory are needed to perform quantization and dequantization. For GPUs, this makes 8-bit optimizers much faster than regular 32-bit optimizers. 23 | 24 |
25 |
26 | 27 |
A comparison of memory and time saved using 8-bit and 32-bit optimizers.
28 |
29 |
30 | 31 | ## Stable embedding layer 32 | 33 | The stable embedding layer improves the training stability of the standard word embedding layer for NLP tasks. It addresses the challenge of non-uniform input distributions and mitigates extreme gradient variations. This means the stable embedding layer can support more aggressive quantization strategies without compromising training stability, and it can help achieve stable training outcomes, which is particularly important for models dealing with diverse and complex language data. 34 | 35 | There are three features of the stable embedding layer: 36 | 37 | - Initialization: utilizes Xavier uniform initialization to maintain consistent variance, reducing the likelihood of large gradients. 38 | - Normalization: incorporates layer normalization before adding positional embeddings, aiding in output stability. 39 | - Optimizer states: employs 32-bit optimizer states exclusively for this layer to enhance stability, while the rest of the model may use standard 16-bit precision. 40 | 41 | ## Paged optimizers 42 | 43 | Paged optimizers are built on top of the [unified memory](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/) feature of CUDA. Unified memory provides a single memory space the GPU and CPU can easily access. While this feature is not supported by PyTorch, it has been added to bitsandbytes. 44 | 45 | Paged optimizers works like regular CPU paging, which means that it *only becomes active if you run out of GPU memory*. When that happens, memory is transferred page-by-page from GPU to CPU. The memory is mapped, meaning that pages are pre-allocated on the CPU but they are not updated automatically. Pages are only updated if the memory is accessed or a swapping operation is launched. 46 | 47 | The unified memory feature is less efficient than regular asynchronous memory transfers, and you usually won't be able to get full PCIe memory bandwidth utilization. If you do a manual prefetch, transfer speeds can be high but still only about half or worse than the full PCIe memory bandwidth (tested on 16x lanes PCIe 3.0). 48 | 49 | This means performance depends highly on the particular use-case. For example, if you evict 1 GB of memory per forward-backward-optimizer loop, then you can expect about 50% of the PCIe bandwidth as time in the best case. So, 1 GB for PCIe 3.0 with 16x lanes would run at 16 GB/s, which is `1/(16*0.5) = 1/8 = 125ms` of overhead per optimizer step. Other overhead can be estimated for the particular use-case given a PCIe interface, lanes, and the memory evicted in each iteration. 50 | 51 | Compared to CPU offloading, a paged optimizer has zero overhead if all the memory fits onto the device and only some overhead if some of memory needs to be evicted. For offloading, you usually offload fixed parts of the model and need to off and onload all this memory with each iteration through the model (sometimes twice for both forward and backward pass). 52 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["scikit-build-core", "setuptools >= 77.0.3", "trove-classifiers>=2025.8.6.13"] 3 | build-backend = "scikit_build_core.setuptools.build_meta" 4 | 5 | [project] 6 | name = "bitsandbytes" 7 | dynamic = ["version"] 8 | description = "k-bit optimizers and matrix multiplication routines." 9 | authors = [{name="Tim Dettmers", email="dettmers@cs.washington.edu"}] 10 | maintainers = [ 11 | {name="Titus von Köller", email="titus@huggingface.co"}, 12 | {name="Matthew Douglas", email="matthew.douglas@huggingface.co"} 13 | ] 14 | requires-python = ">=3.10" 15 | readme = "README.md" 16 | license = "MIT" 17 | license-files = ["LICENSE"] 18 | keywords = [ 19 | "gpu", 20 | "optimizers", 21 | "optimization", 22 | "8-bit", 23 | "quantization", 24 | "compression" 25 | ] 26 | classifiers = [ 27 | "Development Status :: 4 - Beta", 28 | "Environment :: GPU :: NVIDIA CUDA :: 11.8", 29 | "Environment :: GPU :: NVIDIA CUDA :: 12", 30 | "Environment :: GPU :: NVIDIA CUDA :: 13", 31 | "Intended Audience :: Developers", 32 | "Intended Audience :: Science/Research", 33 | "Operating System :: POSIX :: Linux", 34 | "Operating System :: MacOS", 35 | "Operating System :: Microsoft :: Windows", 36 | "Programming Language :: C++", 37 | "Programming Language :: Python :: Implementation :: CPython", 38 | "Programming Language :: Python :: 3.10", 39 | "Programming Language :: Python :: 3.11", 40 | "Programming Language :: Python :: 3.12", 41 | "Programming Language :: Python :: 3.13", 42 | "Programming Language :: Python :: 3.14", 43 | "Topic :: Scientific/Engineering :: Artificial Intelligence" 44 | ] 45 | dependencies = [ 46 | "torch>=2.3,<3", 47 | "numpy>=1.17", 48 | "packaging>=20.9", 49 | ] 50 | 51 | [project.urls] 52 | homepage = "https://github.com/bitsandbytes-foundation/bitsandbytes" 53 | changelog = "https://github.com/bitsandbytes-foundation/bitsandbytes/blob/main/CHANGELOG.md" 54 | docs = "https://huggingface.co/docs/bitsandbytes/main" 55 | issues = "https://github.com/bitsandbytes-foundation/bitsandbytes/issues" 56 | 57 | [project.optional-dependencies] 58 | benchmark = ["pandas", "matplotlib"] 59 | docs = ["hf-doc-builder==0.5.0"] 60 | dev = [ 61 | "bitsandbytes[test]", 62 | "build>=1.0.0,<2", 63 | "ruff~=0.14.3", 64 | "pre-commit>=3.5.0,<4", 65 | "wheel>=0.42,<1" 66 | ] 67 | test = [ 68 | "einops~=0.8.0", 69 | "lion-pytorch==0.2.3", 70 | "pytest~=8.3", 71 | "scipy>=1.11.4,<2", 72 | "transformers>=4.30.1,<5" 73 | ] 74 | 75 | [tool.setuptools] 76 | package-data = { "*" = ["libbitsandbytes*.*", "py.typed"] } 77 | 78 | [tool.setuptools.packages.find] 79 | include = ["bitsandbytes*"] 80 | 81 | [tool.setuptools.dynamic] 82 | version = {attr = "bitsandbytes.__version__"} 83 | 84 | [tool.coverage.report] 85 | exclude_also = [ 86 | # exclude backward() functions from coverage, as they are invoked from C++ 87 | 'def backward\(ctx' 88 | ] 89 | 90 | [tool.pytest.ini_options] 91 | addopts = "-rP -m 'not slow and not benchmark and not deprecated'" 92 | # ; --cov=bitsandbytes 93 | # ; # contexts: record which test ran which line; can be seen in html coverage report 94 | # ; --cov-context=test 95 | # ; --cov-report html 96 | log_cli = true 97 | log_cli_level = "INFO" 98 | log_file = "logs/pytest.log" 99 | markers = [ 100 | "benchmark: mark test as a benchmark", 101 | "deprecated: mark test as covering a deprecated feature", 102 | "slow: mark test as slow", 103 | ] 104 | 105 | [tool.ruff] 106 | src = [ 107 | "bitsandbytes", 108 | "tests", 109 | "benchmarking" 110 | ] 111 | target-version = "py310" 112 | line-length = 119 113 | 114 | [tool.ruff.lint] 115 | select = [ 116 | "B", # bugbear: security warnings 117 | "E", # pycodestyle (error) 118 | "W", # pycodestyle (warning) 119 | "F", # pyflakes 120 | "I", # isort 121 | "ISC", # implicit string concatenation 122 | "UP", # alert you when better syntax is available in your python version 123 | "RUF", # the ruff developer's own rules 124 | ] 125 | ignore = [ 126 | "B007", # Loop control variable not used within the loop body (TODO: enable) 127 | "B028", # Warning without stacklevel (TODO: enable) 128 | "B905", # zip without explicit `strict=` kwarg 129 | "E501", # Suppress line-too-long warnings: trust yapf's judgement on this one. 130 | "E701", # Multiple statements on one line (TODO: enable) 131 | "E712", # Allow using if x == False, as it's not always equivalent to if x. 132 | "E731", # Do not use lambda 133 | "RUF012",# Mutable class attribute annotations 134 | "RUF034",# Useless if-else (TODO: enable) 135 | "UP045", # Use `X | None` instead of `Optional[X]` 136 | ] 137 | 138 | [tool.ruff.lint.extend-per-file-ignores] 139 | "**/__init__.py" = ["F401"] # allow unused imports in __init__.py 140 | "{benchmarking,tests}/**/*.py" = [ 141 | "B007", 142 | "B011", 143 | "B023", 144 | "E701", 145 | "E731", 146 | "F841", 147 | "UP030", 148 | ] 149 | "bitsandbytes/**/triton/**/*.py" = [ 150 | "I001", # import order 151 | ] 152 | 153 | [tool.ruff.lint.isort] 154 | combine-as-imports = true 155 | detect-same-package = true 156 | force-sort-within-sections = true 157 | known-first-party = ["bitsandbytes"] 158 | 159 | [[tool.mypy.overrides]] 160 | module = "triton.*" 161 | ignore_missing_imports = true 162 | 163 | [[tool.mypy.overrides]] 164 | module = "scipy.stats" 165 | ignore_missing_imports = true 166 | 167 | [tool.scikit-build] 168 | cmake.build-type = "Release" 169 | cmake.build-args = ["--config", "Release"] 170 | wheel.cmake = false 171 | -------------------------------------------------------------------------------- /benchmarking/switchback/make_plot_with_jsonl.py: -------------------------------------------------------------------------------- 1 | import matplotlib.gridspec as gridspec 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | 5 | cmap = plt.get_cmap("cool") 6 | 7 | if __name__ == "__main__": 8 | fig = plt.figure(tight_layout=True, figsize=(12, 3.5)) 9 | gs = gridspec.GridSpec(1, 2) 10 | 11 | dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096] 12 | batch_size_for_plot1 = 32768 13 | batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17] 14 | dims_to_xtick = [1024, 2048, 4096] 15 | logscale_plot1 = True 16 | 17 | ax = fig.add_subplot(gs[0, 0]) 18 | 19 | # TODO: change this to what you want. 20 | rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True) 21 | df = rdf[rdf.batch_size == batch_size_for_plot1] 22 | 23 | # first plot the time occupied by different operations 24 | for k, marker, ls, color, name in [ 25 | ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"), 26 | ( 27 | "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", 28 | "o", 29 | "-", 30 | "C4", 31 | "SwitchBack int8 (sum of parts)", 32 | ), 33 | ("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"), 34 | ("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"), 35 | ("standard_gx", "^", ":", "gray", "Matmul GX (both)"), 36 | ("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"), 37 | ("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"), 38 | ("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"), 39 | ("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"), 40 | ("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"), 41 | ("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"), 42 | ]: 43 | xs = [] 44 | ys = [] 45 | for embed_dim in dims_to_consider: 46 | # average over dim -> 4*dim and 4*dim -> dim 47 | df_ = df[df.dim_in == embed_dim] 48 | df_ = df_[df_.dim_out == embed_dim * 4] 49 | xs.append(embed_dim) 50 | y_ = 0 51 | for k_ in k.split("+"): 52 | y_ += df_[k_].values[0] 53 | df_ = df[df.dim_in == embed_dim * 4] 54 | df_ = df_[df_.dim_out == embed_dim] 55 | for k_ in k.split("+"): 56 | y_ += df_[k_].values[0] 57 | ys.append(y_ * 0.5) 58 | 59 | ax.plot( 60 | xs, 61 | ys, 62 | color=color, 63 | label=name, 64 | marker=marker, 65 | markersize=5 if marker == "s" else 5, 66 | linestyle=ls, 67 | linewidth=2 if "+" in k else 1.0, 68 | ) 69 | 70 | ax.set_xlabel("dim", fontsize=13) 71 | ax.set_ylabel("time (ms)", fontsize=13) 72 | 73 | ax.grid() 74 | 75 | ax.set_xscale("log") 76 | if logscale_plot1: 77 | ax.set_yscale("log") 78 | 79 | ax.tick_params(axis="x", labelsize=11) 80 | ax.tick_params(axis="y", labelsize=11) 81 | 82 | ax.set_xticks(dims_to_xtick) 83 | ax.set_xticklabels(dims_to_xtick) 84 | ax.set_xticks([], minor=True) 85 | 86 | leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10) 87 | leg.get_texts()[0].set_fontweight("bold") 88 | leg.get_texts()[1].set_fontweight("bold") 89 | plt.subplots_adjust(left=0.1) 90 | ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20) 91 | 92 | ax = fig.add_subplot(gs[0, 1]) 93 | 94 | # now plot the % speedup for different batch sizes 95 | for j, batch_size in enumerate(batch_sizes_for_plot2): 96 | all_xs, all_ys = [], [] 97 | for k, marker, ls, color, name in [ 98 | ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"), 99 | ( 100 | "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", 101 | "o", 102 | "-", 103 | "C4", 104 | "SwitchBack int8 (total time)", 105 | ), 106 | ]: 107 | xs, ys = [], [] 108 | df = rdf[rdf.batch_size == batch_size] 109 | for embed_dim in dims_to_consider: 110 | df_ = df[df.dim_in == embed_dim] 111 | df_ = df_[df_.dim_out == embed_dim * 4] 112 | xs.append(embed_dim) 113 | y_ = 0 114 | for k_ in k.split("+"): 115 | y_ += df_[k_].values[0] 116 | df_ = df[df.dim_in == embed_dim * 4] 117 | df_ = df_[df_.dim_out == embed_dim] 118 | for k_ in k.split("+"): 119 | y_ += df_[k_].values[0] 120 | ys.append(y_ * 0.5) 121 | all_xs.append(xs) 122 | all_ys.append(ys) 123 | 124 | color = cmap(j * 0.25) 125 | real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] 126 | markers = ["^", "v", "P", "o"] 127 | ax.plot( 128 | all_xs[0], 129 | real_ys, 130 | color=color, 131 | label=f"batch * sequence length = {batch_size}", 132 | marker=markers[j], 133 | markersize=5 if marker == "s" else 5, 134 | ) 135 | 136 | ax.legend() 137 | ax.set_xlabel("dim", fontsize=13) 138 | ax.set_xscale("log") 139 | ax.grid() 140 | ax.set_ylabel(r"% speedup", fontsize=13) 141 | 142 | ax.tick_params(axis="x", labelsize=11) 143 | ax.tick_params(axis="y", labelsize=11) 144 | 145 | ax.set_xticks(dims_to_xtick) 146 | ax.set_xticklabels(dims_to_xtick) 147 | ax.set_xticks([], minor=True) 148 | 149 | ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20) 150 | 151 | plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight") 152 | -------------------------------------------------------------------------------- /docs/source/integrations.mdx: -------------------------------------------------------------------------------- 1 | # Integrations 2 | 3 | bitsandbytes is widely integrated with many of the libraries in the Hugging Face and wider PyTorch ecosystem. This guide provides a brief overview of the integrations and how to use bitsandbytes with them. For more details, you should refer to the linked documentation for each library. 4 | 5 | ## Transformers 6 | 7 | > [!TIP] 8 | > Learn more in the bitsandbytes Transformers integration [guide](https://huggingface.co/docs/transformers/quantization#bitsandbytes). 9 | 10 | With Transformers, it's very easy to load any model in 4 or 8-bit and quantize them on the fly. To configure the quantization parameters, specify them in the [`~transformers.BitsAndBytesConfig`] class. 11 | 12 | For example, to load and quantize a model to 4-bits and use the bfloat16 data type for compute: 13 | 14 | > [!WARNING] 15 | > bfloat16 is the ideal `compute_dtype` if your hardware supports it. While the default `compute_dtype`, float32, ensures backward compatibility (due to wide-ranging hardware support) and numerical stability, it is large and slows down computations. In contrast, float16 is smaller and faster but can lead to numerical instabilities. bfloat16 combines the best aspects of both; it offers the numerical stability of float32 and the reduced memory footprint and speed of a 16-bit data type. Check if your hardware supports bfloat16 and configure it using the `bnb_4bit_compute_dtype` parameter in [`~transformers.BitsAndBytesConfig`]! 16 | 17 | ```py 18 | from transformers import AutoModelForCausalLM, BitsAndBytesConfig 19 | 20 | quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) 21 | model_4bit = AutoModelForCausalLM.from_pretrained( 22 | "bigscience/bloom-1b7", 23 | device_map=device_map, 24 | quantization_config=quantization_config, 25 | ) 26 | ``` 27 | 28 | ### 8-bit optimizers 29 | 30 | You can use any of the 8-bit or paged optimizers with Transformers by passing them to the [`~transformers.Trainer`] class on initialization. All bitsandbytes optimizers are supported by passing the correct string in the [`~transformers.TrainingArguments`] `optim` parameter. For example, to load a [`~bitsandbytes.optim.PagedAdamW32bit`] optimizer: 31 | 32 | ```py 33 | from transformers import TrainingArguments, Trainer 34 | 35 | training_args = TrainingArguments( 36 | ..., 37 | optim="paged_adamw_32bit", 38 | ) 39 | trainer = Trainer(model, training_args, ...) 40 | trainer.train() 41 | ``` 42 | 43 | ## PEFT 44 | 45 | > [!TIP] 46 | > Learn more in the bitsandbytes PEFT integration [guide](https://huggingface.co/docs/peft/developer_guides/quantization#quantization). 47 | 48 | PEFT builds on the bitsandbytes Transformers integration, and extends it for training with a few more steps. Let's prepare the 4-bit model from the section above for training. 49 | 50 | Call the [`~peft.prepare_model_for_kbit_training`] method to prepare the model for training. This only works for Transformers models! 51 | 52 | ```py 53 | from peft import prepare_model_for_kbit_training 54 | 55 | model_4bit = prepare_model_for_kbit_training(model_4bit) 56 | ``` 57 | 58 | Setup a [`~peft.LoraConfig`] to use QLoRA: 59 | 60 | ```py 61 | from peft import LoraConfig 62 | 63 | config = LoraConfig( 64 | r=16, 65 | lora_alpha=8, 66 | target_modules="all-linear", 67 | lora_dropout=0.05 68 | bias="none", 69 | task_type="CAUSAL_LM" 70 | ) 71 | ``` 72 | 73 | Now call the [`~peft.get_peft_model`] function on your model and config to create a trainable [`PeftModel`]. 74 | 75 | ```py 76 | from peft import get_peft_model 77 | 78 | model = get_peft_model(model_4bit, config) 79 | ``` 80 | 81 | ## Accelerate 82 | 83 | > [!TIP] 84 | > Learn more in the bitsandbytes Accelerate integration [guide](https://huggingface.co/docs/accelerate/usage_guides/quantization). 85 | 86 | bitsandbytes is also easily usable from Accelerate and you can quantize any PyTorch model by passing a [`~accelerate.utils.BnbQuantizationConfig`] with your desired settings, and then calling the [`~accelerate.utils.load_and_quantize_model`] function to quantize it. 87 | 88 | ```py 89 | from accelerate import init_empty_weights 90 | from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model 91 | from mingpt.model import GPT 92 | 93 | model_config = GPT.get_default_config() 94 | model_config.model_type = 'gpt2-xl' 95 | model_config.vocab_size = 50257 96 | model_config.block_size = 1024 97 | 98 | with init_empty_weights(): 99 | empty_model = GPT(model_config) 100 | 101 | bnb_quantization_config = BnbQuantizationConfig( 102 | load_in_4bit=True, 103 | bnb_4bit_compute_dtype=torch.bfloat16, # optional 104 | bnb_4bit_use_double_quant=True, # optional 105 | bnb_4bit_quant_type="nf4" # optional 106 | ) 107 | 108 | quantized_model = load_and_quantize_model( 109 | empty_model, 110 | weights_location=weights_location, 111 | bnb_quantization_config=bnb_quantization_config, 112 | device_map = "auto" 113 | ) 114 | ``` 115 | 116 | ## PyTorch Lightning and Lightning Fabric 117 | 118 | bitsandbytes is available from: 119 | 120 | - [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), a deep learning framework for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale. 121 | - [Lightning Fabric](https://lightning.ai/docs/fabric/stable/), a fast and lightweight way to scale PyTorch models without boilerplate. 122 | 123 | Learn more in the bitsandbytes PyTorch Lightning integration [guide](https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html#quantization-via-bitsandbytes). 124 | 125 | 126 | ## Lit-GPT 127 | 128 | bitsandbytes is integrated with [Lit-GPT](https://github.com/Lightning-AI/lit-gpt), a hackable implementation of state-of-the-art open-source large language models. Lit-GPT is based on Lightning Fabric, and it can be used for quantization during training, finetuning, and inference. 129 | 130 | Learn more in the bitsandbytes Lit-GPT integration [guide](https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md). 131 | 132 | ## Blog posts 133 | 134 | To learn in more detail about some of bitsandbytes integrations, take a look at the following blog posts: 135 | 136 | - [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) 137 | - [A Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and bitsandbytes](https://huggingface.co/blog/hf-bitsandbytes-integration) 138 | -------------------------------------------------------------------------------- /benchmarking/inference_benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference benchmarking tool. 3 | 4 | Requirements: 5 | transformers 6 | accelerate 7 | bitsandbytes 8 | optimum-benchmark 9 | 10 | Usage: python inference_benchmark.py model_id 11 | 12 | options: 13 | -h, --help show this help message and exit 14 | --configs {bf16,fp16,nf4,nf4-dq,int8,int8-decomp} [{bf16,fp16,nf4,nf4-dq,int8,int8-decomp} ...] 15 | --bf16 16 | --fp16 17 | --nf4 18 | --nf4-dq 19 | --int8 20 | --int8-decomp 21 | --batches BATCHES [BATCHES ...] 22 | --input-length INPUT_LENGTH 23 | --out-dir OUT_DIR 24 | --iterations ITERATIONS 25 | --warmup-runs WARMUP_RUNS 26 | --output-length OUTPUT_LENGTH 27 | """ 28 | 29 | import argparse 30 | from pathlib import Path 31 | 32 | from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig 33 | from optimum_benchmark.logging_utils import setup_logging 34 | import torch 35 | 36 | torch.backends.cudnn.benchmark = False 37 | torch.backends.cudnn.deterministic = True 38 | 39 | BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8 40 | 41 | WEIGHTS_CONFIGS = { 42 | "fp16": {"torch_dtype": "float16", "quantization_scheme": None, "quantization_config": {}}, 43 | "bf16": {"torch_dtype": "bfloat16", "quantization_scheme": None, "quantization_config": {}}, 44 | "nf4": { 45 | "torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16", 46 | "quantization_scheme": "bnb", 47 | "quantization_config": { 48 | "load_in_4bit": True, 49 | "bnb_4bit_quant_type": "nf4", 50 | "bnb_4bit_use_double_quant": False, 51 | "bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16", 52 | }, 53 | }, 54 | "nf4-dq": { 55 | "torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16", 56 | "quantization_scheme": "bnb", 57 | "quantization_config": { 58 | "load_in_4bit": True, 59 | "bnb_4bit_quant_type": "nf4", 60 | "bnb_4bit_use_double_quant": True, 61 | "bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16", 62 | }, 63 | }, 64 | "int8-decomp": { 65 | "torch_dtype": "float16", 66 | "quantization_scheme": "bnb", 67 | "quantization_config": { 68 | "load_in_8bit": True, 69 | "llm_int8_threshold": 6.0, 70 | }, 71 | }, 72 | "int8": { 73 | "torch_dtype": "float16", 74 | "quantization_scheme": "bnb", 75 | "quantization_config": { 76 | "load_in_8bit": True, 77 | "llm_int8_threshold": 0.0, 78 | }, 79 | }, 80 | } 81 | 82 | 83 | def parse_args(): 84 | parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool") 85 | 86 | parser.add_argument("model_id", type=str, help="The model checkpoint to use.") 87 | 88 | parser.add_argument( 89 | "--configs", 90 | nargs="+", 91 | choices=["bf16", "fp16", "nf4", "nf4-dq", "int8", "int8-decomp"], 92 | default=["nf4", "int8", "int8-decomp"], 93 | ) 94 | parser.add_argument("--bf16", dest="configs", action="append_const", const="bf16") 95 | parser.add_argument("--fp16", dest="configs", action="append_const", const="fp16") 96 | parser.add_argument("--nf4", dest="configs", action="append_const", const="nf4") 97 | parser.add_argument("--nf4-dq", dest="configs", action="append_const", const="nf4-dq") 98 | parser.add_argument("--int8", dest="configs", action="append_const", const="int8") 99 | parser.add_argument("--int8-decomp", dest="configs", action="append_const", const="int8-decomp") 100 | 101 | parser.add_argument("--batches", nargs="+", type=int, default=[1, 8, 16, 32]) 102 | parser.add_argument("--input-length", type=int, default=64) 103 | 104 | parser.add_argument("--out-dir", type=str, default="reports") 105 | 106 | parser.add_argument("--iterations", type=int, default=10, help="Number of iterations for each benchmark run") 107 | parser.add_argument( 108 | "--warmup-runs", type=int, default=10, help="Number of warmup runs to discard before measurement" 109 | ) 110 | parser.add_argument( 111 | "--output-length", 112 | type=int, 113 | default=64, 114 | help="If set, `max_new_tokens` and `min_new_tokens` will be set to this value.", 115 | ) 116 | 117 | return parser.parse_args() 118 | 119 | 120 | def run_benchmark(args, config, batch_size): 121 | launcher_config = ProcessConfig(device_isolation=True, device_isolation_action="warn", start_method="spawn") 122 | scenario_config = InferenceConfig( 123 | latency=True, 124 | memory=True, 125 | input_shapes={"batch_size": batch_size, "sequence_length": args.input_length}, 126 | iterations=args.iterations, 127 | warmup_runs=args.warmup_runs, 128 | # set duration to 0 to disable the duration-based stopping criterion 129 | # this is IMPORTANT to ensure that all benchmarks run the same number of operations, regardless of hardware speed/bottlenecks 130 | duration=0, 131 | # for consistent results, set a fixed min and max for output tokens 132 | generate_kwargs={"min_new_tokens": args.output_length, "max_new_tokens": args.output_length}, 133 | forward_kwargs={"min_new_tokens": args.output_length, "max_new_tokens": args.output_length}, 134 | ) 135 | 136 | backend_config = PyTorchConfig( 137 | device="cuda", 138 | device_ids="0", 139 | device_map="auto", 140 | no_weights=False, 141 | model=args.model_id, 142 | **WEIGHTS_CONFIGS[config], 143 | ) 144 | 145 | test_name = ( 146 | f"benchmark-{config}" 147 | f"-bsz-{batch_size}" 148 | f"-isz-{args.input_length}" 149 | f"-osz-{args.output_length}" 150 | f"-iter-{args.iterations}" 151 | f"-wrmup-{args.warmup_runs}" 152 | ) 153 | benchmark_config = BenchmarkConfig( 154 | name=test_name, 155 | scenario=scenario_config, 156 | launcher=launcher_config, 157 | backend=backend_config, 158 | ) 159 | 160 | out_path = out_dir / (test_name + ".json") 161 | print(f"[{test_name}] Starting:") 162 | benchmark_report = Benchmark.launch(benchmark_config) 163 | benchmark_report.save_json(out_path) 164 | 165 | 166 | if __name__ == "__main__": 167 | setup_logging(level="INFO") 168 | args = parse_args() 169 | 170 | out_dir = Path(args.out_dir) 171 | out_dir.mkdir(parents=True, exist_ok=True) 172 | 173 | for batch_size in args.batches: 174 | for config in args.configs: 175 | run_benchmark(args, config, batch_size) 176 | -------------------------------------------------------------------------------- /benchmarking/switchback/speed_benchmark.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | import torch 5 | 6 | from bitsandbytes.triton.int8_matmul_mixed_dequantize import ( 7 | int8_matmul_mixed_dequantize, 8 | ) 9 | from bitsandbytes.triton.int8_matmul_rowwise_dequantize import ( 10 | int8_matmul_rowwise_dequantize, 11 | ) 12 | from bitsandbytes.triton.quantize_columnwise_and_transpose import ( 13 | quantize_columnwise_and_transpose, 14 | ) 15 | from bitsandbytes.triton.quantize_global import ( 16 | quantize_global, 17 | quantize_global_transpose, 18 | ) 19 | from bitsandbytes.triton.quantize_rowwise import quantize_rowwise 20 | 21 | # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. 22 | 23 | 24 | def get_time(k, fn, info_dict): 25 | for _ in range(repeat // 2): 26 | fn() 27 | 28 | torch.cuda.synchronize() 29 | start = time.time() 30 | for _ in range(repeat): 31 | fn() 32 | 33 | torch.cuda.synchronize() 34 | end = time.time() 35 | ms = (end - start) / repeat * 1000 36 | print(f"time {k}: {ms:.3f} ms") 37 | info_dict[k] = ms 38 | 39 | 40 | if __name__ == "__main__": 41 | torch.manual_seed(0) 42 | wm = 4 43 | for dim in [1024, 1280, 1408, 1664, 2048, 4096]: 44 | # note "batch_size" is actually "batch_size * embed_dim", which is why it's large 45 | for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]: 46 | # switch switches dim_in and dim_out 47 | for switch in [False, True]: 48 | # hparams 49 | repeat = 64 50 | batch_size = batch_size 51 | dim_out = dim * wm 52 | dim_in = dim 53 | if switch: 54 | dim_out = dim 55 | dim_in = wm * dim 56 | 57 | dim_in = round(dim_in) 58 | dim_out = round(dim_out) 59 | 60 | # simulate forward pass 61 | x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda() 62 | g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda() 63 | w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda() 64 | 65 | x_int8 = x.clone().to(torch.int8) 66 | g_int8 = g.clone().to(torch.int8) 67 | w_int8 = w.clone().to(torch.int8) 68 | wt_int8 = w.t().contiguous().clone().to(torch.int8) 69 | state_x_rowwise = x.max(dim=1)[0] 70 | state_g_rowwise = g.max(dim=1)[0] 71 | state_w_columnwise = w.max(dim=0)[0] 72 | state_w_rowwise = w.max(dim=1)[0] 73 | state_w_global = w.max() 74 | 75 | info = { 76 | "repeat": repeat, 77 | "batch_size": batch_size, 78 | "dim_out": dim_out, 79 | "dim_in": dim_in, 80 | "wm": wm, 81 | "switch": switch, 82 | } 83 | 84 | get_time("standard_fwd", lambda: x.matmul(w.t()), info) 85 | get_time("standard_gw", lambda: g.t().matmul(x), info) 86 | get_time("standard_gx", lambda: g.matmul(w), info) 87 | get_time( 88 | "rowwise_fwd", 89 | lambda: int8_matmul_rowwise_dequantize( 90 | x_int8, 91 | w_int8.t(), 92 | state_x_rowwise, 93 | state_w_columnwise, 94 | None, 95 | ), 96 | info, 97 | ) 98 | get_time( 99 | "rowwise_bwd", 100 | lambda: int8_matmul_rowwise_dequantize( 101 | g_int8, 102 | wt_int8.t(), 103 | state_x_rowwise, 104 | state_w_rowwise, 105 | None, 106 | ), 107 | info, 108 | ) 109 | get_time( 110 | "global_fwd", 111 | lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), 112 | info, 113 | ) 114 | get_time( 115 | "global_bwd", 116 | lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), 117 | info, 118 | ) 119 | get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info) 120 | get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info) 121 | get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info) 122 | get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info) 123 | get_time("w_quantize_global", lambda: quantize_global(w), info) 124 | get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info) 125 | 126 | time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"] 127 | time_rowwise = ( 128 | info["x_quantize_rowwise"] 129 | + info["g_quantize_rowwise"] 130 | + info["w_quantize_colwise_transpose"] 131 | + info["w_quantize_rowwise"] 132 | + info["standard_gw"] 133 | + info["rowwise_fwd"] 134 | + info["rowwise_bwd"] 135 | ) 136 | time_global = ( 137 | info["x_quantize_rowwise"] 138 | + info["g_quantize_rowwise"] 139 | + info["w_quantize_global"] 140 | + info["w_quantize_global_transpose"] 141 | + info["standard_gw"] 142 | + info["global_fwd"] 143 | + info["global_bwd"] 144 | ) 145 | 146 | print("TOTAL STANDARD", time_standard) 147 | print("TOTAL ROWWISE", time_rowwise) 148 | print("TOTAL GLOBAL", time_global) 149 | 150 | print("speedup", -100 * (time_global - time_standard) / time_standard) 151 | 152 | info["time_standard"] = time_standard 153 | info["time_rowwise"] = time_rowwise 154 | info["time_global"] = time_global 155 | 156 | info_json = json.dumps(info) 157 | 158 | # TODO: change this to what you want. 159 | with open("speed_benchmark/info.jsonl", "a") as file: 160 | file.write(info_json + "\n") 161 | -------------------------------------------------------------------------------- /csrc/kernels.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #ifndef kernels 10 | #define kernels 11 | 12 | __global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n); 13 | __global__ void kDequantize(float* code, unsigned char* A, float* out, const int n); 14 | 15 | template 16 | __global__ void kQuantizeBlockwise( 17 | float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, 18 | const int rand_offset, const int n 19 | ); 20 | template 21 | __global__ void 22 | kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n); 23 | 24 | template 25 | __global__ void kPreconditionOptimizer32bit2State( 26 | T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps, 27 | const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n 28 | ); 29 | 30 | template 31 | __global__ void kOptimizer32bit2State( 32 | T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, 33 | const float beta1, const float beta2, const float beta3, const float alpha, const float eps, 34 | const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, 35 | const int n 36 | ); 37 | 38 | template 39 | __global__ void kPreconditionOptimizer32bit1State( 40 | T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, 41 | const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n 42 | ); 43 | 44 | template 45 | __global__ void kOptimizer32bit1State( 46 | T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1, 47 | const float beta2, const float eps, const float weight_decay, const int step, const float lr, 48 | const float gnorm_scale, const bool skip_zeros, const int n 49 | ); 50 | 51 | template 52 | __global__ void kPreconditionOptimizerStatic8bit1State( 53 | T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1, 54 | const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, 55 | float* new_max1, const float weight_decay, const float gnorm_scale, const int n 56 | ); 57 | 58 | template 59 | __global__ void kOptimizerStatic8bit1State( 60 | T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm, 61 | const float beta1, const float beta2, const float eps, const int step, const float lr, 62 | float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, 63 | const int n 64 | ); 65 | 66 | template 67 | __global__ void kPreconditionOptimizerStatic8bit2State( 68 | T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2, 69 | float* unorm, const float beta1, const float beta2, const float eps, const int step, 70 | float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, 71 | float* new_max1, float* new_max2, const float gnorm_scale, const int n 72 | ); 73 | 74 | template 75 | __global__ void kOptimizerStatic8bit2State( 76 | T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm, 77 | const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, 78 | float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, 79 | float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n 80 | ); 81 | 82 | template 83 | __global__ void kOptimizerStatic8bit2StateBlockwise( 84 | T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, 85 | const float beta3, const float alpha, const float eps, const int step, const float lr, 86 | float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, 87 | float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n 88 | ); 89 | 90 | template 91 | __global__ void kOptimizerStatic8bit1StateBlockwise( 92 | T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps, 93 | const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay, 94 | const float gnorm_scale, const bool skip_zeros, const int n 95 | ); 96 | 97 | template 98 | __global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n); 99 | 100 | template 101 | __global__ void kspmm_coo_very_sparse_naive( 102 | int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, 103 | float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB 104 | ); 105 | 106 | template 107 | __global__ void kdequant_mm_int32_fp16( 108 | int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, 109 | half* __restrict__ const bias, const int numRows, const int numCols, const int n 110 | ); 111 | 112 | template 113 | __global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols); 114 | 115 | template 116 | __global__ void kgemm_4bit_inference_naive( 117 | int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, 118 | int lda, int ldb, int ldc, int blocksize 119 | ); 120 | 121 | template __global__ void kfunc(T* A, T* B, T value, long n); 122 | 123 | #endif 124 | -------------------------------------------------------------------------------- /csrc/kernels_hip.cuh: -------------------------------------------------------------------------------- 1 | // !!! This is a file automatically generated by hipify!!! 2 | #include "hip/hip_runtime.h" 3 | // Copyright (c) Facebook, Inc. and its affiliates. 4 | // 5 | // This source code is licensed under the MIT license found in the 6 | // LICENSE file in the root directory of this source tree. 7 | 8 | #include 9 | #include 10 | 11 | #ifndef kernels 12 | #define kernels 13 | 14 | __global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n); 15 | __global__ void kDequantize(float* code, unsigned char* A, float* out, const int n); 16 | 17 | template 18 | __global__ void kQuantizeBlockwise( 19 | float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, 20 | const int rand_offset, const int n 21 | ); 22 | template 23 | __global__ void 24 | kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n); 25 | 26 | template 27 | __global__ void kPreconditionOptimizer32bit2State( 28 | T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps, 29 | const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n 30 | ); 31 | 32 | template 33 | __global__ void kOptimizer32bit2State( 34 | T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, 35 | const float beta1, const float beta2, const float beta3, const float alpha, const float eps, 36 | const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, 37 | const int n 38 | ); 39 | 40 | template 41 | __global__ void kPreconditionOptimizer32bit1State( 42 | T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, 43 | const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n 44 | ); 45 | 46 | template 47 | __global__ void kOptimizer32bit1State( 48 | T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1, 49 | const float beta2, const float eps, const float weight_decay, const int step, const float lr, 50 | const float gnorm_scale, const bool skip_zeros, const int n 51 | ); 52 | 53 | template 54 | __global__ void kPreconditionOptimizerStatic8bit1State( 55 | T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1, 56 | const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, 57 | float* new_max1, const float weight_decay, const float gnorm_scale, const int n 58 | ); 59 | 60 | template 61 | __global__ void kOptimizerStatic8bit1State( 62 | T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm, 63 | const float beta1, const float beta2, const float eps, const int step, const float lr, 64 | float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, 65 | const int n 66 | ); 67 | 68 | template 69 | __global__ void kPreconditionOptimizerStatic8bit2State( 70 | T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2, 71 | float* unorm, const float beta1, const float beta2, const float eps, const int step, 72 | float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, 73 | float* new_max1, float* new_max2, const float gnorm_scale, const int n 74 | ); 75 | 76 | template 77 | __global__ void kOptimizerStatic8bit2State( 78 | T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm, 79 | const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, 80 | float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, 81 | float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n 82 | ); 83 | 84 | template 85 | __global__ void kOptimizerStatic8bit2StateBlockwise( 86 | T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, 87 | const float beta3, const float alpha, const float eps, const int step, const float lr, 88 | float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, 89 | float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n 90 | ); 91 | 92 | template 93 | __global__ void kOptimizerStatic8bit1StateBlockwise( 94 | T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps, 95 | const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay, 96 | const float gnorm_scale, const bool skip_zeros, const int n 97 | ); 98 | 99 | template 100 | __global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n); 101 | 102 | template 103 | __global__ void kspmm_coo_very_sparse_naive( 104 | int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, 105 | float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB 106 | ); 107 | 108 | template 109 | __global__ void kdequant_mm_int32_fp16( 110 | int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, 111 | half* __restrict__ const bias, const int numRows, const int numCols, const int n 112 | ); 113 | 114 | template 115 | __global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols); 116 | 117 | template 118 | __global__ void kgemm_4bit_inference_naive( 119 | int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, 120 | int lda, int ldb, int ldc, int blocksize 121 | ); 122 | 123 | template __global__ void kfunc(T* A, T* B, T value, long n); 124 | 125 | #endif 126 | -------------------------------------------------------------------------------- /bitsandbytes/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer1State 6 | 7 | 8 | class SGD(Optimizer1State): 9 | def __init__( 10 | self, 11 | params, 12 | lr, 13 | momentum=0, 14 | dampening=0, 15 | weight_decay=0, 16 | nesterov=False, 17 | optim_bits=32, 18 | args=None, 19 | min_8bit_size=4096, 20 | percentile_clipping=100, 21 | block_wise=True, 22 | ): 23 | """ 24 | Base SGD optimizer. 25 | 26 | Arguments: 27 | params (`torch.tensor`): 28 | The input parameters to optimize. 29 | lr (`float`): 30 | The learning rate. 31 | momentum (`float`, defaults to 0): 32 | The momentum value speeds up the optimizer by taking bigger steps. 33 | dampening (`float`, defaults to 0): 34 | The dampening value reduces the momentum of the optimizer. 35 | weight_decay (`float`, defaults to 0.0): 36 | The weight decay value for the optimizer. 37 | nesterov (`bool`, defaults to `False`): 38 | Whether to use Nesterov momentum. 39 | optim_bits (`int`, defaults to 32): 40 | The number of bits of the optimizer state. 41 | args (`object`, defaults to `None`): 42 | An object with additional arguments. 43 | min_8bit_size (`int`, defaults to 4096): 44 | The minimum number of elements of the parameter tensors for 8-bit optimization. 45 | percentile_clipping (`int`, defaults to 100): 46 | Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. 47 | block_wise (`bool`, defaults to `True`): 48 | Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. 49 | """ 50 | if momentum == 0: 51 | raise NotImplementedError("SGD without momentum is not supported!") 52 | super().__init__( 53 | "momentum", 54 | params, 55 | lr, 56 | (momentum, dampening), 57 | 0.0, 58 | weight_decay, 59 | optim_bits, 60 | args, 61 | min_8bit_size, 62 | percentile_clipping, 63 | block_wise, 64 | ) 65 | 66 | 67 | class SGD8bit(Optimizer1State): 68 | def __init__( 69 | self, 70 | params, 71 | lr, 72 | momentum=0, 73 | dampening=0, 74 | weight_decay=0, 75 | nesterov=False, 76 | args=None, 77 | min_8bit_size=4096, 78 | percentile_clipping=100, 79 | block_wise=True, 80 | ): 81 | """ 82 | 8-bit SGD optimizer. 83 | 84 | Arguments: 85 | params (`torch.tensor`): 86 | The input parameters to optimize. 87 | lr (`float`): 88 | The learning rate. 89 | momentum (`float`, defaults to 0): 90 | The momentum value speeds up the optimizer by taking bigger steps. 91 | dampening (`float`, defaults to 0): 92 | The dampening value reduces the momentum of the optimizer. 93 | weight_decay (`float`, defaults to 0.0): 94 | The weight decay value for the optimizer. 95 | nesterov (`bool`, defaults to `False`): 96 | Whether to use Nesterov momentum. 97 | args (`object`, defaults to `None`): 98 | An object with additional arguments. 99 | min_8bit_size (`int`, defaults to 4096): 100 | The minimum number of elements of the parameter tensors for 8-bit optimization. 101 | percentile_clipping (`int`, defaults to 100): 102 | Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. 103 | block_wise (`bool`, defaults to `True`): 104 | Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. 105 | """ 106 | if momentum == 0: 107 | raise NotImplementedError("SGD without momentum is not supported!") 108 | super().__init__( 109 | "momentum", 110 | params, 111 | lr, 112 | (momentum, dampening), 113 | 0.0, 114 | weight_decay, 115 | 8, 116 | args, 117 | min_8bit_size, 118 | percentile_clipping, 119 | block_wise, 120 | ) 121 | 122 | 123 | class SGD32bit(Optimizer1State): 124 | def __init__( 125 | self, 126 | params, 127 | lr, 128 | momentum=0, 129 | dampening=0, 130 | weight_decay=0, 131 | nesterov=False, 132 | args=None, 133 | min_8bit_size=4096, 134 | percentile_clipping=100, 135 | block_wise=True, 136 | ): 137 | """ 138 | 32-bit SGD optimizer. 139 | 140 | Arguments: 141 | params (`torch.tensor`): 142 | The input parameters to optimize. 143 | lr (`float`): 144 | The learning rate. 145 | momentum (`float`, defaults to 0): 146 | The momentum value speeds up the optimizer by taking bigger steps. 147 | dampening (`float`, defaults to 0): 148 | The dampening value reduces the momentum of the optimizer. 149 | weight_decay (`float`, defaults to 0.0): 150 | The weight decay value for the optimizer. 151 | nesterov (`bool`, defaults to `False`): 152 | Whether to use Nesterov momentum. 153 | args (`object`, defaults to `None`): 154 | An object with additional arguments. 155 | min_8bit_size (`int`, defaults to 4096): 156 | The minimum number of elements of the parameter tensors for 8-bit optimization. 157 | percentile_clipping (`int`, defaults to 100): 158 | Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. 159 | block_wise (`bool`, defaults to `True`): 160 | Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. 161 | """ 162 | if momentum == 0: 163 | raise NotImplementedError("SGD without momentum is not supported!") 164 | super().__init__( 165 | "momentum", 166 | params, 167 | lr, 168 | (momentum, dampening), 169 | 0.0, 170 | weight_decay, 171 | 32, 172 | args, 173 | min_8bit_size, 174 | percentile_clipping, 175 | block_wise, 176 | ) 177 | -------------------------------------------------------------------------------- /docs/source/fsdp_qlora.md: -------------------------------------------------------------------------------- 1 | # FSDP-QLoRA 2 | 3 | FSDP-QLoRA combines data parallelism (FSDP enables sharding model parameters, optimizer states, and gradients across GPUs), 4-bit quantization, and LoRA to train LLMs up to 70B parameters on a dual 24GB GPU system. This technique was released by [Answer.AI](https://www.answer.ai/posts/2024-03-06-fsdp-qlora) in collaboration with bitsandbytes to make training LLMs more efficient and accessible for everyone. 4 | 5 | This guide provides a brief guide on how bitsandbytes supports storing quantized weights to enable FSDP-QLoRA, and how to run training with the Hugging Face libraries. 6 | 7 | > [!TIP] 8 | > Other changes required for bitsandbytes to support FSDP-QLoRA, such as reconstructing the weights from the quantization metadata and preventing quantizing already quantized weights when they're moved from a CPU to GPU, are documented in this [Pull Request](https://github.com/bitsandbytes-foundation/bitsandbytes/pull/970) and described in the [Enabling 70B Finetuning on Consumer GPUs](https://www.answer.ai/posts/2024-03-14-fsdp-qlora-deep-dive) blog post. We highly recommend reading these resources for a better understanding of FSDP-QLoRA! 9 | 10 | ## Quantized data storage 11 | 12 | FSDP only supports sharding float data types which can be problematic because quantized weights are typically stored as integer data types (uint8). bitsandbytes doesn't have this problem because it uses `StoreChar` to read and write quantized weights regardless of the data type storage. This makes it simple to add a `quant_storage` parameter to the [`~nn.Linear4bit`] and [`~nn.Params4bit`] classes and set it to `torch.uint8` to maintain backward compatibility with the codebase. With the `quant_storage` parameter, you can select any of the FSDP supported data types to shard [`~nn.Linear4bit`] with such as bfloat16, float16 or float32. 13 | 14 | You'll typically access and configure this option from [`transformers.BitsAndBytesConfig`] by setting the `bnb_4bit_quant_storage` parameter. It is very **important** the `quant_storage` data type matches the data types used throughout the model because FSDP can only wrap layers and modules that have the *same floating data type*. Making sure the data types are aligned will ensure the model is correctly sharded. 15 | 16 | > [!TIP] 17 | > The `compute_dtype` is the data type used for computation inside the CUDA kernel, where the 4-bit quantized weights are unpacked from the data type in `quant_storage` and dequantized to `compute_dtype`. We recommend using torch.bfloat16 (if available on your hardware) for better numerical stability. 18 | 19 | ```py 20 | from transformers import BitsAndBytesConfig, AutoModelForCausalLM 21 | 22 | bnb_config = BitsAndBytesConfig( 23 | load_in_4bit=True, 24 | bnb_4bit_quant_type="nf4", 25 | bnb_4bit_compute_dtype=torch.bfloat16, 26 | bnb_4bit_quant_storage=torch.bfloat16, 27 | ) 28 | 29 | model = AutoModelForCausalLM.from_pretrained( 30 | "meta-llama/Llama-2-70b", 31 | quantization_config=bnb_config, 32 | torch_dtype=torch.bfloat16, 33 | ) 34 | ``` 35 | 36 | Check out this [section](https://hf.co/docs/peft/main/en/accelerate/fsdp#use-peft-qlora-and-fsdp-for-finetuning-large-models-on-multiple-gpus) of the PEFT documentation for the config file and training code to run FSDP-QLoRA training. 37 | 38 | ## Training 39 | 40 | > [!TIP] 41 | > FSDP is a distributed training framework that needs to be launched as a distributed training job with a library like [Accelerate](https://hf.co/docs/accelerate/index) or [torchrun](https://pytorch.org/docs/stable/elastic/run.html). The launch command provided in this section uses Accelerate to launch the training script. 42 | 43 | bitsandbytes is deeply integrated with the Hugging Face ecosystem, making it easy to use with libraries like [Transformers](https://hf.co/docs/transformers), [PEFT](https://hf.co/docs/peft), and [TRL](https://hf.co/docs/trl). 44 | 45 | PEFT provides a configuration file ([fsdp_config_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/fsdp_config_qlora.yaml)), launch command ([run_peft_qlora_fsdp.sh](https://github.com/huggingface/peft/blob/main/examples/sft/run_peft_qlora_fsdp.sh)), and training script ([train.py](https://github.com/huggingface/peft/blob/main/examples/sft/train.py)) for running FSDP-QLoRA. To learn more, check out the [Use PEFT QLoRA and FSDP for finetuning large models on multiple GPUs](https://huggingface.co/docs/peft/main/en/accelerate/fsdp#use-peft-qlora-and-fsdp-for-finetuning-large-models-on-multiple-gpus) documentation. This section briefly covers the steps to run FSDP-QLoRA training. 46 | 47 | Before you begin, make sure you have the latest libraries installed. 48 | 49 | ```bash 50 | pip install -U bitsandbytes accelerate transformers peft trl 51 | ``` 52 | 53 | The important change that enables FSDP-QLoRA training is the `bnb_4bit_quant_storage` parameter in the [`~transformers.BitsAndBytesConfig`] class. This allows you to set the storage data type of the quantized weights to a float data type. 54 | 55 | ```py 56 | from transformers import BitsAndBytesConfig 57 | 58 | bnb_config = BitsAndBytesConfig( 59 | load_in_4bit=True, 60 | bnb_4bit_quant_type="nf4", 61 | bnb_4bit_compute_dtype=torch.bfloat16, 62 | bnb_4bit_use_double_quant=True, 63 | bnb_4bit_quant_storage=torch.bfloat16, 64 | ) 65 | ``` 66 | 67 | Pass the [`~transformers.BitsAndBytesConfig`] to a model to set it up for FSDP-QLoRA. You should set the `torch_dtype` parameter to match `bnb_4bit_quant_storage` so that the [`~nn.Linear4bit`] layers are wrapped identically to the `Linear` layers. If the storage types do not match, then each [`~nn.Linear4bit`] layer is wrapped individually. 68 | 69 | ```py 70 | from transformers import AutoModelForCausalLM 71 | 72 | model = AutoModelForCausalLM.from_pretrained( 73 | "meta-llama/Llama-2-70b", 74 | quantization_config=bnb_config, 75 | torch_dtype=torch.bfloat16, 76 | ) 77 | ``` 78 | 79 | Configure the [`~peft.LoraConfig`] class for QLoRA training by setting `target_modules="all-linear"`. 80 | 81 | ```py 82 | from peft import LoraConfig 83 | 84 | peft_config = LoraConfig( 85 | lora_alpha=16, 86 | lora_dropout=0.1, 87 | r=64, 88 | bias="none", 89 | task_type="CAUSAL_LM", 90 | target_modules="all-linear", 91 | ) 92 | ``` 93 | 94 | Now you can pass everything to the [`~trl.SFTTrainer`] for training. 95 | 96 | ```py 97 | from trl import SFTTrainer 98 | 99 | trainer = SFTTrainer( 100 | model=model, 101 | train_dataset=dataset, 102 | peft_config=peft_config, 103 | processing_class=tokenizer, 104 | args=training_arguments, 105 | ) 106 | trainer.train() 107 | ``` 108 | 109 | ## Resources 110 | 111 | To learn more about FSDP and QLoRA, check out the following resources: 112 | 113 | - The [AnswerDotAI/fsdp_qlora](https://github.com/AnswerDotAI/fsdp_qlora) repository. 114 | - The introductory [You can now train a 70b language model at home](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) blog post by Answer.AI. 115 | - For an introduction to FSDP, read the [Introducing PyTorch Fully Sharded Data Parallel (FSDP) API](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api) blog post. 116 | - For more details about QLoRA, take a look at the [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) blog post. 117 | -------------------------------------------------------------------------------- /bitsandbytes/backends/triton/kernels_8bit_quant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import triton 4 | import triton.language as tl 5 | 6 | 7 | # @triton.autotune( 8 | # configs=[ 9 | # # triton.Config({'SPLIT_SIZE': 64}), 10 | # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32), 11 | # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), 12 | # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32), 13 | # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), 14 | # # triton.Config({'SPLIT_SIZE': 128}), 15 | # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32), 16 | # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), 17 | # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32), 18 | # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), 19 | # triton.Config({"SPLIT_SIZE": 256}), 20 | # # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32), 21 | # # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), 22 | # triton.Config({"SPLIT_SIZE": 512}), 23 | # # triton.Config({'SPLIT_SIZE': 1024}), 24 | # ], 25 | # key=["num_paired_elements", "QUANT_BLOCK"], 26 | # ) 27 | @triton.jit 28 | def dequant_8bit_kernel( 29 | a_ptr, 30 | out_ptr, 31 | code_ptr, 32 | absmax_ptr, 33 | n, 34 | QUANT_BLOCK: tl.constexpr, 35 | SPLIT_SIZE: tl.constexpr, 36 | ): 37 | pid = tl.program_id(axis=0) 38 | block_start = pid * SPLIT_SIZE 39 | offsets = block_start + tl.arange(0, SPLIT_SIZE) 40 | mask = offsets < n 41 | out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK) 42 | tl.store(out_ptr + offsets, out_dq, mask) 43 | 44 | 45 | def dequant_8bit_blockwise( 46 | a: torch.Tensor, 47 | absmax: torch.Tensor, 48 | quant_state_code: torch.Tensor, 49 | quant_blocksize: int = 64, 50 | dtype: torch.dtype = None, 51 | out: torch.Tensor = None, 52 | ): 53 | n = a.numel() 54 | if out is None: 55 | if dtype is None: 56 | raise ValueError("If out is None, dtype must be specified") 57 | out = torch.empty_like(a, dtype=dtype, device=a.device) 58 | 59 | SPLIT_SIZE = 256 60 | # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),) 61 | grid = (triton.cdiv(n, SPLIT_SIZE),) 62 | dequant_8bit_kernel[grid]( 63 | a, 64 | out, 65 | quant_state_code, 66 | absmax, 67 | n, 68 | quant_blocksize, 69 | SPLIT_SIZE, 70 | ) 71 | return out 72 | 73 | 74 | # @triton.autotune( 75 | # configs=[ 76 | # triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32), 77 | # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32), 78 | # triton.Config({"SPLIT_NUM_BLOCKS": 1}), 79 | # triton.Config({"SPLIT_NUM_BLOCKS": 2}), 80 | # ], 81 | # key=["n_elements"], 82 | # ) 83 | @triton.jit 84 | def quantize_8bit_blockwise_kernel( 85 | A_ptr, 86 | code_ptr, 87 | absmax_ptr, 88 | out_ptr, 89 | n_elements, 90 | BLOCK_SIZE: tl.constexpr, 91 | CODE_SIZE: tl.constexpr, 92 | SPLIT_NUM_BLOCKS: tl.constexpr, 93 | ): 94 | block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS 95 | thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) 96 | 97 | offsets = block_start_idx * BLOCK_SIZE + thread_idx 98 | mask = offsets < n_elements 99 | 100 | A = tl.load(A_ptr + offsets, mask=mask, other=0.0) 101 | 102 | quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS) 103 | tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax) 104 | tl.store(out_ptr + offsets, quantized, mask=mask) 105 | 106 | 107 | def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None): 108 | n = A.numel() 109 | blocks = -(n // -blocksize) 110 | 111 | if absmax is None: 112 | absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) 113 | if out is None: 114 | out = torch.empty_like(A.flatten(), dtype=torch.uint8) 115 | 116 | split_num_blocks = 1 117 | grid = (triton.cdiv(blocks, split_num_blocks),) 118 | # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),) 119 | quantize_8bit_blockwise_kernel[grid]( 120 | A_ptr=A, 121 | code_ptr=code, 122 | absmax_ptr=absmax, 123 | out_ptr=out, 124 | n_elements=n, 125 | BLOCK_SIZE=blocksize, 126 | CODE_SIZE=code.numel(), 127 | SPLIT_NUM_BLOCKS=split_num_blocks, 128 | # num_warps=1, 129 | # num_stages=2, 130 | ) 131 | out = out.reshape(A.shape) 132 | 133 | return out, absmax 134 | 135 | 136 | @triton.jit 137 | def quantize_8bit_blockwise_kernel_util( 138 | a, 139 | code_ptr, 140 | CODE_SIZE: tl.constexpr, 141 | BLOCK_SIZE: tl.constexpr, 142 | N_PER_TH: tl.constexpr, 143 | ): 144 | # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) 145 | a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE)) 146 | 147 | # Calculating absmax for each block 148 | absmax = tl.max(tl.abs(a_reshaped), axis=1) 149 | 150 | a_normalized = a_reshaped / absmax[:, None] 151 | a_normalized = tl.clamp(a_normalized, -1.0, 1.0) 152 | 153 | lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32) 154 | upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) 155 | 156 | # ceil(log2(code_size)) = 8, actually, in general case should be input parameter 157 | for _ in range(8): 158 | pivot = (lower_pivot + upper_pivot) // 2 159 | val = tl.load(code_ptr + pivot) 160 | is_higher = a_normalized > val # code[pivot] 161 | lower_pivot = tl.where(is_higher, pivot, lower_pivot) 162 | upper_pivot = tl.where(is_higher, upper_pivot, pivot) 163 | 164 | # Choose closest level 165 | lower_val = tl.load(code_ptr + lower_pivot) 166 | upper_val = tl.load(code_ptr + upper_pivot) 167 | lower_dist = tl.abs(a_normalized - lower_val) 168 | upper_dist = tl.abs(a_normalized - upper_val) 169 | quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) 170 | 171 | # too slow approach 172 | # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :]) 173 | # quantized = tl.argmin(diff, axis=2).to(tl.uint8) 174 | 175 | quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,)) 176 | return quantized_flat, absmax 177 | 178 | 179 | @triton.jit 180 | def dequant_8bit_blockwise_kernel_util( 181 | a_ptr, 182 | offsets, 183 | code_ptr, 184 | absmax_ptr, 185 | mask, 186 | BLOCK_SIZE: tl.constexpr, 187 | ): 188 | a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8) 189 | scaled_int8 = tl.load(code_ptr + a, mask) 190 | # Load scales 191 | absmax_offsets = offsets // BLOCK_SIZE 192 | absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last") 193 | # Apply scales 194 | out_dq = scaled_int8 * absmax 195 | return out_dq 196 | --------------------------------------------------------------------------------