├── .buckconfig ├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── NOTICE.md ├── README.md ├── benchmarking └── switchback │ ├── README.md │ ├── info_a100_py2.jsonl │ ├── make_plot_with_jsonl.py │ ├── plot_with_info.pdf │ └── speed_benchmark.py ├── bitsandbytes ├── __init__.py ├── __main__.py ├── autograd │ ├── __init__.py │ └── _functions.py ├── cextension.py ├── cuda_setup │ ├── __init__.py │ ├── env_vars.py │ └── main.py ├── functional.py ├── nn │ ├── __init__.py │ ├── modules.py │ └── triton_based_modules.py ├── optim │ ├── __init__.py │ ├── adagrad.py │ ├── adam.py │ ├── adamw.py │ ├── lamb.py │ ├── lars.py │ ├── lion.py │ ├── optimizer.py │ ├── rmsprop.py │ └── sgd.py ├── research │ ├── __init__.py │ ├── autograd │ │ ├── __init__.py │ │ └── _functions.py │ └── nn │ │ ├── __init__.py │ │ └── modules.py ├── triton │ ├── __init__.py │ ├── dequantize_rowwise.py │ ├── int8_matmul_mixed_dequanitze.py │ ├── int8_matmul_rowwise_dequantize.py │ ├── quantize_columnwise_and_transpose.py │ ├── quantize_global.py │ ├── quantize_rowwise.py │ └── triton_utils.py └── utils.py ├── check_bnb_install.py ├── compile_from_source.md ├── csrc ├── common.cpp ├── common.h ├── cpu_ops.cpp ├── cpu_ops.h ├── kernels.cu ├── kernels.cuh ├── ops.cu ├── ops.cuh └── pythonInterface.c ├── cuda_install.sh ├── deploy.sh ├── environment.yml ├── errors_and_solutions.md ├── examples └── int8_inference_huggingface.py ├── how_to_use_nonpytorch_cuda.md ├── howto_config_override.md ├── include ├── AAlloc.h ├── Algo-Direct-Common.h ├── Algo-Direct2.h ├── AlgoXCodes.h ├── BinAlgo.h ├── BinSearch.h ├── Portable.h ├── SIMD.h └── Type.h ├── pyproject.toml ├── requirements.txt ├── setup.py └── tests ├── test_autograd.py ├── test_cuda_setup_evaluator.py ├── test_functional.py ├── test_generation.py ├── test_linear8bitlt.py ├── test_modules.py ├── test_optim.py └── test_triton.py /.buckconfig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jllllll/bitsandbytes/e229fbce66adde7c2a6bc58cbe7d57c1f4a0ba02/.buckconfig -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # vim 132 | *.swp 133 | 134 | dependencies 135 | cuda_build 136 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to bitsandbytes 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to bitsandbytes, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST))) 2 | ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH))) 3 | 4 | GPP:= /usr/bin/g++ 5 | #GPP:= /sw/gcc/11.2.0/bin/g++ 6 | ifeq ($(CUDA_HOME),) 7 | CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev) 8 | endif 9 | 10 | ifndef CUDA_VERSION 11 | $(warning WARNING: CUDA_VERSION not set. Call make with CUDA string, for example: make cuda11x CUDA_VERSION=115 or make cpuonly CUDA_VERSION=CPU) 12 | CUDA_VERSION:= 13 | endif 14 | 15 | 16 | 17 | NVCC := $(CUDA_HOME)/bin/nvcc 18 | 19 | ########################################### 20 | 21 | CSRC := $(ROOT_DIR)/csrc 22 | BUILD_DIR:= $(ROOT_DIR)/build 23 | 24 | FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu 25 | FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c 26 | 27 | INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include 28 | LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib 29 | 30 | # NVIDIA NVCC compilation flags 31 | COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell 32 | COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell 33 | COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal 34 | COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal 35 | COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta 36 | 37 | CC_KEPLER := -gencode arch=compute_35,code=sm_35 # Kepler 38 | CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler 39 | 40 | # Later versions of CUDA support the new architectures 41 | CC_CUDA11x := -gencode arch=compute_75,code=sm_75 42 | CC_CUDA11x += -gencode arch=compute_80,code=sm_80 43 | CC_CUDA11x += -gencode arch=compute_86,code=sm_86 44 | 45 | 46 | CC_cublasLt110 := -gencode arch=compute_75,code=sm_75 47 | CC_cublasLt110 += -gencode arch=compute_80,code=sm_80 48 | 49 | CC_cublasLt111 := -gencode arch=compute_75,code=sm_75 50 | CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 51 | CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 52 | 53 | CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 54 | CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 55 | 56 | 57 | all: $(BUILD_DIR) env 58 | $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) 59 | $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o 60 | $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) 61 | 62 | cuda110_nomatmul_kepler: $(BUILD_DIR) env 63 | $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT 64 | $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o 65 | $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) 66 | 67 | cuda11x_nomatmul_kepler: $(BUILD_DIR) env 68 | $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT 69 | $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o 70 | $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) 71 | 72 | 73 | cuda110_nomatmul: $(BUILD_DIR) env 74 | $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT 75 | $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o 76 | $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) 77 | 78 | cuda11x_nomatmul: $(BUILD_DIR) env 79 | $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT 80 | $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o 81 | $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) 82 | 83 | cuda118_nomatmul: $(BUILD_DIR) env 84 | $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT 85 | $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o 86 | $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) 87 | 88 | cuda12x_nomatmul: $(BUILD_DIR) env 89 | $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT 90 | $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o 91 | $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) 92 | 93 | cuda110: $(BUILD_DIR) env 94 | $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) 95 | $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o 96 | $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) 97 | 98 | cuda11x: $(BUILD_DIR) env 99 | $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) 100 | $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o 101 | $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) 102 | 103 | cuda118: $(BUILD_DIR) env 104 | $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) 105 | $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o 106 | $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) 107 | 108 | cuda12x: $(BUILD_DIR) env 109 | $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) 110 | $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o 111 | $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) 112 | 113 | cpuonly: $(BUILD_DIR) env 114 | $(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cpu.so 115 | 116 | env: 117 | @echo "ENVIRONMENT" 118 | @echo "============================" 119 | @echo "CUDA_VERSION: $(CUDA_VERSION)" 120 | @echo "============================" 121 | @echo "NVCC path: $(NVCC)" 122 | @echo "GPP path: $(GPP) VERSION: `$(GPP) --version | head -n 1`" 123 | @echo "CUDA_HOME: $(CUDA_HOME)" 124 | @echo "CONDA_PREFIX: $(CONDA_PREFIX)" 125 | @echo "PATH: $(PATH)" 126 | @echo "LD_LIBRARY_PATH: $(LD_LIBRARY_PATH)" 127 | @echo "============================" 128 | 129 | $(BUILD_DIR): 130 | mkdir -p build 131 | mkdir -p dependencies 132 | 133 | $(ROOT_DIR)/dependencies/cub: 134 | git clone https://github.com/NVlabs/cub $(ROOT_DIR)/dependencies/cub 135 | cd dependencies/cub; git checkout 1.11.0 136 | 137 | clean: 138 | rm build/* 139 | 140 | cleaneggs: 141 | rm -rf *.egg* 142 | 143 | cleanlibs: 144 | rm ./bitsandbytes/libbitsandbytes*.so 145 | -------------------------------------------------------------------------------- /NOTICE.md: -------------------------------------------------------------------------------- 1 | The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license. 2 | 3 | We thank Fabio Cannizzo for this work on FastBinarySearch which is included in this project. 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bitsandbytes 2 | 3 | The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions. 4 | 5 | 6 | 7 | Resources: 8 | - [8-bit Optimizer Paper](https://arxiv.org/abs/2110.02861) -- [Video](https://www.youtube.com/watch?v=IxrlHAJtqKE) -- [Docs](https://bitsandbytes.readthedocs.io/en/latest/) 9 | 10 | - [LLM.int8() Paper](https://arxiv.org/abs/2208.07339) -- [LLM.int8() Software Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) -- [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/) 11 | 12 | ## TL;DR 13 | **Requirements** 14 | Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. 15 | 16 | (Deprecated: CUDA 10.0 is deprecated and only CUDA >= 11.0) will be supported with release 0.39.0) 17 | 18 | **Installation**: 19 | 20 | ``pip install bitsandbytes`` 21 | 22 | In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below. 23 | 24 | Compilation quickstart: 25 | ```bash 26 | git clone https://github.com/timdettmers/bitsandbytes.git 27 | cd bitsandbytes 28 | 29 | # CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120} 30 | # make argument in {cuda110, cuda11x, cuda12x} 31 | # if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes 32 | CUDA_VERSION=117 make cuda11x 33 | python setup.py install 34 | ``` 35 | 36 | **Using Int8 inference with HuggingFace Transformers** 37 | 38 | ```python 39 | from transformers import AutoModelForCausalLM 40 | model = AutoModelForCausalLM.from_pretrained( 41 | 'decapoda-research/llama-7b-hf, 42 | device_map='auto', 43 | load_in_8bit=True, 44 | max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB') 45 | ``` 46 | 47 | A more detailed example, can be found in [examples/int8_inference_huggingface.py](examples/int8_inference_huggingface.py). 48 | 49 | **Using 8-bit optimizer**: 50 | 1. Comment out optimizer: ``#torch.optim.Adam(....)`` 51 | 2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same) 52 | 3. Replace embedding layer if necessary: ``torch.nn.Embedding(..) -> bnb.nn.Embedding(..)`` 53 | 54 | 55 | **Using 8-bit Inference**: 56 | 1. Comment out torch.nn.Linear: ``#linear = torch.nn.Linear(...)`` 57 | 2. Add bnb 8-bit linear light module: ``linear = bnb.nn.Linear8bitLt(...)`` (base arguments stay the same) 58 | 3. There are two modes: 59 | - Mixed 8-bit training with 16-bit main weights. Pass the argument ``has_fp16_weights=True`` (default) 60 | - Int8 inference. Pass the argument ``has_fp16_weights=False`` 61 | 4. To use the full LLM.int8() method, use the ``threshold=k`` argument. We recommend ``k=6.0``. 62 | ```python 63 | # LLM.int8() 64 | linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, has_fp16_weights=False, threshold=6.0) 65 | # inputs need to be fp16 66 | out = linear(x.to(torch.float16)) 67 | ``` 68 | 69 | 70 | ## Features 71 | - 8-bit Matrix multiplication with mixed precision decomposition 72 | - LLM.int8() inference 73 | - 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory) 74 | - Stable Embedding Layer: Improved stability through better initialization, and normalization 75 | - 8-bit quantization: Quantile, Linear, and Dynamic quantization 76 | - Fast quantile estimation: Up to 100x faster than other algorithms 77 | 78 | ## Requirements & Installation 79 | 80 | Requirements: anaconda, cudatoolkit, pytorch 81 | 82 | Hardware requirements: 83 | - LLM.int8(): NVIDIA Turing (RTX 20xx; T4) or Ampere GPU (RTX 30xx; A4-A100); (a GPU from 2018 or older). 84 | - 8-bit optimizers and quantization: NVIDIA Kepler GPU or newer (>=GTX 78X). 85 | 86 | Supported CUDA versions: 10.2 - 12.0 87 | 88 | The bitsandbytes library is currently only supported on Linux distributions. Windows is not supported at the moment. 89 | 90 | The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website. 91 | 92 | To install run: 93 | 94 | ``pip install bitsandbytes`` 95 | 96 | ## Using bitsandbytes 97 | 98 | ### Using Int8 Matrix Multiplication 99 | 100 | For straight Int8 matrix multiplication with mixed precision decomposition you can use ``bnb.matmul(...)``. To enable mixed precision decomposition, use the threshold parameter: 101 | ```python 102 | bnb.matmul(..., threshold=6.0) 103 | ``` 104 | 105 | For instructions how to use LLM.int8() inference layers in your own code, see the TL;DR above or for extended instruction see [this blog post](https://github.com/huggingface/transformers). 106 | 107 | ### Using the 8-bit Optimizers 108 | 109 | With bitsandbytes 8-bit optimizers can be used by changing a single line of code in your codebase. For NLP models we recommend also to use the StableEmbedding layers (see below) which improves results and helps with stable 8-bit optimization. To get started with 8-bit optimizers, it is sufficient to replace your old optimizer with the 8-bit optimizer in the following way: 110 | ```python 111 | import bitsandbytes as bnb 112 | 113 | # adam = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # comment out old optimizer 114 | adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # add bnb optimizer 115 | adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=8) # equivalent 116 | 117 | 118 | torch.nn.Embedding(...) -> bnb.nn.StableEmbedding(...) # recommended for NLP models 119 | ``` 120 | 121 | Note that by default all parameter tensors with less than 4096 elements are kept at 32-bit even if you initialize those parameters with 8-bit optimizers. This is done since such small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm). You can change this behavior like so: 122 | ``` 123 | # parameter tensors with less than 16384 values are optimized in 32-bit 124 | # it is recommended to use multiplies of 4096 125 | adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384) 126 | ``` 127 | 128 | ### Change Bits and other Hyperparameters for Individual Parameters 129 | 130 | If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details 131 | 132 | ### Fairseq Users 133 | 134 | To use the Stable Embedding Layer, override the respective `build_embedding(...)` function of your model. Make sure to also use the `--no-scale-embedding` flag to disable scaling of the word embedding layer (nor replaced with layer norm). You can use the optimizers by replacing the optimizer in the respective file (`adam.py` etc.). 135 | 136 | ## Release and Feature History 137 | 138 | For upcoming features and changes and full history see [Patch Notes](CHANGELOG.md). 139 | 140 | ## Errors 141 | 142 | 1. RuntimeError: CUDA error: no kernel image is available for execution on the device. [Solution](errors_and_solutions.md#No-kernel-image-available) 143 | 2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_) 144 | 145 | ## Compile from source 146 | To compile from source, you need an installation of CUDA. If `nvcc` is not installed, you can install the CUDA Toolkit with nvcc through the following commands. 147 | 148 | ```bash 149 | wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh 150 | # Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH 151 | # CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121} 152 | # EXPORT_TO_BASH in {0, 1} with 0=False and 1=True 153 | 154 | # For example, the following installs CUDA 11.8 to ~/local/cuda-11.8 and exports the path to your .bashrc 155 | bash cuda install 118 ~/local 1 156 | ``` 157 | 158 | To use a specific CUDA version just for a single compile run, you can set the variable `CUDA_HOME`, for example the following command compiles `libbitsandbytes_cuda117.so` using compiler flags for cuda11x with the cuda version at `~/local/cuda-11.7`: 159 | 160 | ``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x`` 161 | 162 | For more detailed instruction, please follow the [compile_from_source.md](compile_from_source.md) instructions. 163 | 164 | ## License 165 | 166 | The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license. 167 | 168 | We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization. 169 | 170 | ## How to cite us 171 | If you found this library and found LLM.int8() useful, please consider citing our work: 172 | 173 | ```bibtex 174 | @article{dettmers2022llmint8, 175 | title={LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale}, 176 | author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke}, 177 | journal={arXiv preprint arXiv:2208.07339}, 178 | year={2022} 179 | } 180 | ``` 181 | 182 | For 8-bit optimizers or quantization routines, please consider citing the following work: 183 | 184 | ```bibtex 185 | @article{dettmers2022optimizers, 186 | title={8-bit Optimizers via Block-wise Quantization}, 187 | author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke}, 188 | journal={9th International Conference on Learning Representations, ICLR}, 189 | year={2022} 190 | } 191 | ``` 192 | -------------------------------------------------------------------------------- /benchmarking/switchback/README.md: -------------------------------------------------------------------------------- 1 | Steps: 2 | 3 | 1. Run `python speed_benchmark/speed_benchmark.py` which times operations and writes their time to `speed_benchmark/info_a100_py2.jsonl` (change the name of the jsonl to a different name for your profiling). 4 | 2. Run `python speed_benchmark/make_plot_with_jsonl.py`, which produces the `speed_benchmark/plot_with_info.pdf`. Again make sure you change the jsonl which is being processed. -------------------------------------------------------------------------------- /benchmarking/switchback/make_plot_with_jsonl.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | import numpy as np 4 | import os 5 | 6 | import matplotlib.gridspec as gridspec 7 | 8 | cmap=plt.get_cmap('cool') 9 | 10 | if __name__ == '__main__': 11 | 12 | fig = plt.figure(tight_layout=True, figsize=(12,3.5)) 13 | gs = gridspec.GridSpec(1, 2) 14 | 15 | dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096] 16 | batch_size_for_plot1 = 32768 17 | batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17] 18 | dims_to_xtick = [1024, 2048, 4096] 19 | logscale_plot1 = True 20 | 21 | ax = fig.add_subplot(gs[0, 0]) 22 | 23 | # TODO: change this to what you want. 24 | rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True) 25 | df = rdf[rdf.batch_size == batch_size_for_plot1] 26 | 27 | # first plot the time occupied by different operations 28 | for k, marker, ls, color, name in [ 29 | ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), 30 | ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), 31 | 32 | ('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'), 33 | ('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'), 34 | ('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'), 35 | 36 | ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), 37 | ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), 38 | 39 | ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), 40 | ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), 41 | ('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'), 42 | ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'), 43 | ]: 44 | xs = [] 45 | ys = [] 46 | for embed_dim in dims_to_consider: 47 | # average over dim -> 4*dim and 4*dim -> dim 48 | df_ = df[df.dim_in == embed_dim] 49 | df_ = df_[df_.dim_out == embed_dim * 4] 50 | xs.append(embed_dim) 51 | y_ = 0 52 | for k_ in k.split('+'): 53 | y_ += df_[k_].values[0] 54 | df_ = df[df.dim_in == embed_dim * 4] 55 | df_ = df_[df_.dim_out == embed_dim] 56 | for k_ in k.split('+'): 57 | y_ += df_[k_].values[0] 58 | ys.append(y_ * 0.5) 59 | 60 | 61 | ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) 62 | 63 | 64 | ax.set_xlabel('dim', fontsize=13) 65 | ax.set_ylabel('time (ms)', fontsize=13) 66 | 67 | ax.grid() 68 | 69 | ax.set_xscale('log') 70 | if logscale_plot1: 71 | ax.set_yscale('log') 72 | 73 | ax.tick_params(axis='x', labelsize=11) 74 | ax.tick_params(axis='y', labelsize=11) 75 | 76 | ax.set_xticks(dims_to_xtick) 77 | ax.set_xticklabels(dims_to_xtick) 78 | ax.set_xticks([], minor=True) 79 | 80 | leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) 81 | leg.get_texts()[0].set_fontweight('bold') 82 | leg.get_texts()[1].set_fontweight('bold') 83 | plt.subplots_adjust(left=0.1) 84 | ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20) 85 | 86 | 87 | ax = fig.add_subplot(gs[0, 1]) 88 | 89 | # now plot the % speedup for different batch sizes 90 | for j, batch_size in enumerate(batch_sizes_for_plot2): 91 | all_xs, all_ys = [], [] 92 | for k, marker, ls, color, name in [ 93 | ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), 94 | ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), 95 | ]: 96 | 97 | xs, ys = [], [] 98 | df = rdf[rdf.batch_size == batch_size] 99 | for embed_dim in dims_to_consider: 100 | df_ = df[df.dim_in == embed_dim] 101 | df_ = df_[df_.dim_out == embed_dim * 4] 102 | xs.append(embed_dim) 103 | y_ = 0 104 | for k_ in k.split('+'): 105 | y_ += df_[k_].values[0] 106 | df_ = df[df.dim_in == embed_dim * 4] 107 | df_ = df_[df_.dim_out == embed_dim] 108 | for k_ in k.split('+'): 109 | y_ += df_[k_].values[0] 110 | ys.append(y_ * 0.5) 111 | all_xs.append(xs) 112 | all_ys.append(ys) 113 | 114 | color = cmap(j * 0.25) 115 | real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] 116 | markers = ['^', 'v', 'P', 'o'] 117 | ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) 118 | 119 | ax.legend() 120 | ax.set_xlabel('dim', fontsize=13) 121 | ax.set_xscale('log') 122 | ax.grid() 123 | ax.set_ylabel(r'% speedup', fontsize=13) 124 | 125 | 126 | ax.tick_params(axis='x', labelsize=11) 127 | ax.tick_params(axis='y', labelsize=11) 128 | 129 | ax.set_xticks(dims_to_xtick) 130 | ax.set_xticklabels(dims_to_xtick) 131 | ax.set_xticks([], minor=True) 132 | 133 | ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) 134 | 135 | 136 | 137 | plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight') 138 | 139 | -------------------------------------------------------------------------------- /benchmarking/switchback/plot_with_info.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jllllll/bitsandbytes/e229fbce66adde7c2a6bc58cbe7d57c1f4a0ba02/benchmarking/switchback/plot_with_info.pdf -------------------------------------------------------------------------------- /benchmarking/switchback/speed_benchmark.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | 7 | from bitsandbytes.triton.quantize_rowwise import quantize_rowwise 8 | from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose 9 | from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize 10 | from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose 11 | from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze 12 | 13 | # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. 14 | 15 | def get_time(k, fn, info_dict): 16 | 17 | for _ in range(repeat // 2): 18 | fn() 19 | 20 | torch.cuda.synchronize() 21 | start = time.time() 22 | for _ in range(repeat): 23 | fn() 24 | 25 | torch.cuda.synchronize() 26 | end = time.time() 27 | ms = (end - start) / repeat * 1000 28 | print(f"time {k}: {ms:.3f} ms") 29 | info_dict[k] = ms 30 | 31 | if __name__ == '__main__': 32 | torch.manual_seed(0) 33 | wm = 4 34 | for dim in [1024, 1280, 1408, 1664, 2048, 4096]: 35 | # note "batch_size" is actually "batch_size * embed_dim", which is why it's large 36 | for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]: 37 | 38 | # switch switches dim_in and dim_out 39 | for switch in [False, True]: 40 | 41 | # hparams 42 | repeat = 64 43 | batch_size = batch_size 44 | dim_out = dim * wm 45 | dim_in = dim 46 | if switch: 47 | dim_out = dim 48 | dim_in = wm * dim 49 | 50 | dim_in = round(dim_in) 51 | dim_out = round(dim_out) 52 | 53 | # simulate forward pass 54 | x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda() 55 | g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda() 56 | w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda() 57 | 58 | x_int8 = x.clone().to(torch.int8) 59 | g_int8 = g.clone().to(torch.int8) 60 | w_int8 = w.clone().to(torch.int8) 61 | wt_int8 = w.t().contiguous().clone().to(torch.int8) 62 | state_x_rowwise = x.max(dim=1)[0] 63 | state_g_rowwise = g.max(dim=1)[0] 64 | state_w_columnwise = w.max(dim=0)[0] 65 | state_w_rowwise = w.max(dim=1)[0] 66 | state_w_global = w.max() 67 | 68 | info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch} 69 | 70 | get_time('standard_fwd', lambda : x.matmul(w.t()), info) 71 | get_time('standard_gw', lambda : g.t().matmul(x), info) 72 | get_time('standard_gx', lambda : g.matmul(w), info) 73 | get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info) 74 | get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info) 75 | get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) 76 | get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info) 77 | get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info) 78 | get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info) 79 | get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info) 80 | get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info) 81 | get_time('w_quantize_global', lambda : quantize_global(w), info) 82 | get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info) 83 | 84 | time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw'] 85 | time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd'] 86 | time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] 87 | 88 | print('TOTAL STANDARD', time_standard) 89 | print('TOTAL ROWWISE', time_rowwise) 90 | print('TOTAL GLOBAL', time_global) 91 | 92 | print('speedup', -100*(time_global - time_standard)/time_standard) 93 | 94 | info['time_standard'] = time_standard 95 | info['time_rowwise'] = time_rowwise 96 | info['time_global'] = time_global 97 | 98 | info_json = json.dumps(info) 99 | 100 | # TODO: change this to what you want. 101 | with open("speed_benchmark/info.jsonl", "a") as file: 102 | file.write(info_json + "\n") 103 | -------------------------------------------------------------------------------- /bitsandbytes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import cuda_setup, utils, research 7 | from .autograd._functions import ( 8 | MatmulLtState, 9 | bmm_cublas, 10 | matmul, 11 | matmul_cublas, 12 | mm_cublas, 13 | matmul_4bit 14 | ) 15 | from .cextension import COMPILED_WITH_CUDA 16 | from .nn import modules 17 | 18 | if COMPILED_WITH_CUDA: 19 | from .optim import adam 20 | 21 | __pdoc__ = { 22 | "libbitsandbytes": False, 23 | "optim.optimizer.Optimizer8bit": False, 24 | "optim.optimizer.MockArgs": False, 25 | } 26 | 27 | PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" 28 | -------------------------------------------------------------------------------- /bitsandbytes/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shlex 4 | import subprocess 5 | 6 | from warnings import warn 7 | from typing import Tuple 8 | from os.path import isdir 9 | 10 | import torch 11 | 12 | HEADER_WIDTH = 60 13 | 14 | def execute_and_return(command_string: str) -> Tuple[str, str]: 15 | def _decode(subprocess_err_out_tuple): 16 | return tuple( 17 | to_decode.decode("UTF-8").strip() 18 | for to_decode in subprocess_err_out_tuple 19 | ) 20 | 21 | def execute_and_return_decoded_std_streams(command_string): 22 | return _decode( 23 | subprocess.Popen( 24 | shlex.split(command_string), 25 | stdout=subprocess.PIPE, 26 | stderr=subprocess.PIPE, 27 | ).communicate() 28 | ) 29 | 30 | std_out, std_err = execute_and_return_decoded_std_streams(command_string) 31 | return std_out, std_err 32 | 33 | def find_file_recursive(folder, filename): 34 | cmd = f'find {folder} -name {filename}' 35 | out, err = execute_and_return(cmd) 36 | if len(err) > 0: 37 | raise RuntimeError('Something when wrong when trying to find file. Maybe you do not have a linux system?') 38 | 39 | return out 40 | 41 | 42 | def generate_bug_report_information(): 43 | print_header("") 44 | print_header("BUG REPORT INFORMATION") 45 | print_header("") 46 | print('') 47 | 48 | if 'CONDA_PREFIX' in os.environ: 49 | paths = find_file_recursive(os.environ['CONDA_PREFIX'], '*cuda*so') 50 | print_header("ANACONDA CUDA PATHS") 51 | print(paths) 52 | print('') 53 | if isdir('/usr/local/'): 54 | paths = find_file_recursive('/usr/local', '*cuda*so') 55 | print_header("/usr/local CUDA PATHS") 56 | print(paths) 57 | print('') 58 | 59 | if isdir(os.getcwd()): 60 | paths = find_file_recursive(os.getcwd(), '*cuda*so') 61 | print_header("WORKING DIRECTORY CUDA PATHS") 62 | print(paths) 63 | print('') 64 | 65 | print_header("LD_LIBRARY CUDA PATHS") 66 | if 'LD_LIBRARY_PATH' in os.environ: 67 | lib_path = os.environ['LD_LIBRARY_PATH'].strip() 68 | for path in set(lib_path.split(':')): 69 | try: 70 | if isdir(path): 71 | print_header(f"{path} CUDA PATHS") 72 | paths = find_file_recursive(path, '*cuda*so') 73 | print(paths) 74 | except: 75 | print(f'Could not read LD_LIBRARY_PATH: {path}') 76 | print('') 77 | 78 | 79 | 80 | 81 | 82 | def print_header( 83 | txt: str, width: int = HEADER_WIDTH, filler: str = "+" 84 | ) -> None: 85 | txt = f" {txt} " if txt else "" 86 | print(txt.center(width, filler)) 87 | 88 | 89 | def print_debug_info() -> None: 90 | print( 91 | "\nAbove we output some debug information. Please provide this info when " 92 | f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n" 93 | ) 94 | 95 | 96 | generate_bug_report_information() 97 | 98 | 99 | from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL 100 | from .cuda_setup.env_vars import to_be_ignored 101 | from .cuda_setup.main import get_compute_capabilities 102 | 103 | 104 | print_header("OTHER") 105 | print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") 106 | print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}") 107 | print_header("") 108 | print_header("DEBUG INFO END") 109 | print_header("") 110 | print( 111 | """ 112 | Running a quick check that: 113 | + library is importable 114 | + CUDA function is callable 115 | """ 116 | ) 117 | print("\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n") 118 | 119 | try: 120 | from bitsandbytes.optim import Adam 121 | 122 | p = torch.nn.Parameter(torch.rand(10, 10).cuda()) 123 | a = torch.rand(10, 10).cuda() 124 | 125 | p1 = p.data.sum().item() 126 | 127 | adam = Adam([p]) 128 | 129 | out = a * p 130 | loss = out.sum() 131 | loss.backward() 132 | adam.step() 133 | 134 | p2 = p.data.sum().item() 135 | 136 | assert p1 != p2 137 | print("SUCCESS!") 138 | print("Installation was successful!") 139 | sys.exit(0) 140 | 141 | except ImportError: 142 | print() 143 | warn( 144 | f"WARNING: {__package__} is currently running as CPU-only!\n" 145 | "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" 146 | f"If you think that this is so erroneously,\nplease report an issue!" 147 | ) 148 | print_debug_info() 149 | sys.exit(0) 150 | except Exception as e: 151 | print(e) 152 | print_debug_info() 153 | sys.exit(1) 154 | 155 | -------------------------------------------------------------------------------- /bitsandbytes/autograd/__init__.py: -------------------------------------------------------------------------------- 1 | from ._functions import undo_layout, get_inverse_transform_indices 2 | -------------------------------------------------------------------------------- /bitsandbytes/cextension.py: -------------------------------------------------------------------------------- 1 | import ctypes as ct 2 | import os 3 | import torch 4 | 5 | from pathlib import Path 6 | from warnings import warn 7 | 8 | from bitsandbytes.cuda_setup.main import CUDASetup 9 | 10 | 11 | setup = CUDASetup.get_instance() 12 | if setup.initialized != True: 13 | setup.run_cuda_setup() 14 | 15 | lib = setup.lib 16 | try: 17 | if lib is None and torch.cuda.is_available(): 18 | CUDASetup.get_instance().generate_instructions() 19 | CUDASetup.get_instance().print_log_stack() 20 | raise RuntimeError(''' 21 | CUDA Setup failed despite GPU being available. Please run the following command to get more information: 22 | 23 | python -m bitsandbytes 24 | 25 | Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them 26 | to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes 27 | and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') 28 | lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False 29 | lib.get_context.restype = ct.c_void_p 30 | lib.get_cusparse.restype = ct.c_void_p 31 | lib.cget_managed_ptr.restype = ct.c_void_p 32 | COMPILED_WITH_CUDA = True 33 | except AttributeError as ex: 34 | warn("The installed version of bitsandbytes was compiled without GPU support. " 35 | "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") 36 | COMPILED_WITH_CUDA = False 37 | print(str(ex)) 38 | 39 | 40 | # print the setup details after checking for errors so we do not print twice 41 | #if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': 42 | #setup.print_log_stack() 43 | -------------------------------------------------------------------------------- /bitsandbytes/cuda_setup/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jllllll/bitsandbytes/e229fbce66adde7c2a6bc58cbe7d57c1f4a0ba02/bitsandbytes/cuda_setup/__init__.py -------------------------------------------------------------------------------- /bitsandbytes/cuda_setup/env_vars.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | 5 | def to_be_ignored(env_var: str, value: str) -> bool: 6 | ignorable = { 7 | "PWD", # PWD: this is how the shell keeps track of the current working dir 8 | "OLDPWD", 9 | "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated 10 | "SSH_TTY", 11 | "HOME", # Linux shell default 12 | "TMUX", # Terminal Multiplexer 13 | "XDG_DATA_DIRS", # XDG: Desktop environment stuff 14 | "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff 15 | "XDG_RUNTIME_DIR", 16 | "MAIL", # something related to emails 17 | "SHELL", # binary for currently invoked shell 18 | "DBUS_SESSION_BUS_ADDRESS", # hardware related 19 | "PATH", # this is for finding binaries, not libraries 20 | "LESSOPEN", # related to the `less` command 21 | "LESSCLOSE", 22 | "_", # current Python interpreter 23 | } 24 | return env_var in ignorable 25 | 26 | 27 | def might_contain_a_path(candidate: str) -> bool: 28 | return "/" in candidate 29 | 30 | 31 | def is_active_conda_env(env_var: str) -> bool: 32 | return "CONDA_PREFIX" == env_var 33 | 34 | 35 | def is_other_conda_env_var(env_var: str) -> bool: 36 | return "CONDA" in env_var 37 | 38 | 39 | def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: 40 | return is_active_conda_env(env_var) or ( 41 | might_contain_a_path(value) and not 42 | is_other_conda_env_var(env_var) and not 43 | to_be_ignored(env_var, value) 44 | ) 45 | 46 | 47 | def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: 48 | return { 49 | env_var: value 50 | for env_var, value in os.environ.items() 51 | if is_relevant_candidate_env_var(env_var, value) 52 | } 53 | -------------------------------------------------------------------------------- /bitsandbytes/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb 6 | from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear 7 | -------------------------------------------------------------------------------- /bitsandbytes/nn/triton_based_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import time 4 | from functools import partial 5 | 6 | from bitsandbytes.triton.triton_utils import is_triton_available 7 | 8 | from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise 9 | from bitsandbytes.triton.quantize_rowwise import quantize_rowwise 10 | from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose 11 | from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize 12 | from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose 13 | from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze 14 | 15 | 16 | class _switchback_global(torch.autograd.Function): 17 | 18 | @staticmethod 19 | def forward(ctx, X_3D, W, bias): 20 | # reshape input to [N * L, D] 21 | X = X_3D.view(-1, X_3D.size(-1)) 22 | 23 | # rowwise quantize for X, global quantize for W 24 | X_int8, state_X = quantize_rowwise(X) 25 | W_int8, state_W = quantize_global(W) 26 | 27 | # save for backward. 28 | ctx.save_for_backward = X, W 29 | 30 | # matmult, fused dequant and add bias 31 | # call "mixed" because we are mixing rowwise quantized and global quantized 32 | return int8_matmul_mixed_dequanitze( 33 | X_int8, W_int8.t(), state_X, state_W, bias 34 | ).view(*X_3D.size()[:-1], -1) 35 | 36 | @staticmethod 37 | def backward(ctx, G_3D): 38 | # reshape input to [N_out * L, D] 39 | G = G_3D.reshape(-1, G_3D.size(-1)) 40 | 41 | grad_X = grad_W = grad_bias = None 42 | 43 | X, W = ctx.save_for_backward 44 | if ctx.needs_input_grad[0]: 45 | # rowwise quantize for G, global quantize for W 46 | # for W, we also fuse the transpose operation because only A @ B^T is supported 47 | # so we transpose once then call .t() in the matmul 48 | G_int8, state_G = quantize_rowwise(G) 49 | W_int8, state_W = quantize_global_transpose(W) 50 | grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( 51 | *G_3D.size()[:-1], -1 52 | ) 53 | if ctx.needs_input_grad[1]: 54 | # backward pass uses standard weight grad 55 | grad_W = torch.matmul(G.t(), X.to(G.dtype)) 56 | if ctx.needs_input_grad[2]: 57 | grad_bias = G.sum(dim=0) 58 | 59 | return grad_X, grad_W, grad_bias 60 | 61 | class _switchback_vectorrize(torch.autograd.Function): 62 | 63 | @staticmethod 64 | def forward(ctx, X_3D, W, bias): 65 | # reshape input to [N * L, D] 66 | X = X_3D.view(-1, X_3D.size(-1)) 67 | 68 | ctx.save_for_backward = X, W 69 | # rowwise quantize for X 70 | # columnwise quantize for W (first rowwise, transpose later) 71 | X_int8, state_X = quantize_rowwise(X) 72 | W_int8, state_W = quantize_rowwise(W) 73 | 74 | # matmult, fused dequant and add bias 75 | # call kernel which expects rowwise quantized X and W 76 | return int8_matmul_rowwise_dequantize( 77 | X_int8, W_int8.t(), state_X, state_W, bias 78 | ).view(*X_3D.size()[:-1], -1) 79 | 80 | @staticmethod 81 | def backward(ctx, G_3D): 82 | X, W = ctx.save_for_backward 83 | 84 | G = G_3D.reshape(-1, G_3D.size(-1)) 85 | 86 | grad_X = grad_W = grad_bias = None 87 | 88 | if ctx.needs_input_grad[0]: 89 | # rowwise quantize for G, columnwise quantize for W and fused transpose 90 | # we call .t() for weight later because only A @ B^T is supported 91 | G_int8, state_G = quantize_rowwise(G) 92 | W_int8, state_W = quantize_columnwise_and_transpose(W) 93 | grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( 94 | *G_3D.size()[:-1], -1 95 | ) 96 | if ctx.needs_input_grad[1]: 97 | # backward pass uses standard weight grad 98 | grad_W = torch.matmul(G.t(), X.to(G.dtype)) 99 | if ctx.needs_input_grad[2]: 100 | grad_bias = G.sum(dim=0) 101 | 102 | return grad_X, grad_W, grad_bias 103 | 104 | class _switchback_global_mem_efficient(torch.autograd.Function): 105 | 106 | @staticmethod 107 | def forward(ctx, X_3D, W, bias): 108 | # reshape input to [N * L, D] 109 | X = X_3D.view(-1, X_3D.size(-1)) 110 | X_3D_sz = X_3D.size() 111 | 112 | # rowwise quantize for X, global quantize for W 113 | X_int8, state_X = quantize_rowwise(X) 114 | del X 115 | W_int8, state_W = quantize_global(W) 116 | 117 | # save for backward. 118 | ctx.save_for_backward = X_int8, state_X, W_int8, state_W 119 | 120 | # matmult, fused dequant and add bias 121 | # call "mixed" because we are mixing rowwise quantized and global quantized 122 | return int8_matmul_mixed_dequanitze( 123 | X_int8, W_int8.t(), state_X, state_W, bias 124 | ).view(*X_3D_sz[:-1], -1) 125 | 126 | @staticmethod 127 | def backward(ctx, G_3D): 128 | # reshape input to [N_out * L, D] 129 | G = G_3D.reshape(-1, G_3D.size(-1)) 130 | G_3D_sz = G_3D.size() 131 | 132 | grad_X = grad_W = grad_bias = None 133 | 134 | X_int8, state_X, W_int8, state_W = ctx.save_for_backward 135 | if ctx.needs_input_grad[1]: 136 | real_X = dequantize_rowwise(X_int8, state_X) 137 | del X_int8 138 | grad_W = torch.matmul(G.t(), real_X.to(G.dtype)) 139 | del real_X 140 | if ctx.needs_input_grad[2]: 141 | grad_bias = G.sum(dim=0) 142 | if ctx.needs_input_grad[0]: 143 | G_int8, state_G = quantize_rowwise(G) 144 | del G 145 | W_int8 = W_int8.t().contiguous() 146 | grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( 147 | *G_3D_sz[:-1], -1 148 | ) 149 | 150 | return grad_X, grad_W, grad_bias 151 | 152 | class SwitchBackLinear(nn.Linear): 153 | def __init__( 154 | self, 155 | in_features: int, 156 | out_features: int, 157 | bias: bool = True, 158 | device=None, 159 | dtype=None, 160 | vector_wise_quantization: bool = False, 161 | mem_efficient : bool = False, 162 | ): 163 | super().__init__(in_features, out_features, bias, device, dtype) 164 | 165 | if not is_triton_available: 166 | raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. 167 | Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') 168 | 169 | # By default, we use the global quantization. 170 | self.vector_wise_quantization = vector_wise_quantization 171 | if self.vector_wise_quantization: 172 | self._fn = _switchback_vectorrize 173 | if mem_efficient: 174 | print('mem efficient is not supported for vector-wise quantization.') 175 | exit(1) 176 | else: 177 | if mem_efficient: 178 | self._fn = _switchback_global_mem_efficient 179 | else: 180 | self._fn = _switchback_global 181 | 182 | def prepare_for_eval(self): 183 | # If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass. 184 | # Note this is experimental and not tested thoroughly. 185 | # Note this needs to be explicitly called with something like 186 | # def cond_prepare(m): 187 | # if hasattr(m, "prepare_for_eval"): 188 | # m.prepare_for_eval() 189 | # model.apply(cond_prepare) 190 | print('=> preparing for eval.') 191 | if self.vector_wise_quantization: 192 | W_int8, state_W = quantize_rowwise(self.weight) 193 | else: 194 | W_int8, state_W = quantize_global(self.weight) 195 | 196 | self.register_buffer("W_int8", W_int8) 197 | self.register_buffer("state_W", state_W) 198 | 199 | del self.weight 200 | 201 | def forward(self, x): 202 | if self.training: 203 | return self._fn.apply(x, self.weight, self.bias) 204 | else: 205 | # If it hasn't been "prepared for eval", run the standard forward pass. 206 | if not hasattr(self, "W_int8"): 207 | return self._fn.apply(x, self.weight, self.bias) 208 | 209 | # Otherwise, use pre-computed weights. 210 | X = x.view(-1, x.size(-1)) 211 | X_int8, state_X = quantize_rowwise(X) 212 | 213 | if self.vector_wise_quantization: 214 | return int8_matmul_rowwise_dequantize( 215 | X_int8, self.W_int8.t(), state_X, self.state_W, self.bias 216 | ).view(*x.size()[:-1], -1) 217 | else: 218 | return int8_matmul_mixed_dequanitze( 219 | X_int8, self.W_int8.t(), state_X, self.state_W, self.bias 220 | ).view(*x.size()[:-1], -1) 221 | 222 | SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) 223 | SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) 224 | SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) 225 | 226 | # This is just the standard linear function. 227 | class StandardLinearFunction(torch.autograd.Function): 228 | @staticmethod 229 | def forward(ctx, input, weight, bias=None): 230 | X = input.view(-1, input.size(-1)) 231 | 232 | ctx.save_for_backward(X, weight, bias) 233 | output = input.matmul(weight.t()) 234 | if bias is not None: 235 | output += bias.unsqueeze(0).expand_as(output) 236 | return output.view(*input.size()[:-1], -1) 237 | 238 | @staticmethod 239 | def backward(ctx, grad_output_3D): 240 | input, weight, bias = ctx.saved_tensors 241 | 242 | grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1)) 243 | 244 | grad_input = grad_weight = grad_bias = None 245 | 246 | if ctx.needs_input_grad[0]: 247 | grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1) 248 | if ctx.needs_input_grad[1]: 249 | grad_weight = grad_output.t().matmul(input.to(grad_output.dtype)) 250 | if bias is not None and ctx.needs_input_grad[2]: 251 | grad_bias = grad_output.sum(0) 252 | 253 | return grad_input, grad_weight, grad_bias 254 | 255 | class StandardLinear(nn.Linear): 256 | 257 | def forward(self, x): 258 | return StandardLinearFunction.apply(x, self.weight, self.bias) 259 | -------------------------------------------------------------------------------- /bitsandbytes/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from bitsandbytes.cextension import COMPILED_WITH_CUDA 7 | 8 | from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit 9 | from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit 10 | from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit 11 | from .lamb import LAMB, LAMB8bit, LAMB32bit 12 | from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS 13 | from .optimizer import GlobalOptimManager 14 | from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit 15 | from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit 16 | from .sgd import SGD, SGD8bit, SGD32bit 17 | -------------------------------------------------------------------------------- /bitsandbytes/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer1State 6 | 7 | 8 | class Adagrad(Optimizer1State): 9 | def __init__( 10 | self, 11 | params, 12 | lr=1e-2, 13 | lr_decay=0, 14 | weight_decay=0, 15 | initial_accumulator_value=0, 16 | eps=1e-10, 17 | optim_bits=32, 18 | args=None, 19 | min_8bit_size=4096, 20 | percentile_clipping=100, 21 | block_wise=True, 22 | ): 23 | if not 0.0 <= lr: 24 | raise ValueError(f"Invalid learning rate: {lr}") 25 | if not 0.0 <= weight_decay: 26 | raise ValueError( 27 | f"Invalid weight_decay value: {weight_decay}" 28 | ) 29 | if not 0.0 <= eps: 30 | raise ValueError(f"Invalid epsilon value: {eps}") 31 | if initial_accumulator_value != 0.0: 32 | raise ValueError("Initial accumulator value != 0.0 not supported!") 33 | if lr_decay != 0.0: 34 | raise ValueError("Lr Decay != 0.0 not supported!") 35 | super().__init__( 36 | "adagrad", 37 | params, 38 | lr, 39 | (0.0, 0.0), 40 | eps, 41 | weight_decay, 42 | optim_bits, 43 | args, 44 | min_8bit_size, 45 | percentile_clipping, 46 | block_wise, 47 | ) 48 | 49 | 50 | class Adagrad8bit(Optimizer1State): 51 | def __init__( 52 | self, 53 | params, 54 | lr=1e-2, 55 | lr_decay=0, 56 | weight_decay=0, 57 | initial_accumulator_value=0, 58 | eps=1e-10, 59 | optim_bits=8, 60 | args=None, 61 | min_8bit_size=4096, 62 | percentile_clipping=100, 63 | block_wise=True, 64 | ): 65 | if not 0.0 <= lr: 66 | raise ValueError(f"Invalid learning rate: {lr}") 67 | if not 0.0 <= weight_decay: 68 | raise ValueError( 69 | f"Invalid weight_decay value: {weight_decay}" 70 | ) 71 | if not 0.0 <= eps: 72 | raise ValueError(f"Invalid epsilon value: {eps}") 73 | if initial_accumulator_value != 0.0: 74 | raise ValueError("Initial accumulator value != 0.0 not supported!") 75 | if lr_decay != 0.0: 76 | raise ValueError("Lr Decay != 0.0 not supported!") 77 | assert block_wise 78 | super().__init__( 79 | "adagrad", 80 | params, 81 | lr, 82 | (0.0, 0.0), 83 | eps, 84 | weight_decay, 85 | 8, 86 | args, 87 | min_8bit_size, 88 | percentile_clipping, 89 | block_wise, 90 | ) 91 | 92 | 93 | class Adagrad32bit(Optimizer1State): 94 | def __init__( 95 | self, 96 | params, 97 | lr=1e-2, 98 | lr_decay=0, 99 | weight_decay=0, 100 | initial_accumulator_value=0, 101 | eps=1e-10, 102 | optim_bits=32, 103 | args=None, 104 | min_8bit_size=4096, 105 | percentile_clipping=100, 106 | block_wise=True, 107 | ): 108 | if not 0.0 <= lr: 109 | raise ValueError(f"Invalid learning rate: {lr}") 110 | if not 0.0 <= weight_decay: 111 | raise ValueError( 112 | f"Invalid weight_decay value: {weight_decay}" 113 | ) 114 | if not 0.0 <= eps: 115 | raise ValueError(f"Invalid epsilon value: {eps}") 116 | if initial_accumulator_value != 0.0: 117 | raise ValueError("Initial accumulator value != 0.0 not supported!") 118 | if lr_decay != 0.0: 119 | raise ValueError("Lr Decay != 0.0 not supported!") 120 | super().__init__( 121 | "adagrad", 122 | params, 123 | lr, 124 | (0.0, 0.0), 125 | eps, 126 | weight_decay, 127 | 32, 128 | args, 129 | min_8bit_size, 130 | percentile_clipping, 131 | block_wise, 132 | ) 133 | -------------------------------------------------------------------------------- /bitsandbytes/optim/adamw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer2State 6 | 7 | 8 | 9 | class AdamW(Optimizer2State): 10 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, 11 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 12 | super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) 13 | 14 | class AdamW8bit(Optimizer2State): 15 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, 16 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 17 | super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) 18 | 19 | class AdamW32bit(Optimizer2State): 20 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, 21 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 22 | super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) 23 | 24 | 25 | class PagedAdamW(Optimizer2State): 26 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, 27 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): 28 | super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) 29 | 30 | class PagedAdamW8bit(Optimizer2State): 31 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, 32 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): 33 | super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) 34 | 35 | class PagedAdamW32bit(Optimizer2State): 36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, 37 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): 38 | super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) 39 | 40 | -------------------------------------------------------------------------------- /bitsandbytes/optim/lamb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer2State 6 | 7 | 8 | class LAMB(Optimizer2State): 9 | def __init__( 10 | self, 11 | params, 12 | lr=1e-3, 13 | bias_correction=True, 14 | betas=(0.9, 0.999), 15 | eps=1e-8, 16 | weight_decay=0, 17 | amsgrad=False, 18 | adam_w_mode=True, 19 | optim_bits=32, 20 | args=None, 21 | min_8bit_size=4096, 22 | percentile_clipping=100, 23 | block_wise=False, 24 | max_unorm=1.0, 25 | ): 26 | super().__init__( 27 | "lamb", 28 | params, 29 | lr, 30 | betas, 31 | eps, 32 | weight_decay, 33 | optim_bits, 34 | args, 35 | min_8bit_size, 36 | percentile_clipping, 37 | block_wise, 38 | max_unorm=1.0, 39 | ) 40 | 41 | 42 | class LAMB8bit(Optimizer2State): 43 | def __init__( 44 | self, 45 | params, 46 | lr=1e-3, 47 | bias_correction=True, 48 | betas=(0.9, 0.999), 49 | eps=1e-8, 50 | weight_decay=0, 51 | amsgrad=False, 52 | adam_w_mode=True, 53 | args=None, 54 | min_8bit_size=4096, 55 | percentile_clipping=100, 56 | block_wise=False, 57 | max_unorm=1.0, 58 | ): 59 | super().__init__( 60 | "lamb", 61 | params, 62 | lr, 63 | betas, 64 | eps, 65 | weight_decay, 66 | 8, 67 | args, 68 | min_8bit_size, 69 | percentile_clipping, 70 | block_wise, 71 | max_unorm=1.0, 72 | ) 73 | 74 | 75 | class LAMB32bit(Optimizer2State): 76 | def __init__( 77 | self, 78 | params, 79 | lr=1e-3, 80 | bias_correction=True, 81 | betas=(0.9, 0.999), 82 | eps=1e-8, 83 | weight_decay=0, 84 | amsgrad=False, 85 | adam_w_mode=True, 86 | args=None, 87 | min_8bit_size=4096, 88 | percentile_clipping=100, 89 | block_wise=False, 90 | max_unorm=1.0, 91 | ): 92 | super().__init__( 93 | "lamb", 94 | params, 95 | lr, 96 | betas, 97 | eps, 98 | weight_decay, 99 | 32, 100 | args, 101 | min_8bit_size, 102 | percentile_clipping, 103 | block_wise, 104 | max_unorm=1.0, 105 | ) 106 | -------------------------------------------------------------------------------- /bitsandbytes/optim/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import torch 6 | from torch.optim import Optimizer 7 | 8 | from bitsandbytes.optim.optimizer import Optimizer1State 9 | 10 | 11 | class LARS(Optimizer1State): 12 | def __init__( 13 | self, 14 | params, 15 | lr, 16 | momentum=0, 17 | dampening=0, 18 | weight_decay=0, 19 | nesterov=False, 20 | optim_bits=32, 21 | args=None, 22 | min_8bit_size=4096, 23 | percentile_clipping=100, 24 | max_unorm=0.02, 25 | ): 26 | if momentum == 0: 27 | raise NotImplementedError( 28 | "LARS without momentum is not supported!" 29 | ) 30 | super().__init__( 31 | "lars", 32 | params, 33 | lr, 34 | (momentum, dampening), 35 | 0.0, 36 | weight_decay, 37 | optim_bits, 38 | args, 39 | min_8bit_size, 40 | percentile_clipping, 41 | max_unorm=max_unorm, 42 | block_wise=False, 43 | ) 44 | 45 | 46 | class LARS8bit(Optimizer1State): 47 | def __init__( 48 | self, 49 | params, 50 | lr, 51 | momentum=0, 52 | dampening=0, 53 | weight_decay=0, 54 | nesterov=False, 55 | args=None, 56 | min_8bit_size=4096, 57 | percentile_clipping=100, 58 | max_unorm=0.02, 59 | ): 60 | if momentum == 0: 61 | raise NotImplementedError( 62 | "LARS without momentum is not supported!" 63 | ) 64 | super().__init__( 65 | "lars", 66 | params, 67 | lr, 68 | (momentum, dampening), 69 | 0.0, 70 | weight_decay, 71 | 8, 72 | args, 73 | min_8bit_size, 74 | percentile_clipping, 75 | max_unorm=max_unorm, 76 | block_wise=False, 77 | ) 78 | 79 | 80 | class LARS32bit(Optimizer1State): 81 | def __init__( 82 | self, 83 | params, 84 | lr, 85 | momentum=0, 86 | dampening=0, 87 | weight_decay=0, 88 | nesterov=False, 89 | args=None, 90 | min_8bit_size=4096, 91 | percentile_clipping=100, 92 | max_unorm=0.02, 93 | ): 94 | if momentum == 0: 95 | raise NotImplementedError( 96 | "LARS without momentum is not supported!" 97 | ) 98 | super().__init__( 99 | "lars", 100 | params, 101 | lr, 102 | (momentum, dampening), 103 | 0.0, 104 | weight_decay, 105 | 32, 106 | args, 107 | min_8bit_size, 108 | percentile_clipping, 109 | max_unorm=max_unorm, 110 | block_wise=False, 111 | ) 112 | 113 | 114 | class PytorchLARS(Optimizer): 115 | def __init__( 116 | self, 117 | params, 118 | lr=0.01, 119 | momentum=0, 120 | dampening=0, 121 | weight_decay=0, 122 | nesterov=False, 123 | max_unorm=0.02, 124 | ): 125 | if lr < 0.0: 126 | raise ValueError(f"Invalid learning rate: {lr}") 127 | if momentum < 0.0: 128 | raise ValueError(f"Invalid momentum value: {momentum}") 129 | if weight_decay < 0.0: 130 | raise ValueError( 131 | f"Invalid weight_decay value: {weight_decay}" 132 | ) 133 | 134 | defaults = dict( 135 | lr=lr, 136 | momentum=momentum, 137 | dampening=dampening, 138 | weight_decay=weight_decay, 139 | nesterov=nesterov, 140 | max_unorm=max_unorm, 141 | ) 142 | if nesterov and (momentum <= 0 or dampening != 0): 143 | raise ValueError( 144 | "Nesterov momentum requires a momentum and zero dampening" 145 | ) 146 | super().__init__(params, defaults) 147 | 148 | def __setstate__(self, state): 149 | super().__setstate__(state) 150 | for group in self.param_groups: 151 | group.setdefault("nesterov", False) 152 | 153 | @torch.no_grad() 154 | def step(self, closure=None): 155 | """Performs a single optimization step. 156 | 157 | Args: 158 | closure (callable, optional): A closure that reevaluates the model 159 | and returns the loss. 160 | """ 161 | loss = None 162 | if closure is not None: 163 | with torch.enable_grad(): 164 | loss = closure() 165 | 166 | for group in self.param_groups: 167 | params_with_grad = [] 168 | d_p_list = [] 169 | momentum_buffer_list = [] 170 | weight_decay = group["weight_decay"] 171 | momentum = group["momentum"] 172 | dampening = group["dampening"] 173 | nesterov = group["nesterov"] 174 | max_unorm = group["max_unorm"] 175 | lr = group["lr"] 176 | 177 | for p in group["params"]: 178 | if p.grad is None: 179 | continue 180 | 181 | state = self.state[p] 182 | d_p = p.grad 183 | if weight_decay != 0: 184 | d_p = d_p.add(p, alpha=weight_decay) 185 | 186 | if momentum != 0: 187 | buf = state.get("momentum_buffer", None) 188 | 189 | if buf is None: 190 | buf = torch.clone(d_p).detach() 191 | state["momentum_buffer"] = buf 192 | else: 193 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 194 | 195 | if nesterov: 196 | update = d_p + buf * momentum 197 | else: 198 | update = buf 199 | 200 | update_scale = 1.0 201 | if max_unorm > 0.0: 202 | assert p.dtype == torch.float32 203 | pnorm = torch.norm(p.detach()) 204 | unorm = torch.norm(update) 205 | if unorm > max_unorm * pnorm: 206 | update_scale = max_unorm * pnorm / unorm 207 | 208 | p.add_(update, alpha=-lr * update_scale) 209 | 210 | return loss 211 | -------------------------------------------------------------------------------- /bitsandbytes/optim/lion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer1State 6 | 7 | class Lion(Optimizer1State): 8 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 9 | super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) 10 | 11 | class Lion8bit(Optimizer1State): 12 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 13 | super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) 14 | 15 | class Lion32bit(Optimizer1State): 16 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): 17 | super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) 18 | 19 | 20 | class PagedLion(Optimizer1State): 21 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): 22 | super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) 23 | 24 | class PagedLion8bit(Optimizer1State): 25 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): 26 | super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) 27 | 28 | class PagedLion32bit(Optimizer1State): 29 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): 30 | super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) 31 | -------------------------------------------------------------------------------- /bitsandbytes/optim/rmsprop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer1State 6 | 7 | 8 | class RMSprop(Optimizer1State): 9 | def __init__( 10 | self, 11 | params, 12 | lr=1e-2, 13 | alpha=0.99, 14 | eps=1e-8, 15 | weight_decay=0, 16 | momentum=0, 17 | centered=False, 18 | optim_bits=32, 19 | args=None, 20 | min_8bit_size=4096, 21 | percentile_clipping=100, 22 | block_wise=True, 23 | ): 24 | if alpha == 0: 25 | raise NotImplementedError( 26 | "RMSprop with alpha==0.0 is not supported!" 27 | ) 28 | if centered: 29 | raise NotImplementedError("Centered RMSprop is not supported!") 30 | super().__init__( 31 | "rmsprop", 32 | params, 33 | lr, 34 | (alpha, momentum), 35 | eps, 36 | weight_decay, 37 | optim_bits, 38 | args, 39 | min_8bit_size, 40 | percentile_clipping, 41 | block_wise, 42 | ) 43 | 44 | 45 | class RMSprop8bit(Optimizer1State): 46 | def __init__( 47 | self, 48 | params, 49 | lr=1e-2, 50 | alpha=0.99, 51 | eps=1e-8, 52 | weight_decay=0, 53 | momentum=0, 54 | centered=False, 55 | args=None, 56 | min_8bit_size=4096, 57 | percentile_clipping=100, 58 | block_wise=True, 59 | ): 60 | if alpha == 0: 61 | raise NotImplementedError( 62 | "RMSprop with alpha==0.0 is not supported!" 63 | ) 64 | if centered: 65 | raise NotImplementedError("Centered RMSprop is not supported!") 66 | super().__init__( 67 | "rmsprop", 68 | params, 69 | lr, 70 | (alpha, momentum), 71 | eps, 72 | weight_decay, 73 | 8, 74 | args, 75 | min_8bit_size, 76 | percentile_clipping, 77 | block_wise, 78 | ) 79 | 80 | 81 | class RMSprop32bit(Optimizer1State): 82 | def __init__( 83 | self, 84 | params, 85 | lr=1e-2, 86 | alpha=0.99, 87 | eps=1e-8, 88 | weight_decay=0, 89 | momentum=0, 90 | centered=False, 91 | args=None, 92 | min_8bit_size=4096, 93 | percentile_clipping=100, 94 | block_wise=True, 95 | ): 96 | 97 | if alpha == 0: 98 | raise NotImplementedError( 99 | "RMSprop with alpha==0.0 is not supported!" 100 | ) 101 | if centered: 102 | raise NotImplementedError("Centered RMSprop is not supported!") 103 | super().__init__( 104 | "rmsprop", 105 | params, 106 | lr, 107 | (alpha, momentum), 108 | eps, 109 | weight_decay, 110 | 32, 111 | args, 112 | min_8bit_size, 113 | percentile_clipping, 114 | block_wise, 115 | ) 116 | -------------------------------------------------------------------------------- /bitsandbytes/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from bitsandbytes.optim.optimizer import Optimizer1State 6 | 7 | 8 | class SGD(Optimizer1State): 9 | def __init__( 10 | self, 11 | params, 12 | lr, 13 | momentum=0, 14 | dampening=0, 15 | weight_decay=0, 16 | nesterov=False, 17 | optim_bits=32, 18 | args=None, 19 | min_8bit_size=4096, 20 | percentile_clipping=100, 21 | block_wise=True, 22 | ): 23 | if momentum == 0: 24 | raise NotImplementedError("SGD without momentum is not supported!") 25 | super().__init__( 26 | "momentum", 27 | params, 28 | lr, 29 | (momentum, dampening), 30 | 0.0, 31 | weight_decay, 32 | optim_bits, 33 | args, 34 | min_8bit_size, 35 | percentile_clipping, 36 | block_wise, 37 | ) 38 | 39 | 40 | class SGD8bit(Optimizer1State): 41 | def __init__( 42 | self, 43 | params, 44 | lr, 45 | momentum=0, 46 | dampening=0, 47 | weight_decay=0, 48 | nesterov=False, 49 | args=None, 50 | min_8bit_size=4096, 51 | percentile_clipping=100, 52 | block_wise=True, 53 | ): 54 | if momentum == 0: 55 | raise NotImplementedError("SGD without momentum is not supported!") 56 | super().__init__( 57 | "momentum", 58 | params, 59 | lr, 60 | (momentum, dampening), 61 | 0.0, 62 | weight_decay, 63 | 8, 64 | args, 65 | min_8bit_size, 66 | percentile_clipping, 67 | block_wise, 68 | ) 69 | 70 | 71 | class SGD32bit(Optimizer1State): 72 | def __init__( 73 | self, 74 | params, 75 | lr, 76 | momentum=0, 77 | dampening=0, 78 | weight_decay=0, 79 | nesterov=False, 80 | args=None, 81 | min_8bit_size=4096, 82 | percentile_clipping=100, 83 | block_wise=True, 84 | ): 85 | if momentum == 0: 86 | raise NotImplementedError("SGD without momentum is not supported!") 87 | super().__init__( 88 | "momentum", 89 | params, 90 | lr, 91 | (momentum, dampening), 92 | 0.0, 93 | weight_decay, 94 | 32, 95 | args, 96 | min_8bit_size, 97 | percentile_clipping, 98 | block_wise, 99 | ) 100 | -------------------------------------------------------------------------------- /bitsandbytes/research/__init__.py: -------------------------------------------------------------------------------- 1 | from . import nn 2 | from .autograd._functions import ( 3 | switchback_bnb, 4 | matmul_fp8_global, 5 | matmul_fp8_mixed, 6 | ) 7 | -------------------------------------------------------------------------------- /bitsandbytes/research/autograd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jllllll/bitsandbytes/e229fbce66adde7c2a6bc58cbe7d57c1f4a0ba02/bitsandbytes/research/autograd/__init__.py -------------------------------------------------------------------------------- /bitsandbytes/research/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import LinearFP8Mixed, LinearFP8Global 2 | -------------------------------------------------------------------------------- /bitsandbytes/research/nn/modules.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, TypeVar, Union, overload 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor, device, dtype, nn 6 | 7 | import bitsandbytes as bnb 8 | from bitsandbytes.optim import GlobalOptimManager 9 | from bitsandbytes.utils import OutlierTracer, find_outlier_dims 10 | 11 | T = TypeVar("T", bound="torch.nn.Module") 12 | 13 | 14 | class LinearFP8Mixed(nn.Linear): 15 | def __init__(self, input_features, output_features, bias=True): 16 | super().__init__(input_features, output_features, bias) 17 | self.bw_code = None 18 | self.fw_code = None 19 | array = [4096, 2048, 1024, 512, 256, 128, 64, 0] 20 | for i, k in enumerate(array): 21 | if input_features > array[i + 1]: 22 | self.bsz = k 23 | break 24 | for i, k in enumerate(array): 25 | if output_features > array[i + 1]: 26 | self.bsz2 = k 27 | break 28 | 29 | def forward(self, x: torch.Tensor): 30 | if self.fw_code is None: 31 | self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) 32 | self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) 33 | 34 | out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) 35 | if self.bias is not None: 36 | out += self.bias 37 | 38 | return out 39 | 40 | class LinearFP8Global(nn.Linear): 41 | def __init__(self, input_features, output_features, bias=True): 42 | super().__init__(input_features, output_features, bias) 43 | self.bw_code = None 44 | self.fw_code = None 45 | array = [4096, 2048, 1024, 512, 256, 128, 64, 0] 46 | for i, k in enumerate(array): 47 | if input_features > array[i + 1]: 48 | self.bsz = k 49 | break 50 | for i, k in enumerate(array): 51 | if output_features > array[i + 1]: 52 | self.bsz2 = k 53 | break 54 | 55 | def forward(self, x: torch.Tensor): 56 | if self.fw_code is None: 57 | self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) 58 | self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) 59 | 60 | out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) 61 | if self.bias is not None: 62 | out += self.bias 63 | 64 | return out 65 | -------------------------------------------------------------------------------- /bitsandbytes/triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jllllll/bitsandbytes/e229fbce66adde7c2a6bc58cbe7d57c1f4a0ba02/bitsandbytes/triton/__init__.py -------------------------------------------------------------------------------- /bitsandbytes/triton/dequantize_rowwise.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import time 4 | from bitsandbytes.triton.triton_utils import is_triton_available 5 | 6 | if not is_triton_available(): 7 | def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None 8 | else: 9 | 10 | import triton 11 | import triton.language as tl 12 | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time 13 | 14 | # rowwise quantize 15 | 16 | # TODO: autotune this better. 17 | @triton.autotune( 18 | configs=[ 19 | triton.Config({}, num_stages=1, num_warps=8), 20 | triton.Config({}, num_stages=2, num_warps=8), 21 | triton.Config({}, num_stages=4, num_warps=8), 22 | triton.Config({}, num_stages=8, num_warps=8), 23 | triton.Config({}, num_stages=1), 24 | triton.Config({}, num_stages=2), 25 | triton.Config({}, num_stages=4), 26 | triton.Config({}, num_stages=8), 27 | triton.Config({}, num_warps=1), 28 | triton.Config({}, num_warps=2), 29 | triton.Config({}, num_warps=4), 30 | triton.Config({}, num_warps=8), 31 | ], 32 | key=['n_elements'] 33 | ) 34 | @triton.jit 35 | def _dequantize_rowwise( 36 | x_ptr, 37 | state_x, 38 | output_ptr, 39 | inv_127, 40 | n_elements, 41 | BLOCK_SIZE: tl.constexpr, 42 | P2: tl.constexpr, 43 | ): 44 | pid = tl.program_id(axis=0) 45 | block_start = pid * BLOCK_SIZE 46 | arange = tl.arange(0, P2) 47 | offsets = block_start + arange 48 | row_mask = arange < BLOCK_SIZE 49 | x = tl.load(x_ptr + offsets, mask=row_mask) 50 | max_val = tl.load(state_x + pid) 51 | output = max_val * x * inv_127 52 | tl.store(output_ptr + offsets, output, mask=row_mask) 53 | 54 | 55 | def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): 56 | output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) 57 | 58 | P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) 59 | 60 | assert x.is_cuda and output.is_cuda 61 | n_elements = output.numel() 62 | grid = lambda meta: (x.shape[0],) 63 | _dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) 64 | return output 65 | -------------------------------------------------------------------------------- /bitsandbytes/triton/int8_matmul_mixed_dequanitze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bitsandbytes.triton.triton_utils import is_triton_available 3 | 4 | if not is_triton_available(): 5 | def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None 6 | else: 7 | 8 | import triton 9 | import triton.language as tl 10 | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time 11 | 12 | 13 | # This is a matmul kernel based on triton.ops.matmul 14 | # It is modified to support rowwise quantized input and global quantized weight 15 | # It's purpose is fused matmul then dequantize 16 | # It does support bias. 17 | 18 | def init_to_zero(name): 19 | return lambda nargs: nargs[name].zero_() 20 | 21 | def get_configs_io_bound(): 22 | configs = [] 23 | for num_stages in [2, 3, 4, 5, 6]: 24 | for block_m in [16, 32]: 25 | for block_k in [32, 64]: 26 | for block_n in [32, 64, 128, 256]: 27 | num_warps = 2 if block_n <= 64 else 4 28 | configs.append( 29 | triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, 30 | num_stages=num_stages, num_warps=num_warps)) 31 | # split_k 32 | for split_k in [2, 4, 8, 16]: 33 | configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, 34 | num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) 35 | return configs 36 | 37 | 38 | @triton.autotune( 39 | configs=[ 40 | # basic configs for compute-bound matmuls 41 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 42 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 43 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 44 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 45 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 46 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 47 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 48 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 49 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), 50 | # good for int8 51 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 52 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 53 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 54 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 55 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 56 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 57 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 58 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 59 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), 60 | ] + get_configs_io_bound(), 61 | key=['M', 'N', 'K'], 62 | prune_configs_by={ 63 | 'early_config_prune': early_config_prune, 64 | 'perf_model': estimate_matmul_time, 65 | 'top_k': 10 66 | }, 67 | ) 68 | @triton.heuristics({ 69 | 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, 70 | }) 71 | @triton.jit 72 | def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, 73 | stride_am, stride_ak, 74 | stride_bk, stride_bn, 75 | stride_cm, stride_cn, 76 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 77 | GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, 78 | ACC_TYPE: tl.constexpr 79 | ): 80 | # matrix multiplication 81 | pid = tl.program_id(0) 82 | pid_z = tl.program_id(1) 83 | grid_m = tl.cdiv(M, BLOCK_M) 84 | grid_n = tl.cdiv(N, BLOCK_N) 85 | # re-order program ID for better L2 performance 86 | width = GROUP_M * grid_n 87 | group_id = pid // width 88 | group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 89 | pid_m = group_id * GROUP_M + (pid % group_size) 90 | pid_n = (pid % width) // (group_size) 91 | # do matrix multiplication 92 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 93 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 94 | ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 95 | rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 96 | rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 97 | # pointers 98 | A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 99 | B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 100 | 101 | # rematerialize rm and rn to save registers 102 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 103 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 104 | 105 | w_factor = tl.load(state_w_ptr) 106 | x_factor = tl.load(state_x_ptr + ram)[:, None] 107 | 108 | # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) 109 | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) 110 | for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): 111 | if EVEN_K: 112 | a = tl.load(A) 113 | b = tl.load(B) 114 | else: 115 | k_remaining = K - k * (BLOCK_K * SPLIT_K) 116 | a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) 117 | b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) 118 | acc += tl.dot(a, b) 119 | A += BLOCK_K * SPLIT_K * stride_ak 120 | B += BLOCK_K * SPLIT_K * stride_bk 121 | 122 | acc = (w_factor * (x_factor * (acc * divfactor))) 123 | acc = acc.to(C.dtype.element_ty) 124 | 125 | # conditionally add bias 126 | if has_bias: 127 | bias = tl.load(bias + rn).to(C.dtype.element_ty) 128 | acc = acc + bias[None, :] 129 | 130 | C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 131 | mask = (rm < M)[:, None] & (rn < N)[None, :] 132 | # handles write-back with reduction-splitting 133 | if SPLIT_K == 1: 134 | tl.store(C, acc, mask=mask) 135 | else: 136 | tl.atomic_add(C, acc, mask=mask) 137 | 138 | 139 | def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): 140 | device = a.device 141 | divfactor = 1. / (127. * 127.) 142 | has_bias = 0 if bias is None else 1 143 | # handle non-contiguous inputs if necessary 144 | if a.stride(0) > 1 and a.stride(1) > 1: 145 | a = a.contiguous() 146 | if b.stride(0) > 1 and b.stride(1) > 1: 147 | b = b.contiguous() 148 | # checks constraints 149 | assert a.shape[1] == b.shape[0], "incompatible dimensions" 150 | M, K = a.shape 151 | _, N = b.shape 152 | # allocates output 153 | c = torch.empty((M, N), device=device, dtype=torch.float16) 154 | # accumulator types 155 | ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 156 | # launch int8_matmul_mixed_dequantize kernel 157 | grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) 158 | _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, 159 | a.stride(0), a.stride(1), 160 | b.stride(0), b.stride(1), 161 | c.stride(0), c.stride(1), 162 | GROUP_M=8, ACC_TYPE=ACC_TYPE) 163 | return c 164 | -------------------------------------------------------------------------------- /bitsandbytes/triton/int8_matmul_rowwise_dequantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from bitsandbytes.triton.triton_utils import is_triton_available 4 | 5 | if not is_triton_available(): 6 | def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None 7 | else: 8 | import triton 9 | import triton.language as tl 10 | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time 11 | 12 | # This is a matmul kernel based on triton.ops.matmul 13 | # It is modified to support rowwise quantized input and columnwise quantized weight 14 | # It's purpose is fused matmul then dequantize 15 | # It does support bias. 16 | 17 | def init_to_zero(name): 18 | return lambda nargs: nargs[name].zero_() 19 | 20 | 21 | def get_configs_io_bound(): 22 | configs = [] 23 | for num_stages in [2, 3, 4, 5, 6]: 24 | for block_m in [16, 32]: 25 | for block_k in [32, 64]: 26 | for block_n in [32, 64, 128, 256]: 27 | num_warps = 2 if block_n <= 64 else 4 28 | configs.append( 29 | triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, 30 | num_stages=num_stages, num_warps=num_warps)) 31 | # split_k 32 | for split_k in [2, 4, 8, 16]: 33 | configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, 34 | num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) 35 | return configs 36 | 37 | 38 | @triton.autotune( 39 | configs=[ 40 | # basic configs for compute-bound matmuls 41 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 42 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 43 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 44 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 45 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 46 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 47 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 48 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 49 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), 50 | # good for int8 51 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 52 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), 53 | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 54 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 55 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 56 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 57 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 58 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), 59 | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), 60 | ] + get_configs_io_bound(), 61 | key=['M', 'N', 'K'], 62 | prune_configs_by={ 63 | 'early_config_prune': early_config_prune, 64 | 'perf_model': estimate_matmul_time, 65 | 'top_k': 10 66 | }, 67 | ) 68 | @triton.heuristics({ 69 | 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, 70 | }) 71 | @triton.jit 72 | def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, 73 | stride_am, stride_ak, 74 | stride_bk, stride_bn, 75 | stride_cm, stride_cn, 76 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 77 | GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, 78 | ACC_TYPE: tl.constexpr 79 | ): 80 | # matrix multiplication 81 | pid = tl.program_id(0) 82 | pid_z = tl.program_id(1) 83 | grid_m = tl.cdiv(M, BLOCK_M) 84 | grid_n = tl.cdiv(N, BLOCK_N) 85 | # re-order program ID for better L2 performance 86 | width = GROUP_M * grid_n 87 | group_id = pid // width 88 | group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 89 | pid_m = group_id * GROUP_M + (pid % group_size) 90 | pid_n = (pid % width) // (group_size) 91 | # do matrix multiplication 92 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 93 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 94 | ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 95 | rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 96 | rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 97 | # pointers 98 | A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 99 | B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 100 | 101 | # rematerialize rm and rn to save registers 102 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 103 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 104 | 105 | w_factor = tl.load(state_w_ptr + rbn)[None, :] 106 | x_factor = tl.load(state_x_ptr + ram)[:, None] 107 | 108 | # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) 109 | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) 110 | for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): 111 | if EVEN_K: 112 | a = tl.load(A) 113 | b = tl.load(B) 114 | else: 115 | k_remaining = K - k * (BLOCK_K * SPLIT_K) 116 | a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) 117 | b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) 118 | acc += tl.dot(a, b) 119 | A += BLOCK_K * SPLIT_K * stride_ak 120 | B += BLOCK_K * SPLIT_K * stride_bk 121 | 122 | acc = (w_factor * (x_factor * (acc * divfactor))) 123 | acc = acc.to(C.dtype.element_ty) 124 | 125 | if has_bias: 126 | bias = tl.load(bias + rn).to(C.dtype.element_ty) 127 | acc = acc + bias[None, :] 128 | 129 | C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 130 | mask = (rm < M)[:, None] & (rn < N)[None, :] 131 | # handles write-back with reduction-splitting 132 | if SPLIT_K == 1: 133 | tl.store(C, acc, mask=mask) 134 | else: 135 | tl.atomic_add(C, acc, mask=mask) 136 | 137 | 138 | def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): 139 | divfactor = 1. / (127. * 127.) 140 | 141 | has_bias = 0 if bias is None else 1 142 | 143 | device = a.device 144 | # handle non-contiguous inputs if necessary 145 | if a.stride(0) > 1 and a.stride(1) > 1: 146 | a = a.contiguous() 147 | if b.stride(0) > 1 and b.stride(1) > 1: 148 | b = b.contiguous() 149 | # checks constraints 150 | assert a.shape[1] == b.shape[0], "incompatible dimensions" 151 | M, K = a.shape 152 | _, N = b.shape 153 | # allocates output 154 | c = torch.empty((M, N), device=device, dtype=torch.float16) 155 | # accumulator types 156 | ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 157 | # launch int8_matmul_rowwise_dequantize kernel 158 | grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) 159 | _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, 160 | a.stride(0), a.stride(1), 161 | b.stride(0), b.stride(1), 162 | c.stride(0), c.stride(1), 163 | GROUP_M=8, ACC_TYPE=ACC_TYPE) 164 | return c 165 | -------------------------------------------------------------------------------- /bitsandbytes/triton/quantize_columnwise_and_transpose.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import time 4 | from bitsandbytes.triton.triton_utils import is_triton_available 5 | 6 | if not is_triton_available(): 7 | def quantize_columnwise_and_transpose(x: torch.Tensor): return None 8 | else: 9 | 10 | import triton 11 | import triton.language as tl 12 | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time 13 | 14 | # This kernel does fused columnwise quantization and transpose. 15 | 16 | # TODO: autotune this better. 17 | @triton.autotune( 18 | configs=[ 19 | triton.Config({}, num_stages=1), 20 | triton.Config({}, num_stages=2), 21 | triton.Config({}, num_stages=4), 22 | triton.Config({}, num_stages=8), 23 | triton.Config({}, num_stages=16), 24 | triton.Config({}, num_stages=1, num_warps=8), 25 | triton.Config({}, num_stages=2, num_warps=8), 26 | triton.Config({}, num_stages=4, num_warps=8), 27 | triton.Config({}, num_stages=8, num_warps=8), 28 | triton.Config({}, num_stages=16, num_warps=8), 29 | triton.Config({}, num_warps=1), 30 | triton.Config({}, num_warps=2), 31 | triton.Config({}, num_warps=4), 32 | triton.Config({}, num_warps=8), 33 | ], 34 | key=['n_elements'] 35 | ) 36 | @triton.jit 37 | def _quantize_columnwise_and_transpose( 38 | x_ptr, 39 | output_ptr, 40 | output_maxs, 41 | n_elements, 42 | M : tl.constexpr, N : tl.constexpr, 43 | BLOCK_SIZE: tl.constexpr, 44 | P2: tl.constexpr, 45 | ): 46 | pid = tl.program_id(axis=0) 47 | block_start = pid 48 | p2_arange = tl.arange(0, P2) 49 | p2_arange_mask = p2_arange < M 50 | arange = p2_arange * N 51 | offsets = block_start + arange 52 | x = tl.load(x_ptr + offsets, mask=p2_arange_mask) 53 | abs_x = tl.abs(x) 54 | max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) 55 | output = tl.libdevice.llrint(127. * (x / max_val)) 56 | 57 | new_start = pid * M 58 | new_offsets = new_start + p2_arange 59 | tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) 60 | tl.store(output_maxs + pid, max_val) 61 | 62 | def quantize_columnwise_and_transpose(x: torch.Tensor): 63 | M, N = x.shape 64 | output = torch.empty(N, M, device=x.device, dtype=torch.int8) 65 | output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) 66 | 67 | P2 = int(2 ** (math.ceil(math.log2(M)))) 68 | 69 | assert x.is_cuda and output.is_cuda 70 | n_elements = output.numel() 71 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 72 | _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) 73 | return output, output_maxs 74 | 75 | -------------------------------------------------------------------------------- /bitsandbytes/triton/quantize_global.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import time 4 | from bitsandbytes.triton.triton_utils import is_triton_available 5 | 6 | if not is_triton_available(): 7 | def quantize_global_transpose(input): return None 8 | def quantize_global(x: torch.Tensor): return None 9 | else: 10 | 11 | import triton 12 | import triton.language as tl 13 | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time 14 | 15 | # global quantize 16 | @triton.autotune( 17 | configs=[ 18 | triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), 19 | triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), 20 | 21 | ], 22 | key=['n_elements'] 23 | ) 24 | @triton.jit 25 | def _quantize_global( 26 | x_ptr, 27 | absmax_inv_ptr, 28 | output_ptr, 29 | n_elements, 30 | BLOCK_SIZE: tl.constexpr, 31 | ): 32 | pid = tl.program_id(axis=0) 33 | block_start = pid * BLOCK_SIZE 34 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 35 | mask = offsets < n_elements 36 | x = tl.load(x_ptr + offsets, mask=mask) 37 | absmax_inv = tl.load(absmax_inv_ptr) 38 | output = tl.libdevice.llrint(127. * (x * absmax_inv)) 39 | tl.store(output_ptr + offsets, output, mask=mask) 40 | 41 | def quantize_global(x: torch.Tensor): 42 | absmax = x.abs().max().unsqueeze(0) 43 | absmax_inv = 1./ absmax 44 | output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) 45 | assert x.is_cuda and output.is_cuda 46 | n_elements = output.numel() 47 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 48 | _quantize_global[grid](x, absmax_inv, output, n_elements) 49 | return output, absmax 50 | 51 | 52 | # global quantize and transpose 53 | @triton.autotune( 54 | configs=[ 55 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), 56 | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), 57 | 58 | # ... 59 | ], 60 | key=['M', 'N'] 61 | ) 62 | @triton.jit 63 | def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, 64 | BLOCK_M : tl.constexpr, 65 | BLOCK_N : tl.constexpr, 66 | GROUP_M : tl.constexpr): 67 | pid = tl.program_id(0) 68 | grid_m = (M + BLOCK_M - 1) // BLOCK_M 69 | grid_n = (N + BLOCK_N - 1) // BLOCK_N 70 | 71 | width = GROUP_M * grid_n 72 | group_id = pid // width 73 | group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 74 | pid_m = group_id * GROUP_M + (pid % group_size) 75 | pid_n = (pid % width) // group_size 76 | 77 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 78 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 79 | A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) 80 | mask = (rm < M)[:, None] & (rn < N)[None, :] 81 | a = tl.load(A, mask=mask) 82 | absmax_inv = tl.load(absmax_inv_ptr) 83 | 84 | # rematerialize to save registers 85 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 86 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 87 | B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) 88 | mask = (rm < M)[:, None] & (rn < N)[None, :] 89 | 90 | output = tl.libdevice.llrint(127. * (a * absmax_inv)) 91 | 92 | tl.store(B, output, mask=mask) 93 | 94 | def quantize_global_transpose(input): 95 | absmax = input.abs().max().unsqueeze(0) 96 | absmax_inv = 1./ absmax 97 | M, N = input.shape 98 | out = torch.empty(N, M, device='cuda', dtype=torch.int8) 99 | 100 | assert out.size(0) == N and out.size(1) == M 101 | assert input.stride(0) == 1 or input.stride(1) == 1 102 | assert out.stride(0) == 1 or out.stride(1) == 1 103 | 104 | grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) 105 | _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) 106 | return out, absmax 107 | 108 | -------------------------------------------------------------------------------- /bitsandbytes/triton/quantize_rowwise.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import time 4 | 5 | from bitsandbytes.triton.triton_utils import is_triton_available 6 | 7 | if not is_triton_available(): 8 | def quantize_rowwise(x: torch.Tensor): return None 9 | else: 10 | 11 | import triton 12 | import triton.language as tl 13 | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time 14 | 15 | # rowwise quantize 16 | 17 | # TODO: autotune this better. 18 | @triton.autotune( 19 | configs=[ 20 | triton.Config({}, num_stages=1, num_warps=8), 21 | triton.Config({}, num_stages=2, num_warps=8), 22 | triton.Config({}, num_stages=4, num_warps=8), 23 | triton.Config({}, num_stages=8, num_warps=8), 24 | triton.Config({}, num_stages=1), 25 | triton.Config({}, num_stages=2), 26 | triton.Config({}, num_stages=4), 27 | triton.Config({}, num_stages=8), 28 | triton.Config({}, num_warps=1), 29 | triton.Config({}, num_warps=2), 30 | triton.Config({}, num_warps=4), 31 | triton.Config({}, num_warps=8), 32 | ], 33 | key=['n_elements'] 34 | ) 35 | @triton.jit 36 | def _quantize_rowwise( 37 | x_ptr, 38 | output_ptr, 39 | output_maxs, 40 | n_elements, 41 | BLOCK_SIZE: tl.constexpr, 42 | P2: tl.constexpr, 43 | ): 44 | pid = tl.program_id(axis=0) 45 | block_start = pid * BLOCK_SIZE 46 | arange = tl.arange(0, P2) 47 | offsets = block_start + arange 48 | row_mask = arange < BLOCK_SIZE 49 | x = tl.load(x_ptr + offsets, mask=row_mask) 50 | 51 | abs_x = tl.abs(x) 52 | max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) 53 | output = tl.libdevice.llrint(127. * (x / max_val)) 54 | tl.store(output_ptr + offsets, output, mask=row_mask) 55 | tl.store(output_maxs + pid, max_val) 56 | 57 | def quantize_rowwise(x: torch.Tensor): 58 | output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) 59 | output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) 60 | 61 | P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) 62 | 63 | assert x.is_cuda and output.is_cuda 64 | n_elements = output.numel() 65 | grid = lambda meta: (x.shape[0],) 66 | _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) 67 | return output, output_maxs 68 | 69 | -------------------------------------------------------------------------------- /bitsandbytes/triton/triton_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | def is_triton_available(): 4 | return importlib.util.find_spec("triton") is not None 5 | -------------------------------------------------------------------------------- /bitsandbytes/utils.py: -------------------------------------------------------------------------------- 1 | import shlex 2 | import subprocess 3 | import torch 4 | from typing import Tuple 5 | 6 | def outlier_hook(module, input): 7 | assert isinstance(module, torch.nn.Linear) 8 | tracer = OutlierTracer.get_instance() 9 | hvalue = tracer.get_hvalue(module.weight) 10 | if hvalue not in tracer.hvalue2outlier_idx: 11 | outlier_idx = find_outlier_dims(module.weight) 12 | tracer.outliers.append(outlier_idx) 13 | tracer.hvalues.append(hvalue) 14 | if len(tracer.outliers) > 1: 15 | # assign the current layer the outlier idx found from the weight 16 | # of the previous linear layer 17 | if tracer.outliers[-1].numel() > 0: 18 | assert tracer.outliers[-1].max() < module.weight.shape[1] 19 | tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1] 20 | 21 | else: 22 | # first layer, we cannot use the weight for outlier detection 23 | # we follow a mixed approach: 24 | # (1) zscore test of std of hidden dimension 25 | # (2) magnitude > 6 test 26 | merged = input[0].view(-1, input[0].shape[-1]) 27 | # (1) zscore test of std of hidden dimension 28 | outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) 29 | # (2) magnitude > 6 test 30 | dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1))) 31 | outlier_idx2 = torch.where(dims > 0)[0] 32 | outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() 33 | tracer.hvalue2outlier_idx[hvalue] = outlier_idx 34 | else: 35 | for hook in tracer.hooks: 36 | hook.remove() 37 | 38 | 39 | class OutlierTracer(object): 40 | _instance = None 41 | 42 | def __init__(self): 43 | raise RuntimeError("Call get_instance() instead") 44 | 45 | def initialize(self, model): 46 | self.last_w = None 47 | self.current_outlier_dims = None 48 | self.hvalues = [] 49 | self.outliers = [] 50 | self.hvalue2outlier_idx = {} 51 | self.initialized = True 52 | self.hooks = [] 53 | 54 | for n, m in model.named_modules(): 55 | if isinstance(m, torch.nn.Linear): 56 | self.hooks.append(m.register_forward_pre_hook(outlier_hook)) 57 | 58 | def is_initialized(self): 59 | return getattr(self, 'initialized', False) 60 | 61 | def get_hvalue(self, weight): 62 | return weight.data.storage().data_ptr() 63 | 64 | def get_outliers(self, weight): 65 | if not self.is_initialized(): 66 | print('Outlier tracer is not initialized...') 67 | return None 68 | hvalue = self.get_hvalue(weight) 69 | if hvalue in self.hvalue2outlier_idx: 70 | return self.hvalue2outlier_idx[hvalue] 71 | else: 72 | return None 73 | 74 | @classmethod 75 | def get_instance(cls): 76 | if cls._instance is None: 77 | cls._instance = cls.__new__(cls) 78 | return cls._instance 79 | 80 | def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): 81 | if rdm: 82 | return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() 83 | 84 | m = weight.mean(reduction_dim) 85 | mm = m.mean() 86 | mstd = m.std() 87 | zm = (m-mm)/mstd 88 | 89 | std = weight.std(reduction_dim) 90 | stdm = std.mean() 91 | stdstd = std.std() 92 | 93 | zstd = (std-stdm)/stdstd 94 | 95 | if topk is not None: 96 | val, idx = torch.topk(std.abs(), k=topk, dim=0) 97 | else: 98 | idx = torch.where(zstd > zscore)[0] 99 | 100 | return idx 101 | 102 | def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): 103 | """ 104 | Replace linear modules with a new Linear module. 105 | 106 | Parameters: 107 | model (`torch.nn.Module`): 108 | Input model or `torch.nn.Module` as the function is run recursively. 109 | linear_replacement (`torch.nn.Module`): 110 | The linear module that replaces the old one. Only expects standard arguments. 111 | If other arguments need to be passed, use a lambda. 112 | skip_modules (`List[str]`, *optional*, defaults to `lm_head`): 113 | List of modules names not to convert. Defaults to `lm_head`. 114 | copy_weights (`bool`): 115 | Copy the weights from the old linear module to the new one 116 | post_processing_fun_name (`str`): 117 | A function name of the replacement linear class that is called 118 | after processing. 119 | """ 120 | for name, module in model.named_children(): 121 | if len(list(module.children())) > 0: 122 | replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) 123 | 124 | if isinstance(module, torch.nn.Linear) and name not in skip_modules: 125 | old_module = model._modules[name] 126 | model._modules[name] = linear_replacement( 127 | module.in_features, 128 | module.out_features, 129 | module.bias is not None, 130 | ) 131 | if copy_weights: 132 | model._modules[name].weight = old_module.weight 133 | model._modules[name].bias = old_module.bias 134 | 135 | if post_processing_function is not None: 136 | func = getattr(module, post_processing_function, None) 137 | if func is not None: func(module) 138 | return model 139 | 140 | 141 | 142 | def execute_and_return(command_string: str) -> Tuple[str, str]: 143 | def _decode(subprocess_err_out_tuple): 144 | return tuple( 145 | to_decode.decode("UTF-8").strip() 146 | for to_decode in subprocess_err_out_tuple 147 | ) 148 | 149 | def execute_and_return_decoded_std_streams(command_string): 150 | return _decode( 151 | subprocess.Popen( 152 | shlex.split(command_string), 153 | stdout=subprocess.PIPE, 154 | stderr=subprocess.PIPE, 155 | ).communicate() 156 | ) 157 | 158 | std_out, std_err = execute_and_return_decoded_std_streams(command_string) 159 | return std_out, std_err 160 | 161 | 162 | 163 | def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): 164 | """ 165 | Replace linear modules with a new Linear module. 166 | Parameters: 167 | model (`torch.nn.Module`): 168 | Input model or `torch.nn.Module` as the function is run recursively. 169 | linear_replacement (`torch.nn.Module`): 170 | The linear module that replaces the old one. Only expects standard arguments. 171 | If other arguments need to be passed, use a lambda. 172 | skip_modules (`List[str]`, *optional*, defaults to `lm_head`): 173 | List of modules names not to convert. Defaults to `lm_head`. 174 | copy_weights (`bool`): 175 | Copy the weights from the old linear module to the new one 176 | post_processing_fun_name (`str`): 177 | A function name of the replacement linear class that is called 178 | after processing. 179 | """ 180 | for name, module in model.named_children(): 181 | if len(list(module.children())) > 0: 182 | replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) 183 | 184 | if isinstance(module, torch.nn.Linear) and name not in skip_modules: 185 | old_module = model._modules[name] 186 | model._modules[name] = linear_replacement( 187 | module.in_features, 188 | module.out_features, 189 | module.bias is not None, 190 | ) 191 | if copy_weights: 192 | model._modules[name].weight = old_module.weight 193 | model._modules[name].bias = old_module.bias 194 | 195 | if post_processing_function is not None: 196 | func = getattr(module, post_processing_function, None) 197 | if func is not None: func(module) 198 | return model 199 | 200 | -------------------------------------------------------------------------------- /check_bnb_install.py: -------------------------------------------------------------------------------- 1 | import bitsandbytes as bnb 2 | import torch 3 | 4 | p = torch.nn.Parameter(torch.rand(10,10).cuda()) 5 | a = torch.rand(10,10).cuda() 6 | 7 | p1 = p.data.sum().item() 8 | 9 | adam = bnb.optim.Adam([p]) 10 | 11 | out = a*p 12 | loss = out.sum() 13 | loss.backward() 14 | adam.step() 15 | 16 | p2 = p.data.sum().item() 17 | 18 | assert p1 != p2 19 | print('SUCCESS!') 20 | print('Installation was successful!') 21 | -------------------------------------------------------------------------------- /compile_from_source.md: -------------------------------------------------------------------------------- 1 | # Compiling from source 2 | 3 | Basic steps. 4 | 1. `CUDA_VERSION=XXX make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cuda12x, cpuonly` 5 | 2. `python setup.py install` 6 | 7 | To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive). 8 | 9 | You can install CUDA locally without sudo by following the following steps: 10 | 11 | ```bash 12 | wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh 13 | # Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH 14 | # CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121} 15 | # EXPORT_TO_BASH in {0, 1} with 0=False and 1=True 16 | 17 | # For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc 18 | bash cuda install 117 ~/local 1 19 | ``` 20 | 21 | By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler. 22 | 23 | Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed 24 | 25 | If you type `nvcc` and it cannot be found, you might need to add to your path or set the CUDA_HOME variable. You can run `python -m bitsandbytes` to find the path to CUDA. For example if `python -m bitsandbytes` shows you the following: 26 | ``` 27 | ++++++++++++++++++ /usr/local CUDA PATHS +++++++++++++++++++ 28 | /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudart.so 29 | ``` 30 | You can set `CUDA_HOME` to `/usr/local/cuda-11.7`. For example, you might be able to compile like this. 31 | 32 | ``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x`` 33 | 34 | 35 | If you have problems compiling the library with these instructions from source, please open an issue. 36 | 37 | ## Compilation with Kepler 38 | 39 | Since 0.39.1 bitsandbytes installed via pip no longer provides Kepler binaries and these need to be compiled from source. Follow the steps above and instead of `cuda11x_nomatmul` etc use `cuda11x_nomatmul_kepler` 40 | 41 | -------------------------------------------------------------------------------- /csrc/common.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | void *quantize_block(void *arguments) { 5 | // 1. find absmax in block 6 | // 2. divide input value by absmax to normalize into [-1.0, 1.0] 7 | // 3. do binary search to find the closest value 8 | // 4. check minimal distance 9 | // 5. store index 10 | 11 | struct quantize_block_args *args = (quantize_block_args *) arguments; 12 | 13 | // 1. find absmax in block 14 | float absmax_block = -FLT_MAX; 15 | for (long long i = args->block_idx; i < args->block_end; i++) 16 | absmax_block = fmax(absmax_block, fabs(args->A[i])); 17 | 18 | args->absmax[args->block_idx / args->blocksize] = absmax_block; 19 | 20 | for (long long i = args->block_idx; i < args->block_end; i++) { 21 | // 2. divide input value by absmax to normalize into [-1.0, 1.0] 22 | // 3. do binary search to find the closest value 23 | float normed_value = args->A[i] / absmax_block; 24 | long long idx = args->bin_searcher->scalar(normed_value); 25 | 26 | // 4. check minimal distance 27 | // The binary search returns always the value to the left, which might not be the closest value 28 | if (idx < 255) { 29 | float dist_left = fabs(normed_value - (args->code[idx])); 30 | float dist_right = fabs(normed_value - (args->code[idx + 1])); 31 | if (dist_right < dist_left) { idx += 1; } 32 | } 33 | 34 | // 5. store index 35 | args->out[i] = (unsigned char) idx; 36 | } 37 | 38 | return NULL; 39 | } 40 | -------------------------------------------------------------------------------- /csrc/common.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifndef common 4 | #define common 5 | 6 | using namespace BinSearch; 7 | 8 | #define BLOCK_SIZE 16384 9 | 10 | struct quantize_block_args { 11 | BinAlgo *bin_searcher; 12 | float *code; 13 | float *A; 14 | float *absmax; 15 | unsigned char *out; 16 | long long block_end; 17 | long long block_idx; 18 | long long threadidx; 19 | long long blocksize; 20 | }; 21 | 22 | 23 | void *quantize_block(void *arguments); 24 | 25 | #endif 26 | -------------------------------------------------------------------------------- /csrc/cpu_ops.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | using namespace BinSearch; 6 | 7 | void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) { 8 | for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { 9 | long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; 10 | long long block_end = block_idx + valid_items; 11 | for (long long i = block_idx; i < block_end; i++) 12 | out[i] = code[A[i]] * absmax[block_idx / blocksize]; 13 | } 14 | } 15 | 16 | void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n) 17 | { 18 | 19 | // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below 20 | code[0] = -1.0f; 21 | 22 | long long num_blocks = n / blocksize; 23 | num_blocks += n % blocksize == 0 ? 0 : 1; 24 | 25 | const uint32 elements_code = 256; 26 | BinAlgo bin_searcher(code, elements_code); 27 | 28 | int thread_wave_size = 256; 29 | // we chunk the thresds into waves of 256 since the max limit is 30 | // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) 31 | for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) 32 | { 33 | long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; 34 | pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks); 35 | 36 | struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *)); 37 | 38 | for(long long i = 0; i < valid_chunks; i++) 39 | args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); 40 | 41 | int chunks_processed = 0; 42 | for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) 43 | { 44 | long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; 45 | long long block_end = block_idx + valid_items; 46 | 47 | struct quantize_block_args *arg = args[chunks_processed]; 48 | arg->bin_searcher = &bin_searcher; 49 | arg->code = code; 50 | arg->A = A; 51 | arg->absmax = absmax; 52 | arg->out = out; 53 | arg->block_end = block_end; 54 | arg->block_idx = block_idx; 55 | arg->threadidx = block_idx / blocksize; 56 | arg->blocksize = blocksize; 57 | 58 | pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg); 59 | chunks_processed += 1; 60 | if(chunks_processed == valid_chunks){ break; } 61 | } 62 | 63 | for (int i = 0; i < valid_chunks; i++) 64 | int err = pthread_join(threads[i], NULL); 65 | 66 | free(threads); 67 | for (int i = 0; i < valid_chunks; i++) 68 | free(args[i]); 69 | free(args); 70 | 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /csrc/cpu_ops.h: -------------------------------------------------------------------------------- 1 | #ifndef BITSANDBYTES_CPU_OPS_H 2 | #define BITSANDBYTES_CPU_OPS_H 3 | 4 | #include 5 | #include 6 | 7 | void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n); 8 | void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n); 9 | 10 | #endif 11 | -------------------------------------------------------------------------------- /csrc/kernels.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #ifndef kernels 10 | #define kernels 11 | 12 | //template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); 13 | 14 | template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); 15 | 16 | __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); 17 | __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); 18 | 19 | template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); 20 | template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); 21 | 22 | template 23 | __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, 24 | float* state1, float* state2, float *unorm, 25 | const float beta1, const float beta2, const float eps, const float weight_decay, 26 | const int step, const float lr, const float gnorm_scale, const int n); 27 | 28 | template 29 | __global__ void kOptimizer32bit2State(T* g, T* p, 30 | float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, 31 | const float beta1, const float beta2, const float eps, const float weight_decay, 32 | const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); 33 | 34 | template 35 | __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, 36 | float* state1, float *unorm, 37 | const float beta1, const float beta2, const float eps, const float weight_decay, 38 | const int step, const float lr, const float gnorm_scale, const int n); 39 | 40 | template 41 | __global__ void kOptimizer32bit1State(T* g, T* p, 42 | float* state1, float *unorm, const float max_unorm, const float param_norm, 43 | const float beta1, const float beta2, const float eps, const float weight_decay, 44 | const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); 45 | 46 | template 47 | __global__ void 48 | kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, 49 | float *unorm, 50 | const float beta1, const float beta2, 51 | const float eps, const int step, 52 | float* __restrict__ const quantiles1, 53 | float* max1, float* new_max1, 54 | const float weight_decay, 55 | const float gnorm_scale, const int n); 56 | 57 | 58 | template 59 | __global__ void 60 | kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, 61 | const float *unorm, const float max_unorm, const float param_norm, 62 | const float beta1, const float beta2, 63 | const float eps, const int step, const float lr, 64 | float* __restrict__ const quantiles1, 65 | float* max1, float* new_max1, 66 | float weight_decay, const float gnorm_scale, const int n); 67 | 68 | 69 | 70 | template 71 | __global__ void 72 | kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, 73 | float *unorm, 74 | const float beta1, const float beta2, 75 | const float eps, const int step, 76 | float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, 77 | float* max1, float* max2, float* new_max1, float* new_max2, 78 | const float gnorm_scale, const int n); 79 | 80 | 81 | template 82 | __global__ void 83 | kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, 84 | const float *unorm, const float max_unorm, const float param_norm, 85 | const float beta1, const float beta2, 86 | const float eps, const int step, const float lr, 87 | float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, 88 | float* max1, float* max2, float* new_max1, float* new_max2, 89 | float weight_decay, const float gnorm_scale, const int n); 90 | 91 | template __global__ void kOptimizerStatic8bit2StateBlockwise( 92 | T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, 93 | const float beta1, const float beta2, const float eps, const int step, const float lr, 94 | float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, 95 | float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); 96 | 97 | template __global__ void kOptimizerStatic8bit1StateBlockwise( 98 | T* p, T* __restrict__ const g, unsigned char* state1, 99 | const float beta1, const float beta2, 100 | const float eps, const int step, const float lr, 101 | float* __restrict__ const quantiles1, 102 | float* absmax1, 103 | float weight_decay, 104 | const float gnorm_scale, const bool skip_zeros, const int n); 105 | 106 | 107 | template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); 108 | 109 | 110 | __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); 111 | 112 | 113 | template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); 114 | 115 | template __global__ void kdequant_mm_int32_fp16( 116 | int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, 117 | half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); 118 | 119 | template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); 120 | template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); 121 | 122 | template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); 123 | 124 | template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); 125 | 126 | template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); 127 | template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); 128 | template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); 129 | 130 | template __global__ void kfunc(T *A, T *B, T value, long n); 131 | 132 | #endif 133 | -------------------------------------------------------------------------------- /csrc/ops.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #ifndef ops_H 8 | #define ops_H 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include 24 | #include 25 | 26 | 27 | 28 | #define CUDA_CHECK_RETURN(value) { \ 29 | cudaError_t _m_cudaStat = value; \ 30 | if (_m_cudaStat != cudaSuccess) { \ 31 | fprintf(stderr, "Error %s at line %d in file %s\n", \ 32 | cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ 33 | exit(1); \ 34 | } } 35 | 36 | #define THREADS_PER_BLOCKS (512) 37 | 38 | #define CHECK_CUSPARSE(value) { \ 39 | cusparseStatus_t _m_cudaStat = value; \ 40 | if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \ 41 | fprintf(stderr, "Error %s at line %d in file %s\n", \ 42 | cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ 43 | exit(1); \ 44 | } } 45 | 46 | 47 | #define THREADS_PER_BLOCKS (512) 48 | 49 | 50 | inline void checkCudaStatus(cudaError_t status) { 51 | if (status != cudaSuccess) { 52 | printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status)); 53 | throw std::logic_error("cuda API failed"); 54 | } 55 | } 56 | 57 | inline int checkCublasStatus(cublasStatus_t status) { 58 | if (status != CUBLAS_STATUS_SUCCESS) { 59 | printf("cuBLAS API failed with status %d\n", status); 60 | //throw std::logic_error("cuBLAS API failed"); 61 | return 1; 62 | } 63 | return 0; 64 | } 65 | 66 | typedef enum Operations_t 67 | { 68 | ksmul = 0, 69 | } Operations_t; 70 | 71 | typedef enum Optimizer_t 72 | { 73 | ADAM = 0, 74 | MOMENTUM = 1, 75 | RMSPROP = 2, 76 | LARS = 3, 77 | ADAGRAD = 4, 78 | LION = 5, 79 | } Optimizer_t; 80 | 81 | typedef enum Transform_t 82 | { 83 | ROW = 0, 84 | COL = 1, 85 | COL32 = 2, 86 | COL_TURING = 3, 87 | COL_AMPERE = 4, 88 | } Transform_t; 89 | 90 | typedef enum DataType_t 91 | { 92 | General8bit = 0, 93 | FP4 = 1, 94 | NF4 = 2, 95 | } DataType_t; 96 | 97 | typedef enum Funcs_t 98 | { 99 | FILL = 0, 100 | ARANGE = 1, 101 | _MUL = 2, 102 | } Funcs_t; 103 | 104 | class Context 105 | { 106 | public: 107 | cublasHandle_t m_handle; 108 | 109 | Context() 110 | { 111 | cublasHandle_t handle; 112 | cublasCreate_v2(&handle); 113 | m_handle = handle; 114 | } 115 | 116 | }; 117 | 118 | class ContextLt 119 | { 120 | public: 121 | cublasLtHandle_t m_handle; 122 | 123 | ContextLt() 124 | { 125 | cublasLtHandle_t handle; 126 | cublasLtCreate(&handle); 127 | m_handle = handle; 128 | } 129 | 130 | }; 131 | 132 | class ContextCusparse 133 | { 134 | public: 135 | cusparseHandle_t m_handle; 136 | 137 | ContextCusparse() 138 | { 139 | cusparseHandle_t handle; 140 | cusparseCreate(&handle); 141 | m_handle = handle; 142 | } 143 | 144 | }; 145 | 146 | 147 | template void estimateQuantiles(T *A, float *code, float offset, int n); 148 | 149 | void quantize(float *code, float *A, unsigned char *out, int n); 150 | void dequantize(float *code, unsigned char *A, float *out, int n); 151 | template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); 152 | template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); 153 | 154 | template void optimizer32bit(T* g, T* p, 155 | float* state1, float* state2, float *unorm, float max_unorm, float param_norm, 156 | float beta1, float beta2, float eps, float weight_decay, 157 | int step, float lr, const float gnorm_scale, bool skip_zeros, int n); 158 | 159 | template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, 160 | float *unorm, float max_unorm, float param_norm, 161 | float beta1, float beta2, 162 | float eps, int step, float lr, 163 | float* quantiles1, float* quantiles2, 164 | float* max1, float* max2, float* new_max1, float* new_max2, 165 | float weight_decay, 166 | const float gnorm_scale, int n); 167 | 168 | template void optimizerStatic8bitBlockwise(T* p, T* g, 169 | unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, 170 | float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, 171 | bool skip_zeros, int n); 172 | 173 | template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); 174 | 175 | void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); 176 | 177 | void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); 178 | void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, 179 | long long int strideA, long long int strideB, long long int strideC, int batchCount); 180 | 181 | 182 | template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); 183 | 184 | template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); 185 | void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); 186 | void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols); 187 | void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); 188 | void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, 189 | int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); 190 | 191 | template void transformRowToFormat(char * A, char *out, int rows, int cols); 192 | 193 | void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); 194 | 195 | template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); 196 | 197 | template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); 198 | 199 | void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); 200 | 201 | template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); 202 | template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); 203 | template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); 204 | 205 | template void func(T *A, T *B, T value, long n); 206 | 207 | #endif 208 | -------------------------------------------------------------------------------- /cuda_install.sh: -------------------------------------------------------------------------------- 1 | URL92=https://developer.nvidia.com/compute/cuda/9.2/Prod2/local_installers/cuda_9.2.148_396.37_linux 2 | URL100=https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux 3 | URL101=https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run 4 | URL102=https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run 5 | URL110=https://developer.download.nvidia.com/compute/cuda/11.0.3/local_installers/cuda_11.0.3_450.51.06_linux.run 6 | URL111=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run 7 | URL112=https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run 8 | URL113=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run 9 | URL114=https://developer.download.nvidia.com/compute/cuda/11.4.4/local_installers/cuda_11.4.4_470.82.01_linux.run 10 | URL115=https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run 11 | URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run 12 | URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run 13 | URL118=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run 14 | URL120=https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda_12.0.0_525.60.13_linux.run 15 | URL121=https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run 16 | URL122=https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run 17 | 18 | 19 | CUDA_VERSION=$1 20 | BASE_PATH=$2 21 | EXPORT_BASHRC=$3 22 | 23 | if [[ -n "$CUDA_VERSION" ]]; then 24 | if [[ "$CUDA_VERSION" -eq "92" ]]; then 25 | URL=$URL92 26 | FOLDER=cuda-9.2 27 | elif [[ "$CUDA_VERSION" -eq "100" ]]; then 28 | URL=$URL100 29 | FOLDER=cuda-10.0 30 | elif [[ "$CUDA_VERSION" -eq "101" ]]; then 31 | URL=$URL101 32 | FOLDER=cuda-10.1 33 | elif [[ "$CUDA_VERSION" -eq "102" ]]; then 34 | URL=$URL102 35 | FOLDER=cuda-10.2 36 | elif [[ "$CUDA_VERSION" -eq "110" ]]; then 37 | URL=$URL110 38 | FOLDER=cuda-11.0 39 | elif [[ "$CUDA_VERSION" -eq "111" ]]; then 40 | URL=$URL111 41 | FOLDER=cuda-11.1 42 | elif [[ "$CUDA_VERSION" -eq "112" ]]; then 43 | URL=$URL112 44 | FOLDER=cuda-11.2 45 | elif [[ "$CUDA_VERSION" -eq "113" ]]; then 46 | URL=$URL113 47 | FOLDER=cuda-11.3 48 | elif [[ "$CUDA_VERSION" -eq "114" ]]; then 49 | URL=$URL114 50 | FOLDER=cuda-11.4 51 | elif [[ "$CUDA_VERSION" -eq "115" ]]; then 52 | URL=$URL115 53 | FOLDER=cuda-11.5 54 | elif [[ "$CUDA_VERSION" -eq "116" ]]; then 55 | URL=$URL116 56 | FOLDER=cuda-11.6 57 | elif [[ "$CUDA_VERSION" -eq "117" ]]; then 58 | URL=$URL117 59 | FOLDER=cuda-11.7 60 | elif [[ "$CUDA_VERSION" -eq "118" ]]; then 61 | URL=$URL118 62 | FOLDER=cuda-11.8 63 | elif [[ "$CUDA_VERSION" -eq "120" ]]; then 64 | URL=$URL120 65 | FOLDER=cuda-12.0 66 | elif [[ "$CUDA_VERSION" -eq "121" ]]; then 67 | URL=$URL121 68 | FOLDER=cuda-12.1 69 | elif [[ "$CUDA_VERSION" -eq "122" ]]; then 70 | URL=$URL122 71 | FOLDER=cuda-12.2 72 | else 73 | echo "argument error: No cuda version passed as input. Choose among versions 92 to 121" 74 | fi 75 | else 76 | echo "argument error: No cuda version passed as input. Choose among versions 92 to 112" 77 | fi 78 | 79 | FILE=$(basename $URL) 80 | 81 | if [[ -n "$CUDA_VERSION" ]]; then 82 | echo $URL 83 | echo $FILE 84 | wget $URL 85 | bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent 86 | if [ "$EXPORT_BASHRC" -eq "1" ]; then 87 | echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64" >> ~/.bashrc 88 | echo "export PATH=\$PATH:$BASE_PATH/$FOLDER/bin" >> ~/.bashrc 89 | source ~/.bashrc 90 | fi 91 | else 92 | echo "" 93 | fi 94 | -------------------------------------------------------------------------------- /deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BASE_PATH=$1 3 | 4 | echo "MAKE SURE LD_LIBRARY_PATH IS EMPTY!" 5 | echo $LD_LIBRARY_PATH 6 | 7 | if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then 8 | echo "Compilation unsuccessul!" 1>&2 9 | exit 64 10 | fi 11 | 12 | 13 | module unload cuda && echo "no module function available. Probably not on a slurm cluster." 14 | module unload gcc && echo "no module function available. Probably not on a slurm cluster." 15 | 16 | rm -rf dist build 17 | make cleaneggs 18 | make cleanlibs 19 | 20 | make clean 21 | export CUDA_HOME= 22 | export CUDA_VERSION= 23 | make cpuonly CUDA_VERSION="CPU" 24 | 25 | if [ ! -f "./bitsandbytes/libbitsandbytes_cpu.so" ]; then 26 | # Control will enter here if $DIRECTORY doesn't exist. 27 | echo "Compilation unsuccessul!" 1>&2 28 | exit 64 29 | fi 30 | 31 | make clean 32 | export CUDA_HOME=$BASE_PATH/cuda-11.0 33 | make cuda110 CUDA_VERSION=110 34 | 35 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110.so" ]; then 36 | # Control will enter here if $DIRECTORY doesn't exist. 37 | echo "Compilation unsuccessul!" 1>&2 38 | exit 64 39 | fi 40 | 41 | make clean 42 | export CUDA_HOME=$BASE_PATH/cuda-11.1 43 | make cuda11x CUDA_VERSION=111 44 | 45 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111.so" ]; then 46 | # Control will enter here if $DIRECTORY doesn't exist. 47 | echo "Compilation unsuccessul!" 1>&2 48 | exit 64 49 | fi 50 | 51 | make clean 52 | export CUDA_HOME=$BASE_PATH/cuda-11.4 53 | make cuda11x CUDA_VERSION=114 54 | 55 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114.so" ]; then 56 | # Control will enter here if $DIRECTORY doesn't exist. 57 | echo "Compilation unsuccessul!" 1>&2 58 | exit 64 59 | fi 60 | 61 | make clean 62 | export CUDA_HOME=$BASE_PATH/cuda-11.5 63 | make cuda11x CUDA_VERSION=115 64 | 65 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115.so" ]; then 66 | # Control will enter here if $DIRECTORY doesn't exist. 67 | echo "Compilation unsuccessul!" 1>&2 68 | exit 64 69 | fi 70 | 71 | make clean 72 | export CUDA_HOME=$BASE_PATH/cuda-11.7 73 | make cuda11x CUDA_VERSION=117 74 | 75 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117.so" ]; then 76 | # Control will enter here if $DIRECTORY doesn't exist. 77 | echo "Compilation unsuccessul!" 1>&2 78 | exit 64 79 | fi 80 | 81 | make clean 82 | export CUDA_HOME=$BASE_PATH/cuda-11.8 83 | make cuda118 CUDA_VERSION=118 84 | 85 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda118.so" ]; then 86 | # Control will enter here if $DIRECTORY doesn't exist. 87 | echo "Compilation unsuccessul!" 1>&2 88 | exit 64 89 | fi 90 | 91 | make clean 92 | export CUDA_HOME=$BASE_PATH/cuda-12.0 93 | make cuda12x CUDA_VERSION=120 94 | 95 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120.so" ]; then 96 | # Control will enter here if $DIRECTORY doesn't exist. 97 | echo "Compilation unsuccessul!" 1>&2 98 | exit 64 99 | fi 100 | 101 | make clean 102 | export CUDA_HOME=$BASE_PATH/cuda-12.1 103 | make cuda12x CUDA_VERSION=121 104 | 105 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then 106 | # Control will enter here if $DIRECTORY doesn't exist. 107 | echo "Compilation unsuccessul!" 1>&2 108 | exit 64 109 | fi 110 | 111 | make clean 112 | export CUDA_HOME=$BASE_PATH/cuda-12.2 113 | make cuda12x CUDA_VERSION=122 114 | 115 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda122.so" ]; then 116 | # Control will enter here if $DIRECTORY doesn't exist. 117 | echo "Compilation unsuccessul!" 1>&2 118 | exit 64 119 | fi 120 | 121 | 122 | make clean 123 | export CUDA_HOME=$BASE_PATH/cuda-11.0 124 | make cuda110_nomatmul CUDA_VERSION=110 125 | 126 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110_nocublaslt.so" ]; then 127 | # Control will enter here if $DIRECTORY doesn't exist. 128 | echo "Compilation unsuccessul!" 1>&2 129 | exit 64 130 | fi 131 | 132 | 133 | make clean 134 | export CUDA_HOME=$BASE_PATH/cuda-11.1 135 | make cuda11x_nomatmul CUDA_VERSION=111 136 | 137 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111_nocublaslt.so" ]; then 138 | # Control will enter here if $DIRECTORY doesn't exist. 139 | echo "Compilation unsuccessul!" 1>&2 140 | exit 64 141 | fi 142 | 143 | make clean 144 | export CUDA_HOME=$BASE_PATH/cuda-11.4 145 | make cuda11x_nomatmul CUDA_VERSION=114 146 | 147 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114_nocublaslt.so" ]; then 148 | # Control will enter here if $DIRECTORY doesn't exist. 149 | echo "Compilation unsuccessul!" 1>&2 150 | exit 64 151 | fi 152 | 153 | make clean 154 | export CUDA_HOME=$BASE_PATH/cuda-11.5 155 | make cuda11x_nomatmul CUDA_VERSION=115 156 | 157 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115_nocublaslt.so" ]; then 158 | # Control will enter here if $DIRECTORY doesn't exist. 159 | echo "Compilation unsuccessul!" 1>&2 160 | exit 64 161 | fi 162 | 163 | make clean 164 | export CUDA_HOME=$BASE_PATH/cuda-11.7 165 | make cuda11x_nomatmul CUDA_VERSION=117 166 | 167 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117_nocublaslt.so" ]; then 168 | # Control will enter here if $DIRECTORY doesn't exist. 169 | echo "Compilation unsuccessul!" 1>&2 170 | exit 64 171 | fi 172 | 173 | make clean 174 | export CUDA_HOME=$BASE_PATH/cuda-11.8 175 | make cuda118_nomatmul CUDA_VERSION=118 176 | 177 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda118_nocublaslt.so" ]; then 178 | # Control will enter here if $DIRECTORY doesn't exist. 179 | echo "Compilation unsuccessul!" 1>&2 180 | exit 64 181 | fi 182 | 183 | make clean 184 | export CUDA_HOME=$BASE_PATH/cuda-12.0 185 | make cuda12x_nomatmul CUDA_VERSION=120 186 | 187 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so" ]; then 188 | # Control will enter here if $DIRECTORY doesn't exist. 189 | echo "Compilation unsuccessul!" 1>&2 190 | exit 64 191 | fi 192 | 193 | make clean 194 | export CUDA_HOME=$BASE_PATH/cuda-12.1 195 | make cuda12x_nomatmul CUDA_VERSION=121 196 | 197 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121_nocublaslt.so" ]; then 198 | # Control will enter here if $DIRECTORY doesn't exist. 199 | echo "Compilation unsuccessul!" 1>&2 200 | exit 64 201 | fi 202 | 203 | make clean 204 | export CUDA_HOME=$BASE_PATH/cuda-12.2 205 | make cuda12x_nomatmul CUDA_VERSION=122 206 | 207 | if [ ! -f "./bitsandbytes/libbitsandbytes_cuda122_nocublaslt.so" ]; then 208 | # Control will enter here if $DIRECTORY doesn't exist. 209 | echo "Compilation unsuccessul!" 1>&2 210 | exit 64 211 | fi 212 | 213 | python -m build 214 | python -m twine upload dist/* --verbose 215 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: 8-bit 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - python=3.9 7 | - pytest 8 | - pytorch 9 | - torchaudio 10 | - torchvision 11 | - cudatoolkit=11.1 12 | - typer 13 | - ca-certificates 14 | - certifi 15 | - openssl 16 | -------------------------------------------------------------------------------- /errors_and_solutions.md: -------------------------------------------------------------------------------- 1 | # No kernel image available 2 | 3 | This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. To solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``? 4 | 5 | If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation. 6 | 7 | 8 | __If you encounter any other error not listed here please create an issue. This will help resolve your problem and will help out others in the future. 9 | 10 | 11 | # fatbinwrap 12 | 13 | This error occurs if there is a mismatch between CUDA versions in the C++ library and the CUDA part. Make sure you have right CUDA in your $PATH and $LD_LIBRARY_PATH variable. In the conda base environment you can find the library under: 14 | ```bash 15 | ls $CONDA_PREFIX/lib/*cudart* 16 | ``` 17 | Make sure this path is appended to the `LD_LIBRARY_PATH` so bnb can find the CUDA runtime environment library (cudart). 18 | 19 | If this does not fix the issue, please try [compilation from source](compile_from_source.md) next. 20 | 21 | If this does not work, please open an issue and paste the printed environment if you call `make` and the associated error when running bnb. 22 | -------------------------------------------------------------------------------- /examples/int8_inference_huggingface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | 4 | MAX_NEW_TOKENS = 128 5 | model_name = 'decapoda-research/llama-7b-hf' 6 | 7 | text = 'Hamburg is in which country?\n' 8 | tokenizer = AutoTokenizer.from_pretrained(model_name) 9 | input_ids = tokenizer(text, return_tensors="pt").input_ids 10 | 11 | free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) 12 | max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' 13 | 14 | n_gpus = torch.cuda.device_count() 15 | max_memory = {i: max_memory for i in range(n_gpus)} 16 | 17 | model = AutoModelForCausalLM.from_pretrained( 18 | model_name, 19 | device_map='auto', 20 | load_in_8bit=True, 21 | max_memory=max_memory 22 | ) 23 | generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) 24 | print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /how_to_use_nonpytorch_cuda.md: -------------------------------------------------------------------------------- 1 | ## How to use a CUDA version that is different from PyTorch 2 | 3 | Some features of bitsandbytes may need a newer CUDA version than regularly supported by PyTorch binaries from conda / pip. In that case you can use the following instructions to load a precompiled bitsandbytes binary that works for you. 4 | 5 | ## Installing or determining the CUDA installation 6 | 7 | Determine the path of the CUDA version that you want to use. Common paths paths are: 8 | ```bash 9 | /usr/local/cuda 10 | /usr/local/cuda-XX.X 11 | ``` 12 | 13 | where XX.X is the CUDA version number. 14 | 15 | You can also install CUDA version that you need locally with a script provided by bitsandbytes as follows: 16 | 17 | ```bash 18 | wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh 19 | # Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH 20 | # CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122} 21 | # EXPORT_TO_BASH in {0, 1} with 0=False and 1=True 22 | 23 | # For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc 24 | bash cuda install 117 ~/local 1 25 | ``` 26 | 27 | ## Setting the environmental variables BNB_CUDA_VERSION, and LD_LIBRARY_PATH 28 | 29 | To manually override the PyTorch installed CUDA version you need to set to variable, like so: 30 | 31 | ```bash 32 | export BNB_CUDA_VERSION= 33 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: 34 | ``` 35 | 36 | For example, to use the local install path from above: 37 | 38 | ```bash 39 | export BNB_CUDA_VERSION=117 40 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/tim/local/cuda-11.7 41 | ``` 42 | 43 | It is best to add these lines to the `.bashrc` file to make them permanent. 44 | 45 | If you now launch bitsandbytes with these environmental variables the PyTorch CUDA version will be overridden by the new CUDA version and a different bitsandbytes library is loaded (in this case version 117). 46 | -------------------------------------------------------------------------------- /howto_config_override.md: -------------------------------------------------------------------------------- 1 | # How to override config hyperparameters for particular weights/parameters 2 | 3 | If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details 4 | 5 | For global overrides in many different places in your code you can do: 6 | ```python 7 | import torch 8 | import bitsandbytes as bnb 9 | 10 | mng = bnb.optim.GlobalOptimManager.get_instance() 11 | 12 | model = MyModel() 13 | mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU 14 | 15 | model = model.cuda() 16 | # use 8-bit optimizer states for all parameters 17 | adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8) 18 | 19 | # 2a. override: the parameter model.fc1.weight now uses 32-bit Adam 20 | mng.override_config(model.fc1.weight, 'optim_bits', 32) 21 | 22 | # 2b. override: the two special layers use 23 | # sparse optimization + different learning rate + different Adam betas 24 | mng.override_config([model.special.weight, model.also_special.weight], 25 | key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)}) 26 | ``` 27 | Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm` 28 | 29 | For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager: 30 | ```python 31 | class MyModule(torch.nn.Module): 32 | def __init__(din, dout): 33 | super(MyModule, self).__init__() 34 | self.linear = torch.nn.Linear(din, dout) 35 | # optimization will happen in 32-bit and 36 | # learning rate will be set to 0.0001 independent of the main learning rate 37 | config = {'optim_bits': 32, 'lr' : 0.0001} 38 | GlobalOptimManager.get_instance().register_module_override(self, 'weight', config) 39 | 40 | ``` 41 | -------------------------------------------------------------------------------- /include/AAlloc.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Portable.h" 4 | 5 | namespace BinSearch { 6 | namespace Details { 7 | 8 | template 9 | bool isAligned(const T *p, size_t A) 10 | { 11 | return (reinterpret_cast(p) % A) == 0; 12 | } 13 | 14 | template 15 | struct AlignedVec 16 | { 17 | AlignedVec() 18 | : m_storage(0) 19 | , m_data(0) 20 | , m_sz(0) 21 | { 22 | } 23 | 24 | static size_t nBytes(size_t sz) 25 | { 26 | return sz * sizeof(T) + A; 27 | } 28 | 29 | static size_t shiftAmt(char *p) 30 | { 31 | return A>1? (A - (reinterpret_cast(p) % A)) % A: 0; 32 | } 33 | 34 | void setPtr(char *p, size_t sz) 35 | { 36 | m_sz = sz; 37 | m_data = reinterpret_cast(p + shiftAmt(p)); 38 | } 39 | 40 | //void setPtr(T *p, size_t sz) 41 | //{ 42 | // m_sz = sz; 43 | // if (A>1) 44 | // myassert(((reinterpret_cast(p) % A) == 0), "bad alignment"); 45 | // m_data = p; 46 | //} 47 | 48 | // internal allocation 49 | void resize(size_t sz) 50 | { 51 | m_storage = new char[nBytes(sz)]; 52 | setPtr(m_storage, sz); 53 | } 54 | 55 | // external allocation 56 | void set(char *storage, size_t sz) 57 | { 58 | setPtr(storage, sz); 59 | } 60 | 61 | ~AlignedVec() 62 | { 63 | if (m_storage) 64 | delete [] m_storage; 65 | } 66 | 67 | size_t size() const { return m_sz; } 68 | T& operator[](size_t i) { return m_data[i]; } 69 | const T& operator[](size_t i) const { return m_data[i]; } 70 | T* begin() { return m_data; } 71 | T* end() { return m_data+m_sz; } 72 | const T* begin() const { return m_data; } 73 | const T* end() const { return m_data+m_sz; } 74 | T& front() { return m_data[0]; } 75 | T& back() { return m_data[m_sz-1]; } 76 | const T& front() const { return m_data[0]; } 77 | const T& back() const { return m_data[m_sz - 1]; } 78 | 79 | private: 80 | char *m_storage; 81 | T *m_data; 82 | size_t m_sz; 83 | }; 84 | 85 | } // namespace Details 86 | } // namespace BinSearch 87 | -------------------------------------------------------------------------------- /include/Algo-Direct-Common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include "AAlloc.h" 7 | 8 | namespace BinSearch { 9 | namespace Details { 10 | 11 | namespace DirectAux { 12 | 13 | #define SAFETY_MULTI_PASS true 14 | 15 | template 16 | struct HResults 17 | { 18 | HResults(T h, double ratio, size_t n) : H(h), hRatio(ratio), nInc(n) {} 19 | T H; 20 | double hRatio; 21 | size_t nInc; 22 | }; 23 | 24 | 25 | #ifdef USE_FMA 26 | template struct IsDirect { static const bool value = (A == Direct) || (A == DirectFMA); }; 27 | template struct IsDirect2 { static const bool value = (A == Direct2) || (A == Direct2FMA); }; 28 | template struct IsDirectCache { static const bool value = (A == DirectCache) || (A == DirectCacheFMA); }; 29 | #else 30 | template struct IsDirect { static const bool value = (A == Direct); }; 31 | template struct IsDirect2 { static const bool value = (A == Direct2); }; 32 | template struct IsDirectCache { static const bool value = (A == DirectCache); }; 33 | #endif 34 | 35 | // general definition 36 | template 37 | struct BucketElem 38 | { 39 | FORCE_INLINE void set( uint32 b, const T *) 40 | { 41 | m_b = b; 42 | } 43 | 44 | FORCE_INLINE uint32 index() const { return m_b; } 45 | 46 | private: 47 | uint32 m_b; 48 | }; 49 | 50 | // specialization for DirectCache methods 51 | 52 | template struct MatchingIntType; 53 | template <> struct MatchingIntType { typedef uint64 type; }; 54 | template <> struct MatchingIntType { typedef uint32 type; }; 55 | 56 | template 57 | struct BucketElem::value >::type > 58 | { 59 | typedef typename MatchingIntType::type I; 60 | 61 | void set(uint32 b, const T *xi) 62 | { 63 | u.u.x = xi[b]; 64 | u.u.b = b; 65 | } 66 | 67 | FORCE_INLINE I index() const { return u.u.b; } 68 | FORCE_INLINE T x() const { return u.u.x; } 69 | 70 | private: 71 | union { 72 | double dummy; 73 | struct 74 | { 75 | T x; 76 | I b; 77 | } u; 78 | } u; 79 | }; 80 | 81 | 82 | template 83 | struct DirectTraits 84 | { 85 | static void checkH(T scaler, T x0, T xN) 86 | { 87 | T Dn = xN - x0; 88 | T ifmax = Dn * scaler; 89 | myassert((ifmax < std::numeric_limits::max() - (Gap - 1)), 90 | "Problem unfeasible: index size exceeds uint32 capacity:" 91 | << " D[N] =" << Dn 92 | << ", H =" << scaler 93 | << ", H D[n] =" << ifmax << "\n" 94 | ); 95 | } 96 | 97 | FORCE_INLINE static uint32 f(T scaler, T x0, T z) 98 | { 99 | T tmp = scaler * (z - x0); 100 | #ifdef USE_SSE2 101 | return ftoi(FVec1(tmp)); 102 | #else 103 | return static_cast(tmp); 104 | #endif 105 | } 106 | 107 | template 108 | FORCE_INLINE static typename FTOITraits::vec_t f(const FVec& scaler, const FVec& x0, const FVec& z) 109 | { 110 | return ftoi(scaler*(z-x0)); 111 | } 112 | 113 | static T cst0(T scaler, T x0) 114 | { 115 | return x0; 116 | } 117 | }; 118 | 119 | #ifdef USE_FMA 120 | template 121 | struct DirectTraits 122 | { 123 | typedef FVec1 fVec1; 124 | 125 | static void checkH(T scaler, T H_Times_x0, T xN) 126 | { 127 | union { 128 | typename FVec1::vec_t v; 129 | T s; 130 | } ifmax; 131 | ifmax.v = mulSub(fVec1(scaler), fVec1(xN), fVec1(H_Times_x0)); 132 | myassert((ifmax.s < std::numeric_limits::max() - (Gap - 1)), 133 | "Problem unfeasible: index size exceeds uint32 capacity:" 134 | << " H X[0] =" << H_Times_x0 135 | << ", H =" << scaler 136 | << ", X[N] =" << xN 137 | << ", H X[N] - H X[0] =" << ifmax.s << "\n" 138 | ); 139 | } 140 | 141 | FORCE_INLINE static uint32 f(T scaler, T Hx0, T xi) 142 | { 143 | return ftoi(mulSub(fVec1(scaler), fVec1(xi), fVec1(Hx0))); 144 | } 145 | 146 | template 147 | FORCE_INLINE static typename FTOITraits::vec_t f(const FVec& scaler, const FVec& H_Times_X0, const FVec& z) 148 | { 149 | return ftoi(mulSub(scaler, z, H_Times_X0)); 150 | } 151 | 152 | static T cst0(T scaler, T x0) 153 | { 154 | return scaler*x0; 155 | } 156 | }; 157 | #endif 158 | 159 | template 160 | struct DirectInfo 161 | { 162 | static const bool UseFMA = (A == DirectFMA) || (A == Direct2FMA) || (A == DirectCacheFMA); 163 | typedef DirectTraits fun_t; 164 | typedef BucketElem bucket_t; 165 | typedef AlignedVec bucketvec_t; 166 | 167 | struct Data { 168 | Data() : buckets(0), xi(0), scaler(0), cst0(0) {} 169 | Data( const T *x // for Direct must persist if xws=NULL 170 | , uint32 n 171 | , T H 172 | , bucket_t *bws // assumed to gave size nb, as computed below 173 | , T *xws = NULL // assumed to have size (n+Gap-1). Optional for Direct, unused for DirectCache, required for DirectGap 174 | ) 175 | : buckets(bws) 176 | , scaler(H) 177 | , cst0(fun_t::cst0(H, x[0])) 178 | { 179 | myassert(((bws != NULL) && (isAligned(bws,64))), "bucket pointer not allocated or incorrectly aligned"); 180 | 181 | uint32 nb = 1 + fun_t::f(H, cst0, x[n-1]); 182 | 183 | const uint32 npad = Gap-1; 184 | const uint32 n_sz = n + npad; // size of padded vector 185 | 186 | if (xws) { 187 | myassert(isAligned(xws,8), "x pointer not allocated or incorrectly aligned"); 188 | std::fill_n(xws, npad, x[0]); // pad in front with x[0] 189 | std::copy(x, x+n, xws + npad); 190 | xi = xws; 191 | } 192 | else { 193 | myassert(Gap==1, "if Gap>1 then X workspace must be provided"); 194 | xi = x; 195 | } 196 | 197 | populateIndex(bws, nb, xi, n_sz, scaler, cst0); 198 | } 199 | 200 | const bucket_t *buckets; 201 | const T *xi; 202 | T scaler; 203 | T cst0; // could be x0 or (scaler*x0), depending if we are using FMA or not 204 | } data; 205 | 206 | static T growStep(T H) 207 | { 208 | T step; 209 | T P = next(H); 210 | while ((step = P - H) == 0) 211 | P = next(P); 212 | return step; 213 | } 214 | 215 | static HResults computeH(const T *px, uint32 nx) 216 | { 217 | myassert((nx > Gap), "Array X too small"); 218 | myassert(((Gap == 1) || (Gap == 2)), "Only tested for these values of Gap"); 219 | 220 | const T x0 = px[0]; 221 | const T xN = px[nx-1]; 222 | 223 | const T range = xN - x0; 224 | myassert((range < std::numeric_limits::max()), "range too large"); 225 | 226 | // check that D_i are strictly increasing and compute minimum value D_{i+Offset}-D_i 227 | T deltaDMin = range; 228 | for (uint32 i = Gap; i < nx; ++i) { 229 | T Dnew = px[i] - x0; 230 | T Dold = px[i - Gap] - x0; 231 | myassert((Dnew > Dold), 232 | "Problem unfeasible: D_i sequence not strictly increasing" 233 | << " X[" << 0 << "]=" << x0 234 | << " X[" << i - Gap << "]=" << px[i - Gap] 235 | << " X[" << i << "]=" << px[i] 236 | << "\n" 237 | ); 238 | T deltaD = Dnew - Dold; 239 | if (deltaD < deltaDMin) 240 | deltaDMin = deltaD; 241 | } 242 | 243 | // initial guess for H 244 | const T H0 = T(1.0) / deltaDMin; 245 | T H = H0; 246 | 247 | T cst0 = fun_t::cst0(H, x0); 248 | fun_t::checkH(H, cst0, xN); 249 | 250 | // adjust H by trial and error until succeed 251 | size_t nInc = 0; 252 | bool modified = false; 253 | size_t npasses = 0; 254 | T step = growStep(H); 255 | uint32 seg_already_checked_from = nx; 256 | do { 257 | myassert((npasses++ < 2), "verification failed\n"); 258 | // if there has been an increase, then check only up to that point 259 | uint32 last_seg_to_be_checked = seg_already_checked_from - 1; 260 | modified = false; 261 | uint32 inew = 0; 262 | for (uint32 i = Gap; i <= last_seg_to_be_checked; ++i) { 263 | uint32 iold = fun_t::f(H, cst0, px[i-Gap]); 264 | uint32 inew = fun_t::f(H, cst0, px[i]); 265 | while (inew == iold) { 266 | seg_already_checked_from = i; 267 | last_seg_to_be_checked = nx-1; // everything needs to be checked 268 | modified = true; 269 | H = H + step; 270 | step *= 2; 271 | // recalculate all constants and indices 272 | cst0 = fun_t::cst0(H, x0); 273 | fun_t::checkH(H, cst0, xN); 274 | iold = fun_t::f(H, cst0, px[i - Gap]); 275 | inew = fun_t::f(H, cst0, px[i]); 276 | } 277 | } 278 | } while (SAFETY_MULTI_PASS && modified); 279 | 280 | return HResults(H, (((double)H) / H0) - 1.0, nInc); 281 | } 282 | 283 | static void populateIndex(BucketElem *buckets, uint32 index_size, const T *px, uint32 x_size, T scaler, T cst0) 284 | { 285 | for (uint32 i = x_size-1, b = index_size-1, j=0; ; --i) { 286 | uint32 idx = fun_t::f(scaler, cst0, px[i]); 287 | while (b > idx) { // in the 1st iteration it is j=0 but this condition is always false 288 | buckets[b].set( j, px ); 289 | --b; 290 | } 291 | if (Gap==1 || b == idx) { // if Gap==1, which is known at compile time, the check b==idx is redundant 292 | j = i - (Gap-1); // subtracting (Gap-1) points to the index of the first X-element to check 293 | buckets[b].set(j, px); 294 | if (b-- == 0) 295 | break; 296 | } 297 | } 298 | } 299 | 300 | DirectInfo(const Data& d) 301 | : data(d) 302 | { 303 | } 304 | 305 | DirectInfo(const T* px, const uint32 n) 306 | { 307 | HResults res = computeH(px, n); 308 | 309 | #ifdef PAPER_TEST 310 | nInc = res.nInc; 311 | hRatio = res.hRatio; 312 | #endif 313 | const uint32 npad = Gap-1; 314 | const uint32 n_sz = n + npad; // size of padded vector 315 | 316 | if (npad) 317 | xi.resize(n_sz); 318 | 319 | T H = res.H; 320 | T cst0 = fun_t::cst0(H, px[0]); 321 | const uint32 maxIndex = fun_t::f(H, cst0, px[n-1]); 322 | buckets.resize(maxIndex + 1); 323 | 324 | data = Data(px, n, H, buckets.begin(), (npad? xi.begin(): NULL)); 325 | } 326 | 327 | private: 328 | bucketvec_t buckets; 329 | AlignedVec xi; 330 | 331 | #ifdef PAPER_TEST 332 | public: 333 | double hRatio; 334 | size_t nInc; 335 | #endif 336 | }; 337 | 338 | 339 | } // namespace DirectAux 340 | } // namespace Details 341 | } // namespace BinSearch 342 | -------------------------------------------------------------------------------- /include/AlgoXCodes.h: -------------------------------------------------------------------------------- 1 | ALGOENUM(DirectCacheFMA, 5) 2 | ALGOENUM(DirectFMA, 15) 3 | ALGOENUM(Direct2FMA, 25) 4 | ALGOENUM(DirectCache, 10) 5 | ALGOENUM(Direct, 20) 6 | ALGOENUM(Direct2, 30) 7 | ALGOENUM(Nonary, 40) 8 | ALGOENUM(Pentary, 50) 9 | ALGOENUM(Ternary, 60) 10 | ALGOENUM(Eytzinger, 70) 11 | ALGOENUM(BitSet, 80) 12 | ALGOENUM(ClassicOffset, 90) 13 | #ifdef PAPER_TEST 14 | ALGOENUM(MorinOffset, 100) 15 | ALGOENUM(BitSetNoPad, 110) 16 | ALGOENUM(ClassicMod, 120) 17 | ALGOENUM(MorinBranchy, 130) 18 | ALGOENUM(Classic, 140) 19 | ALGOENUM(LowerBound, 145) 20 | #ifdef USE_MKL 21 | ALGOENUM(MKL, 150) 22 | #endif 23 | #endif 24 | -------------------------------------------------------------------------------- /include/BinAlgo.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Type.h" 4 | #include 5 | 6 | namespace BinSearch { 7 | 8 | template 9 | struct BinAlgo : Details::BinAlgoBase 10 | { 11 | typedef Details::BinAlgoBase base_t; 12 | 13 | BinAlgo(const T* px, const uint32 n) : base_t(px, n), x0(px[0]), xN(px[n-1]), N(n) {} 14 | BinAlgo(const T* px, const uint32 n, const typename base_t::Data& d) : base_t(d), x0(px[0]), xN(px[n-1]), N(n) {} 15 | 16 | FORCE_INLINE 17 | uint32 scalar(T z) const 18 | { 19 | if (!L || z >= x0) 20 | if (!R || z < xN) 21 | return base_t::scalar(z); 22 | else 23 | return N; 24 | else 25 | return std::numeric_limits::max(); 26 | } 27 | 28 | 29 | FORCE_INLINE 30 | void vectorial(uint32 *pr, const T *pz, uint32 n) const 31 | { 32 | if (!L && !R) { 33 | Details::Loop::loop(*this, pr, pz, n); 34 | } 35 | else { 36 | const uint32 nElem = base_t::nElem; 37 | const uint32 idealbufsize = 256; 38 | const uint32 bufsize = nElem * (idealbufsize / nElem + ((idealbufsize % nElem) ? 1 : 0)); 39 | T databuf[bufsize]; 40 | uint32 resbuf[bufsize]; 41 | uint32 indexbuf[bufsize]; 42 | 43 | uint32 *prend = pr + n; 44 | while(pr != prend) { 45 | uint32 cnt = 0; 46 | uint32 niter = std::min(bufsize, (uint32)std::distance(pr,prend)); 47 | for (uint32 j = 0; j < niter; ++j) { 48 | T z = pz[j]; 49 | // FIXME: use SSE2? 50 | if (!L || z >= x0) 51 | if (!R || z < xN) { 52 | databuf[cnt] = z; 53 | indexbuf[cnt] = j; 54 | ++cnt; 55 | } 56 | else 57 | pr[j] = N; 58 | else 59 | pr[j] = std::numeric_limits::max(); 60 | } 61 | // FIXME: merge these two loops 62 | Details::Loop::loop(*this, resbuf, databuf, cnt); 63 | for (uint32 j = 0; j < cnt; ++j) 64 | pr[indexbuf[j]] = resbuf[j]; 65 | pr += niter; 66 | pz += niter; 67 | } 68 | } 69 | } 70 | 71 | Details::CondData x0; 72 | Details::CondData xN; 73 | Details::CondData N; 74 | }; 75 | 76 | 77 | } // namespace BinSearch 78 | -------------------------------------------------------------------------------- /include/BinSearch.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "AAlloc.h" 4 | #include "BinAlgo.h" 5 | #include "SIMD.h" 6 | 7 | #include 8 | #include 9 | 10 | 11 | #include "Algo-Direct2.h" 12 | -------------------------------------------------------------------------------- /include/Portable.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #ifdef __FMA__ 8 | #define USE_FMA 9 | #endif 10 | 11 | #ifdef __AVX2__ 12 | #define USE_AVX2 13 | #endif 14 | 15 | #ifdef __AVX__ 16 | #define USE_AVX 17 | #endif 18 | 19 | 20 | #ifdef __SSE4_1__ 21 | #define USE_SSE41 22 | #endif 23 | 24 | #ifdef __SSE4_2__ 25 | #define USE_SSE42 26 | #endif 27 | 28 | 29 | #ifndef _MSC_VER 30 | #include 31 | #endif 32 | 33 | namespace BinSearch { 34 | 35 | #ifndef _MSC_VER 36 | typedef int8_t int8; 37 | typedef uint8_t uint8; 38 | typedef int32_t int32; 39 | typedef uint32_t uint32; 40 | typedef int64_t int64; 41 | typedef uint64_t uint64; 42 | #else 43 | typedef __int8 int8; 44 | typedef unsigned __int8 uint8; 45 | typedef __int32 int32; 46 | typedef unsigned __int32 uint32; 47 | typedef __int64 int64; 48 | typedef unsigned __int64 uint64; 49 | #endif 50 | 51 | namespace Details { 52 | 53 | #define myassert(cond, msg) if (!cond){ std::ostringstream os; os << "\nassertion failed: " << #cond << ", " << msg << "\n"; throw std::invalid_argument(os.str()); } 54 | 55 | // log2 is not defined in VS2008 56 | #if defined(_MSC_VER) 57 | inline uint32 log2 (uint32 val) { 58 | if (val == 1) return 0; 59 | uint32 ret = 0; 60 | do { 61 | ret++; 62 | val >>= 1; 63 | } while (val > 1); 64 | return ret; 65 | } 66 | #endif 67 | 68 | #ifdef _DEBUG 69 | #define DEBUG 70 | #endif 71 | 72 | #ifdef _MSC_VER 73 | # define FORCE_INLINE __forceinline 74 | # define NO_INLINE __declspec(noinline) 75 | #else 76 | # define NO_INLINE __attribute__((noinline)) 77 | # ifdef DEBUG 78 | # define FORCE_INLINE NO_INLINE 79 | # else 80 | # define FORCE_INLINE __attribute__((always_inline)) inline 81 | # endif 82 | #endif 83 | 84 | #ifdef USE_AVX 85 | #define COMISS "vcomiss" 86 | #define COMISD "vcomisd" 87 | #else 88 | #define COMISS "comiss" 89 | #define COMISD "comisd" 90 | #endif 91 | 92 | // nextafter is not defined in VS2008 93 | #if defined(_MSC_VER) && (_MSC_VER <= 1500) 94 | #include 95 | inline float mynext(float x) 96 | { 97 | return _nextafterf(x, std::numeric_limits::max()); 98 | } 99 | 100 | inline double mynext(double x) 101 | { 102 | return _nextafter(x, std::numeric_limits::max()); 103 | } 104 | inline float myprev(float x) 105 | { 106 | return _nextafterf(x, -std::numeric_limits::max()); 107 | } 108 | 109 | inline double myprev(double x) 110 | { 111 | return _nextafter(x, -std::numeric_limits::max()); 112 | } 113 | #else 114 | inline float mynext(float x) 115 | { 116 | return std::nextafterf(x, std::numeric_limits::max()); 117 | } 118 | 119 | inline double mynext(double x) 120 | { 121 | return std::nextafter(x, std::numeric_limits::max()); 122 | } 123 | inline float myprev(float x) 124 | { 125 | return std::nextafterf(x, -std::numeric_limits::max()); 126 | } 127 | 128 | inline double myprev(double x) 129 | { 130 | return std::nextafter(x, -std::numeric_limits::max()); 131 | } 132 | #endif 133 | 134 | template 135 | inline T next(T x) 136 | { 137 | for (int i = 0; i < 4; ++i) 138 | x = mynext(x); 139 | return x; 140 | } 141 | 142 | template 143 | inline T prev(T x) 144 | { 145 | for (int i = 0; i < 4; ++i) 146 | x = myprev(x); 147 | return x; 148 | } 149 | 150 | } // namepsace Details 151 | } // namespace BinSearch 152 | -------------------------------------------------------------------------------- /include/Type.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "Portable.h" 8 | 9 | using std::size_t; 10 | 11 | namespace BinSearch { 12 | 13 | enum InstrSet { Scalar, SSE, AVX }; 14 | 15 | #define ALGOENUM(x, b) x, 16 | enum Algos 17 | { 18 | #include "AlgoXCodes.h" 19 | }; 20 | #undef ALGOENUM 21 | 22 | namespace Details { 23 | 24 | template 25 | struct InstrIntTraits; 26 | 27 | template 28 | struct InstrFloatTraits; 29 | 30 | // base class for algorithm supporting the method: 31 | // uint32 scalar(T z) const 32 | template 33 | struct AlgoScalarBase; 34 | 35 | // base class for algorithm supporting the following methods, constants and definitions: 36 | // static const uint32 nElem 37 | // struct Constants; 38 | // void initConstants(Constants& cst) const 39 | // void vectorial(uint32 *pr, const T *pz, const Constants& cst) const 40 | // The function vectorial processes nElem items 41 | template 42 | struct AlgoVecBase; 43 | 44 | template struct IntTraits; 45 | 46 | template <> struct IntTraits 47 | { 48 | typedef uint32 itype; 49 | }; 50 | template <> struct IntTraits 51 | { 52 | typedef uint64 itype; 53 | }; 54 | 55 | template 56 | struct Body 57 | { 58 | template 59 | FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const typename Expr::Constants& cst) 60 | { 61 | e.vectorial(ri, zi, cst); 62 | Body::template iteration(e, ri + D, zi + D, cst); 63 | } 64 | 65 | }; 66 | 67 | template <> 68 | struct Body<0> 69 | { 70 | template 71 | FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const H&) 72 | { 73 | } 74 | }; 75 | 76 | template 77 | struct Loop 78 | { 79 | typedef Algo algo_type; 80 | static const uint32 M = 4; 81 | static const uint32 D = algo_type::nElem; 82 | 83 | FORCE_INLINE static void loop(const algo_type& e, uint32 *ri, const T* zi, uint32 n) 84 | { 85 | typename algo_type::Constants cst; 86 | e.initConstants(cst); 87 | 88 | uint32 j = 0; 89 | while (j + (D*M) <= n) { 90 | Details::Body::template iteration(e, ri + j, zi + j, cst); 91 | j += (D*M); 92 | } 93 | while (j + D <= n) { 94 | e.vectorial(ri + j, zi + j, cst); 95 | j += D; 96 | } 97 | while (D > 1 && j < n) { 98 | ri[j] = e.scalar(zi[j]); 99 | j += 1; 100 | } 101 | } 102 | }; 103 | 104 | template 105 | struct _Pipeliner 106 | { 107 | template 108 | FORCE_INLINE static void go(const Expr& e, Data* d) 109 | { 110 | e.template run(d); 111 | _Pipeliner::go(e, d); 112 | } 113 | }; 114 | 115 | template 116 | struct _Pipeliner 117 | { 118 | template 119 | FORCE_INLINE static void go(const Expr& e, Data* d) 120 | { 121 | } 122 | }; 123 | 124 | template 125 | struct Pipeliner 126 | { 127 | template 128 | FORCE_INLINE static void go(const Expr& e, Data* d) 129 | { 130 | _Pipeliner::go(e, d); 131 | } 132 | }; 133 | 134 | 135 | #if 1 136 | template 137 | char is_complete_impl(char (*)[sizeof(T)]); 138 | 139 | template 140 | long is_complete_impl(...); 141 | 142 | template 143 | struct IsComplete 144 | { 145 | static const bool value = sizeof(is_complete_impl(0)) == sizeof(char); 146 | }; 147 | #else 148 | template 149 | std::true_type is_complete_impl(T *); 150 | 151 | std::false_type is_complete_impl(...); 152 | 153 | template 154 | struct IsComplete : decltype(is_complete_impl(std::declval())) {}; 155 | #endif 156 | 157 | template 158 | struct AlgoScalarToVec : AlgoScalarBase 159 | { 160 | typedef AlgoScalarBase base_t; 161 | 162 | AlgoScalarToVec(const typename base_t::Data& d) : base_t(d) {} 163 | AlgoScalarToVec(const T* px, const uint32 n) : base_t(px, n) {} 164 | 165 | static const uint32 nElem = 1; 166 | 167 | struct Constants 168 | { 169 | }; 170 | 171 | void initConstants(Constants& cst) const 172 | { 173 | } 174 | 175 | FORCE_INLINE 176 | void vectorial(uint32 *pr, const T *pz, const Constants& cst) const 177 | { 178 | *pr = base_t::scalar(*pz); 179 | } 180 | }; 181 | 182 | template 183 | struct conditional { typedef T type; }; 184 | 185 | template 186 | struct conditional { typedef F type; }; 187 | 188 | template 189 | struct CondData 190 | { 191 | FORCE_INLINE CondData(T x) : v(x) {} 192 | FORCE_INLINE operator const T&() const { return v;} 193 | private: 194 | T v; 195 | }; 196 | 197 | template 198 | struct CondData 199 | { 200 | FORCE_INLINE CondData(T) {} 201 | FORCE_INLINE operator const T() const { return 0;} 202 | }; 203 | 204 | template 205 | struct BinAlgoBase : Details::conditional< Details::IsComplete>::value 206 | , Details::AlgoVecBase 207 | , Details::AlgoScalarToVec 208 | >::type 209 | { 210 | typedef typename Details::conditional< Details::IsComplete>::value 211 | , Details::AlgoVecBase 212 | , Details::AlgoScalarToVec 213 | >::type base_t; 214 | 215 | BinAlgoBase(const T* px, const uint32 n) : base_t(px, n) {} 216 | BinAlgoBase(const typename base_t::Data& d) : base_t(d) {} 217 | }; 218 | 219 | } // namespace Details 220 | 221 | } // namespace BinSearch 222 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lion-pytorch 2 | pytest 3 | scipy 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import glob 6 | import os 7 | 8 | from setuptools import find_packages, setup 9 | 10 | libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so")) 11 | libs = [os.path.basename(p) for p in libs] 12 | print("libs:", libs) 13 | 14 | 15 | def read(fname): 16 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 17 | 18 | 19 | setup( 20 | name=f"bitsandbytes", 21 | version=f"0.40.2", 22 | author="Tim Dettmers", 23 | author_email="dettmers@cs.washington.edu", 24 | description="k-bit optimizers and matrix multiplication routines.", 25 | license="MIT", 26 | keywords="gpu optimizers optimization 8-bit quantization compression", 27 | url="https://github.com/TimDettmers/bitsandbytes", 28 | packages=find_packages(), 29 | package_data={"": libs}, 30 | long_description=read("README.md"), 31 | long_description_content_type="text/markdown", 32 | classifiers=[ 33 | "Development Status :: 4 - Beta", 34 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /tests/test_cuda_setup_evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import torch 4 | from pathlib import Path 5 | 6 | # hardcoded test. Not good, but a sanity check for now 7 | def test_manual_override(): 8 | manual_cuda_path = str(Path('/mmfs1/home/dettmers/data/local/cuda-12.2')) 9 | 10 | pytorch_version = torch.version.cuda.replace('.', '') 11 | 12 | assert pytorch_version != 122 13 | 14 | os.environ['CUDA_HOME']='{manual_cuda_path}' 15 | os.environ['CUDA_VERSION']='122' 16 | assert str(manual_cuda_path) in os.environ['LD_LIBRARY_PATH'] 17 | import bitsandbytes as bnb 18 | loaded_lib = bnb.cuda_setup.main.CUDASetup.get_instance().binary_name 19 | assert loaded_lib == 'libbitsandbytes_cuda122.so' 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /tests/test_generation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import math 4 | 5 | from itertools import product 6 | 7 | import transformers 8 | from transformers import ( 9 | AutoConfig, 10 | AutoModelForCausalLM, 11 | AutoTokenizer, 12 | BitsAndBytesConfig, 13 | GenerationConfig, 14 | set_seed, 15 | 16 | ) 17 | 18 | 19 | 20 | def get_4bit_config(): 21 | return BitsAndBytesConfig( 22 | load_in_4bit=True, 23 | load_in_8bit=False, 24 | llm_int8_threshold=6.0, 25 | llm_int8_has_fp16_weight=False, 26 | bnb_4bit_compute_dtype=torch.float16, 27 | bnb_4bit_use_double_quant=True, 28 | bnb_4bit_quant_type='nf4', 29 | ) 30 | 31 | 32 | def get_model_and_tokenizer(config): 33 | model_name_or_path, quant_type = config 34 | bnb_config = get_4bit_config() 35 | if quant_type == '16bit': 36 | bnb_config.load_in_4bit = False 37 | else: 38 | bnb_config.bnb_4bit_quant_type= quant_type 39 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, 40 | quantization_config=bnb_config, 41 | max_memory={0:'48GB'}, 42 | device_map='auto', 43 | torch_dtype=torch.bfloat16 44 | ).eval() 45 | 46 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) 47 | 48 | return model, tokenizer 49 | 50 | def get_prompt_for_generation_eval(text, add_roles=True): 51 | description = ( 52 | "A chat between a curious human and an artificial intelligence assistant. " 53 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 54 | ) 55 | if add_roles: 56 | prompt = f'{description} ### Human: {text} ### Assistant:' 57 | else: 58 | prompt = f'{description} {text}' 59 | return prompt 60 | 61 | def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_for_generation_eval): 62 | text = prompt_func(text) 63 | inputs = tokenizer(text, return_tensors="pt").to('cuda:0') 64 | outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config) 65 | return tokenizer.decode(outputs[0], skip_special_tokens=True) 66 | 67 | models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7'] 68 | dtypes = ['nf4', 'fp4'] 69 | load_in_4bit = [True, False] 70 | values = list(product(models, dtypes)) 71 | strfunc = lambda lst: [str(x) for x in lst] 72 | ids = ['_'.join(strfunc(x)) for x in values] 73 | @pytest.fixture(scope='session', params=values, ids=ids) 74 | def model_and_tokenizer(request): 75 | model, tokenizer = get_model_and_tokenizer(request.param) 76 | yield request.param, model, tokenizer 77 | del model 78 | 79 | @pytest.mark.parametrize("DQ", [True, False], ids=['DQ_True', 'DQ_False']) 80 | @pytest.mark.parametrize("inference_kernel", [True, False], ids=['inference_kernel_True', 'inference_kernel_False']) 81 | #@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) 82 | def test_pi(model_and_tokenizer, inference_kernel, DQ): 83 | print('') 84 | dtype = torch.float16 85 | 86 | fixture_config, model, tokenizer = model_and_tokenizer 87 | 88 | generation_config = transformers.GenerationConfig( 89 | max_new_tokens=20, 90 | do_sample=True, 91 | top_p=0.9, 92 | temperature=0.7, 93 | ) 94 | generation_config.max_new_tokens = 20 95 | 96 | 97 | #text = 'Please write down the first 50 digits of pi.' 98 | #text = get_prompt_for_generation_eval(text) 99 | #text += ' Sure, here the first 50 digits of pi: 3.14159' 100 | n_cases = 6 101 | text = '3.14159' 102 | if hasattr(model.config, 'quantization_config'): 103 | model.config.quantization_config.bnb_4bit_compute_dtype = dtype 104 | model.config.quantization_config.bnb_4bit_use_double_quant = DQ 105 | 106 | if not inference_kernel: 107 | text = [text]*n_cases 108 | inputs = tokenizer(text, return_tensors="pt").to('cuda:0') 109 | x = inputs['input_ids'] 110 | outputs = [] 111 | if inference_kernel: 112 | for i in range(n_cases): 113 | output = model.generate(x, generation_config=generation_config) 114 | textout = tokenizer.decode(output[0], skip_special_tokens=True) 115 | outputs.append(textout) 116 | else: 117 | outputs = model.generate(x, generation_config=generation_config) 118 | outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] 119 | 120 | 121 | assert len(outputs) == n_cases 122 | failure_count = 0 123 | for i in range(n_cases): 124 | if not outputs[i][:len(str(math.pi))] == str(math.pi): 125 | failure_count += 1 126 | failure_max = (2 if fixture_config[0] == 'huggyllama/llama-7b' else 4) 127 | if failure_count > failure_max: 128 | print(math.pi) 129 | for out in outputs: 130 | print(out) 131 | raise ValueError(f'Failure count: {failure_count}/{n_cases}') 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /tests/test_linear8bitlt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import nullcontext 3 | from itertools import product 4 | from tempfile import TemporaryDirectory 5 | 6 | import pytest 7 | import torch 8 | 9 | import bitsandbytes as bnb 10 | from bitsandbytes import functional as F 11 | from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout 12 | from bitsandbytes.nn.modules import Linear8bitLt 13 | 14 | 15 | # contributed by Alex Borzunov, see: 16 | # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py 17 | 18 | @pytest.mark.skipif( 19 | not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), 20 | reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", 21 | ) 22 | def test_layout_exact_match(): 23 | x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda() 24 | for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"): 25 | transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device) 26 | tile_indices = get_inverse_transform_indices(transform, tile_size) 27 | cxb = transform(x) 28 | 29 | torch.cuda.synchronize() 30 | restored_x = undo_layout(cxb, tile_indices) 31 | torch.cuda.synchronize() 32 | assert restored_x.is_contiguous() 33 | assert torch.all(torch.eq(restored_x, x)) 34 | 35 | 36 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") 37 | def test_linear_no_igemmlt(): 38 | linear = torch.nn.Linear(1024, 3072) 39 | x = torch.randn(3, 1024, dtype=torch.half) 40 | linear_custom = Linear8bitLt( 41 | linear.in_features, 42 | linear.out_features, 43 | linear.bias is not None, 44 | has_fp16_weights=False, 45 | threshold=6.0, 46 | ) 47 | linear_custom.state.force_no_igemmlt = True 48 | 49 | linear_custom.weight = bnb.nn.Int8Params( 50 | linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False 51 | ).to(linear.weight.dtype) 52 | linear_custom.bias = linear.bias 53 | linear_custom = linear_custom.cuda() 54 | linear = linear.half().cuda() 55 | 56 | x_ref = x.clone().cuda().requires_grad_(True) 57 | x_ours = x.clone().cuda().requires_grad_(True) 58 | fx_ref = linear(x_ref).float() 59 | grad_proj = torch.randn_like(fx_ref) 60 | (fx_ref * grad_proj).mean().backward() 61 | 62 | fx_ours = linear_custom(x_ours).float() 63 | (fx_ours * grad_proj).mean().backward() 64 | assert torch.allclose(fx_ref, fx_ours, atol=0.02) 65 | assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01) 66 | assert not linear_custom.state.has_fp16_weights 67 | assert linear_custom.state.CB is not None 68 | assert linear_custom.state.CxB is None 69 | 70 | 71 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") 72 | @pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt", 73 | list(product([False, True], [False, True], [False, True], [False, True]))) 74 | def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt): 75 | linear = torch.nn.Linear(32, 96) 76 | x = torch.randn(3, 32, dtype=torch.half) 77 | 78 | linear_custom = Linear8bitLt( 79 | linear.in_features, 80 | linear.out_features, 81 | linear.bias is not None, 82 | has_fp16_weights=has_fp16_weights, 83 | threshold=6.0, 84 | ) 85 | if force_no_igemmlt: 86 | linear_custom.state.force_no_igemmlt = True 87 | 88 | linear_custom.weight = bnb.nn.Int8Params( 89 | linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights 90 | ) 91 | linear_custom.bias = linear.bias 92 | linear_custom = linear_custom.cuda() 93 | 94 | if serialize_before_forward: 95 | state_dict_8bit = linear_custom.state_dict() 96 | 97 | x_first = x.clone().cuda().requires_grad_(True) 98 | fx_first = linear_custom(x_first).float() 99 | grad_proj = torch.randn_like(fx_first) 100 | (fx_first * grad_proj).mean().backward() 101 | 102 | if not serialize_before_forward: 103 | state_dict_8bit = linear_custom.state_dict() 104 | 105 | with TemporaryDirectory() as tmpdir: 106 | state_path_8bit = os.path.join(tmpdir, "state_8bit.pth") 107 | state_path = os.path.join(tmpdir, "state.pth") 108 | 109 | torch.save(linear.state_dict(), state_path) 110 | torch.save(state_dict_8bit, state_path_8bit) 111 | 112 | if not has_fp16_weights: 113 | assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path) 114 | 115 | new_state_dict = torch.load(state_path_8bit) 116 | 117 | new_linear_custom = Linear8bitLt( 118 | linear.in_features, 119 | linear.out_features, 120 | linear.bias is not None, 121 | has_fp16_weights=has_fp16_weights, 122 | threshold=6.0, 123 | ) 124 | if force_no_igemmlt: 125 | new_linear_custom.state.force_no_igemmlt = True 126 | 127 | if deserialize_before_cuda: 128 | with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): 129 | new_linear_custom.load_state_dict(new_state_dict, strict=True) 130 | 131 | new_linear_custom = new_linear_custom.cuda() 132 | 133 | if not deserialize_before_cuda: 134 | new_linear_custom.load_state_dict(new_state_dict, strict=True) 135 | 136 | x_second = x.clone().cuda().requires_grad_(True) 137 | fx_second = new_linear_custom(x_second).float() 138 | (fx_second * grad_proj).mean().backward() 139 | 140 | # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised 141 | if has_fp16_weights or not deserialize_before_cuda: 142 | assert torch.allclose(fx_first, fx_second, atol=1e-5) 143 | assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) 144 | -------------------------------------------------------------------------------- /tests/test_triton.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from bitsandbytes.triton.triton_utils import is_triton_available 5 | from bitsandbytes.nn.triton_based_modules import SwitchBackLinear 6 | from bitsandbytes.nn import Linear8bitLt 7 | 8 | @pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, 9 | reason="This test requires triton and a GPU with compute capability 8.0 or higher.") 10 | @pytest.mark.parametrize("vector_wise_quantization", [False, True]) 11 | def test_switchback(vector_wise_quantization): 12 | for dim in [83]: 13 | for batch in [13]: 14 | 15 | standard = torch.nn.Linear(dim, 4 * dim).cuda().half() 16 | switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() 17 | baseline = Linear8bitLt(dim, 4 * dim).cuda().half() 18 | switchback.weight.data.copy_(standard.weight) 19 | switchback.bias.data.copy_(standard.bias) 20 | baseline.weight.data.copy_(standard.weight) 21 | baseline.bias.data.copy_(standard.bias) 22 | 23 | x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True) 24 | x2 = x1.clone().detach().requires_grad_(True) 25 | x3 = x1.clone().detach().requires_grad_(True) 26 | 27 | out_standard = standard(x1) 28 | (2**10 * out_standard.abs().mean()).backward() 29 | 30 | print(x2.dtype) 31 | out_sb = switchback(x2) 32 | (2**10 * out_sb.abs().mean()).backward() 33 | 34 | out_baseline = baseline(x3) 35 | (2**10 * out_baseline.abs().mean()).backward() 36 | 37 | err_sb = (out_standard - out_sb).abs().mean() 38 | err_baseline = (out_standard - out_baseline).abs().mean() 39 | print('OUT', err_sb, err_baseline) 40 | assert err_sb < 2 * err_baseline 41 | 42 | err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() 43 | err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean() 44 | 45 | print('GW2', err_sb, err_baseline) 46 | assert err_sb < 2 * err_baseline 47 | 48 | err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() 49 | err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean() 50 | 51 | print('GW1', err_sb, err_baseline) 52 | assert err_sb < 2 * err_baseline 53 | 54 | err_sb = (x1.grad - x2.grad).abs().mean() 55 | err_baseline = (x1.grad - x3.grad).abs().mean() 56 | 57 | print('GX1', err_sb, err_baseline) 58 | assert err_sb < 2 * err_baseline 59 | 60 | --------------------------------------------------------------------------------