├── .buckconfig ├── .editorconfig ├── .git-blame-ignore-revs ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ └── feature-request.yml ├── dependabot.yml └── workflows │ ├── build_documentation.yml │ ├── build_pr_documentation.yml │ ├── lint.yml │ ├── python-package.yml │ ├── stale.yml.disabled │ └── upload_pr_documentation.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .style.yapf ├── .vscode ├── extensions.json └── settings.json ├── CHANGELOG.md ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE.md ├── README.md ├── _typos.toml ├── benchmarking └── switchback │ ├── README.md │ ├── info_a100_py2.jsonl │ ├── make_plot_with_jsonl.py │ ├── plot_with_info.pdf │ └── speed_benchmark.py ├── bitsandbytes ├── __init__.py ├── __main__.py ├── autograd │ ├── __init__.py │ └── _functions.py ├── cextension.py ├── cuda_setup │ ├── __init__.py │ ├── env_vars.py │ └── main.py ├── functional.py ├── nn │ ├── __init__.py │ ├── modules.py │ └── triton_based_modules.py ├── optim │ ├── __init__.py │ ├── adagrad.py │ ├── adam.py │ ├── adamw.py │ ├── lamb.py │ ├── lars.py │ ├── lion.py │ ├── optimizer.py │ ├── rmsprop.py │ └── sgd.py ├── research │ ├── __init__.py │ ├── autograd │ │ ├── __init__.py │ │ └── _functions.py │ └── nn │ │ ├── __init__.py │ │ └── modules.py ├── triton │ ├── __init__.py │ ├── dequantize_rowwise.py │ ├── int8_matmul_mixed_dequantize.py │ ├── int8_matmul_rowwise_dequantize.py │ ├── quantize_columnwise_and_transpose.py │ ├── quantize_global.py │ ├── quantize_rowwise.py │ └── triton_utils.py └── utils.py ├── check_bnb_install.py ├── csrc ├── common.cpp ├── common.h ├── cpu_ops.cpp ├── cpu_ops.h ├── kernels.cu ├── kernels.cuh ├── mps_kernels.metal ├── mps_ops.h ├── mps_ops.mm ├── ops.cu ├── ops.cuh └── pythonInterface.cpp ├── deploy.sh ├── docs └── source │ ├── _toctree.yml │ ├── algorithms.mdx │ ├── compiling.mdx │ ├── contributing.mdx │ ├── errors.mdx │ ├── faqs.mdx │ ├── index.mdx │ ├── installation.mdx │ ├── integrations.mdx │ ├── nonpytorchcuda.mdx │ ├── optimizers.mdx │ ├── quantization.mdx │ ├── quickstart.mdx │ └── resources.mdx ├── environment-bnb.yml ├── environment.yml ├── examples └── int8_inference_huggingface.py ├── include ├── AAlloc.h ├── Algo-Direct-Common.h ├── Algo-Direct2.h ├── AlgoXCodes.h ├── BinAlgo.h ├── BinSearch.h ├── Portable.h ├── SIMD.h └── Type.h ├── install_cuda.py ├── install_cuda.sh ├── pyproject.toml ├── pytest.ini ├── requirements-ci.txt ├── requirements-dev.txt ├── scripts └── stale.py ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── helpers.py ├── test_autograd.py ├── test_cuda_setup_evaluator.py ├── test_functional.py ├── test_generation.py ├── test_linear4bit.py ├── test_linear8bitlt.py ├── test_modules.py ├── test_optim.py └── test_triton.py /.buckconfig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rickardp/bitsandbytes/927f7167e3395ec26f859f294c1d4979a70a718a/.buckconfig -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | [*] 2 | trim_trailing_whitespace = true 3 | insert_final_newline = true 4 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # ran black and isort for coherent code formatting 2 | bfa0e33294f2b1dc25e65a33be2397f989824298 3 | 4 | # reran black with linelength 80 for greater readability 5 | ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 6 | 7 | # Remove f-prefix from strings that don't use formatting 8 | 7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve bitsandbytes 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info 8 | description: Please share your relevant system information with us 9 | placeholder: platform, python version, hardware, ... 10 | validations: 11 | required: true 12 | 13 | - type: textarea 14 | id: reproduction 15 | validations: 16 | required: true 17 | attributes: 18 | label: Reproduction 19 | description: | 20 | Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. 21 | Please provide the simplest reproducer as possible so that we can quickly fix the issue. 22 | 23 | placeholder: | 24 | Reproducer: 25 | 26 | - type: textarea 27 | id: expected-behavior 28 | validations: 29 | required: true 30 | attributes: 31 | label: Expected behavior 32 | description: "A clear and concise description of what you would expect to happen." 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a proposal/request for a new feature 3 | labels: [ "feature" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request 11 | description: | 12 | A clear and concise description of the feature proposal. 13 | 14 | - type: textarea 15 | id: motivation 16 | validations: 17 | required: true 18 | attributes: 19 | label: Motivation 20 | description: | 21 | Please outline the motivation for the proposal. Is your feature request related to a problem? 22 | 23 | - type: textarea 24 | id: contribution 25 | validations: 26 | required: true 27 | attributes: 28 | label: Your contribution 29 | description: | 30 | Is there any way that you could help, e.g. by submitting a PR? 31 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | groups: 8 | major: 9 | update-types: [major] 10 | minor-patch: 11 | update-types: [minor, patch] 12 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - doc-builder* 8 | - v*-release 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.sha }} 15 | package: bitsandbytes 16 | repo_owner: TimDettmers 17 | secrets: 18 | hf_token: ${{ secrets.HUGGINGFACE_PUSH }} 19 | -------------------------------------------------------------------------------- /.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | 6 | concurrency: 7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | build: 12 | if: github.repository == 'TimDettmers/bitsandbytes' 13 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 14 | with: 15 | commit_sha: ${{ github.event.pull_request.head.sha }} 16 | pr_number: ${{ github.event.number }} 17 | package: bitsandbytes 18 | repo_owner: TimDettmers 19 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | Lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-python@v4 15 | with: 16 | python-version: "3.12" 17 | - uses: pre-commit/action@v3.0.0 18 | env: 19 | RUFF_OUTPUT_FORMAT: github 20 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml.disabled: -------------------------------------------------------------------------------- 1 | name: Stale Bot 2 | 3 | on: 4 | schedule: 5 | - cron: "0 15 * * *" 6 | 7 | jobs: 8 | close_stale_issues: 9 | name: Close Stale Issues 10 | if: github.repository == 'TimDettmers/bitsandbytes' 11 | runs-on: ubuntu-latest 12 | env: 13 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 14 | steps: 15 | - uses: actions/checkout@v3 16 | 17 | - name: Setup Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: 3.8 21 | 22 | - name: Install requirements 23 | run: | 24 | pip install PyGithub 25 | - name: Close stale issues 26 | run: | 27 | python scripts/stale.py 28 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: bitsandbytes 14 | secrets: 15 | hf_token: ${{ secrets.HUGGINGFACE_PUSH }} 16 | comment_bot_token: ${{ secrets.GITHUB_TOKEN }} 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | *.dll 7 | *.dylib 8 | *.o 9 | *.obj 10 | *.air 11 | *.metallib 12 | 13 | # CMake generated files 14 | CMakeCache.txt 15 | CMakeScripts/ 16 | cmake_install.cmake 17 | Makefile 18 | CMakeFiles/ 19 | *.sln 20 | *.vcxproj* 21 | *.xcodeproj/ 22 | bitsandbytes.dir/ 23 | Debug/ 24 | Release/ 25 | 26 | # IDE local files 27 | .vs/ 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | pip-wheel-metadata/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # IPython 101 | profile_default/ 102 | ipython_config.py 103 | 104 | # pyenv 105 | .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # vim 152 | *.swp 153 | 154 | dependencies 155 | cuda_build 156 | output/ 157 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.2.0 4 | hooks: 5 | - id: ruff 6 | args: 7 | - --fix 8 | # - id: ruff-format # TODO: enable when the time is right 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v4.5.0 11 | hooks: 12 | - id: check-merge-conflict 13 | - id: check-yaml 14 | - id: end-of-file-fixer 15 | - id: fix-byte-order-marker 16 | - id: trailing-whitespace 17 | - id: mixed-line-ending 18 | args: 19 | - --fix=lf 20 | - repo: https://github.com/crate-ci/typos 21 | rev: v1.17.2 22 | hooks: 23 | - id: typos 24 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT = True 3 | ALLOW_MULTILINE_LAMBDAS = True 4 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = True 5 | COLUMN_LIMIT = 88 6 | COALESCE_BRACKETS = True 7 | SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = True 8 | SPACES_BEFORE_COMMENT = 2 9 | SPLIT_BEFORE_BITWISE_OPERATOR = True 10 | SPLIT_BEFORE_FIRST_ARGUMENT = True 11 | SPLIT_BEFORE_LOGICAL_OPERATOR = True 12 | SPLIT_BEFORE_NAMED_ASSIGNS = True 13 | SPLIT_COMPLEX_COMPREHENSION = True 14 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.python", 4 | "charliermarsh.ruff", 5 | "twxs.cmake" 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "ruff.fixAll": true, 3 | "ruff.lint.run": "onType", 4 | "editor.codeActionsOnSave": { 5 | "source.fixAll": "always" 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # This CMake config hopefully makes it easier to compile. 2 | # Ensure the CUDA Toolkit is available on your path. Then run: 3 | # For GCC: `cmake -B build . && cmake --build build` 4 | # For MSVC: `cmake -B build . && cmake --build build --config Release` 5 | # You can also use the following options and variables 6 | # - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend 7 | # - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support 8 | # - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version 9 | # is whatever CMake finds on your path. 10 | # - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. 11 | # Separate by semicolons, i.e. `-DCOMPUTE_CAPABILITY=89;90` 12 | # Check your compute capability here: https://developer.nvidia.com/cuda-gpus 13 | # - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler 14 | cmake_minimum_required(VERSION 3.22.1) 15 | 16 | project(bitsandbytes LANGUAGES CXX) 17 | 18 | # Define included source files 19 | set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) 20 | set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) 21 | set(MPS_FILES csrc/mps_ops.mm) 22 | set(METAL_FILES csrc/mps_kernels.metal) 23 | # C++ sources are always included 24 | list(APPEND SRC_FILES ${CPP_FILES}) 25 | 26 | set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)") 27 | set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps) 28 | option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) 29 | 30 | if(APPLE) 31 | set(CMAKE_OSX_DEPLOYMENT_TARGET 13.1) 32 | endif() 33 | 34 | set(BNB_OUTPUT_NAME "bitsandbytes") 35 | 36 | message(STATUS "Building with backend ${COMPUTE_BACKEND}") 37 | 38 | if(${COMPUTE_BACKEND} STREQUAL "cuda") 39 | if(APPLE) 40 | message(FATAL_ERROR "CUDA is not supported on macOS" ) 41 | endif() 42 | option(NO_CUBLASLT "Disable CUBLAS" OFF) 43 | set(BUILD_CUDA ON) 44 | set(BUILD_MPS OFF) 45 | message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") 46 | elseif(${COMPUTE_BACKEND} STREQUAL "mps") 47 | if(NOT APPLE) 48 | message(FATAL_ERROR "MPS is only supported on macOS" ) 49 | endif() 50 | set(BUILD_CUDA OFF) 51 | set(BUILD_MPS ON) 52 | else() 53 | set(BUILD_CUDA OFF) 54 | set(BUILD_MPS OFF) 55 | endif() 56 | 57 | 58 | if(BUILD_CUDA) 59 | enable_language(CUDA) # This will fail if CUDA is not found 60 | find_package(CUDAToolkit REQUIRED) 61 | 62 | # Convert the CUDA version from X.Y.z to XY. There's probably a shorter way of doing this 63 | string(REGEX MATCH "^[0-9]+.[0-9]+" _CUDA_VERSION_FIRST_TWO "${CMAKE_CUDA_COMPILER_VERSION}") 64 | string(REPLACE "." "" CUDA_VERSION_SHORT "${_CUDA_VERSION_FIRST_TWO}") 65 | 66 | # Expose a cache variable that the user can set to ensure the correct version of CUDA is found 67 | set(CUDA_VERSION "${CUDA_VERSION_SHORT}" CACHE STRING "Expected CUDA Version Shortcode") 68 | 69 | message(STATUS "CUDA Version: ${CUDA_VERSION_SHORT} (${CMAKE_CUDA_COMPILER_VERSION})") 70 | message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}") 71 | 72 | # It should match the discovered version 73 | if(NOT CUDA_VERSION STREQUAL "${CUDA_VERSION_SHORT}") 74 | message(FATAL_ERROR "You've specified CUDA version ${CUDA_VERSION} however the CUDA compiler found is ${CUDA_VERSION_SHORT}." 75 | " Ensure the desired CUDA compiler is the first one available on your PATH." 76 | ) 77 | endif() 78 | 79 | if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS "11.0") 80 | message(FATAL_ERROR "CUDA Version < 11 is not supported") 81 | elseif(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") 82 | message(FATAL_ERROR "CUDA Version > 12 is not supported") 83 | endif() 84 | 85 | string(APPEND CMAKE_CUDA_FLAGS " --use_fast_math") 86 | if(PTXAS_VERBOSE) 87 | # Verbose? Outputs register usage information, and other things... 88 | string(APPEND CMAKE_CUDA_FLAGS " -Xptxas=-v") 89 | endif() 90 | 91 | foreach(capability ${CMAKE_CUDA_ARCHITECTURES_ALL}) 92 | # Most of the items here are like: `xx-real`, so we just extract the `xx` portion 93 | string(REGEX MATCH "[0-9]+" capability_id "${capability}") 94 | if(capability_id GREATER 0) 95 | list(APPEND POSSIBLE_CAPABILITIES ${capability_id}) 96 | endif() 97 | endforeach() 98 | 99 | # This can be changed via -D argument to CMake 100 | # By default all possible capabilities are compiled 101 | set(COMPUTE_CAPABILITY "${POSSIBLE_CAPABILITIES}" CACHE STRING "Compute Capabilities Targeted") 102 | 103 | message(STATUS "CUDA Capabilities Available: ${POSSIBLE_CAPABILITIES}") 104 | message(STATUS "CUDA Capabilities Selected: ${COMPUTE_CAPABILITY}") 105 | 106 | foreach(capability ${COMPUTE_CAPABILITY}) 107 | string(APPEND CMAKE_CUDA_FLAGS " -gencode arch=compute_${capability},code=sm_${capability}") 108 | endforeach() 109 | 110 | message(STATUS "CUDA NVCC Flags: ${CMAKE_CUDA_FLAGS}") 111 | 112 | list(APPEND SRC_FILES ${CUDA_FILES}) 113 | 114 | string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") 115 | if(NO_CUBLASLT) 116 | string(APPEND BNB_OUTPUT_NAME "_nocublaslt") 117 | endif() 118 | add_compile_definitions(BUILD_CUDA) 119 | elseif(BUILD_MPS) 120 | if(NOT APPLE) 121 | message(FATAL_ERROR "MPS is only supported on macOS" ) 122 | endif() 123 | 124 | enable_language(OBJCXX) 125 | 126 | list(APPEND SRC_FILES ${MPS_FILES}) 127 | 128 | string(APPEND BNB_OUTPUT_NAME "_mps") 129 | add_compile_definitions(BUILD_MPS) 130 | file(MAKE_DIRECTORY "build") 131 | add_custom_command(OUTPUT "bitsandbytes/bitsandbytes.metallib" 132 | COMMAND xcrun metal -c -o "build/bitsandbytes.air" ${METAL_FILES} 133 | COMMAND xcrun metallib "build/bitsandbytes.air" -o "bitsandbytes/bitsandbytes.metallib" 134 | DEPENDS "${METAL_FILES}" 135 | COMMENT "Compiling Metal kernels" 136 | VERBATIM) 137 | add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") 138 | else() 139 | set(LIBSUFFIX "cpu") 140 | set(GPU_SOURCES) 141 | endif() 142 | 143 | 144 | if(WIN32) 145 | # Export all symbols 146 | set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) 147 | endif() 148 | 149 | # Weird MSVC hacks 150 | if(MSVC) 151 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast") 152 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2 /fp:fast") 153 | endif() 154 | 155 | set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) 156 | add_library(bitsandbytes SHARED ${SRC_FILES}) 157 | target_compile_features(bitsandbytes PUBLIC cxx_std_14) 158 | target_include_directories(bitsandbytes PUBLIC csrc include) 159 | 160 | 161 | if(BUILD_CUDA) 162 | target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 163 | target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse) 164 | if(NO_CUBLASLT) 165 | target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT) 166 | else() 167 | target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt) 168 | endif() 169 | 170 | set_target_properties(bitsandbytes 171 | PROPERTIES 172 | CUDA_SEPARABLE_COMPILATION ON 173 | ) 174 | endif() 175 | if(BUILD_MPS) 176 | add_dependencies(bitsandbytes metallib) 177 | target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") 178 | endif() 179 | 180 | if(WIN32) 181 | set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") 182 | endif() 183 | set_target_properties(bitsandbytes PROPERTIES OUTPUT_NAME ${BNB_OUTPUT_NAME}) 184 | if(MSVC) 185 | set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_RELEASE bitsandbytes) 186 | set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_DEBUG bitsandbytes) 187 | set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE bitsandbytes) 188 | set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG bitsandbytes) 189 | endif() 190 | 191 | set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY bitsandbytes) 192 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to bitsandbytes 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to bitsandbytes, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NOTICE.md: -------------------------------------------------------------------------------- 1 | The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license. 2 | 3 | We thank Fabio Cannizzo for this work on FastBinarySearch which is included in this project. 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `bitsandbytes` 2 | 3 | The `bitsandbytes` library is a lightweight Python wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and 8 & 4-bit quantization functions. 4 | 5 | The library includes quantization primitives for 8-bit & 4-bit operations, through `bitsandbytes.nn.Linear8bitLt` and `bitsandbytes.nn.Linear4bit` and 8-bit optimizers through `bitsandbytes.optim` module. 6 | 7 | There are ongoing efforts to support further hardware backends, i.e. Intel CPU + GPU, AMD GPU, Apple Silicon. Windows support is quite far along and is on its way as well. 8 | 9 | **Please head to the official documentation page:** 10 | 11 | **[https://huggingface.co/docs/bitsandbytes/main](https://huggingface.co/docs/bitsandbytes/main)** 12 | 13 | ## License 14 | 15 | The majority of bitsandbytes is licensed under MIT, however small portions of the project are available under separate license terms, as the parts adapted from Pytorch are licensed under the BSD license. 16 | 17 | We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization. 18 | -------------------------------------------------------------------------------- /_typos.toml: -------------------------------------------------------------------------------- 1 | [files] 2 | 3 | [default.extend-identifiers] 4 | 5 | [type.py.extend-words] 6 | "BA" = "BA" # used as a commented-out variable in tests 7 | 8 | [type.cuda.extend-words] 9 | "subtile" = "subtile" 10 | "subtiles" = "subtiles" 11 | "transation" = "transation" # TODO: is this transition, transaction, translation..? 12 | -------------------------------------------------------------------------------- /benchmarking/switchback/README.md: -------------------------------------------------------------------------------- 1 | Steps: 2 | 3 | 1. Run `python speed_benchmark/speed_benchmark.py` which times operations and writes their time to `speed_benchmark/info_a100_py2.jsonl` (change the name of the jsonl to a different name for your profiling). 4 | 2. Run `python speed_benchmark/make_plot_with_jsonl.py`, which produces the `speed_benchmark/plot_with_info.pdf`. Again make sure you change the jsonl which is being processed. 5 | -------------------------------------------------------------------------------- /benchmarking/switchback/make_plot_with_jsonl.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.gridspec as gridspec 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | 6 | cmap=plt.get_cmap('cool') 7 | 8 | if __name__ == '__main__': 9 | 10 | fig = plt.figure(tight_layout=True, figsize=(12,3.5)) 11 | gs = gridspec.GridSpec(1, 2) 12 | 13 | dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096] 14 | batch_size_for_plot1 = 32768 15 | batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17] 16 | dims_to_xtick = [1024, 2048, 4096] 17 | logscale_plot1 = True 18 | 19 | ax = fig.add_subplot(gs[0, 0]) 20 | 21 | # TODO: change this to what you want. 22 | rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True) 23 | df = rdf[rdf.batch_size == batch_size_for_plot1] 24 | 25 | # first plot the time occupied by different operations 26 | for k, marker, ls, color, name in [ 27 | ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), 28 | ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), 29 | 30 | ('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'), 31 | ('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'), 32 | ('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'), 33 | 34 | ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), 35 | ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), 36 | 37 | ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), 38 | ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), 39 | ('w_quantize_global', '.', '--', 'C4', 'Quantize global W (switchback)'), 40 | ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize global and\ntranspose W (switchback)'), 41 | ]: 42 | xs = [] 43 | ys = [] 44 | for embed_dim in dims_to_consider: 45 | # average over dim -> 4*dim and 4*dim -> dim 46 | df_ = df[df.dim_in == embed_dim] 47 | df_ = df_[df_.dim_out == embed_dim * 4] 48 | xs.append(embed_dim) 49 | y_ = 0 50 | for k_ in k.split('+'): 51 | y_ += df_[k_].values[0] 52 | df_ = df[df.dim_in == embed_dim * 4] 53 | df_ = df_[df_.dim_out == embed_dim] 54 | for k_ in k.split('+'): 55 | y_ += df_[k_].values[0] 56 | ys.append(y_ * 0.5) 57 | 58 | 59 | ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) 60 | 61 | 62 | ax.set_xlabel('dim', fontsize=13) 63 | ax.set_ylabel('time (ms)', fontsize=13) 64 | 65 | ax.grid() 66 | 67 | ax.set_xscale('log') 68 | if logscale_plot1: 69 | ax.set_yscale('log') 70 | 71 | ax.tick_params(axis='x', labelsize=11) 72 | ax.tick_params(axis='y', labelsize=11) 73 | 74 | ax.set_xticks(dims_to_xtick) 75 | ax.set_xticklabels(dims_to_xtick) 76 | ax.set_xticks([], minor=True) 77 | 78 | leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) 79 | leg.get_texts()[0].set_fontweight('bold') 80 | leg.get_texts()[1].set_fontweight('bold') 81 | plt.subplots_adjust(left=0.1) 82 | ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20) 83 | 84 | 85 | ax = fig.add_subplot(gs[0, 1]) 86 | 87 | # now plot the % speedup for different batch sizes 88 | for j, batch_size in enumerate(batch_sizes_for_plot2): 89 | all_xs, all_ys = [], [] 90 | for k, marker, ls, color, name in [ 91 | ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), 92 | ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), 93 | ]: 94 | 95 | xs, ys = [], [] 96 | df = rdf[rdf.batch_size == batch_size] 97 | for embed_dim in dims_to_consider: 98 | df_ = df[df.dim_in == embed_dim] 99 | df_ = df_[df_.dim_out == embed_dim * 4] 100 | xs.append(embed_dim) 101 | y_ = 0 102 | for k_ in k.split('+'): 103 | y_ += df_[k_].values[0] 104 | df_ = df[df.dim_in == embed_dim * 4] 105 | df_ = df_[df_.dim_out == embed_dim] 106 | for k_ in k.split('+'): 107 | y_ += df_[k_].values[0] 108 | ys.append(y_ * 0.5) 109 | all_xs.append(xs) 110 | all_ys.append(ys) 111 | 112 | color = cmap(j * 0.25) 113 | real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] 114 | markers = ['^', 'v', 'P', 'o'] 115 | ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) 116 | 117 | ax.legend() 118 | ax.set_xlabel('dim', fontsize=13) 119 | ax.set_xscale('log') 120 | ax.grid() 121 | ax.set_ylabel(r'% speedup', fontsize=13) 122 | 123 | 124 | ax.tick_params(axis='x', labelsize=11) 125 | ax.tick_params(axis='y', labelsize=11) 126 | 127 | ax.set_xticks(dims_to_xtick) 128 | ax.set_xticklabels(dims_to_xtick) 129 | ax.set_xticks([], minor=True) 130 | 131 | ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) 132 | 133 | 134 | 135 | plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight') 136 | -------------------------------------------------------------------------------- /benchmarking/switchback/plot_with_info.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rickardp/bitsandbytes/927f7167e3395ec26f859f294c1d4979a70a718a/benchmarking/switchback/plot_with_info.pdf -------------------------------------------------------------------------------- /benchmarking/switchback/speed_benchmark.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | import torch 5 | 6 | from bitsandbytes.triton.int8_matmul_mixed_dequantize import ( 7 | int8_matmul_mixed_dequantize, 8 | ) 9 | from bitsandbytes.triton.int8_matmul_rowwise_dequantize import ( 10 | int8_matmul_rowwise_dequantize, 11 | ) 12 | from bitsandbytes.triton.quantize_columnwise_and_transpose import ( 13 | quantize_columnwise_and_transpose, 14 | ) 15 | from bitsandbytes.triton.quantize_global import ( 16 | quantize_global, 17 | quantize_global_transpose, 18 | ) 19 | from bitsandbytes.triton.quantize_rowwise import quantize_rowwise 20 | 21 | # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. 22 | 23 | def get_time(k, fn, info_dict): 24 | 25 | for _ in range(repeat // 2): 26 | fn() 27 | 28 | torch.cuda.synchronize() 29 | start = time.time() 30 | for _ in range(repeat): 31 | fn() 32 | 33 | torch.cuda.synchronize() 34 | end = time.time() 35 | ms = (end - start) / repeat * 1000 36 | print(f"time {k}: {ms:.3f} ms") 37 | info_dict[k] = ms 38 | 39 | if __name__ == '__main__': 40 | torch.manual_seed(0) 41 | wm = 4 42 | for dim in [1024, 1280, 1408, 1664, 2048, 4096]: 43 | # note "batch_size" is actually "batch_size * embed_dim", which is why it's large 44 | for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]: 45 | 46 | # switch switches dim_in and dim_out 47 | for switch in [False, True]: 48 | 49 | # hparams 50 | repeat = 64 51 | batch_size = batch_size 52 | dim_out = dim * wm 53 | dim_in = dim 54 | if switch: 55 | dim_out = dim 56 | dim_in = wm * dim 57 | 58 | dim_in = round(dim_in) 59 | dim_out = round(dim_out) 60 | 61 | # simulate forward pass 62 | x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda() 63 | g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda() 64 | w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda() 65 | 66 | x_int8 = x.clone().to(torch.int8) 67 | g_int8 = g.clone().to(torch.int8) 68 | w_int8 = w.clone().to(torch.int8) 69 | wt_int8 = w.t().contiguous().clone().to(torch.int8) 70 | state_x_rowwise = x.max(dim=1)[0] 71 | state_g_rowwise = g.max(dim=1)[0] 72 | state_w_columnwise = w.max(dim=0)[0] 73 | state_w_rowwise = w.max(dim=1)[0] 74 | state_w_global = w.max() 75 | 76 | info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch} 77 | 78 | get_time('standard_fwd', lambda : x.matmul(w.t()), info) 79 | get_time('standard_gw', lambda : g.t().matmul(x), info) 80 | get_time('standard_gx', lambda : g.matmul(w), info) 81 | get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info) 82 | get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info) 83 | get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) 84 | get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info) 85 | get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info) 86 | get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info) 87 | get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info) 88 | get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info) 89 | get_time('w_quantize_global', lambda : quantize_global(w), info) 90 | get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info) 91 | 92 | time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw'] 93 | time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd'] 94 | time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] 95 | 96 | print('TOTAL STANDARD', time_standard) 97 | print('TOTAL ROWWISE', time_rowwise) 98 | print('TOTAL GLOBAL', time_global) 99 | 100 | print('speedup', -100*(time_global - time_standard)/time_standard) 101 | 102 | info['time_standard'] = time_standard 103 | info['time_rowwise'] = time_rowwise 104 | info['time_global'] = time_global 105 | 106 | info_json = json.dumps(info) 107 | 108 | # TODO: change this to what you want. 109 | with open("speed_benchmark/info.jsonl", "a") as file: 110 | file.write(info_json + "\n") 111 | -------------------------------------------------------------------------------- /bitsandbytes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import cuda_setup, research, utils 7 | from .autograd._functions import ( 8 | MatmulLtState, 9 | bmm_cublas, 10 | matmul, 11 | matmul_4bit, 12 | matmul_cublas, 13 | mm_cublas, 14 | ) 15 | from .cextension import COMPILED_WITH_CUDA 16 | from .nn import modules 17 | 18 | if COMPILED_WITH_CUDA: 19 | from .optim import adam 20 | 21 | __pdoc__ = { 22 | "libbitsandbytes": False, 23 | "optim.optimizer.Optimizer8bit": False, 24 | "optim.optimizer.MockArgs": False, 25 | } 26 | 27 | __version__ = "0.43.0.dev" 28 | 29 | PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" 30 | -------------------------------------------------------------------------------- /bitsandbytes/__main__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import sys 4 | from warnings import warn 5 | 6 | import torch 7 | 8 | HEADER_WIDTH = 60 9 | 10 | 11 | def find_dynamic_library(folder, filename): 12 | for ext in ("so", "dll", "dylib"): 13 | yield from glob.glob(os.path.join(folder, "**", filename + ext)) 14 | 15 | 16 | def generate_bug_report_information(): 17 | print_header("") 18 | print_header("BUG REPORT INFORMATION") 19 | print_header("") 20 | print('') 21 | 22 | path_sources = [ 23 | ("ANACONDA CUDA PATHS", os.environ.get("CONDA_PREFIX")), 24 | ("/usr/local CUDA PATHS", "/usr/local"), 25 | ("CUDA PATHS", os.environ.get("CUDA_PATH")), 26 | ("WORKING DIRECTORY CUDA PATHS", os.getcwd()), 27 | ] 28 | try: 29 | ld_library_path = os.environ.get("LD_LIBRARY_PATH") 30 | if ld_library_path: 31 | for path in set(ld_library_path.strip().split(os.pathsep)): 32 | path_sources.append((f"LD_LIBRARY_PATH {path} CUDA PATHS", path)) 33 | except Exception as e: 34 | print(f"Could not parse LD_LIBRARY_PATH: {e}") 35 | 36 | for name, path in path_sources: 37 | if path and os.path.isdir(path): 38 | print_header(name) 39 | print(list(find_dynamic_library(path, '*cuda*'))) 40 | print("") 41 | 42 | 43 | def print_header( 44 | txt: str, width: int = HEADER_WIDTH, filler: str = "+" 45 | ) -> None: 46 | txt = f" {txt} " if txt else "" 47 | print(txt.center(width, filler)) 48 | 49 | 50 | def print_debug_info() -> None: 51 | from . import PACKAGE_GITHUB_URL 52 | print( 53 | "\nAbove we output some debug information. Please provide this info when " 54 | f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n" 55 | ) 56 | 57 | 58 | def main(): 59 | generate_bug_report_information() 60 | 61 | from . import COMPILED_WITH_CUDA 62 | from .cuda_setup.main import get_compute_capabilities 63 | 64 | print_header("OTHER") 65 | print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") 66 | print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}") 67 | print_header("") 68 | print_header("DEBUG INFO END") 69 | print_header("") 70 | print("Checking that the library is importable and CUDA is callable...") 71 | print("\nWARNING: Please be sure to sanitize sensitive info from any such env vars!\n") 72 | 73 | try: 74 | from bitsandbytes.optim import Adam 75 | 76 | p = torch.nn.Parameter(torch.rand(10, 10).cuda()) 77 | a = torch.rand(10, 10).cuda() 78 | 79 | p1 = p.data.sum().item() 80 | 81 | adam = Adam([p]) 82 | 83 | out = a * p 84 | loss = out.sum() 85 | loss.backward() 86 | adam.step() 87 | 88 | p2 = p.data.sum().item() 89 | 90 | assert p1 != p2 91 | print("SUCCESS!") 92 | print("Installation was successful!") 93 | except ImportError: 94 | print() 95 | warn( 96 | f"WARNING: {__package__} is currently running as CPU-only!\n" 97 | "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" 98 | f"If you think that this is so erroneously,\nplease report an issue!" 99 | ) 100 | print_debug_info() 101 | except Exception as e: 102 | print(e) 103 | print_debug_info() 104 | sys.exit(1) 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /bitsandbytes/autograd/__init__.py: -------------------------------------------------------------------------------- 1 | from ._functions import get_inverse_transform_indices, undo_layout 2 | -------------------------------------------------------------------------------- /bitsandbytes/cextension.py: -------------------------------------------------------------------------------- 1 | import ctypes as ct 2 | from warnings import warn 3 | 4 | import torch 5 | 6 | from bitsandbytes.cuda_setup.main import CUDASetup 7 | 8 | setup = CUDASetup.get_instance() 9 | if setup.initialized != True: 10 | setup.run_cuda_setup() 11 | 12 | lib = setup.lib 13 | try: 14 | if lib is None and torch.cuda.is_available(): 15 | CUDASetup.get_instance().generate_instructions() 16 | CUDASetup.get_instance().print_log_stack() 17 | raise RuntimeError(''' 18 | CUDA Setup failed despite GPU being available. Please run the following command to get more information: 19 | 20 | python -m bitsandbytes 21 | 22 | Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them 23 | to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes 24 | and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') 25 | _ = lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False 26 | lib.get_context.restype = ct.c_void_p 27 | lib.get_cusparse.restype = ct.c_void_p 28 | lib.cget_managed_ptr.restype = ct.c_void_p 29 | COMPILED_WITH_CUDA = True 30 | except AttributeError as ex: 31 | warn("The installed version of bitsandbytes was compiled without GPU support. " 32 | "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") 33 | COMPILED_WITH_CUDA = False 34 | print(str(ex)) 35 | 36 | 37 | # print the setup details after checking for errors so we do not print twice 38 | #if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': 39 | #setup.print_log_stack() 40 | -------------------------------------------------------------------------------- /bitsandbytes/cuda_setup/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rickardp/bitsandbytes/927f7167e3395ec26f859f294c1d4979a70a718a/bitsandbytes/cuda_setup/__init__.py -------------------------------------------------------------------------------- /bitsandbytes/cuda_setup/env_vars.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | 5 | def to_be_ignored(env_var: str, value: str) -> bool: 6 | ignorable = { 7 | "PWD", # PWD: this is how the shell keeps track of the current working dir 8 | "OLDPWD", 9 | "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated 10 | "SSH_TTY", 11 | "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks 12 | "HOME", # Linux shell default 13 | "TMUX", # Terminal Multiplexer 14 | "XDG_DATA_DIRS", # XDG: Desktop environment stuff 15 | "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff 16 | "XDG_RUNTIME_DIR", 17 | "MAIL", # something related to emails 18 | "SHELL", # binary for currently invoked shell 19 | "DBUS_SESSION_BUS_ADDRESS", # hardware related 20 | "PATH", # this is for finding binaries, not libraries 21 | "LESSOPEN", # related to the `less` command 22 | "LESSCLOSE", 23 | "_", # current Python interpreter 24 | } 25 | return env_var in ignorable 26 | 27 | 28 | def might_contain_a_path(candidate: str) -> bool: 29 | return os.sep in candidate 30 | 31 | 32 | def is_active_conda_env(env_var: str) -> bool: 33 | return "CONDA_PREFIX" == env_var 34 | 35 | 36 | def is_other_conda_env_var(env_var: str) -> bool: 37 | return "CONDA" in env_var 38 | 39 | 40 | def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: 41 | return is_active_conda_env(env_var) or ( 42 | might_contain_a_path(value) and not 43 | is_other_conda_env_var(env_var) and not 44 | to_be_ignored(env_var, value) 45 | ) 46 | 47 | 48 | def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: 49 | return { 50 | env_var: value 51 | for env_var, value in os.environ.items() 52 | if is_relevant_candidate_env_var(env_var, value) 53 | } 54 | -------------------------------------------------------------------------------- /bitsandbytes/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from .modules import ( 6 | Embedding, 7 | Int8Params, 8 | Linear4bit, 9 | Linear8bitLt, 10 | LinearFP4, 11 | LinearNF4, 12 | OutlierAwareLinear, 13 | Params4bit, 14 | StableEmbedding, 15 | SwitchBackLinearBnb, 16 | ) 17 | from .triton_based_modules import ( 18 | StandardLinear, 19 | SwitchBackLinear, 20 | SwitchBackLinearGlobal, 21 | SwitchBackLinearVectorwise, 22 | ) 23 | -------------------------------------------------------------------------------- /bitsandbytes/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from bitsandbytes.cextension import COMPILED_WITH_CUDA 7 | 8 | from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit 9 | from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit 10 | from .adamw import ( 11 | AdamW, 12 | AdamW8bit, 13 | AdamW32bit, 14 | PagedAdamW, 15 | PagedAdamW8bit, 16 | PagedAdamW32bit, 17 | ) 18 | from .lamb import LAMB, LAMB8bit, LAMB32bit 19 | from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS 20 | from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit 21 | from .optimizer import GlobalOptimManager 22 | from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit 23 | from .sgd import SGD, SGD8bit, SGD32bit 24 | -------------------------------------------------------------------------------- /bitsandbytes/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer1State 6 | 7 | 8 | class Adagrad(Optimizer1State): 9 | def __init__( 10 | self, 11 | params, 12 | lr=1e-2, 13 | lr_decay=0, 14 | weight_decay=0, 15 | initial_accumulator_value=0, 16 | eps=1e-10, 17 | optim_bits=32, 18 | args=None, 19 | min_8bit_size=4096, 20 | percentile_clipping=100, 21 | block_wise=True, 22 | ): 23 | if not 0.0 <= lr: 24 | raise ValueError(f"Invalid learning rate: {lr}") 25 | if not 0.0 <= weight_decay: 26 | raise ValueError( 27 | f"Invalid weight_decay value: {weight_decay}" 28 | ) 29 | if not 0.0 <= eps: 30 | raise ValueError(f"Invalid epsilon value: {eps}") 31 | if initial_accumulator_value != 0.0: 32 | raise ValueError("Initial accumulator value != 0.0 not supported!") 33 | if lr_decay != 0.0: 34 | raise ValueError("Lr Decay != 0.0 not supported!") 35 | super().__init__( 36 | "adagrad", 37 | params, 38 | lr, 39 | (0.0, 0.0), 40 | eps, 41 | weight_decay, 42 | optim_bits, 43 | args, 44 | min_8bit_size, 45 | percentile_clipping, 46 | block_wise, 47 | ) 48 | 49 | 50 | class Adagrad8bit(Optimizer1State): 51 | def __init__( 52 | self, 53 | params, 54 | lr=1e-2, 55 | lr_decay=0, 56 | weight_decay=0, 57 | initial_accumulator_value=0, 58 | eps=1e-10, 59 | optim_bits=8, 60 | args=None, 61 | min_8bit_size=4096, 62 | percentile_clipping=100, 63 | block_wise=True, 64 | ): 65 | if not 0.0 <= lr: 66 | raise ValueError(f"Invalid learning rate: {lr}") 67 | if not 0.0 <= weight_decay: 68 | raise ValueError( 69 | f"Invalid weight_decay value: {weight_decay}" 70 | ) 71 | if not 0.0 <= eps: 72 | raise ValueError(f"Invalid epsilon value: {eps}") 73 | if initial_accumulator_value != 0.0: 74 | raise ValueError("Initial accumulator value != 0.0 not supported!") 75 | if lr_decay != 0.0: 76 | raise ValueError("Lr Decay != 0.0 not supported!") 77 | assert block_wise 78 | super().__init__( 79 | "adagrad", 80 | params, 81 | lr, 82 | (0.0, 0.0), 83 | eps, 84 | weight_decay, 85 | 8, 86 | args, 87 | min_8bit_size, 88 | percentile_clipping, 89 | block_wise, 90 | ) 91 | 92 | 93 | class Adagrad32bit(Optimizer1State): 94 | def __init__( 95 | self, 96 | params, 97 | lr=1e-2, 98 | lr_decay=0, 99 | weight_decay=0, 100 | initial_accumulator_value=0, 101 | eps=1e-10, 102 | optim_bits=32, 103 | args=None, 104 | min_8bit_size=4096, 105 | percentile_clipping=100, 106 | block_wise=True, 107 | ): 108 | if not 0.0 <= lr: 109 | raise ValueError(f"Invalid learning rate: {lr}") 110 | if not 0.0 <= weight_decay: 111 | raise ValueError( 112 | f"Invalid weight_decay value: {weight_decay}" 113 | ) 114 | if not 0.0 <= eps: 115 | raise ValueError(f"Invalid epsilon value: {eps}") 116 | if initial_accumulator_value != 0.0: 117 | raise ValueError("Initial accumulator value != 0.0 not supported!") 118 | if lr_decay != 0.0: 119 | raise ValueError("Lr Decay != 0.0 not supported!") 120 | super().__init__( 121 | "adagrad", 122 | params, 123 | lr, 124 | (0.0, 0.0), 125 | eps, 126 | weight_decay, 127 | 32, 128 | args, 129 | min_8bit_size, 130 | percentile_clipping, 131 | block_wise, 132 | ) 133 | -------------------------------------------------------------------------------- /bitsandbytes/optim/adamw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer2State 6 | 7 | 8 | class AdamW(Optimizer2State): 9 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, 10 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 11 | super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) 12 | 13 | class AdamW8bit(Optimizer2State): 14 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, 15 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 16 | super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) 17 | 18 | class AdamW32bit(Optimizer2State): 19 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, 20 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 21 | super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) 22 | 23 | 24 | class PagedAdamW(Optimizer2State): 25 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, 26 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): 27 | super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) 28 | 29 | class PagedAdamW8bit(Optimizer2State): 30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, 31 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): 32 | super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) 33 | 34 | class PagedAdamW32bit(Optimizer2State): 35 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, 36 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): 37 | super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) 38 | -------------------------------------------------------------------------------- /bitsandbytes/optim/lamb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer2State 6 | 7 | 8 | class LAMB(Optimizer2State): 9 | def __init__( 10 | self, 11 | params, 12 | lr=1e-3, 13 | bias_correction=True, 14 | betas=(0.9, 0.999), 15 | eps=1e-8, 16 | weight_decay=0, 17 | amsgrad=False, 18 | adam_w_mode=True, 19 | optim_bits=32, 20 | args=None, 21 | min_8bit_size=4096, 22 | percentile_clipping=100, 23 | block_wise=False, 24 | max_unorm=1.0, 25 | ): 26 | super().__init__( 27 | "lamb", 28 | params, 29 | lr, 30 | betas, 31 | eps, 32 | weight_decay, 33 | optim_bits, 34 | args, 35 | min_8bit_size, 36 | percentile_clipping, 37 | block_wise, 38 | max_unorm=1.0, 39 | ) 40 | 41 | 42 | class LAMB8bit(Optimizer2State): 43 | def __init__( 44 | self, 45 | params, 46 | lr=1e-3, 47 | bias_correction=True, 48 | betas=(0.9, 0.999), 49 | eps=1e-8, 50 | weight_decay=0, 51 | amsgrad=False, 52 | adam_w_mode=True, 53 | args=None, 54 | min_8bit_size=4096, 55 | percentile_clipping=100, 56 | block_wise=False, 57 | max_unorm=1.0, 58 | ): 59 | super().__init__( 60 | "lamb", 61 | params, 62 | lr, 63 | betas, 64 | eps, 65 | weight_decay, 66 | 8, 67 | args, 68 | min_8bit_size, 69 | percentile_clipping, 70 | block_wise, 71 | max_unorm=1.0, 72 | ) 73 | 74 | 75 | class LAMB32bit(Optimizer2State): 76 | def __init__( 77 | self, 78 | params, 79 | lr=1e-3, 80 | bias_correction=True, 81 | betas=(0.9, 0.999), 82 | eps=1e-8, 83 | weight_decay=0, 84 | amsgrad=False, 85 | adam_w_mode=True, 86 | args=None, 87 | min_8bit_size=4096, 88 | percentile_clipping=100, 89 | block_wise=False, 90 | max_unorm=1.0, 91 | ): 92 | super().__init__( 93 | "lamb", 94 | params, 95 | lr, 96 | betas, 97 | eps, 98 | weight_decay, 99 | 32, 100 | args, 101 | min_8bit_size, 102 | percentile_clipping, 103 | block_wise, 104 | max_unorm=1.0, 105 | ) 106 | -------------------------------------------------------------------------------- /bitsandbytes/optim/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import torch 6 | from torch.optim import Optimizer 7 | 8 | from bitsandbytes.optim.optimizer import Optimizer1State 9 | 10 | 11 | class LARS(Optimizer1State): 12 | def __init__( 13 | self, 14 | params, 15 | lr, 16 | momentum=0, 17 | dampening=0, 18 | weight_decay=0, 19 | nesterov=False, 20 | optim_bits=32, 21 | args=None, 22 | min_8bit_size=4096, 23 | percentile_clipping=100, 24 | max_unorm=0.02, 25 | ): 26 | if momentum == 0: 27 | raise NotImplementedError( 28 | "LARS without momentum is not supported!" 29 | ) 30 | super().__init__( 31 | "lars", 32 | params, 33 | lr, 34 | (momentum, dampening), 35 | 0.0, 36 | weight_decay, 37 | optim_bits, 38 | args, 39 | min_8bit_size, 40 | percentile_clipping, 41 | max_unorm=max_unorm, 42 | block_wise=False, 43 | ) 44 | 45 | 46 | class LARS8bit(Optimizer1State): 47 | def __init__( 48 | self, 49 | params, 50 | lr, 51 | momentum=0, 52 | dampening=0, 53 | weight_decay=0, 54 | nesterov=False, 55 | args=None, 56 | min_8bit_size=4096, 57 | percentile_clipping=100, 58 | max_unorm=0.02, 59 | ): 60 | if momentum == 0: 61 | raise NotImplementedError( 62 | "LARS without momentum is not supported!" 63 | ) 64 | super().__init__( 65 | "lars", 66 | params, 67 | lr, 68 | (momentum, dampening), 69 | 0.0, 70 | weight_decay, 71 | 8, 72 | args, 73 | min_8bit_size, 74 | percentile_clipping, 75 | max_unorm=max_unorm, 76 | block_wise=False, 77 | ) 78 | 79 | 80 | class LARS32bit(Optimizer1State): 81 | def __init__( 82 | self, 83 | params, 84 | lr, 85 | momentum=0, 86 | dampening=0, 87 | weight_decay=0, 88 | nesterov=False, 89 | args=None, 90 | min_8bit_size=4096, 91 | percentile_clipping=100, 92 | max_unorm=0.02, 93 | ): 94 | if momentum == 0: 95 | raise NotImplementedError( 96 | "LARS without momentum is not supported!" 97 | ) 98 | super().__init__( 99 | "lars", 100 | params, 101 | lr, 102 | (momentum, dampening), 103 | 0.0, 104 | weight_decay, 105 | 32, 106 | args, 107 | min_8bit_size, 108 | percentile_clipping, 109 | max_unorm=max_unorm, 110 | block_wise=False, 111 | ) 112 | 113 | 114 | class PytorchLARS(Optimizer): 115 | def __init__( 116 | self, 117 | params, 118 | lr=0.01, 119 | momentum=0, 120 | dampening=0, 121 | weight_decay=0, 122 | nesterov=False, 123 | max_unorm=0.02, 124 | ): 125 | if lr < 0.0: 126 | raise ValueError(f"Invalid learning rate: {lr}") 127 | if momentum < 0.0: 128 | raise ValueError(f"Invalid momentum value: {momentum}") 129 | if weight_decay < 0.0: 130 | raise ValueError( 131 | f"Invalid weight_decay value: {weight_decay}" 132 | ) 133 | 134 | defaults = dict( 135 | lr=lr, 136 | momentum=momentum, 137 | dampening=dampening, 138 | weight_decay=weight_decay, 139 | nesterov=nesterov, 140 | max_unorm=max_unorm, 141 | ) 142 | if nesterov and (momentum <= 0 or dampening != 0): 143 | raise ValueError( 144 | "Nesterov momentum requires a momentum and zero dampening" 145 | ) 146 | super().__init__(params, defaults) 147 | 148 | def __setstate__(self, state): 149 | super().__setstate__(state) 150 | for group in self.param_groups: 151 | group.setdefault("nesterov", False) 152 | 153 | @torch.no_grad() 154 | def step(self, closure=None): 155 | """Performs a single optimization step. 156 | 157 | Args: 158 | closure (callable, optional): A closure that reevaluates the model 159 | and returns the loss. 160 | """ 161 | loss = None 162 | if closure is not None: 163 | with torch.enable_grad(): 164 | loss = closure() 165 | 166 | for group in self.param_groups: 167 | params_with_grad = [] 168 | d_p_list = [] 169 | momentum_buffer_list = [] 170 | weight_decay = group["weight_decay"] 171 | momentum = group["momentum"] 172 | dampening = group["dampening"] 173 | nesterov = group["nesterov"] 174 | max_unorm = group["max_unorm"] 175 | lr = group["lr"] 176 | 177 | for p in group["params"]: 178 | if p.grad is None: 179 | continue 180 | 181 | state = self.state[p] 182 | d_p = p.grad 183 | if weight_decay != 0: 184 | d_p = d_p.add(p, alpha=weight_decay) 185 | 186 | if momentum != 0: 187 | buf = state.get("momentum_buffer", None) 188 | 189 | if buf is None: 190 | buf = torch.clone(d_p).detach() 191 | state["momentum_buffer"] = buf 192 | else: 193 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 194 | 195 | if nesterov: 196 | update = d_p + buf * momentum 197 | else: 198 | update = buf 199 | 200 | update_scale = 1.0 201 | if max_unorm > 0.0: 202 | assert p.dtype == torch.float32 203 | pnorm = torch.norm(p.detach()) 204 | unorm = torch.norm(update) 205 | if unorm > max_unorm * pnorm: 206 | update_scale = max_unorm * pnorm / unorm 207 | 208 | p.add_(update, alpha=-lr * update_scale) 209 | 210 | return loss 211 | -------------------------------------------------------------------------------- /bitsandbytes/optim/lion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer1State 6 | 7 | 8 | class Lion(Optimizer1State): 9 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 10 | super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) 11 | 12 | class Lion8bit(Optimizer1State): 13 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 14 | super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) 15 | 16 | class Lion32bit(Optimizer1State): 17 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 18 | super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) 19 | 20 | 21 | class PagedLion(Optimizer1State): 22 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): 23 | super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) 24 | 25 | class PagedLion8bit(Optimizer1State): 26 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): 27 | super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) 28 | 29 | class PagedLion32bit(Optimizer1State): 30 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): 31 | super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) 32 | -------------------------------------------------------------------------------- /bitsandbytes/optim/rmsprop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer1State 6 | 7 | 8 | class RMSprop(Optimizer1State): 9 | def __init__( 10 | self, 11 | params, 12 | lr=1e-2, 13 | alpha=0.99, 14 | eps=1e-8, 15 | weight_decay=0, 16 | momentum=0, 17 | centered=False, 18 | optim_bits=32, 19 | args=None, 20 | min_8bit_size=4096, 21 | percentile_clipping=100, 22 | block_wise=True, 23 | ): 24 | if alpha == 0: 25 | raise NotImplementedError( 26 | "RMSprop with alpha==0.0 is not supported!" 27 | ) 28 | if centered: 29 | raise NotImplementedError("Centered RMSprop is not supported!") 30 | super().__init__( 31 | "rmsprop", 32 | params, 33 | lr, 34 | (alpha, momentum), 35 | eps, 36 | weight_decay, 37 | optim_bits, 38 | args, 39 | min_8bit_size, 40 | percentile_clipping, 41 | block_wise, 42 | ) 43 | 44 | 45 | class RMSprop8bit(Optimizer1State): 46 | def __init__( 47 | self, 48 | params, 49 | lr=1e-2, 50 | alpha=0.99, 51 | eps=1e-8, 52 | weight_decay=0, 53 | momentum=0, 54 | centered=False, 55 | args=None, 56 | min_8bit_size=4096, 57 | percentile_clipping=100, 58 | block_wise=True, 59 | ): 60 | if alpha == 0: 61 | raise NotImplementedError( 62 | "RMSprop with alpha==0.0 is not supported!" 63 | ) 64 | if centered: 65 | raise NotImplementedError("Centered RMSprop is not supported!") 66 | super().__init__( 67 | "rmsprop", 68 | params, 69 | lr, 70 | (alpha, momentum), 71 | eps, 72 | weight_decay, 73 | 8, 74 | args, 75 | min_8bit_size, 76 | percentile_clipping, 77 | block_wise, 78 | ) 79 | 80 | 81 | class RMSprop32bit(Optimizer1State): 82 | def __init__( 83 | self, 84 | params, 85 | lr=1e-2, 86 | alpha=0.99, 87 | eps=1e-8, 88 | weight_decay=0, 89 | momentum=0, 90 | centered=False, 91 | args=None, 92 | min_8bit_size=4096, 93 | percentile_clipping=100, 94 | block_wise=True, 95 | ): 96 | 97 | if alpha == 0: 98 | raise NotImplementedError( 99 | "RMSprop with alpha==0.0 is not supported!" 100 | ) 101 | if centered: 102 | raise NotImplementedError("Centered RMSprop is not supported!") 103 | super().__init__( 104 | "rmsprop", 105 | params, 106 | lr, 107 | (alpha, momentum), 108 | eps, 109 | weight_decay, 110 | 32, 111 | args, 112 | min_8bit_size, 113 | percentile_clipping, 114 | block_wise, 115 | ) 116 | -------------------------------------------------------------------------------- /bitsandbytes/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer1State 6 | 7 | 8 | class SGD(Optimizer1State): 9 | def __init__( 10 | self, 11 | params, 12 | lr, 13 | momentum=0, 14 | dampening=0, 15 | weight_decay=0, 16 | nesterov=False, 17 | optim_bits=32, 18 | args=None, 19 | min_8bit_size=4096, 20 | percentile_clipping=100, 21 | block_wise=True, 22 | ): 23 | if momentum == 0: 24 | raise NotImplementedError("SGD without momentum is not supported!") 25 | super().__init__( 26 | "momentum", 27 | params, 28 | lr, 29 | (momentum, dampening), 30 | 0.0, 31 | weight_decay, 32 | optim_bits, 33 | args, 34 | min_8bit_size, 35 | percentile_clipping, 36 | block_wise, 37 | ) 38 | 39 | 40 | class SGD8bit(Optimizer1State): 41 | def __init__( 42 | self, 43 | params, 44 | lr, 45 | momentum=0, 46 | dampening=0, 47 | weight_decay=0, 48 | nesterov=False, 49 | args=None, 50 | min_8bit_size=4096, 51 | percentile_clipping=100, 52 | block_wise=True, 53 | ): 54 | if momentum == 0: 55 | raise NotImplementedError("SGD without momentum is not supported!") 56 | super().__init__( 57 | "momentum", 58 | params, 59 | lr, 60 | (momentum, dampening), 61 | 0.0, 62 | weight_decay, 63 | 8, 64 | args, 65 | min_8bit_size, 66 | percentile_clipping, 67 | block_wise, 68 | ) 69 | 70 | 71 | class SGD32bit(Optimizer1State): 72 | def __init__( 73 | self, 74 | params, 75 | lr, 76 | momentum=0, 77 | dampening=0, 78 | weight_decay=0, 79 | nesterov=False, 80 | args=None, 81 | min_8bit_size=4096, 82 | percentile_clipping=100, 83 | block_wise=True, 84 | ): 85 | if momentum == 0: 86 | raise NotImplementedError("SGD without momentum is not supported!") 87 | super().__init__( 88 | "momentum", 89 | params, 90 | lr, 91 | (momentum, dampening), 92 | 0.0, 93 | weight_decay, 94 | 32, 95 | args, 96 | min_8bit_size, 97 | percentile_clipping, 98 | block_wise, 99 | ) 100 | -------------------------------------------------------------------------------- /bitsandbytes/research/__init__.py: -------------------------------------------------------------------------------- 1 | from . import nn 2 | from .autograd._functions import ( 3 | matmul_fp8_global, 4 | matmul_fp8_mixed, 5 | switchback_bnb, 6 | ) 7 | -------------------------------------------------------------------------------- /bitsandbytes/research/autograd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rickardp/bitsandbytes/927f7167e3395ec26f859f294c1d4979a70a718a/bitsandbytes/research/autograd/__init__.py -------------------------------------------------------------------------------- /bitsandbytes/research/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import LinearFP8Global, LinearFP8Mixed 2 | -------------------------------------------------------------------------------- /bitsandbytes/research/nn/modules.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import bitsandbytes as bnb 7 | 8 | T = TypeVar("T", bound="torch.nn.Module") 9 | 10 | 11 | class LinearFP8Mixed(nn.Linear): 12 | def __init__(self, input_features, output_features, bias=True): 13 | super().__init__(input_features, output_features, bias) 14 | self.bw_code = None 15 | self.fw_code = None 16 | array = [4096, 2048, 1024, 512, 256, 128, 64, 0] 17 | for i, k in enumerate(array): 18 | if input_features > array[i + 1]: 19 | self.bsz = k 20 | break 21 | for i, k in enumerate(array): 22 | if output_features > array[i + 1]: 23 | self.bsz2 = k 24 | break 25 | 26 | def forward(self, x: torch.Tensor): 27 | if self.fw_code is None: 28 | self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) 29 | self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) 30 | 31 | out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) 32 | if self.bias is not None: 33 | out += self.bias 34 | 35 | return out 36 | 37 | class LinearFP8Global(nn.Linear): 38 | def __init__(self, input_features, output_features, bias=True): 39 | super().__init__(input_features, output_features, bias) 40 | self.bw_code = None 41 | self.fw_code = None 42 | array = [4096, 2048, 1024, 512, 256, 128, 64, 0] 43 | for i, k in enumerate(array): 44 | if input_features > array[i + 1]: 45 | self.bsz = k 46 | break 47 | for i, k in enumerate(array): 48 | if output_features > array[i + 1]: 49 | self.bsz2 = k 50 | break 51 | 52 | def forward(self, x: torch.Tensor): 53 | if self.fw_code is None: 54 | self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) 55 | self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) 56 | 57 | out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) 58 | if self.bias is not None: 59 | out += self.bias 60 | 61 | return out 62 | -------------------------------------------------------------------------------- /bitsandbytes/triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rickardp/bitsandbytes/927f7167e3395ec26f859f294c1d4979a70a718a/bitsandbytes/triton/__init__.py -------------------------------------------------------------------------------- /bitsandbytes/triton/dequantize_rowwise.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from bitsandbytes.triton.triton_utils import is_triton_available 6 | 7 | if not is_triton_available(): 8 | def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None 9 | else: 10 | 11 | import triton 12 | import triton.language as tl 13 | 14 | # rowwise quantize 15 | 16 | # TODO: autotune this better. 17 | @triton.autotune( 18 | configs=[ 19 | triton.Config({}, num_stages=1, num_warps=8), 20 | triton.Config({}, num_stages=2, num_warps=8), 21 | triton.Config({}, num_stages=4, num_warps=8), 22 | triton.Config({}, num_stages=8, num_warps=8), 23 | triton.Config({}, num_stages=1), 24 | triton.Config({}, num_stages=2), 25 | triton.Config({}, num_stages=4), 26 | triton.Config({}, num_stages=8), 27 | triton.Config({}, num_warps=1), 28 | triton.Config({}, num_warps=2), 29 | triton.Config({}, num_warps=4), 30 | triton.Config({}, num_warps=8), 31 | ], 32 | key=['n_elements'] 33 | ) 34 | @triton.jit 35 | def _dequantize_rowwise( 36 | x_ptr, 37 | state_x, 38 | output_ptr, 39 | inv_127, 40 | n_elements, 41 | BLOCK_SIZE: tl.constexpr, 42 | P2: tl.constexpr, 43 | ): 44 | pid = tl.program_id(axis=0) 45 | block_start = pid * BLOCK_SIZE 46 | arange = tl.arange(0, P2) 47 | offsets = block_start + arange 48 | row_mask = arange < BLOCK_SIZE 49 | x = tl.load(x_ptr + offsets, mask=row_mask) 50 | max_val = tl.load(state_x + pid) 51 | output = max_val * x * inv_127 52 | tl.store(output_ptr + offsets, output, mask=row_mask) 53 | 54 | 55 | def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): 56 | output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) 57 | 58 | P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) 59 | 60 | assert x.is_cuda and output.is_cuda 61 | n_elements = output.numel() 62 | grid = lambda meta: (x.shape[0],) 63 | _dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) 64 | return output 65 | -------------------------------------------------------------------------------- /bitsandbytes/triton/int8_matmul_mixed_dequantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from bitsandbytes.triton.triton_utils import is_triton_available 4 | 5 | if not is_triton_available(): 6 | def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None 7 | else: 8 | 9 | import triton 10 | import triton.language as tl 11 | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time 12 | 13 | 14 | # This is a matmul kernel based on triton.ops.matmul 15 | # It is modified to support rowwise quantized input and global quantized weight 16 | # It's purpose is fused matmul then dequantize 17 | # It does support bias. 18 | 19 | def init_to_zero(name): 20 | return lambda nargs: nargs[name].zero_() 21 | 22 | def get_configs_io_bound(): 23 | configs = [] 24 | for num_stages in [2, 3, 4, 5, 6]: 25 | for block_m in [16, 32]: 26 | for block_k in [32, 64]: 27 | for block_n in [32, 64, 128, 256]: 28 | num_warps = 2 if block_n <= 64 else 4 29 | configs.append( 30 | triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, 31 | num_stages=num_stages, num_warps=num_warps)) 32 | # split_k 33 | for split_k in [2, 4, 8, 16]: 34 | configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, 35 | num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) 36 | return configs 37 | 38 | 39 | @triton.autotune( 40 | configs=[ 41 | # basic configs for compute-bound matmuls 42 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 43 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 44 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 45 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 46 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 47 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 48 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 49 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 50 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), 51 | # good for int8 52 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 53 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 54 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 55 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 56 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 57 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 58 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 59 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 60 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), 61 | *get_configs_io_bound(), 62 | ], 63 | key=['M', 'N', 'K'], 64 | prune_configs_by={ 65 | 'early_config_prune': early_config_prune, 66 | 'perf_model': estimate_matmul_time, 67 | 'top_k': 10 68 | }, 69 | ) 70 | @triton.heuristics({ 71 | 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, 72 | }) 73 | @triton.jit 74 | def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, 75 | stride_am, stride_ak, 76 | stride_bk, stride_bn, 77 | stride_cm, stride_cn, 78 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 79 | GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, 80 | ACC_TYPE: tl.constexpr 81 | ): 82 | # matrix multiplication 83 | pid = tl.program_id(0) 84 | pid_z = tl.program_id(1) 85 | grid_m = tl.cdiv(M, BLOCK_M) 86 | grid_n = tl.cdiv(N, BLOCK_N) 87 | # re-order program ID for better L2 performance 88 | width = GROUP_M * grid_n 89 | group_id = pid // width 90 | group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 91 | pid_m = group_id * GROUP_M + (pid % group_size) 92 | pid_n = (pid % width) // (group_size) 93 | # do matrix multiplication 94 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 95 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 96 | ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 97 | rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 98 | rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 99 | # pointers 100 | A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 101 | B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 102 | 103 | # rematerialize rm and rn to save registers 104 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 105 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 106 | 107 | w_factor = tl.load(state_w_ptr) 108 | x_factor = tl.load(state_x_ptr + ram)[:, None] 109 | 110 | # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) 111 | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) 112 | for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): 113 | if EVEN_K: 114 | a = tl.load(A) 115 | b = tl.load(B) 116 | else: 117 | k_remaining = K - k * (BLOCK_K * SPLIT_K) 118 | a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) 119 | b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) 120 | acc += tl.dot(a, b) 121 | A += BLOCK_K * SPLIT_K * stride_ak 122 | B += BLOCK_K * SPLIT_K * stride_bk 123 | 124 | acc = (w_factor * (x_factor * (acc * divfactor))) 125 | acc = acc.to(C.dtype.element_ty) 126 | 127 | # conditionally add bias 128 | if has_bias: 129 | bias = tl.load(bias + rn).to(C.dtype.element_ty) 130 | acc = acc + bias[None, :] 131 | 132 | C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 133 | mask = (rm < M)[:, None] & (rn < N)[None, :] 134 | # handles write-back with reduction-splitting 135 | if SPLIT_K == 1: 136 | tl.store(C, acc, mask=mask) 137 | else: 138 | tl.atomic_add(C, acc, mask=mask) 139 | 140 | 141 | def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): 142 | device = a.device 143 | divfactor = 1. / (127. * 127.) 144 | has_bias = 0 if bias is None else 1 145 | # handle non-contiguous inputs if necessary 146 | if a.stride(0) > 1 and a.stride(1) > 1: 147 | a = a.contiguous() 148 | if b.stride(0) > 1 and b.stride(1) > 1: 149 | b = b.contiguous() 150 | # checks constraints 151 | assert a.shape[1] == b.shape[0], "incompatible dimensions" 152 | M, K = a.shape 153 | _, N = b.shape 154 | # allocates output 155 | c = torch.empty((M, N), device=device, dtype=torch.float16) 156 | # accumulator types 157 | ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 158 | # launch int8_matmul_mixed_dequantize kernel 159 | grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) 160 | _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, 161 | a.stride(0), a.stride(1), 162 | b.stride(0), b.stride(1), 163 | c.stride(0), c.stride(1), 164 | GROUP_M=8, ACC_TYPE=ACC_TYPE) 165 | return c 166 | -------------------------------------------------------------------------------- /bitsandbytes/triton/int8_matmul_rowwise_dequantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from bitsandbytes.triton.triton_utils import is_triton_available 4 | 5 | if not is_triton_available(): 6 | def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None 7 | else: 8 | import triton 9 | import triton.language as tl 10 | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time 11 | 12 | # This is a matmul kernel based on triton.ops.matmul 13 | # It is modified to support rowwise quantized input and columnwise quantized weight 14 | # It's purpose is fused matmul then dequantize 15 | # It does support bias. 16 | 17 | def init_to_zero(name): 18 | return lambda nargs: nargs[name].zero_() 19 | 20 | 21 | def get_configs_io_bound(): 22 | configs = [] 23 | for num_stages in [2, 3, 4, 5, 6]: 24 | for block_m in [16, 32]: 25 | for block_k in [32, 64]: 26 | for block_n in [32, 64, 128, 256]: 27 | num_warps = 2 if block_n <= 64 else 4 28 | configs.append( 29 | triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, 30 | num_stages=num_stages, num_warps=num_warps)) 31 | # split_k 32 | for split_k in [2, 4, 8, 16]: 33 | configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, 34 | num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) 35 | return configs 36 | 37 | 38 | @triton.autotune( 39 | configs=[ 40 | # basic configs for compute-bound matmuls 41 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 42 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 43 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 44 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 45 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 46 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 47 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 48 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 49 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), 50 | # good for int8 51 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 52 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 53 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 54 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 55 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 56 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 57 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 58 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 59 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), 60 | *get_configs_io_bound(), 61 | ], 62 | key=['M', 'N', 'K'], 63 | prune_configs_by={ 64 | 'early_config_prune': early_config_prune, 65 | 'perf_model': estimate_matmul_time, 66 | 'top_k': 10 67 | }, 68 | ) 69 | @triton.heuristics({ 70 | 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, 71 | }) 72 | @triton.jit 73 | def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, 74 | stride_am, stride_ak, 75 | stride_bk, stride_bn, 76 | stride_cm, stride_cn, 77 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 78 | GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, 79 | ACC_TYPE: tl.constexpr 80 | ): 81 | # matrix multiplication 82 | pid = tl.program_id(0) 83 | pid_z = tl.program_id(1) 84 | grid_m = tl.cdiv(M, BLOCK_M) 85 | grid_n = tl.cdiv(N, BLOCK_N) 86 | # re-order program ID for better L2 performance 87 | width = GROUP_M * grid_n 88 | group_id = pid // width 89 | group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 90 | pid_m = group_id * GROUP_M + (pid % group_size) 91 | pid_n = (pid % width) // (group_size) 92 | # do matrix multiplication 93 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 94 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 95 | ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 96 | rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 97 | rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 98 | # pointers 99 | A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 100 | B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 101 | 102 | # rematerialize rm and rn to save registers 103 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 104 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 105 | 106 | w_factor = tl.load(state_w_ptr + rbn)[None, :] 107 | x_factor = tl.load(state_x_ptr + ram)[:, None] 108 | 109 | # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) 110 | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) 111 | for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): 112 | if EVEN_K: 113 | a = tl.load(A) 114 | b = tl.load(B) 115 | else: 116 | k_remaining = K - k * (BLOCK_K * SPLIT_K) 117 | a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) 118 | b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) 119 | acc += tl.dot(a, b) 120 | A += BLOCK_K * SPLIT_K * stride_ak 121 | B += BLOCK_K * SPLIT_K * stride_bk 122 | 123 | acc = (w_factor * (x_factor * (acc * divfactor))) 124 | acc = acc.to(C.dtype.element_ty) 125 | 126 | if has_bias: 127 | bias = tl.load(bias + rn).to(C.dtype.element_ty) 128 | acc = acc + bias[None, :] 129 | 130 | C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 131 | mask = (rm < M)[:, None] & (rn < N)[None, :] 132 | # handles write-back with reduction-splitting 133 | if SPLIT_K == 1: 134 | tl.store(C, acc, mask=mask) 135 | else: 136 | tl.atomic_add(C, acc, mask=mask) 137 | 138 | 139 | def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): 140 | divfactor = 1. / (127. * 127.) 141 | 142 | has_bias = 0 if bias is None else 1 143 | 144 | device = a.device 145 | # handle non-contiguous inputs if necessary 146 | if a.stride(0) > 1 and a.stride(1) > 1: 147 | a = a.contiguous() 148 | if b.stride(0) > 1 and b.stride(1) > 1: 149 | b = b.contiguous() 150 | # checks constraints 151 | assert a.shape[1] == b.shape[0], "incompatible dimensions" 152 | M, K = a.shape 153 | _, N = b.shape 154 | # allocates output 155 | c = torch.empty((M, N), device=device, dtype=torch.float16) 156 | # accumulator types 157 | ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 158 | # launch int8_matmul_rowwise_dequantize kernel 159 | grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) 160 | _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, 161 | a.stride(0), a.stride(1), 162 | b.stride(0), b.stride(1), 163 | c.stride(0), c.stride(1), 164 | GROUP_M=8, ACC_TYPE=ACC_TYPE) 165 | return c 166 | -------------------------------------------------------------------------------- /bitsandbytes/triton/quantize_columnwise_and_transpose.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from bitsandbytes.triton.triton_utils import is_triton_available 6 | 7 | if not is_triton_available(): 8 | def quantize_columnwise_and_transpose(x: torch.Tensor): return None 9 | else: 10 | 11 | import triton 12 | import triton.language as tl 13 | 14 | # This kernel does fused columnwise quantization and transpose. 15 | 16 | # TODO: autotune this better. 17 | @triton.autotune( 18 | configs=[ 19 | triton.Config({}, num_stages=1), 20 | triton.Config({}, num_stages=2), 21 | triton.Config({}, num_stages=4), 22 | triton.Config({}, num_stages=8), 23 | triton.Config({}, num_stages=16), 24 | triton.Config({}, num_stages=1, num_warps=8), 25 | triton.Config({}, num_stages=2, num_warps=8), 26 | triton.Config({}, num_stages=4, num_warps=8), 27 | triton.Config({}, num_stages=8, num_warps=8), 28 | triton.Config({}, num_stages=16, num_warps=8), 29 | triton.Config({}, num_warps=1), 30 | triton.Config({}, num_warps=2), 31 | triton.Config({}, num_warps=4), 32 | triton.Config({}, num_warps=8), 33 | ], 34 | key=['n_elements'] 35 | ) 36 | @triton.jit 37 | def _quantize_columnwise_and_transpose( 38 | x_ptr, 39 | output_ptr, 40 | output_maxs, 41 | n_elements, 42 | M : tl.constexpr, N : tl.constexpr, 43 | BLOCK_SIZE: tl.constexpr, 44 | P2: tl.constexpr, 45 | ): 46 | pid = tl.program_id(axis=0) 47 | block_start = pid 48 | p2_arange = tl.arange(0, P2) 49 | p2_arange_mask = p2_arange < M 50 | arange = p2_arange * N 51 | offsets = block_start + arange 52 | x = tl.load(x_ptr + offsets, mask=p2_arange_mask) 53 | abs_x = tl.abs(x) 54 | max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) 55 | output = tl.libdevice.llrint(127. * (x / max_val)) 56 | 57 | new_start = pid * M 58 | new_offsets = new_start + p2_arange 59 | tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) 60 | tl.store(output_maxs + pid, max_val) 61 | 62 | def quantize_columnwise_and_transpose(x: torch.Tensor): 63 | M, N = x.shape 64 | output = torch.empty(N, M, device=x.device, dtype=torch.int8) 65 | output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) 66 | 67 | P2 = int(2 ** (math.ceil(math.log2(M)))) 68 | 69 | assert x.is_cuda and output.is_cuda 70 | n_elements = output.numel() 71 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 72 | _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) 73 | return output, output_maxs 74 | -------------------------------------------------------------------------------- /bitsandbytes/triton/quantize_global.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from bitsandbytes.triton.triton_utils import is_triton_available 5 | 6 | if not is_triton_available(): 7 | def quantize_global_transpose(input): return None 8 | def quantize_global(x: torch.Tensor): return None 9 | else: 10 | 11 | import triton 12 | import triton.language as tl 13 | 14 | # global quantize 15 | @triton.autotune( 16 | configs=[ 17 | triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), 18 | triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), 19 | 20 | ], 21 | key=['n_elements'] 22 | ) 23 | @triton.jit 24 | def _quantize_global( 25 | x_ptr, 26 | absmax_inv_ptr, 27 | output_ptr, 28 | n_elements, 29 | BLOCK_SIZE: tl.constexpr, 30 | ): 31 | pid = tl.program_id(axis=0) 32 | block_start = pid * BLOCK_SIZE 33 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 34 | mask = offsets < n_elements 35 | x = tl.load(x_ptr + offsets, mask=mask) 36 | absmax_inv = tl.load(absmax_inv_ptr) 37 | output = tl.libdevice.llrint(127. * (x * absmax_inv)) 38 | tl.store(output_ptr + offsets, output, mask=mask) 39 | 40 | def quantize_global(x: torch.Tensor): 41 | absmax = x.abs().max().unsqueeze(0) 42 | absmax_inv = 1./ absmax 43 | output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) 44 | assert x.is_cuda and output.is_cuda 45 | n_elements = output.numel() 46 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 47 | _quantize_global[grid](x, absmax_inv, output, n_elements) 48 | return output, absmax 49 | 50 | 51 | # global quantize and transpose 52 | @triton.autotune( 53 | configs=[ 54 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), 55 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), 56 | 57 | # ... 58 | ], 59 | key=['M', 'N'] 60 | ) 61 | @triton.jit 62 | def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, 63 | BLOCK_M : tl.constexpr, 64 | BLOCK_N : tl.constexpr, 65 | GROUP_M : tl.constexpr): 66 | pid = tl.program_id(0) 67 | grid_m = (M + BLOCK_M - 1) // BLOCK_M 68 | grid_n = (N + BLOCK_N - 1) // BLOCK_N 69 | 70 | width = GROUP_M * grid_n 71 | group_id = pid // width 72 | group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 73 | pid_m = group_id * GROUP_M + (pid % group_size) 74 | pid_n = (pid % width) // group_size 75 | 76 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 77 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 78 | A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) 79 | mask = (rm < M)[:, None] & (rn < N)[None, :] 80 | a = tl.load(A, mask=mask) 81 | absmax_inv = tl.load(absmax_inv_ptr) 82 | 83 | # rematerialize to save registers 84 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 85 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 86 | B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) 87 | mask = (rm < M)[:, None] & (rn < N)[None, :] 88 | 89 | output = tl.libdevice.llrint(127. * (a * absmax_inv)) 90 | 91 | tl.store(B, output, mask=mask) 92 | 93 | def quantize_global_transpose(input): 94 | absmax = input.abs().max().unsqueeze(0) 95 | absmax_inv = 1./ absmax 96 | M, N = input.shape 97 | out = torch.empty(N, M, device='cuda', dtype=torch.int8) 98 | 99 | assert out.size(0) == N and out.size(1) == M 100 | assert input.stride(0) == 1 or input.stride(1) == 1 101 | assert out.stride(0) == 1 or out.stride(1) == 1 102 | 103 | grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) 104 | _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) 105 | return out, absmax 106 | -------------------------------------------------------------------------------- /bitsandbytes/triton/quantize_rowwise.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from bitsandbytes.triton.triton_utils import is_triton_available 6 | 7 | if not is_triton_available(): 8 | def quantize_rowwise(x: torch.Tensor): return None 9 | else: 10 | 11 | import triton 12 | import triton.language as tl 13 | 14 | # rowwise quantize 15 | 16 | # TODO: autotune this better. 17 | @triton.autotune( 18 | configs=[ 19 | triton.Config({}, num_stages=1, num_warps=8), 20 | triton.Config({}, num_stages=2, num_warps=8), 21 | triton.Config({}, num_stages=4, num_warps=8), 22 | triton.Config({}, num_stages=8, num_warps=8), 23 | triton.Config({}, num_stages=1), 24 | triton.Config({}, num_stages=2), 25 | triton.Config({}, num_stages=4), 26 | triton.Config({}, num_stages=8), 27 | triton.Config({}, num_warps=1), 28 | triton.Config({}, num_warps=2), 29 | triton.Config({}, num_warps=4), 30 | triton.Config({}, num_warps=8), 31 | ], 32 | key=['n_elements'] 33 | ) 34 | @triton.jit 35 | def _quantize_rowwise( 36 | x_ptr, 37 | output_ptr, 38 | output_maxs, 39 | n_elements, 40 | BLOCK_SIZE: tl.constexpr, 41 | P2: tl.constexpr, 42 | ): 43 | pid = tl.program_id(axis=0) 44 | block_start = pid * BLOCK_SIZE 45 | arange = tl.arange(0, P2) 46 | offsets = block_start + arange 47 | row_mask = arange < BLOCK_SIZE 48 | x = tl.load(x_ptr + offsets, mask=row_mask) 49 | 50 | abs_x = tl.abs(x) 51 | max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) 52 | output = tl.libdevice.llrint(127. * (x / max_val)) 53 | tl.store(output_ptr + offsets, output, mask=row_mask) 54 | tl.store(output_maxs + pid, max_val) 55 | 56 | def quantize_rowwise(x: torch.Tensor): 57 | output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) 58 | output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) 59 | 60 | P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) 61 | 62 | assert x.is_cuda and output.is_cuda 63 | n_elements = output.numel() 64 | grid = lambda meta: (x.shape[0],) 65 | _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) 66 | return output, output_maxs 67 | -------------------------------------------------------------------------------- /bitsandbytes/triton/triton_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def is_triton_available(): 5 | return importlib.util.find_spec("triton") is not None 6 | -------------------------------------------------------------------------------- /bitsandbytes/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shlex 3 | import subprocess 4 | from typing import Tuple 5 | 6 | import torch 7 | 8 | 9 | def outlier_hook(module, input): 10 | assert isinstance(module, torch.nn.Linear) 11 | tracer = OutlierTracer.get_instance() 12 | hvalue = tracer.get_hvalue(module.weight) 13 | if hvalue not in tracer.hvalue2outlier_idx: 14 | outlier_idx = find_outlier_dims(module.weight) 15 | tracer.outliers.append(outlier_idx) 16 | tracer.hvalues.append(hvalue) 17 | if len(tracer.outliers) > 1: 18 | # assign the current layer the outlier idx found from the weight 19 | # of the previous linear layer 20 | if tracer.outliers[-1].numel() > 0: 21 | assert tracer.outliers[-1].max() < module.weight.shape[1] 22 | tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1] 23 | 24 | else: 25 | # first layer, we cannot use the weight for outlier detection 26 | # we follow a mixed approach: 27 | # (1) zscore test of std of hidden dimension 28 | # (2) magnitude > 6 test 29 | merged = input[0].view(-1, input[0].shape[-1]) 30 | # (1) zscore test of std of hidden dimension 31 | outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) 32 | # (2) magnitude > 6 test 33 | dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1))) 34 | outlier_idx2 = torch.where(dims > 0)[0] 35 | outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() 36 | tracer.hvalue2outlier_idx[hvalue] = outlier_idx 37 | else: 38 | for hook in tracer.hooks: 39 | hook.remove() 40 | 41 | 42 | class OutlierTracer: 43 | _instance = None 44 | 45 | def __init__(self): 46 | raise RuntimeError("Call get_instance() instead") 47 | 48 | def initialize(self, model): 49 | self.last_w = None 50 | self.current_outlier_dims = None 51 | self.hvalues = [] 52 | self.outliers = [] 53 | self.hvalue2outlier_idx = {} 54 | self.initialized = True 55 | self.hooks = [] 56 | 57 | for n, m in model.named_modules(): 58 | if isinstance(m, torch.nn.Linear): 59 | self.hooks.append(m.register_forward_pre_hook(outlier_hook)) 60 | 61 | def is_initialized(self): 62 | return getattr(self, 'initialized', False) 63 | 64 | def get_hvalue(self, weight): 65 | return weight.data.storage().data_ptr() 66 | 67 | def get_outliers(self, weight): 68 | if not self.is_initialized(): 69 | print('Outlier tracer is not initialized...') 70 | return None 71 | hvalue = self.get_hvalue(weight) 72 | if hvalue in self.hvalue2outlier_idx: 73 | return self.hvalue2outlier_idx[hvalue] 74 | else: 75 | return None 76 | 77 | @classmethod 78 | def get_instance(cls): 79 | if cls._instance is None: 80 | cls._instance = cls.__new__(cls) 81 | return cls._instance 82 | 83 | def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): 84 | if rdm: 85 | return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() 86 | 87 | m = weight.mean(reduction_dim) 88 | mm = m.mean() 89 | mstd = m.std() 90 | zm = (m-mm)/mstd 91 | 92 | std = weight.std(reduction_dim) 93 | stdm = std.mean() 94 | stdstd = std.std() 95 | 96 | zstd = (std-stdm)/stdstd 97 | 98 | if topk is not None: 99 | val, idx = torch.topk(std.abs(), k=topk, dim=0) 100 | else: 101 | idx = torch.where(zstd > zscore)[0] 102 | 103 | return idx 104 | 105 | 106 | def execute_and_return(command_string: str) -> Tuple[str, str]: 107 | def _decode(subprocess_err_out_tuple): 108 | return tuple( 109 | to_decode.decode("UTF-8").strip() 110 | for to_decode in subprocess_err_out_tuple 111 | ) 112 | 113 | def execute_and_return_decoded_std_streams(command_string): 114 | return _decode( 115 | subprocess.Popen( 116 | shlex.split(command_string), 117 | stdout=subprocess.PIPE, 118 | stderr=subprocess.PIPE, 119 | ).communicate() 120 | ) 121 | 122 | std_out, std_err = execute_and_return_decoded_std_streams(command_string) 123 | return std_out, std_err 124 | 125 | 126 | 127 | def replace_linear( 128 | model, 129 | linear_replacement, 130 | skip_modules=("lm_head",), 131 | copy_weights=False, 132 | post_processing_function=None, 133 | ): 134 | """ 135 | Replace linear modules with a new Linear module. 136 | Parameters: 137 | model (`torch.nn.Module`): 138 | Input model or `torch.nn.Module` as the function is run recursively. 139 | linear_replacement (`torch.nn.Module`): 140 | The linear module that replaces the old one. Only expects standard arguments. 141 | If other arguments need to be passed, use a lambda. 142 | skip_modules (`List[str]`, *optional*, defaults to `lm_head`): 143 | List of modules names not to convert. Defaults to `lm_head`. 144 | copy_weights (`bool`): 145 | Copy the weights from the old linear module to the new one 146 | post_processing_fun_name (`str`): 147 | A function name of the replacement linear class that is called 148 | after processing. 149 | """ 150 | for name, module in model.named_children(): 151 | if len(list(module.children())) > 0: 152 | replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) 153 | 154 | if isinstance(module, torch.nn.Linear) and name not in skip_modules: 155 | old_module = model._modules[name] 156 | model._modules[name] = linear_replacement( 157 | module.in_features, 158 | module.out_features, 159 | module.bias is not None, 160 | ) 161 | if copy_weights: 162 | model._modules[name].weight = old_module.weight 163 | model._modules[name].bias = old_module.bias 164 | 165 | if post_processing_function is not None: 166 | func = getattr(module, post_processing_function, None) 167 | if func is not None: func(module) 168 | return model 169 | 170 | 171 | def pack_dict_to_tensor(source_dict): 172 | """ 173 | Pack a dictionary into a torch tensor for storing quant_state items in state_dict. 174 | 175 | Parameters: 176 | - source_dict: The dictionary to be packed. 177 | 178 | Returns: 179 | A torch tensor containing the packed data. 180 | """ 181 | json_str = json.dumps(source_dict) 182 | json_bytes = json_str.encode('utf-8') 183 | tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8) 184 | 185 | return tensor_data 186 | 187 | 188 | def unpack_tensor_to_dict(tensor_data): 189 | """ 190 | Unpack a torch tensor into a Python dictionary. 191 | 192 | Parameters: 193 | - tensor_data: The torch tensor containing the packed data. 194 | 195 | Returns: 196 | A Python dictionary containing the unpacked data. 197 | """ 198 | json_bytes = bytes(tensor_data.cpu().numpy()) 199 | json_str = json_bytes.decode('utf-8') 200 | unpacked_dict = json.loads(json_str) 201 | 202 | return unpacked_dict 203 | -------------------------------------------------------------------------------- /check_bnb_install.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import bitsandbytes as bnb 4 | 5 | p = torch.nn.Parameter(torch.rand(10,10).cuda()) 6 | a = torch.rand(10,10).cuda() 7 | 8 | p1 = p.data.sum().item() 9 | 10 | adam = bnb.optim.Adam([p]) 11 | 12 | out = a*p 13 | loss = out.sum() 14 | loss.backward() 15 | adam.step() 16 | 17 | p2 = p.data.sum().item() 18 | 19 | assert p1 != p2 20 | print('SUCCESS!') 21 | print('Installation was successful!') 22 | -------------------------------------------------------------------------------- /csrc/common.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | void quantize_block(const quantize_block_args& args) { 5 | // 1. find absmax in block 6 | // 2. divide input value by absmax to normalize into [-1.0, 1.0] 7 | // 3. do binary search to find the closest value 8 | // 4. check minimal distance 9 | // 5. store index 10 | 11 | // 1. find absmax in block 12 | float absmax_block = -FLT_MAX; 13 | for (long long i = args.block_idx; i < args.block_end; i++) 14 | absmax_block = fmax(absmax_block, fabs(args.A[i])); 15 | 16 | args.absmax[args.block_idx / args.blocksize] = absmax_block; 17 | 18 | for (long long i = args.block_idx; i < args.block_end; i++) { 19 | // 2. divide input value by absmax to normalize into [-1.0, 1.0] 20 | // 3. do binary search to find the closest value 21 | float normed_value = args.A[i] / absmax_block; 22 | long long idx = args.bin_searcher->scalar(normed_value); 23 | 24 | // 4. check minimal distance 25 | // The binary search returns always the value to the left, which might not be the closest value 26 | if (idx < 255) { 27 | float dist_left = fabs(normed_value - (args.code[idx])); 28 | float dist_right = fabs(normed_value - (args.code[idx + 1])); 29 | if (dist_right < dist_left) { idx += 1; } 30 | } 31 | 32 | // 5. store index 33 | args.out[i] = (unsigned char) idx; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /csrc/common.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifndef common 4 | #define common 5 | 6 | using namespace BinSearch; 7 | 8 | #define BLOCK_SIZE 16384 9 | 10 | struct quantize_block_args { 11 | BinAlgo *bin_searcher; 12 | float *code; 13 | float *A; 14 | float *absmax; 15 | unsigned char *out; 16 | long long block_end; 17 | long long block_idx; 18 | long long threadidx; 19 | long long blocksize; 20 | }; 21 | 22 | 23 | void quantize_block(const quantize_block_args& args); 24 | 25 | #endif 26 | -------------------------------------------------------------------------------- /csrc/cpu_ops.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | using namespace BinSearch; 6 | 7 | void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) { 8 | for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { 9 | long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; 10 | long long block_end = block_idx + valid_items; 11 | for (long long i = block_idx; i < block_end; i++) 12 | out[i] = code[A[i]] * absmax[block_idx / blocksize]; 13 | } 14 | } 15 | 16 | void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n) 17 | { 18 | 19 | // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below 20 | code[0] = -1.0f; 21 | 22 | long long num_blocks = n / blocksize; 23 | num_blocks += n % blocksize == 0 ? 0 : 1; 24 | 25 | const uint32 elements_code = 256; 26 | BinAlgo bin_searcher(code, elements_code); 27 | 28 | int thread_wave_size = 256; 29 | // we chunk the threads into waves of 256 since the max limit is 30 | // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) 31 | for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) 32 | { 33 | long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; 34 | std::vector threads(valid_chunks); 35 | std::vector args(valid_chunks); 36 | 37 | int chunks_processed = 0; 38 | for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) 39 | { 40 | long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; 41 | long long block_end = block_idx + valid_items; 42 | 43 | struct quantize_block_args& arg = args[chunks_processed]; 44 | arg.bin_searcher = &bin_searcher; 45 | arg.code = code; 46 | arg.A = A; 47 | arg.absmax = absmax; 48 | arg.out = out; 49 | arg.block_end = block_end; 50 | arg.block_idx = block_idx; 51 | arg.threadidx = block_idx / blocksize; 52 | arg.blocksize = blocksize; 53 | 54 | threads[chunks_processed] = std::thread([arg] { quantize_block(arg); }); 55 | chunks_processed += 1; 56 | if(chunks_processed == valid_chunks){ break; } 57 | } 58 | 59 | for (int i = 0; i < valid_chunks; i++) 60 | threads[i].join(); 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /csrc/cpu_ops.h: -------------------------------------------------------------------------------- 1 | #ifndef BITSANDBYTES_CPU_OPS_H 2 | #define BITSANDBYTES_CPU_OPS_H 3 | 4 | #include 5 | #include 6 | 7 | void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n); 8 | void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n); 9 | 10 | #endif 11 | -------------------------------------------------------------------------------- /csrc/mps_kernels.metal: -------------------------------------------------------------------------------- 1 | #include 2 | using namespace metal; 3 | 4 | #define HLF_MAX 65504 5 | #define TH 1024 6 | #define NUM 4 7 | #define NUM_BLOCK 4096 8 | 9 | template 10 | static unsigned char quantize_scalar( 11 | float rand, 12 | device float* code, 13 | float x) 14 | { 15 | int pivot = 127; 16 | int upper_pivot = 255; 17 | int lower_pivot = 0; 18 | 19 | float lower = -1.0f; 20 | float upper = 1.0f; 21 | 22 | float val = code[pivot]; 23 | // i>>=1 = {32, 16, 8, 4, 2, 1} 24 | for(int i = 64; i > 0; i>>=1) 25 | { 26 | if(x > val) 27 | { 28 | lower_pivot = pivot; 29 | lower = val; 30 | pivot+=i; 31 | } 32 | else 33 | { 34 | upper_pivot = pivot; 35 | upper = val; 36 | pivot-=i; 37 | } 38 | val = code[pivot]; 39 | } 40 | 41 | if(upper_pivot == 255) 42 | upper = code[upper_pivot]; 43 | if(lower_pivot == 0) 44 | lower = code[lower_pivot]; 45 | 46 | if(!STOCHASTIC) 47 | { 48 | if(x > val) 49 | { 50 | float midpoint = (upper+val)*0.5f; 51 | if(x > midpoint) 52 | { 53 | return upper_pivot; 54 | } 55 | else 56 | return pivot; 57 | } 58 | else 59 | { 60 | float midpoint = (lower+val)*0.5f; 61 | if(x < midpoint) 62 | return lower_pivot; 63 | else 64 | return pivot; 65 | } 66 | } 67 | else 68 | { 69 | if(x > val) 70 | { 71 | float dist_to_upper = fabs(upper-x); 72 | float dist_full = upper-val; 73 | if(rand >= dist_to_upper/dist_full) return upper_pivot; 74 | else return pivot; 75 | } 76 | else 77 | { 78 | float dist_to_lower = fabs(lower-x); 79 | float dist_full = val-lower; 80 | if(rand >= dist_to_lower/dist_full) return lower_pivot; 81 | else return pivot; 82 | } 83 | } 84 | } 85 | 86 | kernel void quantize(device float* code [[buffer(0)]], 87 | device float* A [[buffer(1)]], 88 | device uchar* out [[buffer(2)]], 89 | constant uint& n [[buffer(3)]], 90 | uint id [[thread_position_in_grid]]) { 91 | const uint n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); 92 | uint valid_items = (id / NUM_BLOCK + 1 == (n + NUM_BLOCK - 1) / NUM_BLOCK) ? n - (id / NUM_BLOCK * NUM_BLOCK) : NUM_BLOCK; 93 | const uint base_idx = (id / NUM_BLOCK * NUM_BLOCK); 94 | 95 | float vals[NUM]; 96 | uchar qvals[NUM]; 97 | 98 | for (uint i = base_idx; i < n_full; i += ((n + NUM_BLOCK - 1) / NUM_BLOCK) * NUM_BLOCK) { 99 | valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; 100 | 101 | threadgroup_barrier(mem_flags::mem_threadgroup); 102 | 103 | for (uint j = 0; j < valid_items; j++) { 104 | vals[j] = A[i + j]; 105 | } 106 | 107 | for (uint j = 0; j < valid_items; j++) { 108 | qvals[j] = quantize_scalar(0.0f, code, vals[j]); 109 | } 110 | 111 | threadgroup_barrier(mem_flags::mem_threadgroup); 112 | 113 | for (uint j = 0; j < valid_items; j++) { 114 | out[i + j] = qvals[j]; 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /csrc/mps_ops.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rickardp/bitsandbytes/927f7167e3395ec26f859f294c1d4979a70a718a/csrc/mps_ops.h -------------------------------------------------------------------------------- /csrc/mps_ops.mm: -------------------------------------------------------------------------------- 1 | #import 2 | 3 | #define HLF_MAX 65504 4 | #define TH 1024 5 | #define NUM 4 6 | #define NUM_BLOCK 4096 7 | 8 | static inline MPSGraph* get_graph() 9 | { 10 | static MPSGraph* cur = nil; 11 | if(!cur) { 12 | cur = [[MPSGraph alloc] init]; 13 | } 14 | return cur; 15 | } 16 | 17 | static inline id get_device() 18 | { 19 | NSError *error = nil; 20 | static id device = nil; 21 | if(!device) { 22 | device = MTLCreateSystemDefaultDevice(); 23 | } 24 | if(!device) { 25 | NSLog(@"Failed to get MPS device"); 26 | abort(); 27 | } 28 | return device; 29 | } 30 | 31 | static inline id get_library() 32 | { 33 | NSError *error = nil; 34 | static id library = nil; 35 | if(!library) { 36 | library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; 37 | } 38 | if(!library) { 39 | NSLog(@"Failed to load bitsandbytes.metallib"); 40 | abort(); 41 | } 42 | return library; 43 | } 44 | 45 | /*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n) 46 | { 47 | id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 dataType:MPSDataTypeInt8 axis:0 name:@"out"]; 48 | return out; 49 | }*/ 50 | 51 | 52 | // MPSGraph function for quantize 53 | extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) 54 | { 55 | id device = get_device(); 56 | id library = get_library(); 57 | static id kernel = nil; 58 | if(!kernel) { 59 | kernel = [library newFunctionWithName:@"quantize"]; 60 | if(!kernel) { 61 | NSLog(@"Failed to load bitsandbytes.metallib"); 62 | abort(); 63 | } 64 | } 65 | NSLog(@"Not implemented"); 66 | return nil; 67 | } 68 | -------------------------------------------------------------------------------- /csrc/ops.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #ifndef ops_H 8 | #define ops_H 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include 23 | #include 24 | 25 | 26 | 27 | #define CUDA_CHECK_RETURN(value) { \ 28 | cudaError_t _m_cudaStat = value; \ 29 | if (_m_cudaStat != cudaSuccess) { \ 30 | fprintf(stderr, "Error %s at line %d in file %s\n", \ 31 | cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ 32 | exit(1); \ 33 | } } 34 | 35 | #define THREADS_PER_BLOCKS (512) 36 | 37 | #define CHECK_CUSPARSE(value) { \ 38 | cusparseStatus_t _m_cudaStat = value; \ 39 | if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \ 40 | fprintf(stderr, "Error %s at line %d in file %s\n", \ 41 | cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ 42 | exit(1); \ 43 | } } 44 | 45 | 46 | #define THREADS_PER_BLOCKS (512) 47 | 48 | 49 | inline void checkCudaStatus(cudaError_t status) { 50 | if (status != cudaSuccess) { 51 | printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status)); 52 | throw std::logic_error("cuda API failed"); 53 | } 54 | } 55 | 56 | inline int checkCublasStatus(cublasStatus_t status) { 57 | if (status != CUBLAS_STATUS_SUCCESS) { 58 | printf("cuBLAS API failed with status %d\n", status); 59 | //throw std::logic_error("cuBLAS API failed"); 60 | return 1; 61 | } 62 | return 0; 63 | } 64 | 65 | typedef enum Operations_t 66 | { 67 | ksmul = 0, 68 | } Operations_t; 69 | 70 | typedef enum Optimizer_t 71 | { 72 | ADAM = 0, 73 | MOMENTUM = 1, 74 | RMSPROP = 2, 75 | LARS = 3, 76 | ADAGRAD = 4, 77 | LION = 5, 78 | } Optimizer_t; 79 | 80 | typedef enum Transform_t 81 | { 82 | ROW = 0, 83 | COL = 1, 84 | COL32 = 2, 85 | COL_TURING = 3, 86 | COL_AMPERE = 4, 87 | } Transform_t; 88 | 89 | typedef enum DataType_t 90 | { 91 | General8bit = 0, 92 | FP4 = 1, 93 | NF4 = 2, 94 | } DataType_t; 95 | 96 | typedef enum Funcs_t 97 | { 98 | FILL = 0, 99 | ARANGE = 1, 100 | _MUL = 2, 101 | } Funcs_t; 102 | 103 | class Context 104 | { 105 | public: 106 | cublasHandle_t m_handle; 107 | 108 | Context() 109 | { 110 | cublasHandle_t handle; 111 | cublasCreate_v2(&handle); 112 | m_handle = handle; 113 | } 114 | 115 | }; 116 | 117 | class ContextLt 118 | { 119 | public: 120 | cublasLtHandle_t m_handle; 121 | 122 | ContextLt() 123 | { 124 | cublasLtHandle_t handle; 125 | cublasLtCreate(&handle); 126 | m_handle = handle; 127 | } 128 | 129 | }; 130 | 131 | class ContextCusparse 132 | { 133 | public: 134 | cusparseHandle_t m_handle; 135 | 136 | ContextCusparse() 137 | { 138 | cusparseHandle_t handle; 139 | cusparseCreate(&handle); 140 | m_handle = handle; 141 | } 142 | 143 | }; 144 | 145 | 146 | template void estimateQuantiles(T *A, float *code, float offset, int n); 147 | 148 | void quantize(float *code, float *A, unsigned char *out, int n); 149 | void dequantize(float *code, unsigned char *A, float *out, int n); 150 | template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); 151 | template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); 152 | 153 | template void optimizer32bit(T* g, T* p, 154 | float* state1, float* state2, float *unorm, float max_unorm, float param_norm, 155 | float beta1, float beta2, float eps, float weight_decay, 156 | int step, float lr, const float gnorm_scale, bool skip_zeros, int n); 157 | 158 | template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, 159 | float *unorm, float max_unorm, float param_norm, 160 | float beta1, float beta2, 161 | float eps, int step, float lr, 162 | float* quantiles1, float* quantiles2, 163 | float* max1, float* max2, float* new_max1, float* new_max2, 164 | float weight_decay, 165 | const float gnorm_scale, int n); 166 | 167 | template void optimizerStatic8bitBlockwise(T* p, T* g, 168 | unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, 169 | float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, 170 | bool skip_zeros, int n); 171 | 172 | template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); 173 | 174 | void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); 175 | 176 | void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); 177 | void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, 178 | long long int strideA, long long int strideB, long long int strideC, int batchCount); 179 | 180 | 181 | template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); 182 | 183 | template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); 184 | void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); 185 | void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols); 186 | void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); 187 | void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, 188 | int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); 189 | 190 | template void transformRowToFormat(char * A, char *out, int rows, int cols); 191 | 192 | void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); 193 | 194 | template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); 195 | 196 | template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); 197 | 198 | void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); 199 | 200 | template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); 201 | template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); 202 | template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); 203 | 204 | template void func(T *A, T *B, T value, long n); 205 | 206 | #endif 207 | -------------------------------------------------------------------------------- /deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BASE_PATH=$1 3 | 4 | echo "MAKE SURE LD_LIBRARY_PATH IS EMPTY!" 5 | echo $LD_LIBRARY_PATH 6 | 7 | if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then 8 | echo "Compilation unsuccessful!" 1>&2 9 | exit 64 10 | fi 11 | 12 | 13 | module unload cuda && echo "no module function available. Probably not on a slurm cluster." 14 | module unload gcc && echo "no module function available. Probably not on a slurm cluster." 15 | 16 | rm -rf dist build 17 | make cleaneggs 18 | make cleanlibs 19 | 20 | rm -rf build/* 21 | export CUDA_HOME= 22 | export CUDA_VERSION= 23 | make cpuonly CUDA_VERSION="CPU" 24 | 25 | if [ ! -f "./bitsandbytes/libbitsandbytes_cpu.so" ]; then 26 | # Control will enter here if $DIRECTORY doesn't exist. 27 | echo "Compilation unsuccessful!" 1>&2 28 | exit 64 29 | fi 30 | 31 | rm -rf build/* 32 | export CUDA_HOME=$BASE_PATH/cuda-11.0 33 | make cuda110 CUDA_VERSION=110 34 | 35 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110.so" ]; then 36 | # Control will enter here if $DIRECTORY doesn't exist. 37 | echo "Compilation unsuccessful!" 1>&2 38 | exit 64 39 | fi 40 | 41 | rm -rf build/* 42 | export CUDA_HOME=$BASE_PATH/cuda-11.1 43 | make cuda11x CUDA_VERSION=111 44 | 45 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111.so" ]; then 46 | # Control will enter here if $DIRECTORY doesn't exist. 47 | echo "Compilation unsuccessful!" 1>&2 48 | exit 64 49 | fi 50 | 51 | rm -rf build/* 52 | export CUDA_HOME=$BASE_PATH/cuda-11.4 53 | make cuda11x CUDA_VERSION=114 54 | 55 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114.so" ]; then 56 | # Control will enter here if $DIRECTORY doesn't exist. 57 | echo "Compilation unsuccessful!" 1>&2 58 | exit 64 59 | fi 60 | 61 | rm -rf build/* 62 | export CUDA_HOME=$BASE_PATH/cuda-11.5 63 | make cuda11x CUDA_VERSION=115 64 | 65 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115.so" ]; then 66 | # Control will enter here if $DIRECTORY doesn't exist. 67 | echo "Compilation unsuccessful!" 1>&2 68 | exit 64 69 | fi 70 | 71 | rm -rf build/* 72 | export CUDA_HOME=$BASE_PATH/cuda-11.7 73 | make cuda11x CUDA_VERSION=117 74 | 75 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117.so" ]; then 76 | # Control will enter here if $DIRECTORY doesn't exist. 77 | echo "Compilation unsuccessful!" 1>&2 78 | exit 64 79 | fi 80 | 81 | rm -rf build/* 82 | export CUDA_HOME=$BASE_PATH/cuda-11.8 83 | make cuda118 CUDA_VERSION=118 84 | 85 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda118.so" ]; then 86 | # Control will enter here if $DIRECTORY doesn't exist. 87 | echo "Compilation unsuccessful!" 1>&2 88 | exit 64 89 | fi 90 | 91 | rm -rf build/* 92 | export CUDA_HOME=$BASE_PATH/cuda-12.0 93 | make cuda12x CUDA_VERSION=120 94 | 95 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120.so" ]; then 96 | # Control will enter here if $DIRECTORY doesn't exist. 97 | echo "Compilation unsuccessful!" 1>&2 98 | exit 64 99 | fi 100 | 101 | rm -rf build/* 102 | export CUDA_HOME=$BASE_PATH/cuda-12.1 103 | make cuda12x CUDA_VERSION=121 104 | 105 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then 106 | # Control will enter here if $DIRECTORY doesn't exist. 107 | echo "Compilation unsuccessful!" 1>&2 108 | exit 64 109 | fi 110 | 111 | rm -rf build/* 112 | export CUDA_HOME=$BASE_PATH/cuda-12.2 113 | make cuda12x CUDA_VERSION=122 114 | 115 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda122.so" ]; then 116 | # Control will enter here if $DIRECTORY doesn't exist. 117 | echo "Compilation unsuccessful!" 1>&2 118 | exit 64 119 | fi 120 | 121 | rm -rf build/* 122 | export CUDA_HOME=$BASE_PATH/cuda-12.3 123 | make cuda12x CUDA_VERSION=123 124 | 125 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda123.so" ]; then 126 | # Control will enter here if $DIRECTORY doesn't exist. 127 | echo "Compilation unsuccessful!" 1>&2 128 | exit 64 129 | fi 130 | 131 | ############################# START NO CUBLASLT ############################################# 132 | # binaries without 8-bit matmul support START HERE 133 | # ########################################################################################### 134 | 135 | rm -rf build/* 136 | export CUDA_HOME=$BASE_PATH/cuda-11.0 137 | make cuda110_nomatmul CUDA_VERSION=110 138 | 139 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110_nocublaslt.so" ]; then 140 | # Control will enter here if $DIRECTORY doesn't exist. 141 | echo "Compilation unsuccessful!" 1>&2 142 | exit 64 143 | fi 144 | 145 | 146 | rm -rf build/* 147 | export CUDA_HOME=$BASE_PATH/cuda-11.1 148 | make cuda11x_nomatmul CUDA_VERSION=111 149 | 150 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111_nocublaslt.so" ]; then 151 | # Control will enter here if $DIRECTORY doesn't exist. 152 | echo "Compilation unsuccessful!" 1>&2 153 | exit 64 154 | fi 155 | 156 | rm -rf build/* 157 | export CUDA_HOME=$BASE_PATH/cuda-11.4 158 | make cuda11x_nomatmul CUDA_VERSION=114 159 | 160 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114_nocublaslt.so" ]; then 161 | # Control will enter here if $DIRECTORY doesn't exist. 162 | echo "Compilation unsuccessful!" 1>&2 163 | exit 64 164 | fi 165 | 166 | rm -rf build/* 167 | export CUDA_HOME=$BASE_PATH/cuda-11.5 168 | make cuda11x_nomatmul CUDA_VERSION=115 169 | 170 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115_nocublaslt.so" ]; then 171 | # Control will enter here if $DIRECTORY doesn't exist. 172 | echo "Compilation unsuccessful!" 1>&2 173 | exit 64 174 | fi 175 | 176 | rm -rf build/* 177 | export CUDA_HOME=$BASE_PATH/cuda-11.7 178 | make cuda11x_nomatmul CUDA_VERSION=117 179 | 180 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117_nocublaslt.so" ]; then 181 | # Control will enter here if $DIRECTORY doesn't exist. 182 | echo "Compilation unsuccessful!" 1>&2 183 | exit 64 184 | fi 185 | 186 | rm -rf build/* 187 | export CUDA_HOME=$BASE_PATH/cuda-11.8 188 | make cuda118_nomatmul CUDA_VERSION=118 189 | 190 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda118_nocublaslt.so" ]; then 191 | # Control will enter here if $DIRECTORY doesn't exist. 192 | echo "Compilation unsuccessful!" 1>&2 193 | exit 64 194 | fi 195 | 196 | rm -rf build/* 197 | export CUDA_HOME=$BASE_PATH/cuda-12.0 198 | make cuda12x_nomatmul CUDA_VERSION=120 199 | 200 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so" ]; then 201 | # Control will enter here if $DIRECTORY doesn't exist. 202 | echo "Compilation unsuccessful!" 1>&2 203 | exit 64 204 | fi 205 | 206 | rm -rf build/* 207 | export CUDA_HOME=$BASE_PATH/cuda-12.1 208 | make cuda12x_nomatmul CUDA_VERSION=121 209 | 210 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121_nocublaslt.so" ]; then 211 | # Control will enter here if $DIRECTORY doesn't exist. 212 | echo "Compilation unsuccessful!" 1>&2 213 | exit 64 214 | fi 215 | 216 | rm -rf build/* 217 | export CUDA_HOME=$BASE_PATH/cuda-12.2 218 | make cuda12x_nomatmul CUDA_VERSION=122 219 | 220 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda122_nocublaslt.so" ]; then 221 | # Control will enter here if $DIRECTORY doesn't exist. 222 | echo "Compilation unsuccessful!" 1>&2 223 | exit 64 224 | fi 225 | 226 | rm -rf build/* 227 | export CUDA_HOME=$BASE_PATH/cuda-12.3 228 | make cuda12x_nomatmul CUDA_VERSION=123 229 | 230 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda123_nocublaslt.so" ]; then 231 | # Control will enter here if $DIRECTORY doesn't exist. 232 | echo "Compilation unsuccessful!" 1>&2 233 | exit 64 234 | fi 235 | 236 | python -m build 237 | python -m twine upload dist/* --verbose 238 | -------------------------------------------------------------------------------- /docs/source/_toctree.yml: -------------------------------------------------------------------------------- 1 | - title: Get started 2 | sections: 3 | - local: index 4 | title: Index 5 | - local: quickstart 6 | title: Quickstart 7 | - local: installation 8 | title: Installation 9 | - title: Features & Integrations 10 | sections: 11 | - local: quantization 12 | title: Quantization 13 | - local: optimizers 14 | title: Optimizers 15 | - local: integrations 16 | title: Integrations 17 | - local: algorithms 18 | title: Algorithms 19 | - title: Support & Learning 20 | sections: 21 | - local: resources 22 | title: Papers, resources & how to cite 23 | - local: errors 24 | title: Errors & Solutions 25 | - local: nonpytorchcuda 26 | title: Non-PyTorch CUDA 27 | - local: compiling 28 | title: Compilation from Source (extended) 29 | - local: faqs 30 | title: FAQs (Frequently Asked Questions) 31 | - title: Contributors Guidelines 32 | sections: 33 | - local: contributing 34 | title: Contributing 35 | -------------------------------------------------------------------------------- /docs/source/algorithms.mdx: -------------------------------------------------------------------------------- 1 | # Other algorithms 2 | _WIP: Still incomplete... Community contributions would be greatly welcome!_ 3 | 4 | This is an overview of the `bnb.functional` API in `bitsandbytes` that we think would also be useful as standalone entities. 5 | 6 | ## Using Int8 Matrix Multiplication 7 | 8 | For straight Int8 matrix multiplication with mixed precision decomposition you can use ``bnb.matmul(...)``. To enable mixed precision decomposition, use the threshold parameter: 9 | 10 | ```py 11 | bnb.matmul(..., threshold=6.0) 12 | ``` 13 | -------------------------------------------------------------------------------- /docs/source/compiling.mdx: -------------------------------------------------------------------------------- 1 | # Compiling from Source[[compiling]] 2 | 3 | To compile from source, the CUDA Toolkit is required. Ensure `nvcc` is installed; if not, follow these steps to install it along with the CUDA Toolkit: 4 | 5 | ```bash 6 | wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh 7 | # Use the following syntax: cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH 8 | # CUDA_VERSION options include 110 to 122 9 | # EXPORT_TO_BASH: 0 for False, 1 for True 10 | 11 | # Example for installing CUDA 11.7 at ~/local/cuda-11.7 and exporting the path to .bashrc: 12 | bash install_cuda.sh 117 ~/local 1 13 | ``` 14 | 15 | For a single compile run with a specific CUDA version, set `CUDA_HOME` to point to your CUDA installation directory. For instance, to compile using CUDA 11.7 located at `~/local/cuda-11.7`, use: 16 | 17 | ``` 18 | CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x 19 | ``` 20 | 21 | ## General Compilation Steps 22 | 23 | 1. Use `CUDA_VERSION=XXX make [target]` to compile, where `[target]` includes options like `cuda92`, `cuda10x`, `cuda11x`, and others. 24 | 2. Install with `python setup.py install`. 25 | 26 | Ensure `nvcc` is available in your system. If using Anaconda, determine your CUDA version with PyTorch using `conda list | grep cudatoolkit` and match it by downloading the corresponding version from the [CUDA Toolkit Archive](https://developer.nvidia.com/cuda-toolkit-archive). 27 | 28 | To install CUDA locally without administrative rights: 29 | 30 | ```bash 31 | wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh 32 | # Follow the same syntax and example as mentioned earlier 33 | ``` 34 | 35 | The compilation process relies on the `CUDA_HOME` environment variable to locate CUDA. If `CUDA_HOME` is unset, it will attempt to infer the location from `nvcc`. If `nvcc` is not in your path, you may need to add it or set `CUDA_HOME` manually. For example, if `python -m bitsandbytes` indicates your CUDA path as `/usr/local/cuda-11.7`, you can set `CUDA_HOME` to this path. 36 | 37 | If compilation issues arise, please report them. 38 | 39 | ## Compilation for Kepler Architecture 40 | 41 | From version 0.39.1, bitsandbytes no longer includes Kepler binaries in pip installations, requiring manual compilation. Follow the general steps and use `cuda11x_nomatmul_kepler` for Kepler-targeted compilation. 42 | -------------------------------------------------------------------------------- /docs/source/contributing.mdx: -------------------------------------------------------------------------------- 1 | # Contributors guidelines 2 | ... still under construction ... (feel free to propose materials, `bitsandbytes` is a community project) 3 | 4 | ## Setup pre-commit hooks 5 | - Install pre-commit hooks with `pip install pre-commit`. 6 | - Run `pre-commit autoupdate` once to configure the hooks. 7 | - Re-run `pre-commit autoupdate` every time a new hook got added. 8 | 9 | Now all the pre-commit hooks will be automatically run when you try to commit and if they introduce some changes, you need to re-add the changed files before being able to commit and push. 10 | 11 | ## Doc-string syntax 12 | 13 | We're following NumPy doc-string conventions with the only notable difference being that we use Markdown instead of Rich text format (RTF) for markup within the doc-strings. 14 | 15 | Please see the existing documentation to see how to generate autodocs. 16 | 17 | ## Documentation 18 | - [guideline for documentation syntax](https://github.com/huggingface/doc-builder#readme) 19 | - images shall be uploaded via PR in the `bitsandbytes/` directory [here](https://huggingface.co/datasets/huggingface/documentation-images) 20 | - find the documentation builds for each PR in a link posted to the PR, such as https://moon-ci-docs.huggingface.co/docs/bitsandbytes/pr_1012/en/introduction 21 | -------------------------------------------------------------------------------- /docs/source/errors.mdx: -------------------------------------------------------------------------------- 1 | # Errors & Solutions 2 | 3 | ## No kernel image available 4 | 5 | This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. 6 | 7 | To solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME`` as well as ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``? 8 | 9 | If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation. 10 | 11 | ## `fatbinwrap` 12 | 13 | This error occurs if there is a mismatch between CUDA versions in the C++ library and the CUDA part. Make sure you have right CUDA in your `$PATH` and `$LD_LIBRARY_PATH` variable. In the conda base environment you can find the library under: 14 | 15 | ```bash 16 | ls $CONDA_PREFIX/lib/*cudart* 17 | ``` 18 | Make sure this path is appended to the `LD_LIBRARY_PATH` so bnb can find the CUDA runtime environment library (cudart). 19 | 20 | If this does not fix the issue, please try compilation from source next. 21 | 22 | If this does not work, please open an issue and paste the printed environment if you call `make` and the associated error when running bnb. 23 | -------------------------------------------------------------------------------- /docs/source/faqs.mdx: -------------------------------------------------------------------------------- 1 | # FAQs 2 | 3 | Please submit your questions in [this Github Discussion thread](https://github.com/TimDettmers/bitsandbytes/discussions/1013) if you feel that they will likely affect a lot of other users and that they haven't been sufficiently covered in the documentation. 4 | 5 | We'll pick the most generally applicable ones and post the QAs here or integrate them into the general documentation (also feel free to submit doc PRs, please). 6 | 7 | # ... under construction ... 8 | -------------------------------------------------------------------------------- /docs/source/index.mdx: -------------------------------------------------------------------------------- 1 | # `bitsandbytes` 2 | 3 | The `bitsandbytes` library is a lightweight Python wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and 8 + 4-bit quantization functions. 4 | 5 | The library includes quantization primitives for 8-bit & 4-bit operations, through `bitsandbytes.nn.Linear8bitLt` and `bitsandbytes.nn.Linear4bit` and 8bit optimizers through `bitsandbytes.optim` module. 6 | 7 | There are ongoing efforts to support further hardware backends, i.e. Intel CPU + GPU, AMD GPU, Apple Silicon. Windows support is on its way as well. 8 | 9 | ## API documentation 10 | 11 | - [Linear4bit](quantizaton#linear4bit) 12 | - [Linear8bit](quantizaton#linear8bit) 13 | - [StableEmbedding](optimizers#stableembedding) 14 | 15 | # License 16 | 17 | The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms, as the parts adapted from Pytorch are licensed under the BSD license. 18 | 19 | We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization. 20 | -------------------------------------------------------------------------------- /docs/source/installation.mdx: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | Note currently `bitsandbytes` is only supported on CUDA GPU hardwares, support for AMD GPUs and M1 chips (MacOS) is coming soon. 4 | 5 | 6 | 7 | 8 | ## Hardware requirements: 9 | - LLM.int8(): NVIDIA Turing (RTX 20xx; T4) or Ampere GPU (RTX 30xx; A4-A100); (a GPU from 2018 or newer). 10 | - 8-bit optimizers and quantization: NVIDIA Kepler GPU or newer (>=GTX 78X). 11 | 12 | Supported CUDA versions: 10.2 - 12.0 #TODO: check currently supported versions 13 | 14 | ## Linux 15 | 16 | ### From Pypi 17 | 18 | ```bash 19 | pip install bitsandbytes 20 | ``` 21 | 22 | ### From source 23 | 24 | ```bash 25 | git clone https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ 26 | CUDA_VERSION=XXX make cuda12x 27 | python setup.py install 28 | ``` 29 | 30 | with `XXX` being your CUDA version, for <12.0 call `make cuda 11x`. Note support for non-CUDA GPUs (e.g. AMD, Intel), is also coming soon. 31 | 32 | For a more detailed compilation guide, head to the [dedicated page on the topic](./compiling) 33 | 34 | 35 | 36 | 37 | ## Windows 38 | 39 | Currently for Windows users, you need to build bitsandbytes from source: 40 | 41 | ```bash 42 | git clone https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ 43 | cmake -B build -DBUILD_CUDA=ON -S . 44 | cmake --build build --config Release 45 | python -m build --wheel 46 | ``` 47 | 48 | Big thanks to [wkpark](https://github.com/wkpark), [Jamezo97](https://github.com/Jamezo97), [rickardp](https://github.com/rickardp), [akx](https://github.com/akx) for their amazing contributions to make bitsandbytes compatible with Windows. 49 | 50 | For a more detailed compilation guide, head to the [dedicated page on the topic](./compiling) 51 | 52 | 53 | 54 | 55 | ## MacOS 56 | 57 | Mac support is still a work in progress. Please make sure to check out the [Apple Silicon implementation coordination issue](https://github.com/TimDettmers/bitsandbytes/issues/1020) to get notified about the discussions and progress with respect to MacOS integration. 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /docs/source/integrations.mdx: -------------------------------------------------------------------------------- 1 | # Transformers 2 | 3 | With Transformers it's very easy to load any model in 4 or 8-bit, quantizing them on the fly with bitsandbytes primitives. 4 | 5 | Please review the [bitsandbytes section in the Accelerate docs](https://huggingface.co/docs/transformers/v4.37.2/en/quantization#bitsandbytes). 6 | 7 | Details about the BitsAndBytesConfig can be found here](https://huggingface.co/docs/transformers/v4.37.2/en/main_classes/quantization#transformers.BitsAndBytesConfig). 8 | 9 | ## Beware: bf16 is optional compute data type 10 | If your hardware supports it, `bf16` is the optimal compute dtype. The default is `float32` for backward compatibility and numerical stability. `float16` often leads to numerical instabilities, but `bfloat16` provides the benefits of both worlds: numerical stability and significant computation speedup. Therefore, be sure to check if your hardware supports `bf16` and configure it using the `bnb_4bit_compute_dtype` parameter in BitsAndBytesConfig: 11 | 12 | ```py 13 | import torch 14 | from transformers import BitsAndBytesConfig 15 | 16 | quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) 17 | ``` 18 | 19 | # PEFT 20 | With `PEFT`, you can use QLoRA out of the box with `LoraConfig` and a 4-bit base model. 21 | 22 | Please review the [bitsandbytes section in the Accelerate docs](https://huggingface.co/docs/peft/developer_guides/quantization#quantize-a-model). 23 | 24 | # Accelerate 25 | 26 | Bitsandbytes is also easily usable from within Accelerate. 27 | 28 | Please review the [bitsandbytes section in the Accelerate docs](https://huggingface.co/docs/accelerate/en/usage_guides/quantization). 29 | 30 | # Trainer for the optimizers 31 | 32 | You can use any of the 8-bit and/or paged optimizers by simple passing them to the `transformers.Trainer` class on initialization.All bnb optimizers are supported by passing the correct string in `TrainingArguments`'s `optim` attribute - e.g. (`paged_adamw_32bit`). 33 | 34 | See the [official API docs for reference](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Trainer). 35 | 36 | Here we point out to relevant doc sections in transformers / peft / Trainer + very briefly explain how these are integrated: 37 | e.g. for transformers state that you can load any model in 8-bit / 4-bit precision, for PEFT, you can use QLoRA out of the box with `LoraConfig` + 4-bit base model, for Trainer: all bnb optimizers are supported by passing the correct string in `TrainingArguments`'s `optim` attribute - e.g. (`paged_adamw_32bit`): 38 | 39 | # Blog posts 40 | 41 | - [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) 42 | - [A Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and bitsandbytes](https://huggingface.co/blog/hf-bitsandbytes-integration) 43 | -------------------------------------------------------------------------------- /docs/source/nonpytorchcuda.mdx: -------------------------------------------------------------------------------- 1 | # How to use a CUDA version that is different from PyTorch 2 | 3 | Some features of `bitsandbytes` may need a newer CUDA version than regularly supported by PyTorch binaries from conda / pip. In that case you can use the following instructions to load a precompiled `bitsandbytes` binary that works for you. 4 | 5 | ## Installing or determining the CUDA installation 6 | 7 | Determine the path of the CUDA version that you want to use. Common paths paths are: 8 | ```bash 9 | /usr/local/cuda 10 | /usr/local/cuda-XX.X 11 | ``` 12 | 13 | where XX.X is the CUDA version number. 14 | 15 | You can also install CUDA version that you need locally with a script provided by `bitsandbytes` as follows: 16 | 17 | ```bash 18 | wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh 19 | # Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH 20 | # CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122} 21 | # EXPORT_TO_BASH in {0, 1} with 0=False and 1=True 22 | 23 | # For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc 24 | 25 | bash cuda_install.sh 117 ~/local 1 26 | ``` 27 | 28 | ## Setting the environmental variables `BNB_CUDA_VERSION`, and `LD_LIBRARY_PATH` 29 | 30 | To manually override the PyTorch installed CUDA version you need to set to variable, like so: 31 | 32 | ```bash 33 | export BNB_CUDA_VERSION= 34 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: 35 | ``` 36 | 37 | For example, to use the local install path from above: 38 | 39 | ```bash 40 | export BNB_CUDA_VERSION=117 41 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/tim/local/cuda-11.7 42 | ``` 43 | 44 | It is best to add these lines to the `.bashrc` file to make them permanent. 45 | 46 | If you now launch bitsandbytes with these environmental variables the PyTorch CUDA version will be overridden by the new CUDA version and a different bitsandbytes library is loaded (in this case version 117). 47 | -------------------------------------------------------------------------------- /docs/source/quantization.mdx: -------------------------------------------------------------------------------- 1 | # Quantization primitives 2 | 3 | Below you will find the docstring of the quantization primitives exposed in bitsandbytes. 4 | 5 | ## Linear4bit (QLoRA)[[linear4bit]] 6 | 7 | [[autodoc]] bitsandbytes.nn.Linear4bit 8 | - __init__ 9 | 10 | ## Linear8bitLt[[linear8bit]] 11 | 12 | [[autodoc]] bitsandbytes.nn.Linear8bitLt 13 | - __init__ 14 | -------------------------------------------------------------------------------- /docs/source/quickstart.mdx: -------------------------------------------------------------------------------- 1 | # Quickstart 2 | 3 | ## How does it work? 4 | 5 | ... work in progress ... 6 | 7 | (Community contributions would we very welcome!) 8 | 9 | ## Minimal examples 10 | 11 | The following code illustrates the steps above. 12 | 13 | ```py 14 | code examples will soon follow 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/source/resources.mdx: -------------------------------------------------------------------------------- 1 | # Papers, related resources & how to cite 2 | 3 | The below academic work is ordered in reverse chronological order. 4 | 5 | ## [SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression (Jun 2023)](https://arxiv.org/abs/2306.03078) 6 | 7 | Authors: Tim Dettmers, Ruslan Svirschevski, Vage Egiazarian, Denis Kuznedelev, Elias Frantar, Saleh Ashkboos, Alexander Borzunov, Torsten Hoefler, Dan Alistarh 8 | 9 | - [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1666076553665744896) 10 | 11 | ``` 12 | @article{dettmers2023spqr, 13 | title={SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression}, 14 | author={Dettmers, Tim and Svirschevski, Ruslan and Egiazarian, Vage and Kuznedelev, Denis and Frantar, Elias and Ashkboos, Saleh and Borzunov, Alexander and Hoefler, Torsten and Alistarh, Dan}, 15 | journal={arXiv preprint arXiv:2306.03078}, 16 | year={2023} 17 | } 18 | ``` 19 | 20 | ## [QLoRA: Efficient Finetuning of Quantized LLMs (May 2023)](https://arxiv.org/abs/2305.14314) 21 | Authors: Tim Dettmers, Artidoro Pagnoni, Ari Holtzman, Luke Zettlemoyer 22 | 23 | - [Video](https://www.youtube.com/watch?v=y9PHWGOa8HA&ab_channel=LondonMachineLearningMeetup) 24 | - [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1661379354507476994) 25 | 26 | ``` 27 | @article{dettmers2023qlora, 28 | title={Qlora: Efficient finetuning of quantized llms}, 29 | author={Dettmers, Tim and Pagnoni, Artidoro and Holtzman, Ari and Zettlemoyer, Luke}, 30 | journal={arXiv preprint arXiv:2305.14314}, 31 | year={2023} 32 | } 33 | ``` 34 | 35 | ## [The case for 4-bit precision: k-bit Inference Scaling Laws (Dec 2022)](https://arxiv.org/abs/2212.09720) 36 | Authors: Tim Dettmers, Luke Zettlemoyer 37 | 38 | - [Video](https://www.youtube.com/watch?v=odlQa6AE1gY&ab_channel=TheInsideView) 39 | - [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1605209171758284805) 40 | 41 | ``` 42 | @inproceedings{dettmers2023case, 43 | title={The case for 4-bit precision: k-bit inference scaling laws}, 44 | author={Dettmers, Tim and Zettlemoyer, Luke}, 45 | booktitle={International Conference on Machine Learning}, 46 | pages={7750--7774}, 47 | year={2023}, 48 | organization={PMLR} 49 | } 50 | ``` 51 | 52 | ## [LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale (Nov 2022)](https://arxiv.org/abs/2208.07339) 53 | Authors: Tim Dettmers, Mike Lewis, Younes Belkada, Luke Zettlemoyer 54 | 55 | - [LLM.int8() Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) 56 | - [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/) 57 | - [Introduction to Weight Quantization](https://towardsdatascience.com/introduction-to-weight-quantization-2494701b9c0c) 58 | - [Poster](https://twitter.com/Tim_Dettmers/status/1598351301942951937) 59 | 60 | ``` 61 | @article{dettmers2022llm, 62 | title={Llm. int8 (): 8-bit matrix multiplication for transformers at scale}, 63 | author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke}, 64 | journal={arXiv preprint arXiv:2208.07339}, 65 | year={2022} 66 | } 67 | ``` 68 | 69 | ## [8-bit Optimizers via Block-wise Quantization (Oct 2021)](https://arxiv.org/abs/2110.02861) 70 | Authors: Tim Dettmers, Mike Lewis, Sam Shleifer, Luke Zettlemoyer 71 | 72 | - [Video](https://www.youtube.com/watch?v=IxrlHAJtqKE) 73 | - [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1446472128979562499) 74 | 75 | ``` 76 | @article{DBLP:journals/corr/abs-2110-02861, 77 | author = {Tim Dettmers and 78 | Mike Lewis and 79 | Sam Shleifer and 80 | Luke Zettlemoyer}, 81 | title = {8-bit Optimizers via Block-wise Quantization}, 82 | journal = {CoRR}, 83 | volume = {abs/2110.02861}, 84 | year = {2021}, 85 | url = {https://arxiv.org/abs/2110.02861}, 86 | eprinttype = {arXiv}, 87 | eprint = {2110.02861}, 88 | timestamp = {Thu, 21 Oct 2021 16:20:08 +0200}, 89 | biburl = {https://dblp.org/rec/journals/corr/abs-2110-02861.bib}, 90 | bibsource = {dblp computer science bibliography, https://dblp.org} 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /environment-bnb.yml: -------------------------------------------------------------------------------- 1 | # for cmake build 2 | name: bnb 3 | channels: 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | 8 | dependencies: 9 | - python 10 | #- accelerate 11 | #- einops 12 | - scipy 13 | #- transformers 14 | - pytest 15 | - pytest-cases 16 | - ipython 17 | - debugpy 18 | - yapf 19 | - monkeytype 20 | - rich 21 | - pytest-sugar 22 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: bnb 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | 7 | dependencies: 8 | # Base 9 | - conda-forge::python=3.8 10 | - pytorch::pytorch=>2.1 11 | - pytorch::pytorch-cuda=11.8 12 | - nvidia::cuda=11.8 13 | # Libraries 14 | - conda-forge::accelerate 15 | - conda-forge::einops 16 | - conda-forge::scipy 17 | - conda-forge::transformers 18 | # Development 19 | - conda-forge::pytest 20 | - conda-forge::build # build Python packages 21 | - conda-forge::twine # upload Python packages 22 | - conda-forge::pytest-cases # more readable and composable parametrized tests 23 | - conda-forge::ipython # better interactive shell 24 | - conda-forge::debugpy # debugger-support for VSCode 25 | - conda-forge::ruff # linting 26 | - conda-forge::yapf # code formatting 27 | - conda-forge::monkeytype # infer type annotations 28 | - conda-forge::rich # better, colored tracebacks, etc 29 | - conda-forge::pytest-sugar # better pytest output 30 | # - conda-forge::nodejs # for `doc-builder preview` (optional) 31 | 32 | ## ENV CREATION - steps to reproduce: 33 | # mamba env remove -n bnb 34 | # mamba create -y -n bnb python=3.8 # creating an empty env bypasses conda 35 | # # and leads to much faster env resolution in the next step https://github.com/mamba-org/mamba/issues/633#issuecomment-812272143 36 | # mamba env update -n bnb -f environment.yml 37 | # mamba activate bnb 38 | 39 | ## PIP dependencies (install *after* ENV CREATION): 40 | # pip install --no-cache-dir --no-deps lion_pytorch triton hf-doc-builder watchdog 41 | ## NOTE: conda peft is not up to date, so we install from pip 42 | # cd pip install -e . ## installs bitsandbytes as editable development install from within repo root dir 43 | 44 | ## ENV UPDATE: 45 | # # add new packages to environment.yml, then: 46 | # mamba env update -n bnb -f environment.yml 47 | -------------------------------------------------------------------------------- /examples/int8_inference_huggingface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | 4 | MAX_NEW_TOKENS = 128 5 | model_name = 'decapoda-research/llama-7b-hf' 6 | 7 | text = 'Hamburg is in which country?\n' 8 | tokenizer = AutoTokenizer.from_pretrained(model_name) 9 | input_ids = tokenizer(text, return_tensors="pt").input_ids 10 | 11 | free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) 12 | max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 13 | 14 | n_gpus = torch.cuda.device_count() 15 | max_memory = {i: max_memory for i in range(n_gpus)} 16 | 17 | model = AutoModelForCausalLM.from_pretrained( 18 | model_name, 19 | device_map='auto', 20 | load_in_8bit=True, 21 | max_memory=max_memory 22 | ) 23 | generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) 24 | print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) 25 | -------------------------------------------------------------------------------- /include/AAlloc.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Portable.h" 4 | 5 | namespace BinSearch { 6 | namespace Details { 7 | 8 | template 9 | bool isAligned(const T *p, size_t A) 10 | { 11 | return (reinterpret_cast(p) % A) == 0; 12 | } 13 | 14 | template 15 | struct AlignedVec 16 | { 17 | AlignedVec() 18 | : m_storage(0) 19 | , m_data(0) 20 | , m_sz(0) 21 | { 22 | } 23 | 24 | static size_t nBytes(size_t sz) 25 | { 26 | return sz * sizeof(T) + A; 27 | } 28 | 29 | static size_t shiftAmt(char *p) 30 | { 31 | return A>1? (A - (reinterpret_cast(p) % A)) % A: 0; 32 | } 33 | 34 | void setPtr(char *p, size_t sz) 35 | { 36 | m_sz = sz; 37 | m_data = reinterpret_cast(p + shiftAmt(p)); 38 | } 39 | 40 | //void setPtr(T *p, size_t sz) 41 | //{ 42 | // m_sz = sz; 43 | // if (A>1) 44 | // myassert(((reinterpret_cast(p) % A) == 0), "bad alignment"); 45 | // m_data = p; 46 | //} 47 | 48 | // internal allocation 49 | void resize(size_t sz) 50 | { 51 | m_storage = new char[nBytes(sz)]; 52 | setPtr(m_storage, sz); 53 | } 54 | 55 | // external allocation 56 | void set(char *storage, size_t sz) 57 | { 58 | setPtr(storage, sz); 59 | } 60 | 61 | ~AlignedVec() 62 | { 63 | if (m_storage) 64 | delete [] m_storage; 65 | } 66 | 67 | size_t size() const { return m_sz; } 68 | T& operator[](size_t i) { return m_data[i]; } 69 | const T& operator[](size_t i) const { return m_data[i]; } 70 | T* begin() { return m_data; } 71 | T* end() { return m_data+m_sz; } 72 | const T* begin() const { return m_data; } 73 | const T* end() const { return m_data+m_sz; } 74 | T& front() { return m_data[0]; } 75 | T& back() { return m_data[m_sz-1]; } 76 | const T& front() const { return m_data[0]; } 77 | const T& back() const { return m_data[m_sz - 1]; } 78 | 79 | private: 80 | char *m_storage; 81 | T *m_data; 82 | size_t m_sz; 83 | }; 84 | 85 | } // namespace Details 86 | } // namespace BinSearch 87 | -------------------------------------------------------------------------------- /include/AlgoXCodes.h: -------------------------------------------------------------------------------- 1 | ALGOENUM(DirectCacheFMA, 5) 2 | ALGOENUM(DirectFMA, 15) 3 | ALGOENUM(Direct2FMA, 25) 4 | ALGOENUM(DirectCache, 10) 5 | ALGOENUM(Direct, 20) 6 | ALGOENUM(Direct2, 30) 7 | ALGOENUM(Nonary, 40) 8 | ALGOENUM(Pentary, 50) 9 | ALGOENUM(Ternary, 60) 10 | ALGOENUM(Eytzinger, 70) 11 | ALGOENUM(BitSet, 80) 12 | ALGOENUM(ClassicOffset, 90) 13 | #ifdef PAPER_TEST 14 | ALGOENUM(MorinOffset, 100) 15 | ALGOENUM(BitSetNoPad, 110) 16 | ALGOENUM(ClassicMod, 120) 17 | ALGOENUM(MorinBranchy, 130) 18 | ALGOENUM(Classic, 140) 19 | ALGOENUM(LowerBound, 145) 20 | #ifdef USE_MKL 21 | ALGOENUM(MKL, 150) 22 | #endif 23 | #endif 24 | -------------------------------------------------------------------------------- /include/BinAlgo.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Type.h" 4 | #include 5 | 6 | namespace BinSearch { 7 | 8 | template 9 | struct BinAlgo : Details::BinAlgoBase 10 | { 11 | typedef Details::BinAlgoBase base_t; 12 | 13 | BinAlgo(const T* px, const uint32 n) : base_t(px, n), x0(px[0]), xN(px[n-1]), N(n) {} 14 | BinAlgo(const T* px, const uint32 n, const typename base_t::Data& d) : base_t(d), x0(px[0]), xN(px[n-1]), N(n) {} 15 | 16 | FORCE_INLINE 17 | uint32 scalar(T z) const 18 | { 19 | if (!L || z >= x0) 20 | if (!R || z < xN) 21 | return base_t::scalar(z); 22 | else 23 | return N; 24 | else 25 | return std::numeric_limits::max(); 26 | } 27 | 28 | 29 | FORCE_INLINE 30 | void vectorial(uint32 *pr, const T *pz, uint32 n) const 31 | { 32 | if (!L && !R) { 33 | Details::Loop::loop(*this, pr, pz, n); 34 | } 35 | else { 36 | const uint32 nElem = base_t::nElem; 37 | const uint32 idealbufsize = 256; 38 | const uint32 bufsize = nElem * (idealbufsize / nElem + ((idealbufsize % nElem) ? 1 : 0)); 39 | T databuf[bufsize]; 40 | uint32 resbuf[bufsize]; 41 | uint32 indexbuf[bufsize]; 42 | 43 | uint32 *prend = pr + n; 44 | while(pr != prend) { 45 | uint32 cnt = 0; 46 | uint32 niter = std::min(bufsize, (uint32)std::distance(pr,prend)); 47 | for (uint32 j = 0; j < niter; ++j) { 48 | T z = pz[j]; 49 | // FIXME: use SSE2? 50 | if (!L || z >= x0) 51 | if (!R || z < xN) { 52 | databuf[cnt] = z; 53 | indexbuf[cnt] = j; 54 | ++cnt; 55 | } 56 | else 57 | pr[j] = N; 58 | else 59 | pr[j] = std::numeric_limits::max(); 60 | } 61 | // FIXME: merge these two loops 62 | Details::Loop::loop(*this, resbuf, databuf, cnt); 63 | for (uint32 j = 0; j < cnt; ++j) 64 | pr[indexbuf[j]] = resbuf[j]; 65 | pr += niter; 66 | pz += niter; 67 | } 68 | } 69 | } 70 | 71 | Details::CondData x0; 72 | Details::CondData xN; 73 | Details::CondData N; 74 | }; 75 | 76 | 77 | } // namespace BinSearch 78 | -------------------------------------------------------------------------------- /include/BinSearch.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "AAlloc.h" 4 | #include "BinAlgo.h" 5 | #include "SIMD.h" 6 | 7 | #include 8 | #include 9 | 10 | 11 | #include "Algo-Direct2.h" 12 | -------------------------------------------------------------------------------- /include/Portable.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #if defined(__aarch64__) 8 | #ifdef __CUDACC__ 9 | #undef USE_NEON // Doesn't work with nvcc, undefined symbols 10 | #else 11 | #include 12 | #undef USE_NEON // Not yet implemented 13 | #endif 14 | #undef USE_AVX // x86_64 only 15 | #undef USE_AVX2 // x86_64 only 16 | #undef USE_SSE2 // x86_64 only 17 | #undef USE_SSE41 // x86_64 only 18 | #undef USE_SSE42 // x86_64 only 19 | #undef USE_FMA // x86_64 only 20 | #ifdef USE_NEON 21 | typedef float32x4_t __m128; 22 | typedef int32x4_t __m128i; 23 | typedef float64x2_t __m128d; 24 | #else 25 | typedef struct {float a; float b; float c; float d;} __m128; 26 | typedef struct {int a; int b; int c; int d;} __m128i; 27 | typedef struct {double a; double b;} __m128d; 28 | #endif 29 | #else 30 | #undef USE_NEON // ARM64 only 31 | #ifdef __FMA__ 32 | #define USE_FMA 33 | #endif 34 | #if !defined(__SSE2__) && !defined(_MSC_VER) 35 | #error Compiler must support SSE2 36 | #endif 37 | #define USE_SSE2 38 | 39 | #if defined(__aarch64__) 40 | #else 41 | #ifdef __AVX2__ 42 | #define USE_AVX2 43 | #endif 44 | 45 | #ifdef __AVX__ 46 | #define USE_AVX 47 | #endif 48 | 49 | 50 | #ifdef __SSE4_1__ 51 | #define USE_SSE41 52 | #endif 53 | 54 | #ifdef __SSE4_2__ 55 | #define USE_SSE42 56 | #endif 57 | #endif 58 | #endif 59 | 60 | #ifndef _MSC_VER 61 | #include 62 | #endif 63 | 64 | namespace BinSearch { 65 | 66 | #ifndef _MSC_VER 67 | typedef int8_t int8; 68 | typedef uint8_t uint8; 69 | typedef int32_t int32; 70 | typedef uint32_t uint32; 71 | typedef int64_t int64; 72 | typedef uint64_t uint64; 73 | #else 74 | typedef __int8 int8; 75 | typedef unsigned __int8 uint8; 76 | typedef __int32 int32; 77 | typedef unsigned __int32 uint32; 78 | typedef __int64 int64; 79 | typedef unsigned __int64 uint64; 80 | #endif 81 | 82 | namespace Details { 83 | 84 | #define myassert(cond, msg) if (!cond){ std::ostringstream os; os << "\nassertion failed: " << #cond << ", " << msg << "\n"; throw std::invalid_argument(os.str()); } 85 | 86 | // log2 is not defined in VS2008 87 | #if defined(_MSC_VER) 88 | inline uint32 log2 (uint32 val) { 89 | if (val == 1) return 0; 90 | uint32 ret = 0; 91 | do { 92 | ret++; 93 | val >>= 1; 94 | } while (val > 1); 95 | return ret; 96 | } 97 | #endif 98 | 99 | #ifdef _DEBUG 100 | #define DEBUG 101 | #endif 102 | 103 | #ifdef _MSC_VER 104 | # define FORCE_INLINE __forceinline 105 | # define NO_INLINE __declspec(noinline) 106 | #else 107 | # define NO_INLINE __attribute__((noinline)) 108 | # ifdef DEBUG 109 | # define FORCE_INLINE NO_INLINE 110 | # else 111 | # define FORCE_INLINE __attribute__((always_inline)) inline 112 | # endif 113 | #endif 114 | 115 | #ifdef USE_AVX 116 | #define COMISS "vcomiss" 117 | #define COMISD "vcomisd" 118 | #else 119 | #define COMISS "comiss" 120 | #define COMISD "comisd" 121 | #endif 122 | 123 | // nextafter is not defined in VS2008 124 | #if defined(_MSC_VER) && (_MSC_VER <= 1500) 125 | #include 126 | inline float mynext(float x) 127 | { 128 | return _nextafterf(x, std::numeric_limits::max()); 129 | } 130 | 131 | inline double mynext(double x) 132 | { 133 | return _nextafter(x, std::numeric_limits::max()); 134 | } 135 | inline float myprev(float x) 136 | { 137 | return _nextafterf(x, -std::numeric_limits::max()); 138 | } 139 | 140 | inline double myprev(double x) 141 | { 142 | return _nextafter(x, -std::numeric_limits::max()); 143 | } 144 | #else 145 | inline float mynext(float x) 146 | { 147 | return std::nextafterf(x, std::numeric_limits::max()); 148 | } 149 | 150 | inline double mynext(double x) 151 | { 152 | return std::nextafter(x, std::numeric_limits::max()); 153 | } 154 | inline float myprev(float x) 155 | { 156 | return std::nextafterf(x, -std::numeric_limits::max()); 157 | } 158 | 159 | inline double myprev(double x) 160 | { 161 | return std::nextafter(x, -std::numeric_limits::max()); 162 | } 163 | #endif 164 | 165 | template 166 | inline T next(T x) 167 | { 168 | for (int i = 0; i < 4; ++i) 169 | x = mynext(x); 170 | return x; 171 | } 172 | 173 | template 174 | inline T prev(T x) 175 | { 176 | for (int i = 0; i < 4; ++i) 177 | x = myprev(x); 178 | return x; 179 | } 180 | 181 | } // namespace Details 182 | } // namespace BinSearch 183 | -------------------------------------------------------------------------------- /include/Type.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "Portable.h" 8 | 9 | using std::size_t; 10 | 11 | namespace BinSearch { 12 | 13 | enum InstrSet { Scalar, SSE, AVX, Neon }; 14 | 15 | #define ALGOENUM(x, b) x, 16 | enum Algos 17 | { 18 | #include "AlgoXCodes.h" 19 | }; 20 | #undef ALGOENUM 21 | 22 | namespace Details { 23 | 24 | template 25 | struct InstrIntTraits; 26 | 27 | template 28 | struct InstrFloatTraits; 29 | 30 | // base class for algorithm supporting the method: 31 | // uint32 scalar(T z) const 32 | template 33 | struct AlgoScalarBase; 34 | 35 | // base class for algorithm supporting the following methods, constants and definitions: 36 | // static const uint32 nElem 37 | // struct Constants; 38 | // void initConstants(Constants& cst) const 39 | // void vectorial(uint32 *pr, const T *pz, const Constants& cst) const 40 | // The function vectorial processes nElem items 41 | template 42 | struct AlgoVecBase; 43 | 44 | template struct IntTraits; 45 | 46 | template <> struct IntTraits 47 | { 48 | typedef uint32 itype; 49 | }; 50 | template <> struct IntTraits 51 | { 52 | typedef uint64 itype; 53 | }; 54 | 55 | template 56 | struct Body 57 | { 58 | template 59 | FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const typename Expr::Constants& cst) 60 | { 61 | e.vectorial(ri, zi, cst); 62 | Body::template iteration(e, ri + D, zi + D, cst); 63 | } 64 | 65 | }; 66 | 67 | template <> 68 | struct Body<0> 69 | { 70 | template 71 | FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const H&) 72 | { 73 | } 74 | }; 75 | 76 | template 77 | struct Loop 78 | { 79 | typedef Algo algo_type; 80 | static const uint32 M = 4; 81 | static const uint32 D = algo_type::nElem; 82 | 83 | FORCE_INLINE static void loop(const algo_type& e, uint32 *ri, const T* zi, uint32 n) 84 | { 85 | typename algo_type::Constants cst; 86 | e.initConstants(cst); 87 | 88 | uint32 j = 0; 89 | while (j + (D*M) <= n) { 90 | Details::Body::template iteration(e, ri + j, zi + j, cst); 91 | j += (D*M); 92 | } 93 | while (j + D <= n) { 94 | e.vectorial(ri + j, zi + j, cst); 95 | j += D; 96 | } 97 | while (D > 1 && j < n) { 98 | ri[j] = e.scalar(zi[j]); 99 | j += 1; 100 | } 101 | } 102 | }; 103 | 104 | template 105 | struct _Pipeliner 106 | { 107 | template 108 | FORCE_INLINE static void go(const Expr& e, Data* d) 109 | { 110 | e.template run(d); 111 | _Pipeliner::go(e, d); 112 | } 113 | }; 114 | 115 | template 116 | struct _Pipeliner 117 | { 118 | template 119 | FORCE_INLINE static void go(const Expr& e, Data* d) 120 | { 121 | } 122 | }; 123 | 124 | template 125 | struct Pipeliner 126 | { 127 | template 128 | FORCE_INLINE static void go(const Expr& e, Data* d) 129 | { 130 | _Pipeliner::go(e, d); 131 | } 132 | }; 133 | 134 | 135 | #if 1 136 | template 137 | char is_complete_impl(char (*)[sizeof(T)]); 138 | 139 | template 140 | long is_complete_impl(...); 141 | 142 | template 143 | struct IsComplete 144 | { 145 | static const bool value = sizeof(is_complete_impl(0)) == sizeof(char); 146 | }; 147 | #else 148 | template 149 | std::true_type is_complete_impl(T *); 150 | 151 | std::false_type is_complete_impl(...); 152 | 153 | template 154 | struct IsComplete : decltype(is_complete_impl(std::declval())) {}; 155 | #endif 156 | 157 | template 158 | struct AlgoScalarToVec : AlgoScalarBase 159 | { 160 | typedef AlgoScalarBase base_t; 161 | 162 | AlgoScalarToVec(const typename base_t::Data& d) : base_t(d) {} 163 | AlgoScalarToVec(const T* px, const uint32 n) : base_t(px, n) {} 164 | 165 | static const uint32 nElem = 1; 166 | 167 | struct Constants 168 | { 169 | }; 170 | 171 | void initConstants(Constants& cst) const 172 | { 173 | } 174 | 175 | FORCE_INLINE 176 | void vectorial(uint32 *pr, const T *pz, const Constants& cst) const 177 | { 178 | *pr = base_t::scalar(*pz); 179 | } 180 | }; 181 | 182 | template 183 | struct conditional { typedef T type; }; 184 | 185 | template 186 | struct conditional { typedef F type; }; 187 | 188 | template 189 | struct CondData 190 | { 191 | FORCE_INLINE CondData(T x) : v(x) {} 192 | FORCE_INLINE operator const T&() const { return v;} 193 | private: 194 | T v; 195 | }; 196 | 197 | template 198 | struct CondData 199 | { 200 | FORCE_INLINE CondData(T) {} 201 | FORCE_INLINE operator const T() const { return 0;} 202 | }; 203 | 204 | template 205 | struct BinAlgoBase : Details::conditional< Details::IsComplete>::value 206 | , Details::AlgoVecBase 207 | , Details::AlgoScalarToVec 208 | >::type 209 | { 210 | typedef typename Details::conditional< Details::IsComplete>::value 211 | , Details::AlgoVecBase 212 | , Details::AlgoScalarToVec 213 | >::type base_t; 214 | 215 | BinAlgoBase(const T* px, const uint32 n) : base_t(px, n) {} 216 | BinAlgoBase(const typename base_t::Data& d) : base_t(d) {} 217 | }; 218 | 219 | } // namespace Details 220 | 221 | } // namespace BinSearch 222 | -------------------------------------------------------------------------------- /install_cuda.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | from urllib.request import urlretrieve 5 | 6 | cuda_versions = { 7 | "92": "https://developer.nvidia.com/compute/cuda/9.2/Prod2/local_installers/cuda_9.2.148_396.37_linux", 8 | "100": "https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux", 9 | "101": "https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run", 10 | "102": "https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run", 11 | "110": "https://developer.download.nvidia.com/compute/cuda/11.0.3/local_installers/cuda_11.0.3_450.51.06_linux.run", 12 | "111": "https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run", 13 | "112": "https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run", 14 | "113": "https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run", 15 | "114": "https://developer.download.nvidia.com/compute/cuda/11.4.4/local_installers/cuda_11.4.4_470.82.01_linux.run", 16 | "115": "https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run", 17 | "116": "https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run", 18 | "117": "https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run", 19 | "118": "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run", 20 | "120": "https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda_12.0.0_525.60.13_linux.run", 21 | "121": "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run", 22 | "122": "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run", 23 | "123": "https://developer.download.nvidia.com/compute/cuda/12.3.1/local_installers/cuda_12.3.1_545.23.08_linux.run", 24 | } 25 | 26 | 27 | def install_cuda(version, base_path, download_path): 28 | formatted_version = f"{version[:-1]}.{version[-1]}" 29 | folder = f"cuda-{formatted_version}" 30 | install_path = os.path.join(base_path, folder) 31 | 32 | if os.path.exists(install_path): 33 | print(f"Removing existing CUDA version {version} at {install_path}...") 34 | subprocess.run(["rm", "-rf", install_path], check=True) 35 | 36 | url = cuda_versions[version] 37 | filename = url.split('/')[-1] 38 | filepath = os.path.join(download_path, filename) 39 | 40 | if not os.path.exists(filepath): 41 | print(f"Downloading CUDA version {version} from {url}...") 42 | urlretrieve(url, filepath) 43 | else: 44 | print(f"Installer for CUDA version {version} already downloaded.") 45 | 46 | # Make the installer executable 47 | subprocess.run(["chmod", "+x", filepath], check=True) 48 | 49 | # Install CUDA 50 | print(f"Installing CUDA version {version}...") 51 | install_command = [ 52 | "bash", filepath, 53 | "--no-drm", "--no-man-page", "--override", 54 | "--toolkitpath=" + install_path, "--toolkit", "--silent" 55 | ] 56 | 57 | print(f"Running command: {' '.join(install_command)}") 58 | 59 | try: 60 | subprocess.run(install_command, check=True) 61 | except subprocess.CalledProcessError as e: 62 | print(f"Installation failed for CUDA version {version}: {e}") 63 | return 64 | finally: 65 | # Delete the installer file 66 | os.remove(filepath) 67 | 68 | print(f"CUDA version {version} installed at {install_path}") 69 | 70 | def main(): 71 | user_base_path = os.path.expanduser("~/cuda") 72 | system_base_path = "/usr/local/cuda" 73 | base_path = user_base_path # default to user-specific installation 74 | download_path = "/tmp" # default download path 75 | 76 | if len(sys.argv) < 2: 77 | print("Usage: python install_cuda.py [user/system] [download_path]") 78 | sys.exit(1) 79 | 80 | version = sys.argv[1] 81 | if len(sys.argv) > 2: 82 | base_path = system_base_path if sys.argv[2] == "system" else user_base_path 83 | if len(sys.argv) > 3: 84 | download_path = sys.argv[3] 85 | 86 | if not os.path.exists(base_path): 87 | os.makedirs(base_path) 88 | if not os.path.exists(download_path): 89 | os.makedirs(download_path) 90 | 91 | # Install CUDA version(s) 92 | if version == "all": 93 | for ver in cuda_versions.keys(): 94 | install_cuda(ver, base_path, download_path) 95 | elif version in cuda_versions: 96 | install_cuda(version, base_path, download_path) 97 | else: 98 | print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}") 99 | sys.exit(1) 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /install_cuda.sh: -------------------------------------------------------------------------------- 1 | URL92=https://developer.nvidia.com/compute/cuda/9.2/Prod2/local_installers/cuda_9.2.148_396.37_linux 2 | URL100=https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux 3 | URL101=https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run 4 | URL102=https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run 5 | URL110=https://developer.download.nvidia.com/compute/cuda/11.0.3/local_installers/cuda_11.0.3_450.51.06_linux.run 6 | URL111=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run 7 | URL112=https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run 8 | URL113=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run 9 | URL114=https://developer.download.nvidia.com/compute/cuda/11.4.4/local_installers/cuda_11.4.4_470.82.01_linux.run 10 | URL115=https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run 11 | URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run 12 | URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run 13 | URL118=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run 14 | URL120=https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda_12.0.0_525.60.13_linux.run 15 | URL121=https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run 16 | URL122=https://developer.download.nvidia.com/compute/cuda/12.2.1/local_installers/cuda_12.2.1_535.86.10_linux.run 17 | URL123=https://developer.download.nvidia.com/compute/cuda/12.3.1/local_installers/cuda_12.3.1_545.23.08_linux.run 18 | 19 | 20 | CUDA_VERSION=$1 21 | BASE_PATH=$2 22 | EXPORT_BASHRC=$3 23 | 24 | if [[ -n "$CUDA_VERSION" ]]; then 25 | if [[ "$CUDA_VERSION" -eq "92" ]]; then 26 | URL=$URL92 27 | FOLDER=cuda-9.2 28 | elif [[ "$CUDA_VERSION" -eq "100" ]]; then 29 | URL=$URL100 30 | FOLDER=cuda-10.0 31 | elif [[ "$CUDA_VERSION" -eq "101" ]]; then 32 | URL=$URL101 33 | FOLDER=cuda-10.1 34 | elif [[ "$CUDA_VERSION" -eq "102" ]]; then 35 | URL=$URL102 36 | FOLDER=cuda-10.2 37 | elif [[ "$CUDA_VERSION" -eq "110" ]]; then 38 | URL=$URL110 39 | FOLDER=cuda-11.0 40 | elif [[ "$CUDA_VERSION" -eq "111" ]]; then 41 | URL=$URL111 42 | FOLDER=cuda-11.1 43 | elif [[ "$CUDA_VERSION" -eq "112" ]]; then 44 | URL=$URL112 45 | FOLDER=cuda-11.2 46 | elif [[ "$CUDA_VERSION" -eq "113" ]]; then 47 | URL=$URL113 48 | FOLDER=cuda-11.3 49 | elif [[ "$CUDA_VERSION" -eq "114" ]]; then 50 | URL=$URL114 51 | FOLDER=cuda-11.4 52 | elif [[ "$CUDA_VERSION" -eq "115" ]]; then 53 | URL=$URL115 54 | FOLDER=cuda-11.5 55 | elif [[ "$CUDA_VERSION" -eq "116" ]]; then 56 | URL=$URL116 57 | FOLDER=cuda-11.6 58 | elif [[ "$CUDA_VERSION" -eq "117" ]]; then 59 | URL=$URL117 60 | FOLDER=cuda-11.7 61 | elif [[ "$CUDA_VERSION" -eq "118" ]]; then 62 | URL=$URL118 63 | FOLDER=cuda-11.8 64 | elif [[ "$CUDA_VERSION" -eq "120" ]]; then 65 | URL=$URL120 66 | FOLDER=cuda-12.0 67 | elif [[ "$CUDA_VERSION" -eq "121" ]]; then 68 | URL=$URL121 69 | FOLDER=cuda-12.1 70 | elif [[ "$CUDA_VERSION" -eq "122" ]]; then 71 | URL=$URL122 72 | FOLDER=cuda-12.2 73 | elif [[ "$CUDA_VERSION" -eq "123" ]]; then 74 | URL=$URL123 75 | FOLDER=cuda-12.3 76 | else 77 | echo "argument error: No cuda version passed as input. Choose among versions 92 to 123" 78 | fi 79 | else 80 | echo "argument error: No cuda version passed as input. Choose among versions 92 to 123" 81 | fi 82 | 83 | FILE=$(basename $URL) 84 | 85 | if [[ -n "$CUDA_VERSION" ]]; then 86 | echo $URL 87 | echo $FILE 88 | wget $URL 89 | bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent 90 | if [ "$EXPORT_BASHRC" -eq "1" ]; then 91 | echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64" >> ~/.bashrc 92 | echo "export PATH=\$PATH:$BASE_PATH/$FOLDER/bin" >> ~/.bashrc 93 | source ~/.bashrc 94 | fi 95 | else 96 | echo "" 97 | fi 98 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ "setuptools", "wheel" ] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.ruff] 6 | src = [ 7 | "bitsandbytes", 8 | "tests", 9 | "benchmarking" 10 | ] 11 | select = [ 12 | "B", # bugbear: security warnings 13 | "E", # pycodestyle 14 | "F", # pyflakes 15 | "I", # isort 16 | "ISC", # implicit string concatenation 17 | "UP", # alert you when better syntax is available in your python version 18 | "RUF", # the ruff developer's own rules 19 | ] 20 | target-version = "py38" 21 | ignore = [ 22 | "B007", # Loop control variable not used within the loop body (TODO: enable) 23 | "B028", # Warning without stacklevel (TODO: enable) 24 | "E501", # Supress line-too-long warnings: trust yapf's judgement on this one. 25 | "E701", # Multiple statements on one line (TODO: enable) 26 | "E712", # Allow using if x == False, as it's not always equivalent to if x. 27 | "E731", # Do not use lambda 28 | "F841", # Local assigned but not used (TODO: enable, these are likely bugs) 29 | "RUF012", # Mutable class attribute annotations 30 | ] 31 | ignore-init-module-imports = true # allow to expose in __init__.py via imports 32 | 33 | [tool.ruff.extend-per-file-ignores] 34 | "**/__init__.py" = ["F401"] # allow unused imports in __init__.py 35 | "{benchmarking,tests}/**/*.py" = [ 36 | "B007", 37 | "B011", 38 | "B023", 39 | "E701", 40 | "E731", 41 | "F841", 42 | "UP030", 43 | ] 44 | 45 | [tool.ruff.isort] 46 | combine-as-imports = true 47 | detect-same-package = true 48 | force-sort-within-sections = true 49 | known-first-party = ["bitsandbytes"] 50 | 51 | [[tool.mypy.overrides]] 52 | module = "triton.*" 53 | ignore_missing_imports = true 54 | 55 | [[tool.mypy.overrides]] 56 | module = "scipy.stats" 57 | ignore_missing_imports = true 58 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -rP 3 | ; --cov=bitsandbytes 4 | ; # contexts: record which test ran which line; can be seen in html coverage report 5 | ; --cov-context=test 6 | ; --cov-report html 7 | 8 | log_cli = True 9 | log_cli_level = INFO 10 | log_file = logs/pytest.log 11 | markers = 12 | benchmark: mark test as benchmark 13 | slow: mark test as slow 14 | -------------------------------------------------------------------------------- /requirements-ci.txt: -------------------------------------------------------------------------------- 1 | # Requirements used for GitHub actions 2 | pytest==7.2.2 3 | einops==0.6.0 4 | wheel==0.40.0 5 | lion-pytorch==0.0.6 6 | scipy==1.11.4 7 | pandas==2.2.0 8 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Requirements used for local development 2 | setuptools>=63 3 | pytest~=7.2.2 4 | einops~=0.6.0 5 | wheel~=0.40.0 6 | lion-pytorch~=0.0.6 7 | scipy~=1.11.4 8 | pandas~=2.2.0 9 | matplotlib~=3.8.2 10 | -------------------------------------------------------------------------------- /scripts/stale.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team, the AllenNLP library authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Script to close stale issue. Taken in part from the AllenNLP repository. 16 | https://github.com/allenai/allennlp. 17 | """ 18 | from datetime import datetime as dt, timezone 19 | import os 20 | 21 | from github import Github 22 | 23 | # All labels that we don't want to touch 24 | LABELS_TO_EXEMPT = [ 25 | "feature-request", 26 | ] 27 | 28 | 29 | def main(): 30 | g = Github(os.environ["GITHUB_TOKEN"]) 31 | repo = g.get_repo("TimDettmers/bitsandbytes") 32 | open_issues = repo.get_issues(state="open") 33 | 34 | for issue in open_issues: 35 | comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True) 36 | last_comment = comments[0] if len(comments) > 0 else None 37 | if ( 38 | last_comment is not None 39 | and last_comment.user.login == "github-actions[bot]" 40 | and (dt.now(timezone.utc) - issue.updated_at).days > 7 41 | and (dt.now(timezone.utc) - issue.created_at).days >= 30 42 | and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) 43 | ): 44 | issue.edit(state="closed") 45 | elif ( 46 | (dt.now(timezone.utc) - issue.updated_at).days > 23 47 | and (dt.now(timezone.utc) - issue.created_at).days >= 30 48 | and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) 49 | ): 50 | issue.create_comment( 51 | "This issue has been automatically marked as stale because it has not had " 52 | "recent activity. If you think this still needs to be addressed " 53 | "please comment on this thread.\n\n" 54 | ) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import glob 6 | import os 7 | 8 | from setuptools import find_packages, setup 9 | from setuptools.dist import Distribution 10 | 11 | libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.*")) 12 | libs = [os.path.basename(p) for p in libs] 13 | print("libs:", libs) 14 | 15 | 16 | def read(fname): 17 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 18 | 19 | 20 | # Tested with wheel v0.29.0 21 | class BinaryDistribution(Distribution): 22 | def has_ext_modules(self): 23 | return True 24 | 25 | 26 | setup( 27 | name="bitsandbytes", 28 | version="0.43.0.dev0", 29 | author="Tim Dettmers", 30 | author_email="dettmers@cs.washington.edu", 31 | description="k-bit optimizers and matrix multiplication routines.", 32 | license="MIT", 33 | keywords="gpu optimizers optimization 8-bit quantization compression", 34 | url="https://github.com/TimDettmers/bitsandbytes", 35 | packages=find_packages(), 36 | package_data={"": libs}, 37 | install_requires=["torch", "numpy"], 38 | extras_require={ 39 | "benchmark": ["pandas", "matplotlib"], 40 | "test": ["scipy"], 41 | }, 42 | long_description=read("README.md"), 43 | long_description_content_type="text/markdown", 44 | classifiers=[ 45 | "Development Status :: 4 - Beta", 46 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 47 | ], 48 | distclass=BinaryDistribution, 49 | ) 50 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rickardp/bitsandbytes/927f7167e3395ec26f859f294c1d4979a70a718a/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | 5 | def pytest_runtest_call(item): 6 | try: 7 | item.runtest() 8 | except NotImplementedError as nie: 9 | if "NO_CUBLASLT" in str(nie): 10 | pytest.skip("CUBLASLT not available") 11 | raise 12 | except AssertionError as ae: 13 | if str(ae) == "Torch not compiled with CUDA enabled": 14 | pytest.skip("Torch not compiled with CUDA enabled") 15 | raise 16 | 17 | 18 | @pytest.fixture(scope="session") 19 | def requires_cuda() -> bool: 20 | cuda_available = torch.cuda.is_available() 21 | if not cuda_available: 22 | pytest.skip("CUDA is required") 23 | return cuda_available 24 | -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import random 3 | from typing import Any 4 | 5 | import torch 6 | 7 | test_dims_rng = random.Random(42) 8 | 9 | 10 | def get_test_dims(min: int, max: int, *, n: int) -> list[int]: 11 | return [test_dims_rng.randint(min, max) for _ in range(n)] 12 | 13 | 14 | def format_with_label(label: str, value: Any) -> str: 15 | if isinstance(value, bool): 16 | formatted = "T" if value else "F" 17 | elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value): 18 | formatted = "".join("T" if b else "F" for b in value) 19 | else: 20 | formatted = str(value) 21 | return f"{label}={formatted}" 22 | 23 | 24 | def id_formatter(label: str): 25 | """ 26 | Return a function that formats the value given to it with the given label. 27 | """ 28 | return lambda value: format_with_label(label, value) 29 | 30 | 31 | DTYPE_NAMES = { 32 | torch.bfloat16: "bf16", 33 | torch.bool: "bool", 34 | torch.float16: "fp16", 35 | torch.float32: "fp32", 36 | torch.float64: "fp64", 37 | torch.int32: "int32", 38 | torch.int64: "int64", 39 | torch.int8: "int8", 40 | } 41 | 42 | 43 | def describe_dtype(dtype: torch.dtype) -> str: 44 | return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2] 45 | 46 | 47 | TRUE_FALSE = (True, False) 48 | BOOLEAN_TRIPLES = list( 49 | product(TRUE_FALSE, repeat=3) 50 | ) # all combinations of (bool, bool, bool) 51 | BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool) 52 | -------------------------------------------------------------------------------- /tests/test_cuda_setup_evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | 7 | # hardcoded test. Not good, but a sanity check for now 8 | # TODO: improve this 9 | def test_manual_override(requires_cuda): 10 | manual_cuda_path = str(Path('/mmfs1/home/dettmers/data/local/cuda-12.2')) 11 | 12 | pytorch_version = torch.version.cuda.replace('.', '') 13 | 14 | assert pytorch_version != 122 # TODO: this will never be true... 15 | 16 | os.environ['CUDA_HOME']='{manual_cuda_path}' 17 | os.environ['BNB_CUDA_VERSION']='122' 18 | #assert str(manual_cuda_path) in os.environ['LD_LIBRARY_PATH'] 19 | import bitsandbytes as bnb 20 | loaded_lib = bnb.cuda_setup.main.CUDASetup.get_instance().binary_name 21 | #assert loaded_lib == 'libbitsandbytes_cuda122.so' 22 | -------------------------------------------------------------------------------- /tests/test_generation.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import math 3 | 4 | import pytest 5 | import torch 6 | import transformers 7 | from transformers import ( 8 | AutoModelForCausalLM, 9 | BitsAndBytesConfig, 10 | ) 11 | 12 | from tests.helpers import TRUE_FALSE, describe_dtype, id_formatter 13 | 14 | 15 | def get_4bit_config(): 16 | return BitsAndBytesConfig( 17 | load_in_4bit=True, 18 | load_in_8bit=False, 19 | llm_int8_threshold=6.0, 20 | llm_int8_has_fp16_weight=False, 21 | bnb_4bit_compute_dtype=torch.float16, 22 | bnb_4bit_use_double_quant=True, 23 | bnb_4bit_quant_type='nf4', 24 | ) 25 | 26 | 27 | def get_model_and_tokenizer(config): 28 | model_name_or_path, quant_type = config 29 | bnb_config = get_4bit_config() 30 | if quant_type == '16bit': 31 | bnb_config.load_in_4bit = False 32 | else: 33 | bnb_config.bnb_4bit_quant_type= quant_type 34 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, 35 | quantization_config=bnb_config, 36 | max_memory={0:'48GB'}, 37 | device_map='auto', 38 | torch_dtype=torch.bfloat16 39 | ).eval() 40 | 41 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) 42 | 43 | return model, tokenizer 44 | 45 | def get_prompt_for_generation_eval(text, add_roles=True): 46 | description = ( 47 | "A chat between a curious human and an artificial intelligence assistant. " 48 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 49 | ) 50 | if add_roles: 51 | prompt = f'{description} ### Human: {text} ### Assistant:' 52 | else: 53 | prompt = f'{description} {text}' 54 | return prompt 55 | 56 | def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_for_generation_eval): 57 | text = prompt_func(text) 58 | inputs = tokenizer(text, return_tensors="pt").to('cuda:0') 59 | outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config) 60 | return tokenizer.decode(outputs[0], skip_special_tokens=True) 61 | 62 | models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7'] 63 | dtypes = ['nf4', 'fp4'] 64 | 65 | @pytest.fixture(scope='session', params=product(models, dtypes)) 66 | def model_and_tokenizer(request): 67 | model, tokenizer = get_model_and_tokenizer(request.param) 68 | yield request.param, model, tokenizer 69 | del model 70 | 71 | 72 | @pytest.mark.parametrize("DQ", TRUE_FALSE, ids=id_formatter("dq")) 73 | @pytest.mark.parametrize("inference_kernel", TRUE_FALSE, ids=id_formatter("inference_kernel")) 74 | @pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype) 75 | @pytest.mark.slow 76 | def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype): 77 | fixture_config, model, tokenizer = model_and_tokenizer 78 | 79 | generation_config = transformers.GenerationConfig( 80 | max_new_tokens=20, 81 | do_sample=True, 82 | top_p=0.9, 83 | temperature=0.7, 84 | ) 85 | generation_config.max_new_tokens = 20 86 | 87 | 88 | #text = 'Please write down the first 50 digits of pi.' 89 | #text = get_prompt_for_generation_eval(text) 90 | #text += ' Sure, here the first 50 digits of pi: 3.14159' 91 | n_cases = 6 92 | text = '3.14159' 93 | if hasattr(model.config, 'quantization_config'): 94 | model.config.quantization_config.bnb_4bit_compute_dtype = dtype 95 | model.config.quantization_config.bnb_4bit_use_double_quant = DQ 96 | 97 | if not inference_kernel: 98 | text = [text]*n_cases 99 | inputs = tokenizer(text, return_tensors="pt").to('cuda:0') 100 | x = inputs['input_ids'] 101 | outputs = [] 102 | if inference_kernel: 103 | for i in range(n_cases): 104 | output = model.generate(x, generation_config=generation_config) 105 | textout = tokenizer.decode(output[0], skip_special_tokens=True) 106 | outputs.append(textout) 107 | else: 108 | outputs = model.generate(x, generation_config=generation_config) 109 | outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] 110 | 111 | 112 | assert len(outputs) == n_cases 113 | failure_count = 0 114 | for i in range(n_cases): 115 | if not outputs[i][:len(str(math.pi))] == str(math.pi): 116 | failure_count += 1 117 | failure_max = (2 if fixture_config[0] == 'huggyllama/llama-7b' else 4) 118 | if failure_count > failure_max: 119 | print(math.pi) 120 | for out in outputs: 121 | print(out) 122 | raise ValueError(f'Failure count: {failure_count}/{n_cases}') 123 | -------------------------------------------------------------------------------- /tests/test_linear4bit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tempfile import TemporaryDirectory 3 | 4 | import pytest 5 | import torch 6 | 7 | import bitsandbytes as bnb 8 | from tests.helpers import TRUE_FALSE 9 | 10 | storage = { 11 | 'uint8': torch.uint8, 12 | 'float16': torch.float16, 13 | 'bfloat16': torch.bfloat16, 14 | 'float32': torch.float32 15 | } 16 | 17 | @pytest.mark.parametrize("quant_storage", ['uint8', 'float16', 'bfloat16', 'float32']) 18 | @pytest.mark.parametrize("bias", TRUE_FALSE) 19 | @pytest.mark.parametrize("compress_statistics", TRUE_FALSE) 20 | @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) 21 | def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage): 22 | original_dtype = torch.float16 23 | compute_dtype = None 24 | device = "cuda" 25 | layer_shape = (300, 400) 26 | 27 | linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer 28 | 29 | # Quantizing original layer 30 | linear_q = bnb.nn.Linear4bit( 31 | linear.in_features, 32 | linear.out_features, 33 | bias=bias, 34 | compute_dtype=compute_dtype, 35 | compress_statistics=compress_statistics, 36 | quant_type=quant_type, 37 | device="meta", 38 | ) 39 | new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False) 40 | linear_q.weight = new_weight 41 | if bias: 42 | linear_q.bias = torch.nn.Parameter(linear.bias) 43 | linear_q = linear_q.to(device) 44 | 45 | # saving to state_dict: 46 | sd = linear_q.state_dict() 47 | 48 | # restoring from state_dict: 49 | bias_data2 = sd.pop("bias", None) 50 | weight_data2 = sd.pop("weight") 51 | weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2) 52 | 53 | # creating new layer with same params: 54 | linear_q2 = bnb.nn.Linear4bit( 55 | linear.in_features, 56 | linear.out_features, 57 | bias=bias, 58 | compute_dtype=compute_dtype, 59 | compress_statistics=compress_statistics, 60 | quant_type=quant_type, 61 | device="meta", 62 | ) 63 | # loading weights from state_dict: 64 | linear_q2.weight = weight2 65 | if bias: 66 | linear_q2.bias = torch.nn.Parameter(bias_data2) 67 | linear_q2 = linear_q2.to(device) 68 | 69 | # MATCHING 70 | a, b = linear_q.weight, linear_q2.weight 71 | 72 | # Quantizing original layer with specified quant_storage type 73 | linear_qs = bnb.nn.Linear4bit( 74 | linear.in_features, 75 | linear.out_features, 76 | bias=bias, 77 | compute_dtype=compute_dtype, 78 | compress_statistics=compress_statistics, 79 | quant_type=quant_type, 80 | quant_storage=storage[quant_storage], 81 | device="meta", 82 | ) 83 | linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage]) 84 | if bias: 85 | linear_qs.bias = torch.nn.Parameter(linear.bias) 86 | linear_qs = linear_qs.to(device) 87 | 88 | assert a.device == b.device 89 | assert a.dtype == b.dtype 90 | assert torch.equal(a, b) 91 | 92 | q0 = a.quant_state 93 | q1 = b.quant_state 94 | for attr in ('code', 'dtype', 'blocksize', 'absmax'): 95 | c, d = getattr(q0, attr), getattr(q1, attr) 96 | if isinstance(c, torch.Tensor): 97 | assert torch.equal(c, d) 98 | else: 99 | assert c == d, f"{c} != {d}" 100 | 101 | if q0.state2 is not None: 102 | for attr in ('code', 'dtype', 'blocksize', 'absmax'): 103 | c, d = getattr(q0.state2, attr), getattr(q1.state2, attr) 104 | if isinstance(c, torch.Tensor): 105 | assert torch.equal(c, d) 106 | else: 107 | assert c == d, f"{c} != {d}" 108 | 109 | if bias: 110 | a, b = linear_q.bias, linear_q2.bias 111 | assert a.device == b.device 112 | assert a.dtype == b.dtype 113 | assert torch.equal(a, b) 114 | 115 | # Forward test 116 | x = torch.rand(42, layer_shape[0], device=device) 117 | a = linear_q(x) 118 | b = linear_q2(x) 119 | c = linear_qs(x) 120 | assert a.device == b.device 121 | assert a.dtype == b.dtype 122 | assert a.device == c.device 123 | assert a.dtype == c.dtype 124 | assert torch.equal(a, b) 125 | assert torch.equal(a, c) 126 | 127 | # Test moving to CPU and back to GPU 128 | linear_q2.to('cpu') 129 | linear_q2.to(device) 130 | d = linear_qs(x) 131 | assert c.dtype == d.dtype 132 | assert c.device == d.device 133 | assert torch.equal(c, d) 134 | 135 | # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias 136 | with TemporaryDirectory() as tmpdir: 137 | state_path_4bit = os.path.join(tmpdir, "state_4bit.pth") 138 | state_path = os.path.join(tmpdir, "state.pth") 139 | torch.save(linear.state_dict(), state_path) 140 | torch.save(linear_q.state_dict(), state_path_4bit) 141 | 142 | size_orig, size_4 = os.path.getsize(state_path), os.path.getsize( 143 | state_path_4bit 144 | ) 145 | size_ratio = size_4 / size_orig 146 | target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases 147 | ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" 148 | assert size_ratio < target_compression, ratio_error_msg 149 | -------------------------------------------------------------------------------- /tests/test_linear8bitlt.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | import os 3 | from tempfile import TemporaryDirectory 4 | 5 | import pytest 6 | import torch 7 | 8 | import bitsandbytes as bnb 9 | from bitsandbytes import functional as F 10 | from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout 11 | from bitsandbytes.nn.modules import Linear8bitLt 12 | from tests.helpers import TRUE_FALSE, id_formatter 13 | 14 | # contributed by Alex Borzunov, see: 15 | # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py 16 | 17 | @pytest.mark.skipif( 18 | not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), 19 | reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", 20 | ) 21 | def test_layout_exact_match(): 22 | x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda() 23 | for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"): 24 | transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device) 25 | tile_indices = get_inverse_transform_indices(transform, tile_size) 26 | cxb = transform(x) 27 | 28 | torch.cuda.synchronize() 29 | restored_x = undo_layout(cxb, tile_indices) 30 | torch.cuda.synchronize() 31 | assert restored_x.is_contiguous() 32 | assert torch.all(torch.eq(restored_x, x)) 33 | 34 | 35 | def test_linear_no_igemmlt(): 36 | linear = torch.nn.Linear(1024, 3072) 37 | x = torch.randn(3, 1024, dtype=torch.half) 38 | linear_custom = Linear8bitLt( 39 | linear.in_features, 40 | linear.out_features, 41 | linear.bias is not None, 42 | has_fp16_weights=False, 43 | threshold=6.0, 44 | ) 45 | linear_custom.state.force_no_igemmlt = True 46 | 47 | linear_custom.weight = bnb.nn.Int8Params( 48 | linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False 49 | ).to(linear.weight.dtype) 50 | linear_custom.bias = linear.bias 51 | linear_custom = linear_custom.cuda() 52 | linear = linear.half().cuda() 53 | 54 | x_ref = x.clone().cuda().requires_grad_(True) 55 | x_ours = x.clone().cuda().requires_grad_(True) 56 | fx_ref = linear(x_ref).float() 57 | grad_proj = torch.randn_like(fx_ref) 58 | (fx_ref * grad_proj).mean().backward() 59 | 60 | fx_ours = linear_custom(x_ours).float() 61 | (fx_ours * grad_proj).mean().backward() 62 | assert torch.allclose(fx_ref, fx_ours, atol=0.02) 63 | assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01) 64 | assert not linear_custom.state.has_fp16_weights 65 | assert linear_custom.state.CB is not None 66 | assert linear_custom.state.CxB is None 67 | 68 | 69 | @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) 70 | @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) 71 | @pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) 72 | @pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt")) 73 | def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt): 74 | linear = torch.nn.Linear(32, 96) 75 | x = torch.randn(3, 32, dtype=torch.half) 76 | 77 | linear_custom = Linear8bitLt( 78 | linear.in_features, 79 | linear.out_features, 80 | linear.bias is not None, 81 | has_fp16_weights=has_fp16_weights, 82 | threshold=6.0, 83 | ) 84 | if force_no_igemmlt: 85 | linear_custom.state.force_no_igemmlt = True 86 | 87 | linear_custom.weight = bnb.nn.Int8Params( 88 | linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights 89 | ) 90 | linear_custom.bias = linear.bias 91 | linear_custom = linear_custom.cuda() 92 | 93 | if serialize_before_forward: 94 | state_dict_8bit = linear_custom.state_dict() 95 | 96 | x_first = x.clone().cuda().requires_grad_(True) 97 | fx_first = linear_custom(x_first).float() 98 | grad_proj = torch.randn_like(fx_first) 99 | (fx_first * grad_proj).mean().backward() 100 | 101 | if not serialize_before_forward: 102 | state_dict_8bit = linear_custom.state_dict() 103 | 104 | with TemporaryDirectory() as tmpdir: 105 | state_path_8bit = os.path.join(tmpdir, "state_8bit.pth") 106 | state_path = os.path.join(tmpdir, "state.pth") 107 | 108 | torch.save(linear.state_dict(), state_path) 109 | torch.save(state_dict_8bit, state_path_8bit) 110 | 111 | if not has_fp16_weights: 112 | assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path) 113 | 114 | new_state_dict = torch.load(state_path_8bit) 115 | 116 | new_linear_custom = Linear8bitLt( 117 | linear.in_features, 118 | linear.out_features, 119 | linear.bias is not None, 120 | has_fp16_weights=has_fp16_weights, 121 | threshold=6.0, 122 | ) 123 | if force_no_igemmlt: 124 | new_linear_custom.state.force_no_igemmlt = True 125 | 126 | if deserialize_before_cuda: 127 | with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): 128 | new_linear_custom.load_state_dict(new_state_dict, strict=True) 129 | 130 | new_linear_custom = new_linear_custom.cuda() 131 | 132 | if not deserialize_before_cuda: 133 | new_linear_custom.load_state_dict(new_state_dict, strict=True) 134 | 135 | x_second = x.clone().cuda().requires_grad_(True) 136 | fx_second = new_linear_custom(x_second).float() 137 | (fx_second * grad_proj).mean().backward() 138 | 139 | # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised 140 | if has_fp16_weights or not deserialize_before_cuda: 141 | assert torch.allclose(fx_first, fx_second, atol=1e-5) 142 | assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) 143 | -------------------------------------------------------------------------------- /tests/test_triton.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from bitsandbytes.nn import Linear8bitLt 5 | from bitsandbytes.nn.triton_based_modules import SwitchBackLinear 6 | from bitsandbytes.triton.triton_utils import is_triton_available 7 | from tests.helpers import TRUE_FALSE 8 | 9 | 10 | @pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, 11 | reason="This test requires triton and a GPU with compute capability 8.0 or higher.") 12 | @pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE) 13 | def test_switchback(vector_wise_quantization): 14 | for dim in [83]: 15 | for batch in [13]: 16 | 17 | standard = torch.nn.Linear(dim, 4 * dim).cuda().half() 18 | switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() 19 | baseline = Linear8bitLt(dim, 4 * dim).cuda().half() 20 | switchback.weight.data.copy_(standard.weight) 21 | switchback.bias.data.copy_(standard.bias) 22 | baseline.weight.data.copy_(standard.weight) 23 | baseline.bias.data.copy_(standard.bias) 24 | 25 | x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True) 26 | x2 = x1.clone().detach().requires_grad_(True) 27 | x3 = x1.clone().detach().requires_grad_(True) 28 | 29 | out_standard = standard(x1) 30 | (2**10 * out_standard.abs().mean()).backward() 31 | 32 | print(x2.dtype) 33 | out_sb = switchback(x2) 34 | (2**10 * out_sb.abs().mean()).backward() 35 | 36 | out_baseline = baseline(x3) 37 | (2**10 * out_baseline.abs().mean()).backward() 38 | 39 | err_sb = (out_standard - out_sb).abs().mean() 40 | err_baseline = (out_standard - out_baseline).abs().mean() 41 | print('OUT', err_sb, err_baseline) 42 | assert err_sb < 2 * err_baseline 43 | 44 | err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() 45 | err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean() 46 | 47 | print('GW2', err_sb, err_baseline) 48 | assert err_sb < 2 * err_baseline 49 | 50 | err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() 51 | err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean() 52 | 53 | print('GW1', err_sb, err_baseline) 54 | assert err_sb < 2 * err_baseline 55 | 56 | err_sb = (x1.grad - x2.grad).abs().mean() 57 | err_baseline = (x1.grad - x3.grad).abs().mean() 58 | 59 | print('GX1', err_sb, err_baseline) 60 | assert err_sb < 2 * err_baseline 61 | --------------------------------------------------------------------------------