├── .dev-scripts ├── .gitignore ├── README.md ├── basic_tests.sh ├── extract_install_cmd.py ├── publish-docker-internal.sh ├── publish-with-docker.sh └── publish.sh ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── bitorch_engine ├── __init__.py ├── extensions │ └── __init__.py ├── functions │ ├── __init__.py │ └── cuda │ │ ├── __init__.py │ │ ├── extension.py │ │ ├── functions.py │ │ ├── functions_cuda.cpp │ │ └── functions_cuda_kernel.cu ├── layers │ ├── __init__.py │ ├── qconv │ │ ├── __init__.py │ │ ├── binary │ │ │ ├── __init__.py │ │ │ ├── cpp │ │ │ │ ├── __init__.py │ │ │ │ ├── binary_conv.cpp │ │ │ │ ├── extension.py │ │ │ │ └── layer.py │ │ │ ├── cutlass │ │ │ │ ├── __init__.py │ │ │ │ ├── binary_conv2d_cutlass.cpp │ │ │ │ ├── binary_conv2d_cutlass_kernel.cu │ │ │ │ ├── extension.py │ │ │ │ └── layer.py │ │ │ └── layer.py │ │ └── nbit │ │ │ ├── __init__.py │ │ │ ├── cutlass │ │ │ ├── __init__.py │ │ │ ├── extension.py │ │ │ ├── layer.py │ │ │ ├── q4_conv_cutlass.cpp │ │ │ └── q4_conv_cutlass_kernel.cu │ │ │ └── layer.py │ ├── qembedding │ │ ├── __init__.py │ │ └── binary │ │ │ ├── __init__.py │ │ │ └── layer.py │ ├── qlinear │ │ ├── __init__.py │ │ ├── binary │ │ │ ├── __init__.py │ │ │ ├── binary_implementation.py │ │ │ ├── cpp │ │ │ │ ├── __init__.py │ │ │ │ ├── binary_linear.cpp │ │ │ │ ├── extension.py │ │ │ │ └── layer.py │ │ │ ├── cuda │ │ │ │ ├── __init__.py │ │ │ │ ├── binary_linear_cuda.cpp │ │ │ │ ├── binary_linear_cuda_kernel.cu │ │ │ │ ├── bmm.py │ │ │ │ ├── extension.py │ │ │ │ └── layer.py │ │ │ ├── cutlass │ │ │ │ ├── __init__.py │ │ │ │ ├── binary_linear_cutlass.cpp │ │ │ │ ├── binary_linear_cutlass_kernel.cu │ │ │ │ ├── binary_linear_cutlass_kernel.h │ │ │ │ ├── extension.py │ │ │ │ ├── kernel_selection.h │ │ │ │ └── layer.py │ │ │ └── layer.py │ │ ├── layer.py │ │ ├── nbit │ │ │ ├── __init__.py │ │ │ ├── cuda │ │ │ │ ├── __init__.py │ │ │ │ ├── exl2 │ │ │ │ │ ├── compat.cuh │ │ │ │ │ ├── config.h │ │ │ │ │ ├── kernel_select.cuh │ │ │ │ │ ├── matrix_view.cuh │ │ │ │ │ ├── q_gemm_kernel.cuh │ │ │ │ │ ├── q_gemm_kernel_gptq.cuh │ │ │ │ │ ├── quant │ │ │ │ │ │ ├── qdq_2.cuh │ │ │ │ │ │ ├── qdq_3.cuh │ │ │ │ │ │ ├── qdq_4.cuh │ │ │ │ │ │ ├── qdq_5.cuh │ │ │ │ │ │ ├── qdq_6.cuh │ │ │ │ │ │ ├── qdq_8.cuh │ │ │ │ │ │ └── qdq_util.cuh │ │ │ │ │ └── util.cuh │ │ │ │ ├── extension.py │ │ │ │ ├── mbwq_layer.py │ │ │ │ ├── mbwq_linear_cuda_kernel.cu │ │ │ │ ├── mpq_layer.py │ │ │ │ ├── mpq_linear_cuda_kernel.cu │ │ │ │ ├── q_linear_cuda.cpp │ │ │ │ └── utils.py │ │ │ ├── cutlass │ │ │ │ ├── __init__.py │ │ │ │ ├── extension.py │ │ │ │ ├── q4_layer.py │ │ │ │ ├── q4_linear_cutlass_kernel.cu │ │ │ │ ├── q8_layer.py │ │ │ │ ├── q8_linear_cutlass_kernel.cu │ │ │ │ └── q_linear_cutlass.cpp │ │ │ ├── layer.py │ │ │ └── mps │ │ │ │ ├── __init__.py │ │ │ │ ├── extension.py │ │ │ │ ├── mlx_bindings.cpp │ │ │ │ ├── mpq_layer.py │ │ │ │ ├── mpq_linear_mlx.cpp │ │ │ │ └── mpq_linear_mlx.h │ │ └── qlinear_implementation.py │ └── qmha │ │ ├── __init__.py │ │ └── binary │ │ ├── __init__.py │ │ └── layer.py ├── optim │ ├── __init__.py │ ├── diode_beta.py │ └── galore_projector.py └── utils │ ├── __init__.py │ ├── arch_helper.py │ ├── convert.py │ ├── cpp_extension.py │ ├── cuda_extension.py │ ├── cutlass_path.py │ ├── mlx_extension.py │ ├── mlx_path.py │ ├── model_helper.py │ ├── quant_operators.py │ └── safe_import.py ├── docker ├── Dockerfile ├── README.md └── build_scripts │ └── install_modified_pytorch.sh ├── docs ├── .gitignore ├── Makefile ├── README.md ├── make_docs.sh ├── requirements.txt ├── scripts │ └── convert_docs.py └── source │ ├── _templates │ ├── class.rst │ └── module.rst │ ├── build_options.rst │ ├── conf.py │ ├── documentation.rst │ ├── index.rst │ └── installation.rst ├── examples ├── __init__.py ├── mnist-lightning │ ├── main.py │ ├── mlp.py │ └── requirements.txt └── mnist │ ├── README.md │ ├── __init__.py │ ├── datasets │ ├── __init__.py │ ├── base.py │ ├── dummy_dataset.py │ └── mnist.py │ ├── requirements.txt │ └── train_mnist.py ├── licenses ├── LICENSE.GPTQ-for-LLaMa.txt ├── LICENSE.cutlass.txt ├── LICENSE.exllamav2.txt ├── LICENSE.mlx.txt ├── LICENSE.pytorch.txt └── LICENSE.tcbnn.txt ├── requirements-dev.txt ├── requirements.txt ├── setup.py ├── structure.md ├── tests ├── __init__.py ├── functions │ ├── __init__.py │ └── test_quant_ops.py ├── layers │ ├── __init__.py │ ├── test_binary_conv.py │ ├── test_binary_embedding.py │ ├── test_binary_linear.py │ ├── test_custom_binary_linear.py │ ├── test_nbit_conv.py │ ├── test_nbit_linear.py │ ├── test_nbit_linear_mixbits.py │ ├── test_nbit_linear_mps.py │ └── util.py └── util │ ├── __init__.py │ └── binary_mse_loss.py └── version.txt /.dev-scripts/.gitignore: -------------------------------------------------------------------------------- 1 | test_*.sh 2 | -------------------------------------------------------------------------------- /.dev-scripts/README.md: -------------------------------------------------------------------------------- 1 | # Development Scripts 2 | 3 | To publish a binary release with docker (for CUDA), run (replace engine version and cuda version): 4 | ```bash 5 | ./.dev-scripts/publish-with-docker.sh v0.2.3 11.8 6 | ``` 7 | 8 | ## Fixing RPaths 9 | 10 | See [this github issue](https://github.com/pytorch/builder/issues/468). 11 | -------------------------------------------------------------------------------- /.dev-scripts/basic_tests.sh: -------------------------------------------------------------------------------- 1 | 2 | # bash code to run after installation to test for correct package installation 3 | 4 | if [ -n "${SKIP_TESTS}" ]; then 5 | echo "Skipping tests. Done." 6 | exit 0 7 | fi 8 | 9 | # try basic importing, should detect errors of .so loading 10 | python -c "from bitorch_engine.layers.qlinear.binary.cpp import BinaryLinearCPP" 11 | python -c "from bitorch_engine.layers.qembedding.binary import BinaryEmbeddingCuda" 12 | python -c "from bitorch_engine.layers.qlinear.nbit.cutlass import Q4LinearCutlass, Q8LinearCutlass, Q4MatMul" 13 | python -c "from bitorch_engine.layers.qlinear.nbit.cuda import MPQLinearCuda, MBWQLinearCuda" 14 | python -c "from bitorch_engine.layers.qlinear.nbit.cuda.utils import pack_fp_weight, unpack_qweight" 15 | echo "Imports successful!" 16 | 17 | set +o errexit 18 | echo "Testing..." 19 | ( 20 | rm -rf bitorch_install_tmp_test_dir 21 | mkdir bitorch_install_tmp_test_dir 22 | cd bitorch_install_tmp_test_dir 23 | git clone https://github.com/GreenBitAI/bitorch-engine.git --depth 1 --branch "v0.2.6" bitorch_engine_git 24 | mv bitorch_engine_git/tests . 25 | pip install pytest numpy 26 | pytest tests/layers/test_nbit_linear.py 27 | pytest tests/layers/test_nbit_linear_mixbits.py 28 | pytest tests/functions/test_quant_ops.py 29 | pytest tests/layers/test_binary_linear.py 30 | ) 31 | rm -rf bitorch_install_tmp_test_dir 32 | -------------------------------------------------------------------------------- /.dev-scripts/extract_install_cmd.py: -------------------------------------------------------------------------------- 1 | # take caution: everything is quite hardcoded here 2 | # any changes to the readme could break this code 3 | # run it from root directory: python extract_install_cmd.py path/to/custom/torch-xxx.whl 4 | 5 | import argparse 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("custom_pytorch_path", help="Path to custom PyTorch wheel") 8 | parser.add_argument("custom_bitorch_engine_path", help="Path to built bitorch engine wheel file") 9 | args = parser.parse_args() 10 | 11 | BLOCK_HEADER_START_BINARY = "### Binary Release" 12 | BLOCK_HEADER_START_FROM_SOURCE = "#### Conda on Linux" 13 | BLOCK_END = "##########" 14 | 15 | with open("README.md") as infile: 16 | content = infile.readlines() 17 | 18 | with open(".dev-scripts/basic_tests.sh") as infile: 19 | test_appendix = infile.readlines() 20 | 21 | 22 | def write_file(filepath, main_content): 23 | with open(filepath, "w") as outfile: 24 | outfile.write(FILE_INTRO) 25 | outfile.writelines(main_content) 26 | outfile.writelines(test_appendix) 27 | 28 | 29 | source_local_install_instructions = [] 30 | source_global_install_instructions = [] 31 | binary_local_install_instructions = [] 32 | binary_global_install_instructions = [] 33 | 34 | in_code_block = False 35 | reading_instructions = False 36 | insert_block_pause = False 37 | instruction_type = "" 38 | 39 | FILE_INTRO = """#!/usr/bin/env bash 40 | 41 | trap exit INT 42 | set -o errexit 43 | set -o xtrace 44 | 45 | """ 46 | EXTRA_CONDA_INSTRUCTION = """# extra step for bash script (not required in a proper command line): 47 | eval "$(conda shell.bash hook)" 48 | """ 49 | 50 | 51 | for line in content: 52 | if line.startswith("```"): 53 | in_code_block = not in_code_block 54 | continue 55 | if line.startswith(BLOCK_HEADER_START_FROM_SOURCE): 56 | reading_instructions = True 57 | instruction_type = "source-global" 58 | BLOCK_END = BLOCK_HEADER_START_FROM_SOURCE.split()[0] 59 | continue 60 | if line.startswith(BLOCK_HEADER_START_BINARY): 61 | reading_instructions = True 62 | instruction_type = "binary-global" 63 | BLOCK_END = BLOCK_HEADER_START_BINARY.split()[0] 64 | continue 65 | if line.startswith("
"): 66 | if "source" in instruction_type: 67 | instruction_type = "source-local" 68 | if "binary" in instruction_type: 69 | instruction_type = "binary-local" 70 | continue 71 | if line.startswith("
"): 72 | if "source" in instruction_type: 73 | instruction_type = "source-both" 74 | if "binary" in instruction_type: 75 | instruction_type = "binary-both" 76 | continue 77 | if line.startswith(BLOCK_END): 78 | reading_instructions = False 79 | continue 80 | if not reading_instructions: 81 | continue 82 | if not in_code_block: 83 | insert_block_pause = True 84 | continue 85 | 86 | # deal with comments 87 | if line.startswith("# export CC="): 88 | line = line[2:] 89 | if line.startswith("#"): 90 | continue 91 | 92 | # replace some line contents and add some lines 93 | if "conda activate" in line: 94 | line = EXTRA_CONDA_INSTRUCTION + line 95 | if "export BITORCH_WORKSPACE" in line: 96 | line = line.replace("${HOME}", "$(pwd)") 97 | if line.startswith("pip install torch-"): 98 | line = "pip install {}\n".format(args.custom_pytorch_path) 99 | if line.startswith("pip install bitorch_engine"): 100 | line = "pip install {}\n".format(args.custom_bitorch_engine_path) 101 | 102 | # decide how to write line 103 | line_format = "{line}" 104 | if line.startswith("#"): 105 | line_format = "{line}" 106 | if insert_block_pause: 107 | insert_block_pause = False 108 | line_format = "\n" + line_format 109 | 110 | # write result line(s) 111 | if instruction_type == "source-global" or instruction_type == "source-both": 112 | source_global_install_instructions.append(line_format.format(line=line)) 113 | if instruction_type == "source-local" or instruction_type == "source-both": 114 | source_local_install_instructions.append(line_format.format(line=line)) 115 | if instruction_type == "binary-global" or instruction_type == "binary-both": 116 | binary_global_install_instructions.append(line_format.format(line=line)) 117 | if instruction_type == "binary-local" or instruction_type == "binary-both": 118 | binary_local_install_instructions.append(line_format.format(line=line)) 119 | 120 | write_file(".dev-scripts/test_source_local_conda_install.sh", source_local_install_instructions) 121 | write_file(".dev-scripts/test_source_global_conda_install.sh", source_global_install_instructions) 122 | write_file(".dev-scripts/test_binary_local_conda_install.sh", binary_local_install_instructions) 123 | write_file(".dev-scripts/test_binary_global_conda_install.sh", binary_global_install_instructions) 124 | 125 | binary_local_cu118 = [line.replace("cu121", "cu118").replace("cuda-12.1.0", "cuda-11.8.0") for line in binary_local_install_instructions] 126 | write_file(".dev-scripts/test_binary_local_conda_install_cu118.sh", binary_local_cu118) 127 | binary_local_no_cuda = filter(lambda x: "nvidia/label/cuda-12.1.0" not in x, binary_local_install_instructions) 128 | write_file(".dev-scripts/test_binary_local_conda_install_no_cuda.sh", binary_local_no_cuda) 129 | -------------------------------------------------------------------------------- /.dev-scripts/publish-with-docker.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -o xtrace 4 | 5 | function usage() { 6 | echo "./.dev-scripts/publish-with-docker.sh BIE_VERSION [CUDA_VERSION]" 7 | echo "builds a package and publishes it to (test-)pypi" 8 | echo 9 | echo "BIE_VERSION must be a version string like 'v1.2.3'." 10 | echo "optional: CUDA_VERSION can be either '11.8' (not yet supported) or '12.1'. (default)" 11 | } 12 | 13 | trap exit INT 14 | 15 | if ! ((1 <= $# && $# <= 2)) || [ "${1}" = "-h" ]; then 16 | usage 17 | exit 18 | fi 19 | 20 | export PUBLISH_BIE_VERSION="${1}" 21 | CUDA_VERSION="${2:-12.1}" 22 | 23 | if ! [[ "${PUBLISH_BIE_VERSION}" =~ ^v[0-9].[0-9].[0-9]$ ]]; then 24 | echo "Invalid BIE_VERSION '${PUBLISH_BIE_VERSION}' given." 25 | echo 26 | usage 27 | exit 28 | fi 29 | 30 | cuda_known="false" 31 | build_args="" 32 | cuda_abbrev="unknown" 33 | # TODO: check support for 11.8: 34 | if [ "${CUDA_VERSION}" = "11.8" ]; then 35 | cuda_known="true" 36 | cuda_abbrev="cu118" 37 | torch_requirement="torch==2.3.0" 38 | build_args="${build_args} --build-arg FROM_IMAGE=pytorch/manylinux-builder:cuda11.8-2.3" 39 | build_args="${build_args} --build-arg CUSTOM_TORCH_URL=https://packages.greenbit.ai/whl/cu118/torch/torch-2.3.0-cp310-cp310-linux_x86_64.whl" 40 | build_args="${build_args} --build-arg TORCHVISION_INDEX_URL=https://download.pytorch.org/whl/cu118" 41 | fi 42 | if [ "${CUDA_VERSION}" = "12.1" ]; then 43 | cuda_known="true" 44 | cuda_abbrev="cu121" 45 | torch_requirement="torch==2.3.0" 46 | fi 47 | if [ "${cuda_known}" = "false" ]; then 48 | echo "Unknown CUDA_VERSION '${CUDA_VERSION}' given." 49 | echo 50 | usage 51 | exit 52 | fi 53 | 54 | echo "building bitorch engine ${PUBLISH_BIE_VERSION}" 55 | echo "building for cuda ${CUDA_VERSION}" 56 | 57 | bie_image_tag="bitorch/engine:publish-${cuda_abbrev}-${PUBLISH_BIE_VERSION}" 58 | bie_container_name="bie-${cuda_abbrev}-${PUBLISH_BIE_VERSION}" 59 | output_folder="./dist/${cuda_abbrev}" 60 | 61 | # build/tag docker image 62 | pushd docker 63 | docker build --target no-examples ${build_args} --build-arg GIT_BRANCH="${PUBLISH_BIE_VERSION}" -t "${bie_image_tag}" . 64 | popd 65 | 66 | mkdir -p "${output_folder}" 67 | 68 | docker container create -it \ 69 | --rm \ 70 | -it \ 71 | -v "${output_folder}:/bitorch-engine/dist" \ 72 | --name "${bie_container_name}" \ 73 | -e PUBLISH_BIE_VERSION \ 74 | -e BIE_FORCE_CUDA="true" \ 75 | -e BIE_SKIP_BUILD="true" \ 76 | -e USER_ID="$(id -u)" \ 77 | -e GROUP_ID="$(id -g)" \ 78 | -e BIE_TORCH_REQUIREMENT="${torch_requirement}" \ 79 | -e BIE_WHEEL_PLATFORM="linux_x86_64" \ 80 | -w /bitorch-engine \ 81 | "${bie_image_tag}" \ 82 | /workspace/publish-docker-internal.sh release 83 | 84 | # make sure correct version is set 85 | echo "${PUBLISH_BIE_VERSION#v}" > version.txt && docker container cp version.txt "${bie_container_name}":/bitorch-engine 86 | docker container cp .dev-scripts/publish-docker-internal.sh "${bie_container_name}":/workspace 87 | 88 | # for previous versions, we need to manually overwrite setup.py: 89 | # TODO: can (hopefully) be removed later on 90 | docker container cp setup.py "${bie_container_name}":/bitorch-engine 91 | 92 | docker start -ai "${bie_container_name}" 93 | -------------------------------------------------------------------------------- /.dev-scripts/publish.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | function usage() { 4 | echo "./.dev-scripts/publish.sh VERSION" 5 | echo "builds a package and publishes it to (test-)pypi" 6 | echo 7 | echo "VERSION must be either 'pre-release' or 'release'." 8 | } 9 | 10 | if ! [ "$#" = "1" ] || [ "${1}" = "-h" ]; then 11 | usage 12 | exit 13 | fi 14 | 15 | if ! [ "${1}" = "release" ] && ! [ "${1}" = "pre-release" ]; then 16 | usage 17 | exit 18 | fi 19 | 20 | # set SCRIPT_ROOT: 21 | SCRIPT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 22 | 23 | # set SRC_ROOT and make sure that is our working directory 24 | SRC_ROOT="$(readlink -f "${SCRIPT_ROOT}/..")" 25 | cd "${SRC_ROOT}" 26 | 27 | trap exit INT 28 | 29 | function check_yes() { 30 | # asks the given yes or no question, returns true if they answer YES 31 | # usage: 32 | # if check_yes "Do you really want to delete foo?"; then 33 | # rm foo 34 | # fi 35 | 36 | local prompt="${1}" 37 | read -p "${prompt} [y/N] " REPLY 38 | echo "" 39 | if [[ ! "${REPLY}" =~ ^[Yy]$ ]]; then 40 | return 1 41 | fi 42 | return 0 43 | } 44 | 45 | function check_no() { 46 | # asks the given yes or no question, returns false if they answer NO 47 | # usage: 48 | # if check_no "Do you want to exit the script?"; then 49 | # exit 0 50 | # fi 51 | 52 | local prompt="${1}" 53 | read -p "${prompt} [Y/n] " REPLY 54 | echo "" 55 | if [[ "${REPLY}" =~ ^[Nn]$ ]]; then 56 | return 1 57 | fi 58 | return 0 59 | } 60 | 61 | function check_error() { 62 | # shows and then runs a command. if the exit code is not zero, asks the user whether to continue 63 | # usage: check_error mv foo bar 64 | 65 | echo + $@ 66 | "$@" 67 | local exit_code=$? 68 | if [ "${exit_code}" -ne 0 ]; then 69 | if ! check_yes "! > An error occurred, continue with the script?"; then 70 | if [ "${1}" = "pre-release" ]; then 71 | git checkout "${version_file}" 72 | fi 73 | exit 1 74 | fi 75 | fi 76 | } 77 | 78 | # main script content 79 | 80 | if [ -z "$(git status --porcelain)" ]; then 81 | echo "Git seems clean." 82 | else 83 | echo "There are uncommitted changes, aborting." 84 | exit 1 85 | fi 86 | 87 | if [ "${1}" = "release" ]; then 88 | version_file="${SRC_ROOT}/version.txt" 89 | version_content="$(cat "${version_file}")" 90 | major_minor_patch="$(cut -d '.' -f 1,2,3 <<< "${version_content}")" 91 | version_str="${major_minor_patch}" 92 | else 93 | version_file="${SRC_ROOT}/version.txt" 94 | version_content="$(cat "${version_file}")" 95 | major_minor_patch="$(cut -d '.' -f 1,2,3 <<< "${version_content}")" 96 | date_str="$(date +"%Y%m%d")" 97 | git_ref="$(git rev-parse --short HEAD)" 98 | version_str="${major_minor_patch}.dev${date_str}+${git_ref}" 99 | fi 100 | 101 | if [ "${1}" = "release" ] && ! [ "${version_content}" = "${version_str}" ]; then 102 | echo "The file version.txt does not seem to contain a release version." 103 | exit 1 104 | else 105 | if ! check_no "Publish version ${version_str} ?"; then 106 | exit 0 107 | fi 108 | fi 109 | 110 | if [ "${1}" = "pre-release" ]; then 111 | echo "${version_str}" > "${version_file}" 112 | fi 113 | 114 | check_error pip uninstall -y -r <(pip freeze) 115 | check_error pip install --upgrade pip 116 | check_error pip install -e ".[dev]" -v 117 | 118 | pip install build twine 119 | 120 | check_error pytest . 121 | check_error python3 -m build --sdist 122 | 123 | if check_yes "Publish to real PyPi?"; then 124 | echo "Do:" 125 | echo ' python3 -m twine dist/*' 126 | else 127 | echo "Do:" 128 | echo ' python3 -m twine upload --repository testpypi dist/*' 129 | fi 130 | 131 | if [ "${1}" = "pre-release" ]; then 132 | git checkout "${version_file}" 133 | fi 134 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | README.rst 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | .coverage 142 | .idea/ 143 | .vscode/ 144 | .mypy_cache 145 | 146 | # downloaded dataset files 147 | train/ 148 | test/ 149 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | The format is based on [Keep a Changelog](http://keepachangelog.com/) 5 | and this project adheres to [Semantic Versioning](http://semver.org/). 6 | 7 | 8 | ## [0.2.6] - 2024/06/14 9 | 10 | ### Added 11 | 12 | - Installation instructions for binary releases 13 | - Warning if non-customized PyTorch version is detected which can not calculate gradients for non-complex tensor types 14 | 15 | ### Changed 16 | 17 | - Updated development scripts for binary releases 18 | - Adjusting rpaths in .so files (based on PyTorch's implemented solution) 19 | - Docker base image changed to manywheel builder image 20 | 21 | ## [0.2.5] - 2024/05/24 22 | 23 | ### Added 24 | 25 | - Development scripts for preparing binary releases 26 | 27 | ### Changed 28 | 29 | - Updated build instructions to clarify torchvision installation 30 | - Adapted `setup.py` logic for preparing binary releases 31 | 32 | ### Fixed 33 | 34 | - Broken build process by setting setuptools version 35 | 36 | ## [0.2.4] - 2024/05/23 37 | 38 | ### Added 39 | 40 | - Tuned the hyperparameters of DiodeMix optimizer for sft. 41 | - Added sft-support for the classical gptq-style models. 42 | - Implemented qzeros update in finetuning process. 43 | 44 | ### Updated 45 | 46 | - Extended pack_fp_weight function. 47 | - Enhanced the performance of MPQLinearCUDA layer. 48 | 49 | ### Fixed 50 | 51 | - Fixed various errors in DiodeMix update function. 52 | 53 | ## [0.2.3] - 2024/05/01 54 | 55 | ### Updated 56 | 57 | - Enhanced the performance of the MBWQ linear layer for processing long sequences, addressing previous inefficiencies. 58 | 59 | ## [0.2.2] - 2024/04/29 60 | 61 | ### Updated 62 | 63 | - Building instructions (adding a section for cutlass) 64 | - Checksums for custom torch builds (within docker) 65 | 66 | ### Fixed 67 | 68 | - An error in `pack_fp_weight` 69 | 70 | ## [0.2.1] - 2024/04/27 71 | 72 | ### Fixed 73 | 74 | - Broken links in README.md and index.rst 75 | 76 | ## [0.2.0] - 2024/03/10 77 | 78 | ### Added 79 | 80 | - Quantized layers with different acceleration options 81 | - QConv (binary, quantized) - CPU, Cutlass 82 | - QLinear (binary, quantized, mixed bit-width) - CUDA, Cutlass, MPS 83 | - QEmbedding (binary) 84 | - Optimizer(s) for quantized layers 85 | - Hybrid optimizer `diode_beta` based on Diode v1 (binary) and AdamW (quantized) for memory-efficient training 86 | - Initial support for galore projection 87 | - Examples 88 | - MNIST training script with and without PyTorch Lightning 89 | 90 | ## [0.1.0] - 2023/01/13 91 | 92 | The first release of basic functionality. 93 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Bitorch Engine 2 | Your contributions are welcomed and appreciated. We strive to make the process straightforward and transparent. 3 | 4 | ## Development 5 | 1. Propose significant changes via an Request for Comments (RFC) for discussion. 6 | 2. Add features by submitting a PR with accompanying tests and documentation. 7 | 3. Fix bugs by submitting a PR that includes tests validating the fix and any necessary documentation updates. 8 | 9 | ## Pull Requests 10 | We encourage your pull requests. 11 | 12 | 1. Fork the repository and create your branch from main. 13 | 2. Include tests for any new code. 14 | 3. Update documentation for all related functions. 15 | 4. Ensure all tests are passing. 16 | 17 | ## Reporting Issues 18 | We track bugs with GitHub issues. Provide a clear description and instructions to reproduce the issue. 19 | 20 | ## Your Contributions 21 | Contributions are **licensed under the LICENSE** file found in the root directory of this source tree. -------------------------------------------------------------------------------- /bitorch_engine/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def initialize(): 4 | """ 5 | This functions makes all custom layer implementations available in BITorch. 6 | """ 7 | from .layers.qlinear import QLinearInf 8 | 9 | 10 | import os 11 | torch_int_gradients_support = None 12 | if os.environ.get("BIE_SKIP_TORCH_CHECK", "false") == "false" and torch_int_gradients_support is None: 13 | import warnings 14 | torch_int_gradients_support = False 15 | try: 16 | import torch 17 | x = torch.nn.Parameter(torch.zeros((1,), dtype=torch.uint8), requires_grad=True) 18 | torch_int_gradients_support = True 19 | except RuntimeError as e: 20 | if "dtype" in str(e).lower() and "only" in str(e).lower(): 21 | warnings.warn( 22 | "It seems a regular version of torch is installed.\n" 23 | " Please install the custom torch with enabled gradient calculation for integer tensors.\n" 24 | " Check the instructions at https://github.com/GreenBitAI/bitorch-engine for more information.") 25 | else: 26 | warnings.warn("There may be a problem with the currently installed version of torch:\n" + str(e)) 27 | except ModuleNotFoundError as e: 28 | # if torch is not installed, we can not check 29 | pass 30 | -------------------------------------------------------------------------------- /bitorch_engine/extensions/__init__.py: -------------------------------------------------------------------------------- 1 | EXTENSION_PREFIX = "bitorch_engine.extensions." 2 | -------------------------------------------------------------------------------- /bitorch_engine/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/bitorch_engine/functions/__init__.py -------------------------------------------------------------------------------- /bitorch_engine/functions/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * -------------------------------------------------------------------------------- /bitorch_engine/functions/cuda/extension.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension 4 | 5 | CUDA_REQUIRED = True 6 | 7 | 8 | def get_ext(path: Path): 9 | """ 10 | Generate CUDA extension for specified path. 11 | 12 | Args: 13 | path (Path): Path to the directory containing CUDA extension files. 14 | 15 | Returns: 16 | Extension: CUDA extension for specified path. 17 | """ 18 | return get_cuda_extension( 19 | path, 20 | relative_name='functions_cuda', 21 | relative_sources=[ 22 | 'functions_cuda.cpp', 23 | 'functions_cuda_kernel.cu', 24 | ] 25 | ) 26 | -------------------------------------------------------------------------------- /bitorch_engine/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/bitorch_engine/layers/__init__.py -------------------------------------------------------------------------------- /bitorch_engine/layers/qconv/__init__.py: -------------------------------------------------------------------------------- 1 | from .binary import * -------------------------------------------------------------------------------- /bitorch_engine/layers/qconv/binary/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import BinaryConv2dBase, BinaryConvParameter -------------------------------------------------------------------------------- /bitorch_engine/layers/qconv/binary/cpp/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import BinaryConv2dCPP -------------------------------------------------------------------------------- /bitorch_engine/layers/qconv/binary/cpp/extension.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from bitorch_engine.utils.cpp_extension import get_cpp_extension 4 | 5 | 6 | def get_ext(path: Path): 7 | """ 8 | Retrieves the C++ extension for binary convolution. 9 | 10 | Args: 11 | path (Path): The path to the directory containing the binary convolution C++ code. 12 | 13 | Returns: 14 | torch.utils.cpp_extension.CppExtension: The C++ extension for binary convolution. 15 | """ 16 | return get_cpp_extension( 17 | path, 18 | relative_name='binary_conv_cpp', 19 | relative_sources=['binary_conv.cpp'] 20 | ) 21 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qconv/binary/cpp/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from bitorch_engine.utils.safe_import import import_extension 5 | 6 | binary_conv_cpp = import_extension("binary_conv_cpp") 7 | 8 | 9 | from bitorch_engine.utils.quant_operators import get_binary_row 10 | from ..layer import BinaryConv2dBase 11 | 12 | class BinaryConv2dForward(Function): 13 | """ 14 | A custom autograd function to perform forward pass of a 2D binary convolution. 15 | 16 | This class implements a static method `forward` to carry out the convolution operation 17 | using binary weights and activations. The operation is performed using a custom C++ 18 | backend for efficiency. 19 | 20 | Attributes: 21 | - No class-level attributes. 22 | 23 | Methods: 24 | - forward: Performs the forward pass of the binary convolution. 25 | """ 26 | @staticmethod 27 | def forward(ctx, activations: torch.Tensor, weights: torch.Tensor, m: int, n: int, k: int, kernel_size: int, 28 | stride: int, padding: int, dilation: int, output_edge: int) -> torch.Tensor: 29 | """ 30 | Forward pass for the 2D binary convolution. 31 | 32 | Utilizes a C++ backend implemented in `binary_conv_cpp.forward` to perform the operation. 33 | This method is statically defined and automatically integrated with PyTorch's autograd mechanism. 34 | 35 | Parameters: 36 | - ctx (torch.autograd.function.BackwardContext): Context object that can be used to stash information 37 | for backward computation. You can cache arbitrary objects for use in the backward pass using 38 | the `save_for_backward` method. 39 | - activations (Tensor): The input feature map or activation tensor. 40 | - weights (Tensor): The binary weights tensor. 41 | - m, n, k (int): Dimensions of the input, specifically: 42 | - m: The number of output channels. 43 | - n: The number of input channels. 44 | - k: The spatial size of the output feature map. 45 | - kernel_size (int or tuple): Size of the conv kernel. 46 | - stride (int or tuple): Stride of the convolution. 47 | - padding (int or tuple): Zero-padding added to both sides of the input. 48 | - dilation (int or tuple): Spacing between kernel elements. 49 | - output_edge (int): The size of the output edge to ensure the output dimension matches expectations. 50 | 51 | Returns: 52 | - Tensor: The output feature map resulting from the binary convolution operation. 53 | 54 | Note: 55 | This method is part of the forward pass and needs to be paired with a corresponding backward 56 | method to enable gradient computation. 57 | """ 58 | output = binary_conv_cpp.forward(activations, weights, m, n, k, kernel_size, stride, padding, 59 | dilation, output_edge) 60 | return output 61 | 62 | 63 | class BinaryConv2dCPP(BinaryConv2dBase): 64 | """ 65 | This class implements a binary convolutional layer in PyTorch, specifically optimized with C++ extensions. 66 | It inherits from BinaryConv2dBase to leverage common binary convolution functionalities with added 67 | optimizations for efficient computation. 68 | 69 | Attributes: 70 | bits_binary_word (int): Defines the size of the binary word, defaulting to 8 bits. 71 | """ 72 | def __init__(self, *args, **kwargs): 73 | """ 74 | Initializes the BinaryConv2dCPP layer with the given arguments, which are forwarded to the base class. 75 | Additionally, it sets up the binary word size for quantization. 76 | 77 | Args: 78 | *args: Variable length argument list to be passed to the BinaryConv2dBase class. 79 | **kwargs: Arbitrary keyword arguments to be passed to the BinaryConv2dBase class. 80 | """ 81 | super(BinaryConv2dCPP, self).__init__(*args, **kwargs) 82 | self.bits_binary_word = 8 83 | 84 | def prepare_params(self) -> None: 85 | """ 86 | Prepares and initializes the model parameters for training. 87 | One can use "prepare_bie_layers" method from project_root.utils.model_helper to call this function. 88 | """ 89 | pass 90 | 91 | def generate_quantized_weight(self, qweight_only: bool = False) -> None: 92 | """ 93 | Generates and stores quantized weights based on the current weights of the layer, utilizing a binary 94 | quantization method. Quantized weights are stored as a torch.nn.Parameter but are not set to require gradients. 95 | 96 | Args: 97 | qweight_only (bool): If True, the original weights are discarded to save memory. Defaults to False. 98 | """ 99 | w_size = self.out_channels * self.in_channels/self.bits_binary_word * self.kernel_size * self.kernel_size 100 | self.qweight = torch.nn.Parameter( 101 | get_binary_row(self.weight.reshape(-1, ), 102 | torch.empty(int(w_size), dtype=torch.uint8), 103 | w_size * self.bits_binary_word, 104 | self.bits_binary_word), 105 | requires_grad=False 106 | ) 107 | if qweight_only: 108 | self.weight = None 109 | 110 | def forward(self, x: torch.Tensor) -> torch.Tensor: 111 | """ 112 | Defines the forward pass for the binary convolution operation using the quantized weights. 113 | 114 | Args: 115 | x (torch.Tensor): The input tensor for the convolution operation with shape (N, C_in, H, W), 116 | where N is the batch size, C_in is the number of channels, and H, W are the height 117 | and width of the input tensor. 118 | 119 | Returns: 120 | torch.Tensor: The output tensor of the convolution operation with shape determined by the layer's 121 | attributes and the input dimensions. 122 | """ 123 | self._check_forward(x) 124 | # pass m, n, k 125 | m = self.out_channels # number of output channel 126 | k = x.size(dim=1) * self.kernel_size * self.kernel_size; # number of input channels * kernel size 127 | # (Image_w – filter_w + 2*pad_w) / stride + 1 128 | output_edge = int((x.size(dim=2) - self.kernel_size + 2 * self.padding) / self.stride + 1) 129 | n = output_edge * output_edge # number of pixels of output images per channel 130 | return BinaryConv2dForward.apply(x, self.opt_weight, m, n, k, self.kernel_size, self.stride, self.padding, 131 | self.dilation, output_edge) -------------------------------------------------------------------------------- /bitorch_engine/layers/qconv/binary/cutlass/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import BinaryConv2dCutlass -------------------------------------------------------------------------------- /bitorch_engine/layers/qconv/binary/cutlass/binary_conv2d_cutlass.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | 6 | /** 7 | * Performs a forward pass of binary convolution using the CUTLASS library. 8 | * This function is optimized for binary convolutions, leveraging the efficiency of CUTLASS kernels. 9 | * 10 | * Args: 11 | * input (torch::Tensor): The input tensor with shape [batch_size, in_channels, in_height, in_width]. 12 | * weight (torch::Tensor): The filter weights tensor with shape [out_channels, kernel_size, kernel_size, in_channels]. 13 | * scale (float): A scaling factor applied to the output tensor. 14 | * is_train (bool): A flag indicating whether the operation is being performed during training. 15 | * This influences the processing of the weight tensor. 16 | * kernel_size (int): The size of the convolution kernel. 17 | * stride (int): The stride of the convolution. 18 | * padding (int): The padding added to the input tensor. 19 | * dilation (int): The dilation factor for the convolution. 20 | * device_id (int): The ID of the CUDA device on which to perform the operation. 21 | * 22 | * Returns: 23 | * torch::Tensor: The output tensor of the convolution, scaled by the 'scale' parameter. 24 | * The output tensor has shape [batch_size, out_edge, out_edge, out_channels], 25 | * where 'out_edge' is computed based on the input dimensions, padding, and stride. 26 | * 27 | * Notes: 28 | * - The function sets the CUDA device to 'device_id' at the beginning. 29 | * - It calculates the output tensor dimensions based on the input size, kernel size, stride, and padding. 30 | * - The weights are optionally preprocessed (viewed and packed) based on the training mode. 31 | * - The input tensor is reshaped and packed for efficient processing. 32 | * - The actual convolution operation is performed by a call to the 'xnor_cutlass::_impl_conv_forward' function, 33 | * which utilizes CUTLASS kernels optimized for binary convolutions. 34 | * - Finally, the output tensor is scaled by the 'scale' parameter before being returned. 35 | */ 36 | torch::Tensor binary_conv2d_cutlass_forward( 37 | torch::Tensor input, 38 | torch::Tensor weight, 39 | float scale, 40 | bool is_train, 41 | int kernel_size, 42 | int stride, 43 | int padding, 44 | int dilation); 45 | 46 | 47 | /** 48 | * Performs binary convolution operation with weight packing using CUTLASS. 49 | * 50 | * This function adapts the input data tensor for binary convolution by rearranging its dimensions to match 51 | * the expected format {OHWC} (Output Channels, Height, Width, Input Channels) and then packs the data to optimize 52 | * the convolution operation. It leverages CUTLASS kernels for efficient computation. 53 | * 54 | * Args: 55 | * data (torch::Tensor): The input tensor for the convolution operation. Expected to have dimensions 56 | * {Output Channels, Input Channels, Kernel Height, Kernel Width}. 57 | * 58 | * Returns: 59 | * torch::Tensor: A tensor containing the packed data, ready for efficient binary convolution with CUTLASS. 60 | */ 61 | torch::Tensor binary_conv2d_w_pack_cutlass( 62 | torch::Tensor data); 63 | 64 | 65 | // C++ interface 66 | 67 | 68 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 69 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") 70 | #define CHECK_INPUT(x) CHECK_CUDA(x); 71 | 72 | 73 | torch::Tensor binary_conv2d_forward( 74 | torch::Tensor input, 75 | torch::Tensor weight, 76 | float scale, 77 | bool is_train, 78 | int kernel_size, 79 | int stride, 80 | int padding, 81 | int dilation 82 | ) { 83 | CHECK_INPUT(input); 84 | CHECK_INPUT(weight); 85 | return binary_conv2d_cutlass_forward(input, weight, scale, is_train, kernel_size, stride, padding, dilation); 86 | } 87 | 88 | 89 | torch::Tensor binary_conv2d_w_pack( 90 | torch::Tensor data 91 | ){ 92 | CHECK_INPUT(data); 93 | return binary_conv2d_w_pack_cutlass(data); 94 | } 95 | 96 | 97 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 98 | m.def("forward", &binary_conv2d_forward, "scaled binary conv2d forward (CUTLASS)"); 99 | m.def("w_pack", &binary_conv2d_w_pack, "packing binary weight (CUTLASS)"); 100 | } 101 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qconv/binary/cutlass/extension.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension 4 | 5 | CUDA_REQUIRED = True 6 | CUTLASS_REQUIRED = True 7 | 8 | 9 | def get_ext(path: Path): 10 | """ 11 | Obtains the CUDA extension for Cutlass-based binary convolution. 12 | 13 | Args: 14 | path (Path): The path to the directory containing the necessary source files 15 | for the Cutlass-based binary convolution operation. 16 | 17 | Returns: 18 | Extension: The CUDA extension for the Cutlass-based binary convolution. 19 | """ 20 | return get_cuda_extension( 21 | path, 22 | relative_name='binary_conv2d_cutlass', 23 | relative_sources=[ 24 | 'binary_conv2d_cutlass.cpp', 25 | 'binary_conv2d_cutlass_kernel.cu', 26 | ] 27 | ) 28 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qconv/nbit/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import nBitConv2dBase, nBitConvParameter -------------------------------------------------------------------------------- /bitorch_engine/layers/qconv/nbit/cutlass/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import Q4Conv2dCutlass -------------------------------------------------------------------------------- /bitorch_engine/layers/qconv/nbit/cutlass/extension.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension 4 | 5 | CUDA_REQUIRED = True 6 | CUTLASS_REQUIRED = True 7 | 8 | 9 | def get_ext(path: Path): 10 | """ 11 | Get CUDA extension for a specified path. 12 | 13 | Args: 14 | path (Path): The path to the directory containing CUDA extension files. 15 | 16 | Returns: 17 | Extension: The CUDA extension object. 18 | """ 19 | return get_cuda_extension( 20 | path, 21 | relative_name='q4_conv_cutlass', 22 | relative_sources=[ 23 | 'q4_conv_cutlass.cpp', 24 | 'q4_conv_cutlass_kernel.cu', 25 | ] 26 | ) 27 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qconv/nbit/cutlass/q4_conv_cutlass.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | 5 | // CUTLASS forward declarations 6 | 7 | /** 8 | * Performs a forward pass of the quantized 4-bit convolution (q4_conv2d) using the CUTLASS library. 9 | * This function takes a 4-bit quantized input and weight tensors, along with convolution parameters 10 | * like scale factors, kernel size, stride, padding, and dilation, to perform the convolution operation 11 | * optimized for CUDA. It's designed to work with NHWC tensor format for efficient computation. 12 | * 13 | * Parameters: 14 | * input - The input tensor in NCHW format that will be converted to NHWC internally. 15 | * weight - The weight tensor, which can be either pre-packed (in inference mode) or will be packed during training. 16 | * scale_a - The scale factor for the input tensor quantization. 17 | * scale_w - The scale factor for the weight tensor quantization. 18 | * is_train - A boolean flag indicating whether the operation is for training. Affects weight processing. 19 | * kernel_size - The size of the convolution kernel. 20 | * stride - The stride of the convolution. 21 | * padding - The padding added to both sides of the input tensor. 22 | * dilation - The spacing between kernel elements. 23 | * 24 | * Returns: 25 | * A tensor containing the result of the quantized convolution operation, scaled by the product of input 26 | * and weight scale factors. 27 | */ 28 | std::vector q4_conv2d_cutlass_forward( 29 | torch::Tensor input, 30 | torch::Tensor weight, 31 | float scale_a, 32 | float scale_w, 33 | bool is_train, 34 | int kernel_size, 35 | int stride, 36 | int padding, 37 | int dilation); 38 | 39 | 40 | /** 41 | * 42 | * This function prepares the weight tensor for a quantized 4-bit convolution operation. 43 | * It takes a weight tensor and a scale factor as inputs, restructures the weight tensor for the 44 | * convolution operation, and quantizes it to 4 bits. This preparation is crucial for leveraging 45 | * CUTLASS's efficient low-bit computation capabilities. 46 | * 47 | * Parameters: 48 | * - weight: The original weight tensor of the convolutional layer. 49 | * - scale: The scaling factor used for quantization of the weights to 4-bit precision. 50 | * 51 | * Returns: 52 | * - A tensor representing the packed and quantized weights, ready for use in a 4-bit convolution operation. 53 | */ 54 | torch::Tensor q4_conv2d_w_pack_cutlass( 55 | torch::Tensor weight, 56 | float scale); 57 | 58 | 59 | // C++ interface 60 | 61 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 62 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") 63 | #define CHECK_INPUT(x) CHECK_CUDA(x); 64 | 65 | 66 | std::vector q4_conv2d_forward( 67 | torch::Tensor input, 68 | torch::Tensor weight, 69 | float scale_a, 70 | float scale_w, 71 | bool is_train, 72 | int kernel_size, 73 | int stride, 74 | int padding, 75 | int dilation) { 76 | CHECK_INPUT(input); 77 | CHECK_INPUT(weight); 78 | return q4_conv2d_cutlass_forward(input, weight, scale_a, scale_w, is_train, 79 | kernel_size, stride, padding, dilation); 80 | } 81 | 82 | 83 | torch::Tensor q4_conv2d_w_pack( 84 | torch::Tensor weight, 85 | float scale 86 | ) { 87 | CHECK_INPUT(weight); 88 | return q4_conv2d_w_pack_cutlass(weight, scale); 89 | } 90 | 91 | 92 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 93 | m.def("forward", &q4_conv2d_forward, "4-bit conv forward (CUTLASS)"); 94 | m.def("w_pack", &q4_conv2d_w_pack, "pack q4 weight"); 95 | } 96 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qembedding/__init__.py: -------------------------------------------------------------------------------- 1 | from .binary import * 2 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qembedding/binary/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .layer import BinaryEmbeddingBag, BinaryEmbeddingParameter, BinaryEmbeddingCuda -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import QLinearInf -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/binary/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import BinaryLinearBase, BinaryLinearParameter 2 | import torch.cuda 3 | 4 | 5 | def get_best_binary_implementation(): 6 | if torch.cuda.is_available(): 7 | from .cuda import BinaryLinearCuda 8 | return BinaryLinearCuda 9 | else: 10 | from .cpp import BinaryLinearCPP 11 | return BinaryLinearCPP 12 | 13 | 14 | BinaryLinear = get_best_binary_implementation() 15 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/binary/binary_implementation.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Tuple 3 | 4 | from bitorch.layers import QLinearBase 5 | from bitorch.layers.extensions import LayerRecipe 6 | from bitorch.quantizations import Sign, SwishSign 7 | 8 | from bitorch_engine.layers.qlinear.qlinear_implementation import QLinearImplementationMixin 9 | 10 | 11 | class BinaryLinearImplementationMixin(QLinearImplementationMixin, ABC): 12 | """ 13 | A mixin class for binary linear layer implementations that extends the quantized linear layer implementation mixin (QLinearImplementationMixin). 14 | This class provides specialized methods to determine if a layer can be cloned based on the quantization functions used for inputs and weights. 15 | 16 | The class supports binary quantization functions such as Sign and SwishSign for both inputs and weights. It leverages the `can_clone` class method 17 | to check if the specified quantization functions are supported for cloning a layer according to a given recipe. 18 | 19 | Attributes: 20 | None specified explicitly, but inherits from QLinearImplementationMixin and ABC. 21 | 22 | Methods: 23 | can_clone: Class method to determine if a layer can be cloned based on its quantization functions for inputs and weights. 24 | """ 25 | @classmethod 26 | def can_clone(cls, recipe: LayerRecipe) -> Tuple[bool, str]: 27 | """ 28 | Determines if a layer can be cloned based on its quantization functions for inputs and weights. 29 | 30 | This method checks if the layer's input and weight quantization functions are among the supported binary quantization functions. 31 | If either quantization function is not supported, the method returns False along with a message indicating which quantization 32 | function is not supported. 33 | 34 | Args: 35 | recipe (LayerRecipe): An object containing the configuration and parameters for the layer to be cloned. 36 | 37 | Returns: 38 | Tuple[bool, str]: A tuple containing a boolean indicating whether the layer can be cloned and a string message. 39 | If the layer can be cloned, the boolean is True, and the string is empty. If the layer cannot be cloned due to unsupported 40 | quantization functions, the boolean is False, and the string contains a message specifying the unsupported quantization function. 41 | """ 42 | supported_quantization_functions = (Sign, SwishSign) # Define supported quantization functions 43 | args = QLinearBase.get_args_as_kwargs(recipe) # Retrieve layer arguments as keyword arguments 44 | 45 | # Check if input quantization function is supported 46 | if args["input_quantization"].__class__ not in supported_quantization_functions: 47 | return False, f"the input quantization {args['input_quantization'].name} is not yet supported." 48 | 49 | # Check if weight quantization function is supported 50 | if args["weight_quantization"].__class__ not in supported_quantization_functions: 51 | return False, f"the weight quantization {args['weight_quantization'].name} is not yet supported." 52 | 53 | # Call superclass method to perform any additional checks 54 | return super().can_clone(recipe) 55 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/binary/cpp/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import BinaryLinearCPP 2 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/binary/cpp/extension.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from bitorch_engine.utils.cpp_extension import get_cpp_extension 4 | 5 | 6 | def get_ext(path: Path): 7 | """Retrieve C++ extension details for binary linear module. 8 | 9 | Args: 10 | path (Path): Path to the directory containing the extension module. 11 | 12 | Returns: 13 | Any: Extension module details. 14 | """ 15 | return get_cpp_extension( 16 | path, 17 | relative_name='binary_linear_cpp', 18 | relative_sources=['binary_linear.cpp'] 19 | ) 20 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/binary/cpp/layer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from bitorch import RuntimeMode 5 | from bitorch.layers import QLinearBase 6 | from bitorch.layers.extensions import LayerRecipe 7 | from bitorch.layers.register import QLinearImplementation 8 | from torch.autograd import Function 9 | 10 | from bitorch_engine.utils.safe_import import import_extension 11 | from ..binary_implementation import BinaryLinearImplementationMixin 12 | from ..layer import BinaryLinearBase 13 | from bitorch_engine.utils.model_helper import flatten_x, unflatten_x 14 | 15 | binary_linear_cpp = import_extension("binary_linear_cpp") 16 | 17 | 18 | class BinaryLinearForward(Function): 19 | """ 20 | A custom autograd function for performing forward pass of binary linear layer. 21 | This function uses a custom C++ backend for efficient computation. 22 | 23 | Args: 24 | ctx (torch.autograd.function.FunctionCtx): The context for storing information for backward computation. 25 | input (torch.Tensor): The input tensor. 26 | weights (torch.Tensor): The binary weights tensor. 27 | m (int): The batch size. 28 | n (int): The number of output features. 29 | k (int): The number of input features. 30 | 31 | Returns: 32 | torch.Tensor: The output tensor after applying the binary linear transformation. 33 | """ 34 | @staticmethod 35 | def forward(ctx, input: torch.Tensor, weights: torch.Tensor, m: int, n: int, k: int) -> torch.Tensor: 36 | input, shape = flatten_x(input) 37 | output = binary_linear_cpp.forward(input, weights, m, n, k) 38 | output = unflatten_x(output, shape) 39 | return output 40 | 41 | 42 | @QLinearImplementation(RuntimeMode.CPU) 43 | class BinaryLinearCPP(BinaryLinearImplementationMixin, BinaryLinearBase): 44 | """ 45 | A class representing the binary linear layer implemented in C++ for CPU runtime mode. 46 | Inherits from BinaryLinearBase and mixes in BinaryLinearImplementationMixin for common functionality. 47 | 48 | This class supports creating a clone of itself from a given LayerRecipe, allowing for easy replication 49 | and modification of layer parameters. 50 | """ 51 | @classmethod 52 | def create_clone_from(cls, recipe: LayerRecipe) -> Any: 53 | """ 54 | Creates a clone of this layer based on the provided LayerRecipe. 55 | 56 | Args: 57 | recipe (LayerRecipe): The recipe containing the parameters for the clone. 58 | 59 | Returns: 60 | Any: A new instance of this class with parameters derived from the recipe. 61 | """ 62 | args = QLinearBase.get_args_as_kwargs(recipe) 63 | input_features, output_features = args["in_features"], args["out_features"] 64 | input_features //= 8 65 | new_layer = cls(input_features, output_features) 66 | new_layer.set_weight_data(recipe.layer.weight.data) 67 | new_layer.generate_quantized_weight(qweight_only=True) 68 | return new_layer 69 | 70 | def __init__( 71 | self, 72 | input_features: int, 73 | out_features: int, 74 | device: torch.device = None, 75 | ) -> None: 76 | """ 77 | Initializes the BinaryLinearCPP layer. 78 | 79 | Args: 80 | input_features (int): The number of input features (divided by 8 for binary). 81 | out_features (int): The number of output features. 82 | device (torch.device, optional): The device on which to perform computations. 83 | """ 84 | super().__init__(input_features, out_features, device) 85 | 86 | def prepare_params(self) -> None: 87 | """ 88 | Prepares and initializes the model parameters for training. 89 | One can use "prepare_bie_layers" method from project_root.utils.model_helper to call this function. 90 | """ 91 | pass 92 | 93 | def generate_quantized_weight(self, qweight_only: bool = False) -> None: 94 | """ 95 | Generates the quantized weight matrix for this layer and optionally clears the original weight. 96 | 97 | Args: 98 | qweight_only (bool, optional): If True, the original weight matrix is cleared to save memory. 99 | """ 100 | # Generate packed weight using custom C++ function 101 | self.qweight = binary_linear_cpp.w_pack( 102 | self.weight, # Original weight 103 | self.output_features, # n 104 | self.input_features, # k 105 | ) 106 | if qweight_only: 107 | self.weight = None # Clear the original weight matrix if specified 108 | 109 | def forward(self, x: torch.Tensor) -> torch.Tensor: 110 | """ 111 | Defines the forward pass of the binary linear layer. 112 | 113 | Args: 114 | x (torch.Tensor): The input tensor. 115 | 116 | Returns: 117 | torch.Tensor: The output tensor after applying the binary linear transformation. 118 | """ 119 | # Check input validity 120 | self._check_forward(x) 121 | # pass m, n, k 122 | m = x.size(dim=0) # batch size 123 | k = x.size(dim=1) # input features 124 | n = self.output_features # output features 125 | return BinaryLinearForward.apply(x, self.opt_weight, m, n, k) 126 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/binary/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .bmm import BMM 4 | 5 | if torch.cuda.is_available(): 6 | from .layer import BinaryLinearCuda 7 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/binary/cuda/binary_linear_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | 6 | /** 7 | * Performs a forward pass of the binary linear layer using CUDA. 8 | * 9 | * This function facilitates binary linear operations with support for different data types 10 | * for input and weight tensors, including floating point and quantized types. It leverages 11 | * CUDA for efficient computation, especially suited for deep learning models running on GPU. 12 | * 13 | * @param input The input tensor, which can be of type torch::kFloat32 (float), torch::kBFloat16 (bfloat16), 14 | * or torch::kHalf (half). 15 | * @param weight The weight tensor, which supports torch::kInt8 (int8), torch::kFloat32 (float), 16 | * and torch::kUInt8 (uint8) data types. 17 | * @param bmm_type An integer specifying the type of binary matrix multiplication to perform. 18 | * This parameter allows for customization of the operation based on the model's requirements. 19 | * @param transpose A boolean indicating whether the weight matrix should be transposed during the operation. 20 | * 21 | * @return A tensor containing the result of the binary linear operation. 22 | * 23 | * @note This function dynamically dispatches to specialized template functions based on the data types of 24 | * the input and weight tensors. It supports a combination of float, bfloat16, half, int8, and uint8 25 | * types, ensuring flexibility in handling various neural network architectures. 26 | * If the data type of the input or weight tensor is not supported, the function will terminate the 27 | * program and print an error message. 28 | */ 29 | torch::Tensor binary_linear_cuda_forward( 30 | torch::Tensor input, 31 | torch::Tensor weights, 32 | int bmm_type, 33 | bool transpose); 34 | 35 | 36 | /** 37 | * Converts a given weight tensor to its binary representation based on the specified data type. 38 | * 39 | * This function supports weight tensors of different data types (int8, float, bfloat16, and half) 40 | * and converts them to a binary format suitable for certain binary matrix multiplication (BMM) operations. 41 | * The conversion process is dependent on the data type of the input tensor and whether the tensor 42 | * should be transposed as part of the conversion. 43 | * 44 | * @param weight The input weight tensor to be converted to binary format. 45 | * @param bmm_type An integer specifying the type of binary matrix multiplication operation. 46 | * This parameter can influence how the binary conversion is performed. 47 | * @param transpose A boolean indicating whether the weight tensor should be transposed 48 | * as part of the conversion process. 49 | * @return torch::Tensor A tensor containing the binary representation of the input weight tensor. 50 | * The specific format of the binary representation is determined by the 51 | * data type of the input tensor. 52 | * 53 | * @note This function is templated to handle different data types of the input tensor by 54 | * calling the appropriate specialized version of the _get_binary_weight_cuda function. 55 | * If the data type of the input tensor is not supported, the function prints an error message 56 | * and exits the program. 57 | */ 58 | torch::Tensor get_binary_weight_cuda( 59 | torch::Tensor weights, 60 | int bmm_type, 61 | bool transpose); 62 | 63 | 64 | /** 65 | * Performs binary matrix multiplication on CUDA using specified data types. 66 | * 67 | * This function dispatches the binary matrix multiplication operation to specialized 68 | * CUDA kernels based on the data type of the input tensors. It supports int8, float32, 69 | * bfloat16, and half (float16) data types. The function checks if the data types of both 70 | * input tensors match and then calls the appropriate templated CUDA kernel function. 71 | * 72 | * @param x A torch::Tensor representing the first matrix in the multiplication. 73 | * @param y A torch::Tensor representing the second matrix in the multiplication. 74 | * @param bmm_type An integer indicating the type of binary matrix multiplication to perform. 75 | * 76 | * @return A torch::Tensor containing the result of the binary matrix multiplication. 77 | * 78 | * @throws std::runtime_error If the input tensors have different data types or if an unsupported 79 | * data type is provided. 80 | */ 81 | torch::Tensor binary_mm_cuda( 82 | torch::Tensor x, 83 | torch::Tensor y, 84 | int bmm_type); 85 | 86 | // C++ interface 87 | 88 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 89 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") 90 | #define CHECK_INPUT(x) CHECK_CUDA(x); 91 | 92 | torch::Tensor binary_linear_forward( 93 | torch::Tensor input, 94 | torch::Tensor weights, 95 | int bmm_type, 96 | bool transpose) { 97 | CHECK_INPUT(input); 98 | CHECK_INPUT(weights); 99 | return binary_linear_cuda_forward(input, weights, bmm_type, transpose); 100 | } 101 | 102 | torch::Tensor get_binary_weight( 103 | torch::Tensor weights, 104 | int bmm_type, 105 | bool transpose) { 106 | CHECK_INPUT(weights); 107 | return get_binary_weight_cuda(weights, bmm_type, transpose); 108 | } 109 | 110 | torch::Tensor binary_linear_mm( 111 | torch::Tensor x, 112 | torch::Tensor y, 113 | int bmm_type) { 114 | CHECK_INPUT(x); 115 | CHECK_INPUT(y); 116 | return binary_mm_cuda(x, y, bmm_type); 117 | } 118 | 119 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 120 | m.def("forward", &binary_linear_forward, "binary linear forward (CUDA)"); 121 | m.def("w_pack", &get_binary_weight, "get linear binary weight (CUDA)"); 122 | m.def("mm", &binary_linear_mm, "binary linear mm (CUDA)"); 123 | } 124 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/binary/cuda/bmm.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class BMM(Enum): 5 | """ 6 | Enumeration for selecting the Bit-Matrix-Multiplication (BMM) kernel to be used during operations. 7 | This allows for the choice of different underlying implementations based on the requirements or 8 | optimizations desired for specific hardware or computational constraints. 9 | 10 | Attributes: 11 | BSTC32: Software-based Tensor Core implementation. This option utilizes a software-level implementation 12 | to simulate tensor core operations, potentially offering more flexibility at the cost of raw performance. 13 | BTC32: Bit-Matrix-Multiplication using NVIDIA Tensor Cores. This leverages hardware tensor cores for 14 | accelerated computation, suitable for NVIDIA GPUs that support tensor core operations, offering 15 | high performance for matrix multiplications. 16 | ADAPTIVE: Automatically selects the best combination of kernel implementations based on the specific dimension 17 | constraints of the inputs and weights. This option aims to optimize performance by considering 18 | the characteristics of the computation and available hardware capabilities. 19 | 20 | The choice of kernel can significantly affect the performance and efficiency of operations that involve 21 | matrix multiplications, especially in deep learning models where such operations are prevalent. 22 | """ 23 | BSTC32 = 1 # software based tensor core implementation 24 | BTC32 = 2 # Bit-Matrix-Multiplication using NVIDIA Tensor Cores 25 | ADAPTIVE = 3 # Chooses the best kernel based on input and weight dimensions -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/binary/cuda/extension.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension 4 | 5 | CUDA_REQUIRED = True 6 | 7 | 8 | def get_ext(path: Path): 9 | """ 10 | Get the CUDA extension for binary linear operations. 11 | 12 | Args: 13 | path (Path): The path to the CUDA extension directory. 14 | 15 | Returns: 16 | Any: The CUDA extension module. 17 | """ 18 | return get_cuda_extension( 19 | path, 20 | relative_name='binary_linear_cuda', 21 | relative_sources=[ 22 | 'binary_linear_cuda.cpp', 23 | 'binary_linear_cuda_kernel.cu', 24 | ] 25 | ) 26 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/binary/cutlass/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import BinaryLinearCutlass, BinaryMatMul 2 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/binary/cutlass/extension.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension 4 | 5 | CUDA_REQUIRED = True 6 | CUTLASS_REQUIRED = True 7 | 8 | 9 | def get_ext(path: Path): 10 | """ 11 | Get the CUDA extension for binary linear cutlass. 12 | 13 | Args: 14 | path (Path): The path to the CUDA extension. 15 | 16 | Returns: 17 | Extension: The CUDA extension for binary linear cutlass. 18 | """ 19 | ext = get_cuda_extension( 20 | path, 21 | relative_name='binary_linear_cutlass', 22 | relative_sources=[ 23 | 'binary_linear_cutlass.cpp', 24 | 'binary_linear_cutlass_kernel.cu', 25 | ] 26 | ) 27 | ext.include_dirs.extend(['.']) 28 | return ext 29 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/layer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from bitorch import RuntimeMode 5 | from bitorch.layers.extensions import LayerRecipe 6 | from bitorch.layers.qlinear import QLinearImplementation, QLinearBase 7 | 8 | from .binary import BinaryLinear 9 | from .binary.layer import BinaryLinearBase 10 | from .nbit import nBitLinearBase 11 | from .qlinear_implementation import QLinearImplementationMixin 12 | 13 | 14 | @QLinearImplementation(RuntimeMode.INFERENCE_AUTO) 15 | class QLinearInf(QLinearImplementationMixin, BinaryLinearBase): 16 | """ 17 | QLinearInf is a class for quantized linear layers optimized for inference. 18 | It inherits from QLinearImplementationMixin and BinaryLinearBase to utilize 19 | quantization functionalities and binary linear operations. 20 | 21 | This class specifically handles inference operations with quantized weights, 22 | potentially using different bit widths for activations and weights. 23 | """ 24 | @classmethod 25 | def create_clone_from(cls, recipe: LayerRecipe, device: torch.device = None) -> Any: 26 | """ 27 | Creates a clone of the layer from a given recipe, adjusting input feature dimensions 28 | and setting up quantization parameters based on the recipe's specifications. 29 | 30 | Args: 31 | recipe (LayerRecipe): A configuration object containing layer specifications. 32 | device (torch.device, optional): The device on which to create the layer. Defaults to None. 33 | 34 | Returns: 35 | Any: An instance of the cloned layer with quantization applied. 36 | """ 37 | args = QLinearBase.get_args_as_kwargs(recipe) 38 | input_features, output_features = args["in_features"], args["out_features"] 39 | input_features //= 32 40 | new_layer = cls( 41 | input_features, 42 | output_features, 43 | device=device, 44 | a_bit=args["input_quantization"].bit_width, 45 | w_bit=args["input_quantization"].bit_width, 46 | ) 47 | new_layer.set_weight_data(recipe.layer.weight.data.to(device=device)) 48 | new_layer.generate_quantized_weight(qweight_only=True) 49 | return new_layer 50 | 51 | def __init__( 52 | self, 53 | input_features: int, 54 | out_features: int, 55 | device=None, 56 | a_bit: int = 1, 57 | w_bit: int = 1, 58 | bias=False, 59 | ) -> None: 60 | """ 61 | Initializes the QLinearInf layer with specified input and output feature dimensions, 62 | quantization bit widths, and device. Currently, bias is not supported and must be False. 63 | 64 | Args: 65 | input_features (int): The dimension of input features after bit-packing. 66 | out_features (int): The dimension of output features (hidden states). 67 | device (optional): The device on which to initialize the layer. Defaults to None. 68 | a_bit (int, optional): Bit width for activation quantization. Defaults to 1. 69 | w_bit (int, optional): Bit width for weight quantization. Defaults to 1. 70 | bias (bool, optional): Indicates if bias is used. Currently must be False. 71 | 72 | Raises: 73 | AssertionError: If bias is set to True. 74 | """ 75 | super().__init__(input_features, out_features, device) 76 | assert not bias, "currently QLinearInf only supports bias = False" 77 | self.layer = None 78 | if a_bit == 1 and w_bit == 1: 79 | self.layer = BinaryLinear(input_features, out_features, device=device) 80 | else: 81 | self.layer = nBitLinearBase( 82 | input_features, out_features, a_bit, w_bit, device 83 | ) 84 | 85 | def prepare_params(self) -> None: 86 | """ 87 | Prepares the parameters of the layer for quantization and inference, 88 | calling the corresponding method of the underlying binary or n-bit linear layer. 89 | """ 90 | self.layer.prepare_params() 91 | 92 | def generate_quantized_weight(self, qweight_only: bool = False) -> None: 93 | """ 94 | Generates and sets the quantized weights for the layer, optionally focusing 95 | only on the quantized weights without affecting the original weights. 96 | 97 | Args: 98 | qweight_only (bool, optional): If True, only quantized weights are generated. Defaults to False. 99 | """ 100 | self.layer.generate_quantized_weight(qweight_only=qweight_only) 101 | 102 | def set_weight_data(self, x: torch.Tensor): 103 | """ 104 | Sets the weight data for the layer. 105 | 106 | Args: 107 | x (torch.Tensor): The tensor containing the weight data. 108 | """ 109 | self.layer.set_weight_data(x) 110 | 111 | def set_quantized_weight_data(self, x: torch.Tensor): 112 | """ 113 | Sets the quantized weight data for the layer. 114 | 115 | Args: 116 | x (torch.Tensor): The tensor containing the quantized weight data. 117 | """ 118 | self.layer.set_quantized_weight_data(x) 119 | 120 | @property 121 | def weight(self): 122 | """ 123 | Property to access the weight tensor of the layer. 124 | 125 | Returns: 126 | torch.Tensor: The weight tensor. 127 | """ 128 | return self.layer.weight 129 | 130 | @property 131 | def opt_weight(self): 132 | """ 133 | Property to access the optimized weight tensor of the layer, which may 134 | include quantized or otherwise transformed weights for efficient inference. 135 | 136 | Returns: 137 | torch.Tensor: The optimized weight tensor. 138 | """ 139 | return self.layer.opt_weight 140 | 141 | def forward(self, x: torch.Tensor) -> torch.Tensor: 142 | """ 143 | Forwards the input tensor x through the quantized linear layer, performing 144 | the linear operation with quantized weights. 145 | 146 | Args: 147 | x (torch.Tensor): The input tensor to forward through the layer. 148 | 149 | Returns: 150 | torch.Tensor: The output tensor after passing through the layer. 151 | """ 152 | return self.layer(x) 153 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import nBitLinearBase, MPQLinearBase, MPQWeightParameter, nBitLinearParameter -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | from .mpq_layer import MPQLinearCuda 2 | from .mbwq_layer import MBWQLinearCuda 3 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cuda/exl2/compat.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _compat_cuh 2 | #define _compat_cuh 3 | 4 | // atomicAdd for half types, to support CC < 7.x 5 | 6 | __device__ __forceinline__ void atomicAdd_half(half* address, half val) 7 | { 8 | unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); 9 | unsigned int old = *address_as_ui; 10 | unsigned int assumed; 11 | 12 | do 13 | { 14 | assumed = old; 15 | __half_raw hsum; 16 | hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); 17 | half tmpres = __hadd(hsum, val); 18 | hsum = __half_raw(tmpres); 19 | old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; 20 | old = atomicCAS(address_as_ui, assumed, old); 21 | } 22 | while (assumed != old); 23 | } 24 | 25 | // atomicAdd for half2 types 26 | 27 | __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) 28 | { 29 | unsigned int* address_as_ui = (unsigned int*)address; 30 | unsigned int old = *address_as_ui; 31 | unsigned int assumed; 32 | do 33 | { 34 | assumed = old; 35 | half2 old_val = *((half2*)&old); 36 | half2 new_val = __hadd2(old_val, val); 37 | old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); 38 | } 39 | while (assumed != old); 40 | } 41 | 42 | // 43 | 44 | #if defined(__CUDA_ARCH__) || defined(USE_ROCM) 45 | #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) 46 | 47 | __device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } 48 | 49 | #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) 50 | __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } 51 | #endif 52 | 53 | #endif 54 | #endif 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cuda/exl2/config.h: -------------------------------------------------------------------------------- 1 | #ifndef _config_h 2 | #define _config_h 3 | 4 | #define MAX_Q_GEMM_ROWS 32 5 | #define MAX_Q_GEMM_ROWS_KERNEL 4 6 | #define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS 7 | 8 | /* 9 | The macros defined by QMODE_*BIT determine whether qweight should be rearranged. This rearrangement should, 10 | to some extent, enhance computational efficiency. Please note that in `quant/qdq_*.cuh`, 11 | two methods of dequantization are implemented for each bit level. When QMODE_*BIT=1, 12 | the rearrangement method `shuffle_*bit_*()` and the corresponding `dequant_*bit_*()` method are implemented. 13 | Therefore, it is important to note that if QMODE_*BIT=1, the qweight tensor needs to be rearranged by calling 14 | the `shuffle_*bit_*()` method. 15 | */ 16 | #define QMODE_2BIT 0 17 | #define QMODE_3BIT 0 18 | #define QMODE_4BIT 0 19 | #define QMODE_5BIT 0 20 | #define QMODE_6BIT 0 21 | #define QMODE_8BIT 0 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cuda/exl2/matrix_view.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _matrix_view_cuh 2 | #define _matrix_view_cuh 3 | 4 | #include 5 | #include 6 | 7 | #include "quant/qdq_util.cuh" 8 | 9 | class MatrixView_half 10 | { 11 | public: 12 | const half* data; 13 | const int height; 14 | const int width; 15 | 16 | __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) 17 | : data(data), height(height), width(width) 18 | { } 19 | 20 | __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } 21 | __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } 22 | __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } 23 | __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } 24 | 25 | __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const 26 | { 27 | half2* ptr = (half2*) item_ptr(row, column); 28 | half2 i01 = ptr[0]; 29 | half2 i23 = ptr[1]; 30 | items[0] = __low2half(i01); 31 | items[1] = __high2half(i01); 32 | items[2] = __low2half(i23); 33 | items[3] = __high2half(i23); 34 | } 35 | __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const 36 | { 37 | half2* ptr = (half2*)item_ptr(row, column); 38 | half2 i01 = ptr[0]; 39 | half2 i23 = ptr[1]; 40 | items[0] = __half2float(__low2half(i01)); 41 | items[1] = __half2float(__high2half(i01)); 42 | items[2] = __half2float(__low2half(i23)); 43 | items[3] = __half2float(__high2half(i23)); 44 | } 45 | 46 | __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const 47 | { 48 | half2* ptr = (half2*)item_ptr(row, column); 49 | half2 i01 = ptr[0]; 50 | half2 i23 = ptr[1]; 51 | items[0] = __half2half2(__low2half(i01)); 52 | items[1] = __half2half2(__high2half(i01)); 53 | items[2] = __half2half2(__low2half(i23)); 54 | items[3] = __half2half2(__high2half(i23)); 55 | } 56 | }; 57 | 58 | class MatrixView_half_rw 59 | { 60 | public: 61 | half* data; 62 | const int height; 63 | const int width; 64 | 65 | __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) 66 | : data(data), height(height), width(width) 67 | { } 68 | 69 | __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } 70 | __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } 71 | __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } 72 | __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } 73 | __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } 74 | 75 | __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) 76 | { 77 | half2 v01 = __halves2half2(v0, v1); 78 | half2 v23 = __halves2half2(v2, v3); 79 | half2* ptr = (half2*) item_ptr(row, column); 80 | ptr[0] = v01; 81 | ptr[1] = v23; 82 | } 83 | }; 84 | 85 | class MatrixView_q4_row 86 | { 87 | public: 88 | const uint32_t* data; 89 | const int height; 90 | const int width; 91 | 92 | __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) 93 | : data(data), height(height), width(width) 94 | { } 95 | 96 | __device__ __forceinline__ int item(int row, int column) const 97 | { 98 | int shift = (column & 0x07) * 4; 99 | return (data[row * width / 8 + column / 8] >> shift) & 0x0f; 100 | } 101 | 102 | __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const 103 | { 104 | int shift = (column & 0x07) * 4; 105 | uint32_t d = data[row * width / 8 + column / 8] >> shift; 106 | items[0] = d & 0x0f; 107 | items[1] = (d >> 4) & 0x0f; 108 | } 109 | 110 | __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const 111 | { 112 | int shift = (column & 0x07) * 4; 113 | uint32_t d = data[row * width / 8 + column / 8] >> shift; 114 | items[0] = d & 0x0f; 115 | items[1] = (d >> 4) & 0x0f; 116 | items[2] = (d >> 8) & 0x0f; 117 | items[3] = (d >> 12) & 0x0f; 118 | } 119 | }; 120 | 121 | #endif -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cuda/exl2/quant/qdq_2.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _qdq_2_cuh 2 | #define _qdq_2_cuh 3 | 4 | #include "qdq_util.cuh" 5 | #include "../config.h" 6 | 7 | #if QMODE_2BIT == 1 8 | 9 | // Permutation: 10 | // 11 | // ffddbb99 77553311 eeccaa88 66442200 12 | 13 | __forceinline__ __device__ void shuffle_2bit_16 14 | ( 15 | uint32_t* q, 16 | int stride 17 | ) 18 | { 19 | uint32_t qa = q[0]; 20 | uint32_t qb = 0; 21 | 22 | #pragma unroll 23 | for (int i = 0; i < 8; i++) 24 | { 25 | uint32_t qa0 = qa & 0x03; 26 | uint32_t qa1 = (qa & 0x0c) >> 2; 27 | qa >>= 4; 28 | qb |= (qa1 << (i * 2 + 16)); 29 | qb |= (qa0 << (i * 2)); 30 | } 31 | q[0] = qb; 32 | } 33 | 34 | __forceinline__ __device__ void dequant_2bit_16 35 | ( 36 | const uint32_t q_0, 37 | half2 (&dq)[8], 38 | int stride 39 | ) 40 | { 41 | const uint32_t c0 = 0x64006400; 42 | const half y4_ = __float2half_rn(1.0f / 4.0f); 43 | const half y16_ = __float2half_rn(1.0f / 16.0f); 44 | const half y64_ = __float2half_rn(1.0f / 64.0f); 45 | const half2 y4 = __halves2half2(y4_, y4_); 46 | const half2 y16 = __halves2half2(y16_, y16_); 47 | const half2 y64 = __halves2half2(y64_, y64_); 48 | const half z1_ = __float2half_rn(-1024.0f); 49 | const half z4_ = __float2half_rn(-1024.0f / 4.0f); 50 | const half z16_ = __float2half_rn(-1024.0f / 16.0f); 51 | const half z64_ = __float2half_rn(-1024.0f / 64.0f); 52 | const half2 z1 = __halves2half2(z1_, z1_); 53 | const half2 z4 = __halves2half2(z4_, z4_); 54 | const half2 z16 = __halves2half2(z16_, z16_); 55 | const half2 z64 = __halves2half2(z64_, z64_); 56 | 57 | uint32_t qa = q_0; 58 | half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 59 | half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 60 | half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 61 | half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 62 | qa >>= 8; 63 | half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 64 | half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 65 | half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 66 | half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 67 | 68 | dq[0] = __hadd2(q0.as_half2, z1); 69 | dq[1] = __hfma2(q1.as_half2, y4, z4); 70 | dq[2] = __hfma2(q2.as_half2, y16, z16); 71 | dq[3] = __hfma2(q3.as_half2, y64, z64); 72 | dq[4] = __hadd2(q4.as_half2, z1); 73 | dq[5] = __hfma2(q5.as_half2, y4, z4); 74 | dq[6] = __hfma2(q6.as_half2, y16, z16); 75 | dq[7] = __hfma2(q7.as_half2, y64, z64); 76 | } 77 | 78 | #else 79 | 80 | __forceinline__ __device__ void shuffle_2bit_16 81 | ( 82 | uint32_t* q, 83 | int stride 84 | ) 85 | { 86 | } 87 | 88 | __forceinline__ __device__ void dequant_2bit_16 89 | ( 90 | const uint32_t q_0, 91 | half2 (&dq)[8], 92 | int stride 93 | ) 94 | { 95 | half dqh[16]; 96 | for (int i = 0; i < 16; i++) 97 | dqh[i] = __uint2half_rn(exb(q_0, i * 2, 0x03)); 98 | 99 | for (int i = 0; i < 8; i++) 100 | dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); 101 | } 102 | 103 | #endif 104 | 105 | #endif -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cuda/exl2/quant/qdq_3.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _qdq_3_cuh 2 | #define _qdq_3_cuh 3 | 4 | #include "qdq_util.cuh" 5 | #include "../config.h" 6 | 7 | #if QMODE_3BIT == 1 8 | 9 | // Permutation: 10 | // 11 | // v9997775 55333111 u8886664 44222000 (u, v lsb) 12 | // vjjjhhhf ffdddbbb uiiiggge eecccaaa 13 | // vtttrrrp ppnnnlll usssqqqo oommmkkk 14 | 15 | __forceinline__ __device__ void shuffle_3bit_32 16 | ( 17 | uint32_t* q, 18 | int stride 19 | ) 20 | { 21 | uint32_t qa = q[0 * stride]; 22 | uint32_t qb = q[1 * stride]; 23 | uint32_t qc = q[2 * stride]; 24 | 25 | // qa: aa999888 77766655 54443332 22111000 26 | // qb: lkkkjjji iihhhggg fffeeedd dcccbbba 27 | // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll 28 | 29 | uint32_t qd = qc >> 26; 30 | qc <<= 4; 31 | qc |= qb >> 28; 32 | qb <<= 2; 33 | qb |= qa >> 30; 34 | 35 | // qa: ..999888 77766655 54443332 22111000 36 | // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa 37 | // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk 38 | // qd: vvvuuu 39 | 40 | uint32_t za = 0; 41 | uint32_t zb = 0; 42 | uint32_t zc = 0; 43 | 44 | for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } 45 | for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } 46 | for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } 47 | 48 | // za: 9997775 55333111 8886664 44222000 49 | // zb: jjjhhhf ffdddbbb iiiggge eecccaaa 50 | // zc: tttrrrp ppnnnlll sssqqqo oommmkkk 51 | // qd: vvvuuu 52 | 53 | za |= ((qd & 0x01) >> 0) << 15; 54 | zb |= ((qd & 0x02) >> 1) << 15; 55 | zc |= ((qd & 0x04) >> 2) << 15; 56 | za |= ((qd & 0x08) >> 3) << 31; 57 | zb |= ((qd & 0x10) >> 4) << 31; 58 | zc |= ((qd & 0x20) >> 5) << 31; 59 | 60 | // za: v9997775 55333111 u8886664 44222000 (u, v lsb) 61 | // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa 62 | // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk 63 | 64 | q[0 * stride] = za; 65 | q[1 * stride] = zb; 66 | q[2 * stride] = zc; 67 | } 68 | 69 | __forceinline__ __device__ void dequant_3bit_32 70 | ( 71 | const uint32_t q_0, 72 | const uint32_t q_1, 73 | const uint32_t q_2, 74 | half2 (&dq)[16], 75 | int stride 76 | ) 77 | { 78 | const uint32_t c0 = 0x64006400; 79 | const half y8_ = __float2half_rn(1.0f / 8.0f); 80 | const half y64_ = __float2half_rn(1.0f / 64.0f); 81 | const half2 y8 = __halves2half2(y8_, y8_); 82 | const half2 y64 = __halves2half2(y64_, y64_); 83 | const half z1_ = __float2half_rn(-1024.0f); 84 | const half z8_ = __float2half_rn(-1024.0f / 8.0f); 85 | const half z64_ = __float2half_rn(-1024.0f / 64.0f); 86 | const half2 z1 = __halves2half2(z1_, z1_); 87 | const half2 z8 = __halves2half2(z8_, z8_); 88 | const half2 z64 = __halves2half2(z64_, z64_); 89 | 90 | uint32_t qa = q_0; 91 | uint32_t qb = q_1; 92 | uint32_t qc = q_2; 93 | 94 | half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 95 | half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 96 | qa >>= 6; 97 | half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 98 | half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 99 | half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 100 | qa >>= 9; 101 | qa &= 0x00010001; 102 | half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 103 | half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 104 | qb >>= 6; 105 | half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 106 | half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 107 | half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 108 | qb >>= 8; 109 | qb &= 0x00020002; 110 | half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 111 | half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 112 | qc >>= 6; 113 | half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 114 | half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 115 | half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 116 | qc >>= 7; 117 | qc &= 0x00040004; 118 | half2_uint32 q15((qa | qb | qc) | c0); 119 | 120 | dq[ 0] = __hadd2( q0.as_half2, z1); 121 | dq[ 1] = __hfma2( q1.as_half2, y8, z8); 122 | dq[ 2] = __hadd2( q2.as_half2, z1); 123 | dq[ 3] = __hfma2( q3.as_half2, y8, z8); 124 | dq[ 4] = __hfma2( q4.as_half2, y64, z64); 125 | dq[ 5] = __hadd2( q5.as_half2, z1); 126 | dq[ 6] = __hfma2( q6.as_half2, y8, z8); 127 | dq[ 7] = __hadd2( q7.as_half2, z1); 128 | dq[ 8] = __hfma2( q8.as_half2, y8, z8); 129 | dq[ 9] = __hfma2( q9.as_half2, y64, z64); 130 | dq[10] = __hadd2(q10.as_half2, z1); 131 | dq[11] = __hfma2(q11.as_half2, y8, z8); 132 | dq[12] = __hadd2(q12.as_half2, z1); 133 | dq[13] = __hfma2(q13.as_half2, y8, z8); 134 | dq[14] = __hfma2(q14.as_half2, y64, z64); 135 | dq[15] = __hadd2(q15.as_half2, z1); 136 | } 137 | 138 | #else 139 | 140 | __forceinline__ __device__ void shuffle_3bit_32 141 | ( 142 | uint32_t* q, 143 | int stride 144 | ) 145 | { 146 | } 147 | 148 | __forceinline__ __device__ void dequant_3bit_32 149 | ( 150 | const uint32_t q_0, 151 | const uint32_t q_1, 152 | const uint32_t q_2, 153 | half2 (&dq)[16], 154 | int stride 155 | ) 156 | { 157 | half dqh[32]; 158 | for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 0); 159 | dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 0); 160 | for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 0); 161 | dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 0); 162 | for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 0); 163 | 164 | for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); 165 | } 166 | 167 | #endif 168 | 169 | #endif 170 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cuda/exl2/quant/qdq_6.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _qdq_6_cuh 2 | #define _qdq_6_cuh 3 | 4 | #include "qdq_util.cuh" 5 | #include "../config.h" 6 | 7 | #if QMODE_6BIT == 1 8 | 9 | // Permutation: 10 | // 11 | // dddd3333 33111111 cccc2222 22000000 12 | // ffff7777 77555555 eeee6666 66444444 13 | // ffddbbbb bb999999 eeccaaaa aa888888 14 | 15 | __forceinline__ __device__ void shuffle_6bit_16 16 | ( 17 | uint32_t* q, 18 | int stride 19 | ) 20 | { 21 | uint32_t qa = q[0 * stride]; 22 | uint32_t qb = q[1 * stride]; 23 | uint32_t qc = q[2 * stride]; 24 | 25 | // qa: 55444444 33333322 22221111 11000000 26 | // qb: aaaa9999 99888888 77777766 66665555 27 | // qc: ffffffee eeeedddd ddcccccc bbbbbbaa 28 | 29 | uint32_t q00 = (qa ) & 0b111111; 30 | uint32_t q01 = (qa >> 6) & 0b111111; 31 | uint32_t q02 = (qa >> 12) & 0b111111; 32 | uint32_t q03 = (qa >> 18) & 0b111111; 33 | uint32_t q04 = (qa >> 24) & 0b111111; 34 | uint32_t q05 = ((qa >> 30) & 0b11) | ((qb & 0b1111) << 2); 35 | uint32_t q06 = (qb >> 4) & 0b111111; 36 | uint32_t q07 = (qb >> 10) & 0b111111; 37 | uint32_t q08 = (qb >> 16) & 0b111111; 38 | uint32_t q09 = (qb >> 22) & 0b111111; 39 | uint32_t q0a = ((qb >> 28) & 0b1111) | ((qc & 0b11) << 4); 40 | uint32_t q0b = (qc >> 2) & 0b111111; 41 | uint32_t q0c = (qc >> 8) & 0b111111; 42 | uint32_t q0d = (qc >> 14) & 0b111111; 43 | uint32_t q0e = (qc >> 20) & 0b111111; 44 | uint32_t q0f = (qc >> 26) & 0b111111; 45 | 46 | qa = q00 | (q01 << 16) | (q02 << 6) | (q03 << 22); 47 | qb = q04 | (q05 << 16) | (q06 << 6) | (q07 << 22); 48 | qc = q08 | (q09 << 16) | (q0a << 6) | (q0b << 22); 49 | 50 | // qa: ....3333 33111111 ....2222 22000000 51 | // qb: ....7777 77555555 ....6666 66444444 52 | // qc: ....bbbb bb999999 ....aaaa aa888888 53 | 54 | qa |= (q0c & 0b001111) << 12; 55 | qc |= (q0c & 0b110000) << 8; 56 | qa |= (q0d & 0b001111) << 28; 57 | qc |= (q0d & 0b110000) << 24; 58 | 59 | // qa: dddd3333 33111111 cccc2222 22000000 60 | // qb: ....7777 77555555 ....6666 66444444 61 | // qc: ..ddbbbb bb999999 ..ccaaaa aa888888 62 | 63 | qb |= (q0e & 0b001111) << 12; 64 | qc |= (q0e & 0b110000) << 10; 65 | qb |= (q0f & 0b001111) << 28; 66 | qc |= (q0f & 0b110000) << 26; 67 | 68 | // qa: dddd3333 33111111 cccc2222 22000000 69 | // qb: ffff7777 77555555 eeee6666 66444444 70 | // qc: ffddbbbb bb999999 eeccaaaa aa888888 71 | 72 | q[0 * stride] = qa; 73 | q[1 * stride] = qb; 74 | q[2 * stride] = qc; 75 | } 76 | 77 | __forceinline__ __device__ void dequant_6bit_16 78 | ( 79 | const uint32_t q_0, 80 | const uint32_t q_1, 81 | const uint32_t q_2, 82 | half2 (&dq)[8], 83 | int stride 84 | ) 85 | { 86 | const uint32_t c0 = 0x64006400; 87 | const half z1_ = __float2half_rn(-1024.0f); 88 | const half2 z1 = __halves2half2(z1_, z1_); 89 | 90 | uint32_t qa = q_0; 91 | uint32_t qb = q_1; 92 | uint32_t qc = q_2; 93 | 94 | half2_uint32 q0((qa & 0x003f003f) | c0); // half2(q[ 0], q[ 1]) + 1024 95 | qa >>= 6; 96 | half2_uint32 q1((qa & 0x003f003f) | c0); // half2(q[ 2], q[ 3]) + 1024 97 | qa >>= 6; 98 | half2_uint32 q2((qb & 0x003f003f) | c0); // half2(q[ 4], q[ 5]) + 1024 99 | qb >>= 6; 100 | half2_uint32 q3((qb & 0x003f003f) | c0); // half2(q[ 6], q[ 7]) + 1024 101 | qb >>= 6; 102 | half2_uint32 q4((qc & 0x003f003f) | c0); // half2(q[ 8], q[ 9]) + 1024 103 | qc >>= 6; 104 | half2_uint32 q5((qc & 0x003f003f) | c0); // half2(q[10], q[11]) + 1024 105 | qc >>= 2; 106 | half2_uint32 q6((qa & 0x000f000f) | (qc & 0x00300030) | c0); // half2(q[12], q[13]) + 1024 107 | qc >>= 2; 108 | half2_uint32 q7((qb & 0x000f000f) | (qc & 0x00300030) | c0); // half2(q[14], q[15]) + 1024 109 | 110 | dq[0] = __hadd2(q0.as_half2, z1); 111 | dq[1] = __hadd2(q1.as_half2, z1); 112 | dq[2] = __hadd2(q2.as_half2, z1); 113 | dq[3] = __hadd2(q3.as_half2, z1); 114 | dq[4] = __hadd2(q4.as_half2, z1); 115 | dq[5] = __hadd2(q5.as_half2, z1); 116 | dq[6] = __hadd2(q6.as_half2, z1); 117 | dq[7] = __hadd2(q7.as_half2, z1); 118 | } 119 | 120 | #else 121 | 122 | __forceinline__ __device__ void shuffle_6bit_16 123 | ( 124 | uint32_t* q, 125 | int stride 126 | ) 127 | { 128 | } 129 | 130 | __forceinline__ __device__ void dequant_6bit_16 131 | ( 132 | const uint32_t q_0, 133 | const uint32_t q_1, 134 | const uint32_t q_2, 135 | half2 (&dq)[8], 136 | int stride 137 | ) 138 | { 139 | half dqh[16]; 140 | for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 0); 141 | dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 0); 142 | for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 0); 143 | dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 0); 144 | for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 0); 145 | 146 | for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); 147 | } 148 | 149 | #endif 150 | 151 | #endif 152 | 153 | 154 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cuda/exl2/quant/qdq_8.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _qdq_8_cuh 2 | #define _qdq_8_cuh 3 | 4 | #include "qdq_util.cuh" 5 | #include "../config.h" 6 | 7 | #if QMODE_8BIT == 1 8 | 9 | // Not implemented 10 | 11 | #else 12 | 13 | __forceinline__ __device__ void shuffle_8bit_4 14 | ( 15 | uint32_t* q, 16 | int stride 17 | ) 18 | { 19 | } 20 | 21 | __forceinline__ __device__ void dequant_8bit_8 22 | ( 23 | const uint32_t q_0, 24 | const uint32_t q_1, 25 | half2 (&dq)[4], 26 | int stride 27 | ) 28 | { 29 | half dqh[8]; 30 | for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 0); 31 | for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 0); 32 | 33 | for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); 34 | } 35 | 36 | #endif 37 | 38 | #endif -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cuda/exl2/quant/qdq_util.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _qdq_util_cuh 2 | #define _qdq_util_cuh 3 | 4 | union half2_uint32 5 | { 6 | uint32_t as_uint32; 7 | half2 as_half2; 8 | __device__ half2_uint32(uint32_t val) : as_uint32(val) {} 9 | __device__ half2_uint32(half2 val) : as_half2(val) {} 10 | __device__ half2_uint32() : as_uint32(0) {} 11 | }; 12 | 13 | union half_uint16 14 | { 15 | uint16_t as_uint16; 16 | half as_half; 17 | __device__ half_uint16(uint16_t val) : as_uint16(val) {} 18 | __device__ half_uint16(half val) : as_half(val) {} 19 | __device__ half_uint16() : as_uint16(0) {} 20 | }; 21 | 22 | // Max_scale premultiplied by 1/256 23 | 24 | __forceinline__ __device__ half dq_scale(const int qs, const half max_scale) 25 | { 26 | int qs_i = qs + 1; 27 | half qs_h = __int2half_rn(qs_i * qs_i); 28 | qs_h = __hmul(qs_h, max_scale); 29 | return qs_h; 30 | } 31 | 32 | __forceinline__ __device__ half dq_scale_q_zero(const int qs, const half scale, const half zero) 33 | { 34 | half qs_h = __int2half_rn(qs); 35 | // qs_h = __hfma(qs_h, scale, __hneg(zero)); 36 | qs_h = __hadd(qs_h, __hneg(zero)); 37 | qs_h = __hmul(qs_h, scale); 38 | return qs_h; 39 | } 40 | 41 | __forceinline__ __device__ half dq(const int q, const int qzero, const half scale) 42 | { 43 | return __hmul(__int2half_rn(q - qzero), scale); 44 | } 45 | 46 | __forceinline__ __device__ half dq_ns(const int q, const int qzero) 47 | { 48 | //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); 49 | return __int2half_rn(q - qzero); 50 | } 51 | 52 | __forceinline__ __device__ unsigned int as_unsigned(int i) { 53 | return *reinterpret_cast(&i); 54 | } 55 | 56 | __forceinline__ __device__ unsigned int exb(const uint32_t q, const int shift, const int mask) 57 | { 58 | return as_unsigned((q >> shift) & mask); 59 | } 60 | 61 | __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) 62 | { 63 | return (int)(__funnelshift_rc(q0, q1, shift) & mask); 64 | } 65 | 66 | #endif 67 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cuda/exl2/util.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _util_cuh 2 | #define _util_cuh 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) 11 | 12 | #define DBGS(__x) printf("%s\n", __x) 13 | #define DBGI(__x) printf("%s: %i\n", #__x, __x) 14 | #define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y) 15 | #define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z) 16 | #define DBGX(__x) printf("%s: %x\n", #__x, __x) 17 | #define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y) 18 | #define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z) 19 | #define DBGF(__x) printf("%s: %f\n", #__x, __x) 20 | #define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y) 21 | #define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) 22 | #define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x)) 23 | #define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y)) 24 | #define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z)) 25 | 26 | #define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y)) 27 | #define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z)) 28 | 29 | __forceinline__ __device__ half dq_scale_(const int qs, const half max_scale) 30 | { 31 | half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f)); 32 | qs_h = __hmul(qs_h, qs_h); 33 | qs_h = __hmul(qs_h, max_scale); 34 | return qs_h; 35 | } 36 | 37 | __forceinline__ __device__ float clamp(float x, float a, float b) 38 | { 39 | return fmaxf(a, fminf(b, x)); 40 | } 41 | 42 | #define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); } 43 | inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true) 44 | { 45 | if (code != cudaSuccess) 46 | { 47 | fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line); 48 | if (abort) exit(code); 49 | } 50 | } 51 | 52 | void print_global_mem(const half* ptr, int rows, int columns, int stride); 53 | 54 | #endif -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cuda/extension.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension 4 | 5 | CUDA_REQUIRED = True 6 | 7 | def get_ext(path: Path): 8 | """ 9 | Get the CUDA extension for quantized linear operations. 10 | 11 | Args: 12 | path (Path): The path to the CUDA extension directory. 13 | 14 | Returns: 15 | Any: The CUDA extension module. 16 | """ 17 | ext = get_cuda_extension( 18 | path, 19 | relative_name='q_linear_cuda', 20 | relative_sources=[ 21 | 'q_linear_cuda.cpp', 22 | 'mpq_linear_cuda_kernel.cu', 23 | 'mbwq_linear_cuda_kernel.cu', 24 | ] 25 | ) 26 | ext.include_dirs.extend(['exl2']) 27 | return ext -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cutlass/__init__.py: -------------------------------------------------------------------------------- 1 | from .q4_layer import Q4LinearCutlass, Q4MatMul 2 | from .q8_layer import Q8LinearCutlass -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/cutlass/extension.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension 4 | 5 | CUDA_REQUIRED = True 6 | CUTLASS_REQUIRED = True 7 | 8 | def get_ext(path: Path): 9 | """ 10 | Return CUDA extension for Q Linear Cutlass. 11 | 12 | Args: 13 | path (Path): Path to the directory containing the extension files. 14 | 15 | Returns: 16 | Extension: CUDA extension for Q Linear Cutlass. 17 | """ 18 | return get_cuda_extension( 19 | path, 20 | relative_name='q_linear_cutlass', 21 | relative_sources=[ 22 | 'q_linear_cutlass.cpp', 23 | 'q4_linear_cutlass_kernel.cu', 24 | 'q8_linear_cutlass_kernel.cu', 25 | ] 26 | ) 27 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/mps/__init__.py: -------------------------------------------------------------------------------- 1 | from .mpq_layer import MPQLinearMlx 2 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/mps/extension.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from bitorch_engine.utils.mlx_extension import get_mlx_extension 4 | 5 | MLX_REQUIRED = True 6 | 7 | def get_ext(path: Path): 8 | return get_mlx_extension( 9 | path, 10 | relative_name='mpq_linear_mlx', 11 | relative_sources=[ 12 | 'mlx_bindings.cpp', 13 | 'mpq_linear_mlx.cpp', 14 | ] 15 | ) -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/mps/mlx_bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "mpq_linear_mlx.h" 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 5 | { 6 | m.def("mpq_forward", &mpq_linear_mlx_forward, "Forward call for mlx acclelerated MPS quantized linear layer."); 7 | } -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/mps/mpq_linear_mlx.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | typedef float16_t half_t; 8 | 9 | template 10 | mlx::core::array mlx_from_torch_tensor(torch::Tensor tensor, mlx::core::Dtype valueDtype) 11 | { 12 | return mlx::core::array( 13 | reinterpret_cast(tensor.mutable_data_ptr()), // Pointer to the data. 14 | std::vector({ // Shape of the array as a vector of int32 (mlx requires 32 bit int). 15 | static_cast(tensor.size(0)), 16 | static_cast(tensor.size(1)) 17 | }), 18 | valueDtype // Set dtype. 19 | ); 20 | } 21 | 22 | torch::Tensor mlx_to_torch_tensor(mlx::core::array arr, torch::Dtype dtype) 23 | { 24 | return torch::from_blob( 25 | arr.data(), // Pointer to the data. 26 | { // Shape of the array as a vector of int64 (torch requires longs). 27 | static_cast(arr.shape()[0]), 28 | static_cast(arr.shape()[1]) 29 | }, 30 | torch::TensorOptions().dtype(dtype)); // Set dtype. 31 | } 32 | 33 | 34 | /** 35 | * Performs forward pass for mixed precision quantized (MPQ) linear multiplication using a custom MLX library. 36 | * This function processes an input tensor `x` with quantized weights `qweight`, applying scale factors `scales` 37 | * and zero points `zeros` for quantization, within specified group sizes and weight bit precision. It's designed 38 | * for CPU execution, enforcing inputs and computations to reside on the CPU. 39 | * 40 | * @param x The input tensor, expected to be on CPU, containing the features to be processed. 41 | * @param qweight The quantized weights tensor, also on CPU, to be used in the matrix multiplication. 42 | * @param scales The scale factors for the quantized weights, aiding in de-quantization to real values during the computation. 43 | * @param zeros The zero points for the quantized weights, also used in de-quantization process. 44 | * @param group_size The size of groups for performing the quantized matrix multiplication, affecting how inputs are partitioned. 45 | * @param w_bit The precision of the quantized weights in bits, supporting 2, 4, or 8 bits for the computation. 46 | * 47 | * @details 48 | * - The function begins by validating the input tensors to ensure they are CPU tensors, contiguous, and that the weight bits 49 | * are within the supported range. 50 | * - It then converts PyTorch tensors to MLX core arrays, using appropriate data types for the MLX backend computations. 51 | * - Quantized matrix multiplication is performed with the MLX library, leveraging the given scale factors, zero points, 52 | * group size, and weight bit precision. 53 | * - The MLX library uses lazy evaluation for computations; thus, the function explicitly evaluates the output before 54 | * converting it back to a PyTorch tensor. 55 | * - The result is a tensor of the computed output in float16 format, ready for further processing in PyTorch pipelines. 56 | * 57 | * @return A torch::Tensor containing the result of the mixed precision quantized linear multiplication in float16 format. 58 | * 59 | * @note 60 | * - This function is designed to operate exclusively on CPU tensors and will verify the device type of its inputs. 61 | * - It assumes the MLX backend for computation, which requires inputs to be converted to MLX core arrays. 62 | * - The use of float16 and uint32 data types for MLX core arrays is based on the precision and requirements of the inputs and computation. 63 | */ 64 | torch::Tensor mpq_linear_mlx_forward( 65 | torch::Tensor x, 66 | torch::Tensor qweight, 67 | torch::Tensor scales, 68 | torch::Tensor zeros, 69 | int group_size, 70 | int w_bit) 71 | { 72 | // Check the input parameters 73 | TORCH_CHECK((w_bit == 2) || (w_bit == 4) || (w_bit == 8), "weights must have {2,4,8} bits."); 74 | TORCH_CHECK(x.device().type() == torch::kCPU, "x must be a CPU tensor. Cannot read from MPS."); 75 | TORCH_CHECK(qweight.device().type() == torch::kCPU, "qweight must be a CPU tensor. Cannot read from MPS."); 76 | TORCH_CHECK(scales.device().type() == torch::kCPU, "scales must be a CPU tensor. Cannot read from MPS."); 77 | TORCH_CHECK(zeros.device().type() == torch::kCPU, "zeros must be a CPU tensor. Cannot read from MPS."); 78 | TORCH_CHECK(x.is_contiguous(), "x must be contiguous."); 79 | TORCH_CHECK(qweight.is_contiguous(), "qweight must be contiguous."); 80 | TORCH_CHECK(scales.is_contiguous(), "scales must be contiguous."); 81 | TORCH_CHECK(zeros.is_contiguous(), "zeros must be contiguous."); 82 | 83 | mlx::core::array x_arr = mlx_from_torch_tensor(x, mlx::core::float16); 84 | mlx::core::array qweight_arr = mlx_from_torch_tensor(qweight, mlx::core::uint32); 85 | mlx::core::array scales_arr = mlx_from_torch_tensor(scales, mlx::core::float16); 86 | mlx::core::array zeros_arr = mlx_from_torch_tensor(zeros, mlx::core::float16); 87 | mlx::core::array output = mlx::core::quantized_matmul( 88 | x_arr, 89 | qweight_arr, 90 | scales_arr, 91 | zeros_arr, 92 | true, 93 | group_size, 94 | w_bit, 95 | mlx::core::default_stream(mlx::core::default_device()) 96 | ); 97 | 98 | mlx::core::eval(output); // Mlx uses lazy evaluation. Run and wait for the computation to finish. 99 | return mlx_to_torch_tensor(output, torch::kFloat16); 100 | } -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/nbit/mps/mpq_linear_mlx.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor mpq_linear_mlx_forward( 4 | torch::Tensor x, 5 | torch::Tensor qweight, 6 | torch::Tensor scales, 7 | torch::Tensor zeros, 8 | int group_size, 9 | int w_bit); -------------------------------------------------------------------------------- /bitorch_engine/layers/qlinear/qlinear_implementation.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Tuple 3 | 4 | from bitorch.layers import CustomImplementationMixin, QLinearBase 5 | from bitorch.layers.extensions import LayerRecipe 6 | 7 | 8 | class QLinearImplementationMixin(CustomImplementationMixin, ABC): 9 | """ 10 | A mixin class for QLinear layer implementations that provides common utility functions 11 | and checks specific to quantized linear layers. This mixin extends CustomImplementationMixin 12 | and implements the Abstract Base Class (ABC) to ensure that subclasses provide implementations 13 | for abstract methods defined in parent classes. 14 | 15 | The class provides a method to check if a given layer configuration can be cloned 16 | based on specific constraints related to quantized linear layers. 17 | """ 18 | @classmethod 19 | def can_clone(cls, recipe: LayerRecipe) -> Tuple[bool, str]: 20 | """ 21 | Determines if a QLinear layer described by the given recipe can be cloned. 22 | 23 | The method checks if the layer configuration meets certain criteria necessary 24 | for cloning a quantized linear layer. Specifically, it checks if the layer 25 | includes bias, and if the number of input features is divisible by 32, as 26 | these are current limitations for cloning such layers. 27 | 28 | Args: 29 | recipe (LayerRecipe): An object containing the configuration parameters 30 | of the layer to be cloned. 31 | 32 | Returns: 33 | Tuple[bool, str]: A tuple containing a boolean and a string. The boolean 34 | indicates whether the layer can be cloned (True if it can, 35 | False otherwise). The string provides a message explaining 36 | why the layer cannot be cloned if the boolean is False. 37 | """ 38 | # Extract layer arguments from the recipe 39 | args = QLinearBase.get_args_as_kwargs(recipe) 40 | if args["bias"]: 41 | return False, f"bias is not yet supported." 42 | # Check if the number of input features is divisible by 32 43 | if args["in_features"] % 32 != 0: 44 | return False, f"in_features ({args['in_features']}) is not divisible by 32." 45 | # Layer can be cloned if all checks pass 46 | return True, "" 47 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qmha/__init__.py: -------------------------------------------------------------------------------- 1 | from .binary import * 2 | -------------------------------------------------------------------------------- /bitorch_engine/layers/qmha/binary/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import BMHA 2 | -------------------------------------------------------------------------------- /bitorch_engine/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .diode_beta import DiodeMix -------------------------------------------------------------------------------- /bitorch_engine/optim/galore_projector.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the original Galore optimizer implementation from `Galore repo `_ 3 | with `Apache-2.0 License `_ 4 | 5 | @misc{zhao2024galore, 6 | title={GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection}, 7 | author={Jiawei Zhao and Zhenyu Zhang and Beidi Chen and Zhangyang Wang and Anima Anandkumar and Yuandong Tian}, 8 | year={2024}, 9 | eprint={2403.03507}, 10 | archivePrefix={arXiv}, 11 | primaryClass={cs.LG} 12 | } 13 | """ 14 | 15 | import torch 16 | 17 | class GaLoreProjector: 18 | def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'): 19 | self.rank = rank 20 | self.verbose = verbose 21 | self.update_proj_gap = update_proj_gap 22 | self.scale = scale 23 | self.ortho_matrix = None 24 | self.proj_type = proj_type 25 | 26 | def project(self, full_rank_grad, iter): 27 | 28 | if self.proj_type == 'std': 29 | if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: 30 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 31 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') 32 | low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) 33 | else: 34 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 35 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') 36 | low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) 37 | elif self.proj_type == 'reverse_std': 38 | if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: 39 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 40 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') 41 | low_rank_grad = torch.matmul(self.ortho_matrix.t(),full_rank_grad) 42 | else: 43 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 44 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') 45 | low_rank_grad = torch.matmul(full_rank_grad,self.ortho_matrix.t()) 46 | elif self.proj_type == 'right': 47 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 48 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') 49 | low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) 50 | elif self.proj_type == 'left': 51 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 52 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') 53 | low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) 54 | elif self.proj_type == 'full': 55 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0: 56 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full') 57 | low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t() 58 | 59 | return low_rank_grad 60 | 61 | def project_back(self, low_rank_grad): 62 | 63 | if self.proj_type == 'std': 64 | if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: 65 | full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) 66 | else: 67 | full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) 68 | elif self.proj_type == 'reverse_std': 69 | if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std 70 | full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) 71 | else: 72 | full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) 73 | elif self.proj_type == 'right': 74 | full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) 75 | elif self.proj_type == 'left': 76 | full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) 77 | elif self.proj_type == 'full': 78 | full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1] 79 | 80 | 81 | return full_rank_grad * self.scale 82 | 83 | 84 | # svd decomposition 85 | def get_orthogonal_matrix(self, weights, rank, type): 86 | module_params = weights 87 | 88 | if module_params.data.dtype != torch.float: 89 | float_data = False 90 | original_type = module_params.data.dtype 91 | original_device = module_params.data.device 92 | matrix = module_params.data.float() 93 | else: 94 | float_data = True 95 | matrix = module_params.data 96 | 97 | U, s, Vh = torch.linalg.svd(matrix, full_matrices = False) 98 | 99 | #make the smaller matrix always to be orthogonal matrix 100 | if type=='right': 101 | A = U[:, :rank] @ torch.diag(s[:rank]) 102 | B = Vh[:rank, :] 103 | 104 | if not float_data: 105 | B = B.to(original_device).type(original_type) 106 | return B 107 | elif type=='left': 108 | A = U[:, :rank] 109 | B = torch.diag(s[:rank]) @ Vh[:rank, :] 110 | if not float_data: 111 | A = A.to(original_device).type(original_type) 112 | return A 113 | elif type=='full': 114 | A = U[:, :rank] 115 | B = Vh[:rank, :] 116 | if not float_data: 117 | A = A.to(original_device).type(original_type) 118 | B = B.to(original_device).type(original_type) 119 | return [A, B] 120 | else: 121 | raise ValueError('type should be left, right or full') 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /bitorch_engine/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/bitorch_engine/utils/__init__.py -------------------------------------------------------------------------------- /bitorch_engine/utils/arch_helper.py: -------------------------------------------------------------------------------- 1 | import platform 2 | from enum import Enum 3 | import subprocess 4 | import os 5 | 6 | class ARCH_CPU(Enum): 7 | ''' 8 | Indicates which CPU architecture using for computation 9 | ''' 10 | ARM_A76 = 1 11 | ARM_A55 = 2 12 | 13 | 14 | class linux_arch_ident(): 15 | """ 16 | A utility class for identifying Linux architecture, specifically aimed at ARM architectures. 17 | 18 | This class provides methods to check if the current system is running on an ARM architecture 19 | and to determine the specific model of the ARM CPU. 20 | """ 21 | @staticmethod 22 | def is_arm() -> bool: 23 | """ 24 | Determines if the current system's architecture is ARM-based. 25 | 26 | :return: True if the system is ARM-based, False otherwise. 27 | """ 28 | return platform.machine().lower().startswith('arm') or platform.machine().lower().startswith('aarch64') 29 | 30 | @staticmethod 31 | def get_arm_model() -> ARCH_CPU: 32 | """ 33 | Fetches the model name of the ARM CPU from the system and maps it to a predefined ARCH_CPU enum. 34 | 35 | This method attempts to execute the 'lscpu' command to retrieve CPU information, then parses 36 | the output to identify the model name of the ARM CPU. 37 | 38 | :return: An ARCH_CPU enum value corresponding to the ARM CPU model. 39 | :raises Exception: If there's an error executing 'lscpu', or if the CPU model cannot be recognized. 40 | """ 41 | try: 42 | # Execute 'lscpu' command and decode the output to get CPU information 43 | cpuinfo = (subprocess.check_output("lscpu", shell=True).strip()).decode().lower() 44 | # Parse the output to find the model name 45 | s = cpuinfo.index("model name:") 46 | s += len("model name:") 47 | n = cpuinfo[s:].index("\n") 48 | model_n = cpuinfo[s:s+n].strip() 49 | except: 50 | # Raise an exception if 'lscpu' command fails or if parsing fails 51 | raise Exception("Error occurred while running 'lscpu', please check if your OS supports this command.") 52 | 53 | # Map the model name to a specific ARCH_CPU enum value 54 | if model_n.__contains__("cortex-a55"): 55 | return ARCH_CPU.ARM_A55 56 | elif model_n.__contains__("cortex-a76"): 57 | return ARCH_CPU.ARM_A76 58 | # Raise an exception if the model name does not match known values 59 | raise Exception("Invalid architecture name obtained: {}.".format(model_n)) 60 | 61 | 62 | def check_cpu_instruction_support(search_term): 63 | """ 64 | Checks if the CPU supports a specific instruction set. 65 | 66 | This function utilizes the `cpuinfo` library to fetch detailed CPU information, 67 | then searches for a specific term within the CPU flags to determine if a given 68 | CPU instruction set is supported. 69 | 70 | Args: 71 | search_term (str): The CPU instruction set or feature to search for, e.g., "sse4_2", "avx2". 72 | 73 | Returns: 74 | bool: True if the search term is found within the CPU flags, indicating support for the instruction set. 75 | False otherwise. 76 | 77 | Example: 78 | >>> check_cpu_instruction_support("sse4_2") 79 | True # Indicates that "sse4_2" is supported by the CPU. 80 | 81 | Note: 82 | The function prints the entire CPU information fetched by `cpuinfo.get_cpu_info()` which can be quite verbose. 83 | Consider commenting out the `print` statement for production use. 84 | """ 85 | import cpuinfo 86 | print(cpuinfo.get_cpu_info()) 87 | # Check if the search term is present in the CPU flags 88 | if search_term in cpuinfo.get_cpu_info()["flags"]: 89 | return True 90 | else: 91 | return False 92 | -------------------------------------------------------------------------------- /bitorch_engine/utils/cpp_extension.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, Any 3 | from torch.utils.cpp_extension import CppExtension, IS_MACOS 4 | from bitorch_engine.extensions import EXTENSION_PREFIX 5 | from bitorch_engine.utils.arch_helper import linux_arch_ident, ARCH_CPU, check_cpu_instruction_support 6 | 7 | 8 | def get_cpp_extension(root_path: Path, relative_name: str, relative_sources) -> Any: 9 | return CppExtension( 10 | name=EXTENSION_PREFIX + relative_name, 11 | sources=[str(root_path / rel_path) for rel_path in relative_sources], 12 | **get_kwargs() 13 | ) 14 | 15 | 16 | def get_kwargs() -> Dict[str, Any]: 17 | """ 18 | Generates keyword arguments for compilation settings, tailored for specific system architectures and operating systems. 19 | 20 | This function configures compiler flags and arguments based on the operating system and CPU architecture. It ensures 21 | that the appropriate flags are used for MacOS, ARM architectures, and potentially checks for specific CPU instruction 22 | set support (commented out by default). 23 | 24 | Returns: 25 | A dictionary containing: 26 | - `include_dirs`: A list of directories for the compiler to look for header files. 27 | - `libraries`: A list of libraries to link against. This varies between MacOS (`omp`) and other systems (`gomp`). 28 | - `extra_compile_args`: A list of additional arguments to pass to the compiler. This includes flags for warnings, 29 | OpenMP support, and architecture-specific optimizations. 30 | 31 | Note: 32 | - The function checks if the operating system is MacOS and adjusts the compilation flags accordingly. 33 | - For ARM architectures on Linux (not MacOS), it adds flags for ARMv8.2-A features and sets the CPU model 34 | for further optimizations if the model is detected as ARM A55 or A76. 35 | - The commented code snippet shows how to conditionally add compilation flags based on CPU instruction support, 36 | such as AVX2. 37 | """ 38 | extra_compile_args = [ 39 | '-Wall', 40 | '-Wno-deprecated-register', 41 | ] 42 | if IS_MACOS: 43 | extra_compile_args.append('-Xpreprocessor') 44 | 45 | extra_compile_args.append('-fopenmp') 46 | 47 | if linux_arch_ident.is_arm() and not IS_MACOS: 48 | extra_compile_args.append('-march=armv8.2-a+fp16+dotprod') 49 | if linux_arch_ident.get_arm_model() is ARCH_CPU.ARM_A55: 50 | extra_compile_args.append('-mcpu=cortex-a55') 51 | if linux_arch_ident.get_arm_model() is ARCH_CPU.ARM_A76: 52 | extra_compile_args.append('-mcpu=cortex-a76') 53 | 54 | ## can use this code to check the cpu support for instructions 55 | # if check_cpu_instruction_support('avx2'): 56 | # extra_compile_args.append('-mavx2') 57 | 58 | return { 59 | "include_dirs": [ 60 | "/usr/local/opt/llvm/include", 61 | ], 62 | "libraries": [ 63 | "omp" if IS_MACOS else "gomp", 64 | ], 65 | "extra_compile_args": extra_compile_args 66 | } 67 | -------------------------------------------------------------------------------- /bitorch_engine/utils/cuda_extension.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import warnings 4 | from pathlib import Path 5 | from typing import Dict, Any 6 | 7 | from torch.utils.cpp_extension import CUDAExtension 8 | 9 | from bitorch_engine.extensions import EXTENSION_PREFIX 10 | 11 | 12 | SUPPORTED_CUDA_ARCHS = ["sm_80", "sm_75"] 13 | 14 | 15 | def get_cuda_arch(): 16 | cuda_arch = os.environ.get("BIE_CUDA_ARCH", "sm_80") 17 | if cuda_arch not in SUPPORTED_CUDA_ARCHS: 18 | warnings.warn(f"Warning: BIE_CUDA_ARCH={cuda_arch} may not be supported yet.") 19 | return cuda_arch 20 | 21 | 22 | def get_cuda_extension(root_path: Path, relative_name: str, relative_sources) -> Any: 23 | if os.environ.get("BIE_BUILD_SEPARATE_CUDA_ARCH", "false") == "true": 24 | relative_name = f"{relative_name}-{get_cuda_arch()}" 25 | return CUDAExtension( 26 | name=EXTENSION_PREFIX + relative_name, 27 | sources=[str(root_path / rel_path) for rel_path in relative_sources], 28 | **get_kwargs() 29 | ) 30 | 31 | 32 | def gcc_version(): 33 | """ 34 | Determines the version of GCC (GNU Compiler Collection) installed on the system. 35 | 36 | This function executes the 'gcc --version' command using subprocess.run and parses the output 37 | to extract the GCC version number. If GCC is not found or an error occurs during parsing, 38 | it returns a default version number of 0.0.0. 39 | 40 | The function checks if the output contains the string "clang" to identify if clang is masquerading 41 | as gcc, in which case it also returns 0.0.0, assuming GCC is not installed. 42 | 43 | Returns: 44 | tuple: A tuple containing three integers (major, minor, patch) representing the version of GCC. 45 | Returns (0, 0, 0) if GCC is not found or an error occurs. 46 | """ 47 | output = subprocess.run(['gcc', '--version'], check=True, capture_output=True, text=True) 48 | if output.returncode > 0 or "clang" in output.stdout: 49 | return 0, 0, 0 50 | first_line = output.stdout.split("\n")[0] 51 | try: 52 | version = first_line.split(" ")[-1] 53 | major, minor, patch = list(map(int, version.split("."))) 54 | return major, minor, patch 55 | except: 56 | return 0, 0, 0 57 | 58 | 59 | def get_kwargs() -> Dict[str, Any]: 60 | """ 61 | Generates keyword arguments for compilation based on the GCC version and CUDA architecture. 62 | 63 | This function dynamically constructs a dictionary of extra compilation arguments 64 | for both C++ and CUDA (nvcc) compilers. It includes flags to suppress deprecated 65 | declarations warnings, specify the OpenMP library path, and set the CUDA architecture. 66 | Additionally, it conditionally adjusts the nvcc host compiler to GCC 11 if the detected 67 | GCC version is greater than 11. 68 | 69 | Returns: 70 | Dict[str, Any]: A dictionary containing the 'extra_compile_args' key with nested 71 | dictionaries for 'cxx' and 'nvcc' compilers specifying their 72 | respective extra compilation arguments. 73 | """ 74 | # Retrieve the current GCC version to determine compatibility and required flags 75 | major, minor, patch = gcc_version() 76 | 77 | # Initialize the base kwargs dict with common compile arguments for cxx and nvcc 78 | kwargs = { 79 | "extra_compile_args": { 80 | "cxx": [ 81 | "-Wno-deprecated-declarations", # Suppress warnings for deprecated declarations 82 | "-L/usr/lib/gcc/x86_64-pc-linux-gnu/10.3.0/libgomp.so", # Specify libgomp library path for OpenMP 83 | "-fopenmp", # Enable OpenMP support 84 | ], 85 | "nvcc": [ 86 | "-Xcompiler", 87 | "-fopenmp", # Pass -fopenmp to the host compiler via nvcc 88 | f"-arch={get_cuda_arch()}", # Set CUDA architecture, default to 'sm_80' 89 | "-DARCH_SM_75" if get_cuda_arch() == 'sm_75' else "-DARCH_SM_80", # set architecture macro 90 | ], 91 | }, 92 | } 93 | # If the GCC version is greater than 11, adjust the nvcc host compiler settings 94 | if major > 11: 95 | print("Using GCC 11 host compiler for nvcc.") 96 | # Specify GCC 11 as the host compiler for nvcc: 97 | kwargs["extra_compile_args"]["nvcc"].append("-ccbin=/usr/bin/gcc-11") 98 | return kwargs 99 | -------------------------------------------------------------------------------- /bitorch_engine/utils/cutlass_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Union, Tuple 4 | 5 | 6 | def check_path(p: str) -> Tuple[bool, str]: 7 | """ 8 | Checks the provided path for the existence of the 'cutlass.h' file in expected directories. 9 | 10 | This function attempts to locate the 'cutlass.h' header file in three potential locations 11 | relative to the provided path: 12 | 1. Directly within the provided path. 13 | 2. Within a 'cutlass' directory inside the provided path. 14 | 3. Within an 'include/cutlass' directory structure inside the provided path. 15 | 16 | Args: 17 | p (str): The base path as a string where the search for 'cutlass.h' will begin. 18 | 19 | Returns: 20 | Tuple[bool, str]: A tuple containing: 21 | - A boolean indicating whether 'cutlass.h' was found. 22 | - A string representing the path where 'cutlass.h' was found, or an empty string if not found. 23 | """ 24 | p = Path(p) 25 | if (p / "cutlass.h").is_file(): 26 | return True, str(p) 27 | if (p / "cutlass" / "cutlass.h").is_file(): 28 | return True, str(p) 29 | if (p / "include" / "cutlass" / "cutlass.h").is_file(): 30 | return True, str(p / "include") 31 | return False, "" 32 | 33 | 34 | def find_cutlass(check_only: bool = True) -> Union[bool, str]: 35 | """ 36 | Searches for the Cutlass library in predefined and environment-specified directories. 37 | 38 | This function iterates through a list of potential directories where Cutlass might be located. 39 | It checks each directory to see if Cutlass exists there. The search paths include '/usr/local/include' 40 | and any paths specified in the 'CPATH' environment variable. 41 | 42 | Args: 43 | check_only (bool): Determines the behavior of the function upon finding Cutlass. 44 | If True, the function returns a boolean indicating the presence of Cutlass. 45 | If False, it returns the path where Cutlass is found. 46 | 47 | Returns: 48 | Union[bool, str]: Depending on the value of `check_only`, this function either returns: 49 | - A boolean value indicating whether Cutlass was found (True) or not (False). 50 | - The string path to the directory where Cutlass is located. If not found, returns an empty string. 51 | 52 | Note: 53 | The function utilizes `check_path(p)`, a separate function not shown here, to determine 54 | if Cutlass is present in each directory `p`. It is assumed that `check_path(p)` returns 55 | a tuple (bool, str), where the boolean indicates success, and the string represents the path. 56 | """ 57 | success, path = False, "" 58 | search_paths = ["/usr/local/include"] + os.environ.get("CPATH", "").split(":") 59 | for p in search_paths: 60 | if check_only: 61 | print("Searching Cutlass in:", p) 62 | success, path = check_path(p) 63 | if success: 64 | if check_only: 65 | print("Found Cutlass in:", p) 66 | break 67 | return success if check_only else path 68 | 69 | 70 | def is_cutlass_available() -> bool: 71 | return find_cutlass() 72 | 73 | 74 | def get_cutlass_include_path() -> str: 75 | return find_cutlass(check_only=False) 76 | -------------------------------------------------------------------------------- /bitorch_engine/utils/mlx_extension.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from torch.utils.cpp_extension import CppExtension 4 | 5 | from bitorch_engine.extensions import EXTENSION_PREFIX 6 | 7 | from .mlx_path import get_mlx_include_path, get_mlx_lib_path 8 | 9 | 10 | def get_mlx_extension(root_path: Path, relative_name: str, relative_sources) -> CppExtension: 11 | """ 12 | Creates and returns a CppExtension for compiling a C++ extension with MLX library support. 13 | 14 | This function is designed to simplify the configuration of a C++ extension that depends on the MLX library by 15 | automatically setting up include directories, library directories, and other necessary compilation and runtime settings. 16 | 17 | Parameters: 18 | root_path (Path): The root directory path where the C++ source files are located. This path is used to resolve the full paths to the source files specified in `relative_sources`. 19 | relative_name (str): A relative name for the extension. This name is prefixed with a predefined prefix and used as the extension's name. 20 | relative_sources (Iterable[str]): A list or iterable of relative paths to the C++ source files, relative to `root_path`. These source files constitute the extension being compiled. 21 | 22 | Returns: 23 | CppExtension: An instance of CppExtension configured with paths to include directories, library directories, and other settings needed to compile and link the extension with the MLX library. 24 | """ 25 | include_path = get_mlx_include_path() 26 | lib_path = get_mlx_lib_path() 27 | return CppExtension( 28 | name=EXTENSION_PREFIX + relative_name, 29 | sources=[str(root_path / rel_path) for rel_path in relative_sources], 30 | include_dirs=[include_path], 31 | library_dirs=[lib_path], # To find mlx during compilation 32 | runtime_library_dirs=[lib_path], # To find mlx during runtime 33 | libraries=['mlx'], 34 | extra_compile_args=['-std=c++17'], 35 | ) -------------------------------------------------------------------------------- /bitorch_engine/utils/mlx_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from importlib.util import find_spec 4 | from pathlib import Path 5 | from typing import Union 6 | 7 | 8 | def get_mlx_include_path() -> Union[str, None]: 9 | """ 10 | Attempts to find the include path for the mlx library. 11 | 12 | This function searches for the 'mlx.h' header file associated with the mlx library 13 | in various possible locations where the library might be installed. The search follows this order: 14 | 1. Looks within the package's submodule search locations if the mlx package is installed. 15 | 2. Checks the system's default include path under the Python environment's prefix. 16 | 3. Scans paths specified in the 'CPATH' environment variable. 17 | 18 | Returns: 19 | str: The absolute path to the directory containing 'mlx.h' if found. 20 | None: If the 'mlx.h' file cannot be found in any of the searched locations. 21 | """ 22 | filename = "mlx.h" 23 | 24 | mlx_spec = find_spec("mlx") 25 | if mlx_spec is not None: 26 | for search_path in mlx_spec.submodule_search_locations: 27 | path = Path(search_path) / "include" 28 | if path.exists() and ((path / filename).exists() or (path / "mlx" / filename).exists()): 29 | return str(path.resolve()) 30 | 31 | prefix_path = Path(sys.prefix).resolve() 32 | if (prefix_path / "include" / "mlx" / filename).exists(): 33 | return str(prefix_path / "include") 34 | 35 | for path in os.environ.get("CPATH", "").split(":"): 36 | path = Path(path).resolve() 37 | if (path / "include" / filename).exists(): 38 | return str(path / "include") 39 | elif (path / filename).exists(): 40 | return str(path) 41 | 42 | return None 43 | 44 | def get_mlx_lib_path() -> Union[str, None]: 45 | """ 46 | Attempts to find the library path for 'libmlx.dylib'. 47 | 48 | This function searches for the 'libmlx.dylib' file in various locations to determine 49 | the library path for mlx (a hypothetical library). The search follows this order: 50 | 51 | 1. Within the 'mlx' package's installation directory, if the package is installed. 52 | It looks for a 'lib' directory under the package's submodule search locations. 53 | 2. In the 'lib' directory under the Python environment's prefix directory. This is 54 | typically where libraries are installed for the current Python environment. 55 | 3. In directories specified in the 'LIBRARY_PATH' environment variable. This is a 56 | colon-separated list of directories where libraries are searched for on Unix-like 57 | systems. 58 | 59 | The function returns the path as a string if 'libmlx.dylib' is found, or None if the 60 | library cannot be found in any of the searched locations. 61 | 62 | Returns: 63 | str or None: The path to 'libmlx.dylib' if found, otherwise None. 64 | """ 65 | filename = "libmlx.dylib" 66 | 67 | mlx_spec = find_spec("mlx") 68 | if mlx_spec is not None: 69 | for search_path in mlx_spec.submodule_search_locations: 70 | path = Path(search_path) / "lib" 71 | if path.exists() and ((path / filename).exists() or (path / "mlx" / filename).exists()): 72 | return str(path.resolve()) 73 | 74 | prefix_path = Path(sys.prefix).resolve() 75 | if (prefix_path / "lib" / filename).exists(): 76 | return str(prefix_path / "lib") 77 | 78 | for path in os.environ.get("LIBRARY_PATH", "").split(":"): 79 | path = Path(path) 80 | if (path / "lib" / filename).exists(): 81 | return str(path / "lib") 82 | elif (path / filename).exists(): 83 | return str(path) 84 | 85 | return None 86 | 87 | def is_mlx_available() -> bool: 88 | """ 89 | Checks if the MLX library is available for use. 90 | 91 | This function determines the availability of the MLX library by verifying both the include path and the library path of MLX. 92 | It does this by calling two functions: `get_mlx_include_path` and `get_mlx_lib_path`. 93 | For the MLX library to be considered available, both of these functions must return a value that is not `None`. 94 | 95 | Returns: 96 | bool: True if both the MLX include path and the MLX library path are available (i.e., not None), 97 | indicating that the MLX library is available for use. Otherwise, returns False. 98 | """ 99 | return get_mlx_include_path() is not None and get_mlx_lib_path() is not None -------------------------------------------------------------------------------- /bitorch_engine/utils/safe_import.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import Any 3 | 4 | from bitorch_engine.extensions import EXTENSION_PREFIX 5 | 6 | MESSAGE = """The extension '{}' could not be imported. It is either not yet implemented or was not build correctly. 7 | This message is expected during the build process. If it appears later on, try installing the package again.""" 8 | 9 | 10 | class ExtensionModulePlaceholder: 11 | """ 12 | A placeholder class for extension modules. 13 | 14 | This class serves as a placeholder for dynamically loaded extension modules. It is designed to 15 | intercept attribute access and modifications, raising a runtime error if any operation other than 16 | setting the initial name is attempted. This behavior ensures that any misuse of the placeholder, 17 | such as accessing or modifying attributes that are not supposed to exist in the placeholder state, 18 | is promptly identified during development. 19 | 20 | Attributes: 21 | _name (str): The name of the extension module this placeholder represents. 22 | 23 | Methods: 24 | __init__: Initializes the placeholder with the name of the extension module. 25 | __getattr__: Intercepts attribute access attempts, raising a RuntimeError. 26 | __setattr__: Intercepts attribute modification attempts, allowing only the _name attribute to be set. 27 | """ 28 | def __init__(self, name: str) -> None: 29 | """ 30 | Initializes the ExtensionModulePlaceholder with a specified name. 31 | 32 | Args: 33 | name (str): The name of the extension module this placeholder represents. 34 | """ 35 | self._name = name 36 | 37 | def __getattr__(self, item: str) -> Any: 38 | """ 39 | Handles attribute access attempts for the placeholder. 40 | 41 | This method raises a RuntimeError to indicate that the attempted access is invalid, as the 42 | placeholder should not be used for accessing attributes. 43 | 44 | Args: 45 | item (str): The name of the attribute being accessed. 46 | 47 | Returns: 48 | Any: This method does not return but raises RuntimeError instead. 49 | 50 | Raises: 51 | RuntimeError: Indicates an invalid attribute access attempt. 52 | """ 53 | raise RuntimeError(MESSAGE.format(self._name)) 54 | 55 | def __setattr__(self, key: Any, value: Any) -> None: 56 | """ 57 | Handles attribute modification attempts for the placeholder. 58 | 59 | This method allows setting the _name attribute only. Any other attempt to modify attributes 60 | raises a RuntimeError to indicate that the operation is invalid. 61 | 62 | Args: 63 | key (Any): The name of the attribute to be modified. 64 | value (Any): The new value for the attribute. 65 | 66 | Raises: 67 | RuntimeError: Indicates an invalid attribute modification attempt, except for the _name attribute. 68 | """ 69 | if key == "_name": 70 | self.__dict__["_name"] = value 71 | return 72 | raise RuntimeError(MESSAGE.format(self._name)) 73 | 74 | 75 | def import_extension(module_name: str, not_yet_implemented: bool = False) -> Any: 76 | """ 77 | Dynamically imports a Python extension module by name, providing a safe mechanism to handle cases 78 | where the module is not yet built or not yet implemented. This function is particularly useful for 79 | conditionally importing modules that provide optional functionality or are platform-specific. 80 | 81 | If the module is marked as not yet implemented (not_yet_implemented=True), or if the module cannot be 82 | found during import, the function returns a placeholder object instead of raising an ImportError. This 83 | allows the application to continue running and gracefully handle the absence of the module. 84 | 85 | Args: 86 | module_name (str): The name of the module to be imported. The actual module name will be prefixed 87 | with a predefined prefix defined in `EXTENSION_PREFIX` to form the full module name. 88 | not_yet_implemented (bool, optional): A flag indicating whether the module is known to be not yet 89 | implemented. If True, the function immediately returns a placeholder 90 | without attempting to import the module. Defaults to False. 91 | 92 | Returns: 93 | Any: An imported module if successful, or an instance of `ExtensionModulePlaceholder` if the module 94 | is not implemented or cannot be found. 95 | 96 | Example: 97 | binary_linear_cuda = import_extension("binary_linear_cuda") 98 | This example attempts to import a module named "binary_linear_cuda" (prefixed appropriately), 99 | returning the module if found, or a placeholder if not found or not implemented. 100 | 101 | Note: 102 | This function prints a warning message to the console if the module cannot be found, informing the 103 | user of the issue without interrupting the execution of the program. 104 | """ 105 | if not_yet_implemented: 106 | return ExtensionModulePlaceholder(module_name) 107 | 108 | try: 109 | return importlib.import_module(EXTENSION_PREFIX + module_name) 110 | except ModuleNotFoundError: 111 | print("Warning:", MESSAGE.format(module_name)) 112 | return ExtensionModulePlaceholder(module_name) 113 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG FROM_IMAGE=pytorch/manylinux-builder:cuda12.1-2.3 2 | FROM ${FROM_IMAGE} as builder-base 3 | RUN mkdir "/build_scripts" 4 | RUN mkdir "/workspace" 5 | 6 | FROM builder-base as pytorch-base 7 | #COPY "build_scripts/prepare_builder_image.sh" "/build_scripts/" 8 | #RUN bash "/build_scripts/prepare_builder_image.sh" "${FROM_IMAGE}" && \ 9 | # rm "/build_scripts/prepare_builder_image.sh" 10 | 11 | FROM pytorch-base as cutlass-install 12 | ARG CUTLASS_VERSION="2.8.0" 13 | ARG CUTLASS_HOME="/opt/cutlass" 14 | RUN git clone --depth 1 --branch "v${CUTLASS_VERSION}" "https://github.com/NVIDIA/cutlass.git" --recursive "${CUTLASS_HOME}/source" && \ 15 | mkdir "${CUTLASS_HOME}/build" && \ 16 | cd "${CUTLASS_HOME}/build" && \ 17 | cmake ../source \ 18 | -DCMAKE_INSTALL_PREFIX="${CUTLASS_HOME}/install" \ 19 | -DCUTLASS_ENABLE_HEADERS_ONLY=ON \ 20 | -DCUTLASS_ENABLE_TOOLS=ON \ 21 | -DCUTLASS_ENABLE_LIBRARY=OFF \ 22 | -DCUTLASS_ENABLE_PROFILER=OFF \ 23 | -DCUTLASS_NVCC_ARCHS='75;80;86' && \ 24 | cmake --install . && \ 25 | rm -rf "${CUTLASS_HOME}/build" "${CUTLASS_HOME}/source" 26 | 27 | FROM cutlass-install as build-ready 28 | ARG PYTHON_HOME="/opt/python/cp310-cp310" 29 | ENV PATH="${PYTHON_HOME}/bin:${PATH}" 30 | ARG CUSTOM_TORCH_URL="https://packages.greenbit.ai/whl/cu121/torch/torch-2.3.0-cp310-cp310-linux_x86_64.whl" 31 | ARG TORCHVISION_VERSION="0.18.0" 32 | ARG TORCHVISION_INDEX_URL="https://download.pytorch.org/whl/cu121" 33 | RUN pip install "${CUSTOM_TORCH_URL}" && \ 34 | pip install "torchvision==${TORCHVISION_VERSION}" --index-url "${TORCHVISION_INDEX_URL}" && \ 35 | pip cache purge && \ 36 | rm -rf /build_scripts 37 | 38 | # clone instead of mounting makes the code in the image independent from local changes 39 | # to mount your code before building, use the target above, and check the "For Development" section in docs/README.md 40 | FROM build-ready as no-examples 41 | ARG GIT_URL="https://github.com/GreenBitAI/bitorch-engine.git" 42 | ARG GIT_BRANCH="main" 43 | ARG BUILD_TARGET="." 44 | RUN git clone \ 45 | --depth 1 \ 46 | --branch "${GIT_BRANCH}" \ 47 | "${GIT_URL}" \ 48 | /bitorch-engine && \ 49 | cd /bitorch-engine && \ 50 | BIE_FORCE_CUDA="true" CPATH="${CUTLASS_HOME}/install/include" pip install -e ${BUILD_TARGET} -v && \ 51 | rm -rf build/ bitorch_engine.egg-info/ 52 | WORKDIR "/workspace" 53 | 54 | FROM no-examples as example-ready 55 | RUN pip install -r /bitorch-engine/examples/mnist-lightning/requirements.txt 56 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Project Setup with Docker 2 | 3 | ## Build Docker Image 4 | 5 | Use the following commands to build a [Docker](https://www.docker.com/) image with a Bitorch Engine installation. 6 | This is currently only targeted and tested for CUDA 11.8 or 12.1 and _Torch 2.2.x_. 7 | 8 | ```bash 9 | # cd docker 10 | # you should be in this `docker` directory 11 | docker build -t bitorch/engine . 12 | # if you do not want to include installation of example requirements, use this instead: 13 | docker build --target no-examples -t bitorch/engine . 14 | ``` 15 | 16 | After building, the docker image should contain: 17 | - The selected torch package (limited to those that we modified to support gradients for non-floating-point tensors) 18 | - A ready-built bitorch engine, and its requirements 19 | - Everything is installed in a conda environment with Python (currently 3.10) 20 | 21 | ## Build Options 22 | 23 | Depending on your setup, you may want to adjust some options through build arguments: 24 | - CUDA version, e.g. for CUDA 11.8 add 25 | - `--build-arg FROM_IMAGE="pytorch/manylinux-builder:cuda11.8-2.3"` 26 | - `--build-arg CUSTOM_TORCH_URL="https://packages.greenbit.ai/whl/cu118/torch/torch-2.3.0-cp310-cp310-linux_x86_64.whl"` 27 | - `--build-arg TORCHVISION_INDEX_URL="https://download.pytorch.org/whl/cu118"` 28 | - repository URL, e.g. add `--build-arg GIT_URL="https://accesstoken:tokenpassword@github.com/MyFork/bitorch-engine.git"` 29 | - Bitorch Engine branch or tag, e.g. add `--build-arg GIT_BRANCH="v1.2.3"` 30 | - installing requirements for development, e.g. `--build-arg BUILD_TARGET=".[dev]"` 31 | - if there is a problem, set the environment variable `BUILDKIT_PROGRESS=plain` to see all output 32 | 33 | Here is an example: 34 | ```bash 35 | BUILDKIT_PROGRESS=plain docker build -t bitorch/engine --build-arg BUILD_TARGET=".[dev]" --build-arg GIT_BRANCH="mybranch" . 36 | ``` 37 | 38 | ## Run Docker Container 39 | 40 | After building the image you can run a container based on it with: 41 | ```bash 42 | docker run -it --rm --gpus all bitorch/engine:latest 43 | ``` 44 | 45 | ## For Development 46 | 47 | A docker image without the code cloned, e.g. for mounting a local copy of the code, can be made easily with the target `build-ready`: 48 | ```bash 49 | # cd docker 50 | # you should be in this `docker` directory 51 | docker build -t bitorch/engine:build-ready --target build-ready . 52 | docker run -it --rm --gpus all --volume "$(pwd)/..":/bitorch-engine bitorch/engine:build-ready 53 | # in docker container: 54 | cd /bitorch-engine 55 | pip install -e ".[dev]" -v 56 | ``` 57 | However, this means the build results will not be persisted in the image, so you probably want to mount the same directory every time. 58 | -------------------------------------------------------------------------------- /docker/build_scripts/install_modified_pytorch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | from_image="${1}" 4 | action="${2:-install}" 5 | 6 | function usage() { 7 | echo "./install_modified_pytorch.sh DOCKER_IMAGE ACTION" 8 | echo "verify or install a modified torch version suitable for the chosen docker image" 9 | echo 10 | echo "ACTION can be 'install' or 'verify' and is optional (default: install)" 11 | } 12 | 13 | gdrive_id="unknown" 14 | file="custom_torch.whl" 15 | 16 | ## list of known docker images and the corresponding google drive id to download modified torch packages 17 | ## adding them here individually is tedious, but we need to build them manually and ensure compatibility anyway 18 | 19 | if [ "${from_image}" == "pytorch/pytorch:2.2.0-cuda11.8-cudnn8-devel" ]; then 20 | gdrive_id="1PoVor85-RF3s0KpOP19mFV5hNUnHERa1" 21 | file="torch-2.2.2-cp310-cp310-linux_x86_64.whl" 22 | checksum="6646519e5e7b4af8f99b79eb9be3e6460b0d05c4695bbf86de02568f37ff3fea" 23 | fi 24 | if [ "${from_image}" == "pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel" ]; then 25 | gdrive_id="1LjFNImboq8QeFSompMS2gPjBRYtP2Dsz" 26 | file="torch-2.2.2-cp310-cp310-linux_x86_64.whl" 27 | checksum="bcc0ba7f121ee2f42ed0a59f01d4e3d70f82a8981be0be25c5e0fe0635a54b2d" 28 | fi 29 | #if [ "${from_image}" == "pytorch/pytorch:X.X.X-cudaXX.X-cudnn8-devel" ]; then 30 | # gdrive_id="xxx" 31 | # file="torch-X.X.X-cp310-cp310-linux_x86_64.whl" 32 | # checksum="xxx" 33 | #fi 34 | 35 | function check_error() { 36 | # shows and then runs a command. if the exit code is not zero, aborts the script 37 | # usage: check_error mv foo bar 38 | 39 | echo + $@ 40 | "$@" 41 | local exit_code=$? 42 | if [ "${exit_code}" -ne 0 ]; then 43 | echo "! > An error occured, aborting." 44 | exit 1 45 | fi 46 | } 47 | 48 | if [ "${gdrive_id}" == "unknown" ]; then 49 | echo "Unknown image '${from_image}', could not choose modified torch accordingly." 50 | echo "Please add your base image to install_modified_pytorch.sh or request official support for your image via Github." 51 | echo 52 | usage 53 | exit 1 54 | fi 55 | 56 | check_error pip install gdown 57 | check_error gdown "${gdrive_id}" -O "${file}" 58 | check_error pip uninstall -y gdown 59 | 60 | if [ -n "${checksum}" ]; then 61 | check_error sha256sum --check --status <<< "${checksum} ${file}" 62 | fi 63 | 64 | if [ "${action}" == "verify" ]; then 65 | exit 0 66 | fi 67 | 68 | check_error pip install "${file}" 69 | check_error rm "${file}" 70 | check_error pip cache purge 71 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | source/_autosummary/ 2 | build/ 3 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Bitorch Engine Documentation 2 | 3 | ## Requirements 4 | 5 | First you should install BITorch Engine as normal. 6 | Then, install the additional requirements with: 7 | ```bash 8 | # in the docs folder 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ## Magic Build Script 13 | 14 | The script `make_docs.sh` will try to automagically build the documentation for you: 15 | ```bash 16 | ./docs/make_docs.sh 17 | ``` 18 | If there is a problem, see the script and manual steps below. 19 | 20 | ## Manual Build 21 | 22 | The docs for `bitorch_engine` are generated using the [sphinx](https://www.sphinx-doc.org/en/master/>) package. 23 | To build the docs, `cd` into the repository root and execute. 24 | 25 | ```bash 26 | sphinx-build -b html docs/source/ docs/build/ -a 27 | ``` 28 | 29 | The generated `HTML` files will be put into `docs/build`. 30 | 31 | ## Synchronize from Readmes 32 | 33 | To synchronize information from the Readme, we can use pandoc to convert the markdown file to RST: 34 | ```bash 35 | pip install pandoc 36 | pandoc --from=markdown --to=rst --output=README.rst README.md 37 | ``` 38 | 39 | Afterward, we need to fix a few issues, such as incorrect URLs, collapsible sections, etc. 40 | and then move those sections to the appropriate place in the documentation. 41 | You can try automatically doing so by running `python docs/scripts/convert_docs.py README.rst` before building with sphinx. 42 | -------------------------------------------------------------------------------- /docs/make_docs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if ! [ -d "docs" ]; then 4 | cd .. 5 | fi 6 | if ! [ -d "docs" ]; then 7 | echo "Could not locate docs directory. Please run the script from the root or the docs directory." 8 | exit 1 9 | fi 10 | 11 | function package_required() { 12 | pip freeze | grep "${1}" &> /dev/null 13 | if ! [ $? == "0" ]; then 14 | echo "Package '${1}' not found. Please install it, e.g. with: $ pip install ${1}" 15 | exit 1 16 | fi 17 | } 18 | # check new packages are installed 19 | package_required pandoc 20 | package_required sphinx_design 21 | 22 | pandoc --from=markdown --to=rst --output=README.rst README.md 23 | python docs/scripts/convert_docs.py README.rst 24 | sphinx-build -b html docs/source/ docs/build/ -a 25 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | autodocsumm 2 | nbsphinx 3 | pandoc 4 | sphinx_rtd_theme 5 | sphinx_gallery 6 | sphinx_toolbox 7 | sphinx_design 8 | -------------------------------------------------------------------------------- /docs/scripts/convert_docs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | class Converter: 5 | def __init__(self, verbose=False): 6 | self.active = False 7 | self.verbose = verbose 8 | 9 | def process_line(self, lines, line_num): 10 | self.check_start_stop(lines, line_num) 11 | if self.active: 12 | self.internal_process_line(lines, line_num) 13 | 14 | def check_start_stop(self, lines, line_num): 15 | raise NotImplementedError("Should be implemented in subclass") 16 | 17 | def internal_process_line(self, lines, line_num): 18 | raise NotImplementedError("Should be implemented in subclass") 19 | 20 | def replace_line(self, lines, line_num, new_line): 21 | if self.verbose: 22 | print(lines[line_num].rstrip()) 23 | print("vvv") 24 | print(new_line.rstrip()) 25 | lines[line_num] = new_line 26 | 27 | 28 | class Replacer(Converter): 29 | def __init__(self, replacements, **kwargs): 30 | super().__init__(**kwargs) 31 | self.replacements = replacements 32 | self.keys = [pair[0] for pair in replacements] 33 | 34 | def check_start_stop(self, lines, line_num): 35 | if any(key in lines[line_num] for key in self.keys): 36 | self.active = True 37 | return 38 | self.active = False 39 | 40 | def internal_process_line(self, lines, line_num): 41 | for pair in self.replacements: 42 | new_line = lines[line_num].replace(*pair) 43 | self.replace_line(lines, line_num, new_line) 44 | 45 | 46 | class SectionExtractor(Converter): 47 | heading_levels = ["=", "-", "~", "^"] 48 | 49 | def __init__(self, heading_contains, modify_by=-1, **kwargs): 50 | super().__init__(**kwargs) 51 | self.heading_contains = heading_contains 52 | self.modify_by = modify_by 53 | self.level_stop = False 54 | self.start_line = -1 55 | self.end_line = -1 56 | 57 | @staticmethod 58 | def determine_level(line): 59 | line = line.rstrip() 60 | if len(line) == 0: 61 | return False 62 | first_symbol = line[0] 63 | if not first_symbol in SectionExtractor.heading_levels: 64 | return False 65 | if all(c == first_symbol for c in line): 66 | return SectionExtractor.heading_levels.index(first_symbol) 67 | return False 68 | 69 | def check_start_stop(self, lines, line_num): 70 | level = self.determine_level(lines[line_num]) 71 | if level is not False and level == self.level_stop and self.active: 72 | if self.verbose: 73 | print("Inactive:", lines[line_num]) 74 | self.active = False 75 | self.end_line = line_num - 1 76 | if self.heading_contains in lines[line_num]: 77 | if self.verbose: 78 | print("Active:", lines[line_num]) 79 | self.active = True 80 | self.start_line = line_num 81 | 82 | def internal_process_line(self, lines, line_num): 83 | level = self.determine_level(lines[line_num]) 84 | if level is False: 85 | return 86 | if self.level_stop is False: 87 | self.level_stop = level 88 | new_line = lines[line_num].replace( 89 | self.heading_levels[level], 90 | self.heading_levels[level + self.modify_by], 91 | ) 92 | self.replace_line(lines, line_num, new_line) 93 | 94 | def write_to_file(self, lines, output_file, mode="w"): 95 | with open(output_file, mode) as f: 96 | f.writelines(lines[self.start_line:self.end_line]) 97 | 98 | 99 | class SectionCollapseFixer(Converter): 100 | heading_levels = ["=", "-", "~", "^"] 101 | 102 | def __init__(self, **kwargs): 103 | super().__init__(**kwargs) 104 | 105 | def check_start_stop(self, lines, line_num): 106 | line = lines[line_num] 107 | if ".. raw:: html" in line and "
" in lines[line_num + 2]: 108 | if self.verbose: 109 | print("Active:", line) 110 | self.active = True 111 | if "
" in line: 112 | if self.verbose: 113 | print("Inactive:", line) 114 | self.active = False 115 | self.replace_line(lines, line_num, "") 116 | 117 | def internal_process_line(self, lines, line_num): 118 | stripped_line = lines[line_num].strip() 119 | if stripped_line in ["
", ".. raw:: html", "
", ""]: 120 | self.replace_line(lines, line_num, "") 121 | return 122 | if stripped_line == "": 123 | summary_line = line_num + 1 124 | while lines[summary_line].strip() == "": 125 | summary_line += 1 126 | self.replace_line(lines, line_num, ".. dropdown:: " + lines[summary_line]) 127 | for i in range(line_num + 1, summary_line + 1): 128 | self.replace_line(lines, i, "") 129 | return 130 | self.replace_line(lines, line_num, " " + lines[line_num]) 131 | 132 | def main(args): 133 | # set build_options_separate=False to integrate build options into installation part: 134 | build_options_separate = True 135 | verbose = False 136 | 137 | install_section = SectionExtractor("Installation", verbose=verbose) 138 | build_section = SectionExtractor("Build options", modify_by=-1 if build_options_separate else 0, verbose=verbose) 139 | section_collapse_fixer = SectionCollapseFixer(verbose=verbose) 140 | replacer = Replacer([ 141 | [ 142 | r"`docker readme `__", 143 | r"`docker readme `__", 144 | ], 145 | ], verbose=verbose) 146 | 147 | content = None 148 | with open(args.input, "r") as f: 149 | content = f.readlines() 150 | if content is None: 151 | return 152 | 153 | i = 0 154 | while i < len(content): 155 | for converter in [install_section, build_section, section_collapse_fixer, replacer]: 156 | converter.process_line(content, i) 157 | i += 1 158 | 159 | install_section.write_to_file(content, "docs/source/installation.rst") 160 | if build_options_separate: 161 | build_section.write_to_file(content, "docs/source/build_options.rst") 162 | else: 163 | build_section.write_to_file(content, "docs/source/installation_test.rst", mode="a") 164 | 165 | 166 | if __name__ == "__main__": 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument("input", help="Input file") 169 | main(parser.parse_args()) 170 | -------------------------------------------------------------------------------- /docs/source/_templates/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :no-inherited-members: 8 | :special-members: __call__, __add__, __mul__, forward, __init__ 9 | 10 | {% block methods %} 11 | {% if methods %} 12 | .. rubric:: {{ _('Methods') }} 13 | 14 | .. autosummary:: 15 | :nosignatures: 16 | {% for item in methods %} 17 | 18 | {%- if item.startswith('__init__') %} 19 | {%- if item not in inherited_members %} 20 | ~{{ name }}.{{ item }} 21 | {%- endif -%} 22 | {%- endif -%} 23 | {%- if not item.startswith('_') %} 24 | {%- if item not in inherited_members %} 25 | ~{{ name }}.{{ item }} 26 | {%- endif -%} 27 | {%- endif -%} 28 | {%- endfor %} 29 | {% endif %} 30 | {% endblock %} 31 | 32 | {% block attributes %} 33 | {% if attributes %} 34 | .. rubric:: {{ _('Attributes') }} 35 | 36 | .. autosummary:: 37 | {% for item in attributes %} 38 | {%- if item not in inherited_members %} 39 | ~{{ name }}.{{ item }} 40 | {%- endif -%} 41 | {%- endfor %} 42 | {% endif %} 43 | {% endblock %} 44 | -------------------------------------------------------------------------------- /docs/source/_templates/module.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block attributes %} 6 | {% if attributes %} 7 | .. rubric:: Module attributes 8 | 9 | .. autosummary:: 10 | :toctree: 11 | {% for item in attributes %} 12 | {{ item }} 13 | {%- endfor %} 14 | {% endif %} 15 | {% endblock %} 16 | 17 | {% block functions %} 18 | {% if functions and fullname != 'bitorch_engine' %} 19 | .. rubric:: {{ _('Functions') }} 20 | 21 | .. autosummary:: 22 | :toctree: 23 | :nosignatures: 24 | {% for item in functions %} 25 | {{ item }} 26 | {%- endfor %} 27 | {% endif %} 28 | {% endblock %} 29 | 30 | {% block classes %} 31 | {% if classes %} 32 | .. rubric:: {{ _('Classes') }} 33 | 34 | .. autosummary:: 35 | :template: class.rst 36 | :toctree: 37 | {% for item in classes %} 38 | {{ item }} 39 | {%- endfor %} 40 | {% endif %} 41 | {% endblock %} 42 | 43 | {% block exceptions %} 44 | {% if exceptions %} 45 | .. rubric:: {{ _('Exceptions') }} 46 | 47 | .. autosummary:: 48 | :toctree: 49 | {% for item in exceptions %} 50 | {{ item }} 51 | {%- endfor %} 52 | {% endif %} 53 | {% endblock %} 54 | 55 | {% block modules %} 56 | {% if modules %} 57 | .. rubric:: {{ _('Modules') }} 58 | 59 | .. autosummary:: 60 | :toctree: 61 | :template: module.rst 62 | :recursive: 63 | {% for item in modules %} 64 | {% if not item.endswith('.extensions') %}{{ item }}{%- endif %} 65 | {%- endfor %} 66 | {% endif %} 67 | {% endblock %} 68 | -------------------------------------------------------------------------------- /docs/source/build_options.rst: -------------------------------------------------------------------------------- 1 | Build options 2 | ============= 3 | 4 | Building Specific Extensions 5 | ---------------------------- 6 | 7 | While developing, a specific cpp/cuda extension can be (re-)build, by 8 | using the environment variable ``BIE_BUILD_ONLY``, like so: 9 | 10 | .. code:: bash 11 | 12 | BIE_BUILD_ONLY="bitorch_engine/layers/qlinear/binary/cpp" pip install -e . -v 13 | 14 | It needs to a relative path to one extension directory. 15 | 16 | Building for a Specific CUDA Architecture 17 | ----------------------------------------- 18 | 19 | To build for a different CUDA Arch, use the environment variable 20 | ``BIE_CUDA_ARCH`` (e.g. use ‘sm_75’, ‘sm_80’, ‘sm_86’): 21 | 22 | .. code:: bash 23 | 24 | BIE_CUDA_ARCH="sm_86" pip install -e . -v 25 | 26 | Force Building CUDA Modules 27 | --------------------------- 28 | 29 | If you have CUDA development libraries installed, but 30 | ``torch.cuda.is_available()`` is False, e.g. in HPC or docker 31 | environments, you can still build the extensions that depend on CUDA, by 32 | setting ``BIE_FORCE_CUDA="true"``: 33 | 34 | .. code:: bash 35 | 36 | BIE_FORCE_CUDA="true" pip install -e . -v 37 | 38 | Skip Library File Building 39 | -------------------------- 40 | 41 | If you just want to avoid rebuilding any files, you can set 42 | ``BIE_SKIP_BUILD``: 43 | 44 | .. code:: bash 45 | 46 | BIE_SKIP_BUILD="true" python3 -m build --no-isolation --wheel 47 | 48 | This would create a wheel and package ``.so`` files without trying to 49 | rebuild them. 50 | 51 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | import os 7 | import sys 8 | from pathlib import Path 9 | 10 | sys.path.insert(0, os.path.abspath('../../')) 11 | 12 | # -- Project information ----------------------------------------------------- 13 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 14 | 15 | project = 'Bitorch Engine' 16 | copyright = '2024, Haojin Yang, Joseph Bethge, Nianhui Guo, Maximilian Schulze, Hong Guo, Paul Mattes' 17 | author = 'Haojin Yang, Joseph Bethge, Nianhui Guo, Maximilian Schulze, Hong Guo, Paul Mattes' 18 | 19 | root_path = Path(__file__).resolve().parent.parent.parent 20 | release = "unknown" 21 | version_file = root_path / "version.txt" 22 | if version_file.exists(): 23 | with open(root_path / "version.txt") as handle: 24 | version_content = handle.read().strip() 25 | if version_content: 26 | release = version_content 27 | 28 | 29 | # -- General configuration --------------------------------------------------- 30 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 31 | 32 | extensions = [ 33 | "sphinx.ext.autodoc", 34 | "sphinx.ext.autosummary", 35 | "sphinx.ext.doctest", 36 | "sphinx.ext.todo", 37 | "sphinx.ext.mathjax", 38 | "sphinx.ext.ifconfig", 39 | "sphinx.ext.viewcode", 40 | "sphinx.ext.githubpages", 41 | "sphinx.ext.napoleon", 42 | "autodocsumm", 43 | "sphinx.ext.intersphinx", 44 | "sphinx_toolbox.code", 45 | "sphinx_toolbox.collapse", 46 | "sphinx_design", 47 | ] 48 | 49 | # List of patterns, relative to source directory, that match files and 50 | # directories to ignore when looking for source files. 51 | # This pattern also affects html_static_path and html_extra_path. 52 | exclude_patterns = ['.so', '_build', 'Thumbs.db', '.DS_Store'] 53 | 54 | # Add any paths that contain templates here, relative to this directory. 55 | templates_path = ['_templates'] 56 | autosummary_generate = True 57 | 58 | # -- Options for HTML output ------------------------------------------------- 59 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 60 | 61 | import sphinx_rtd_theme 62 | # html_theme = 'alabaster' 63 | html_theme = 'sphinx_rtd_theme' 64 | # html_static_path = ['_static'] 65 | -------------------------------------------------------------------------------- /docs/source/documentation.rst: -------------------------------------------------------------------------------- 1 | Full Documentation 2 | ================== 3 | 4 | .. autosummary:: 5 | :toctree: _autosummary 6 | :template: module.rst 7 | :recursive: 8 | 9 | bitorch_engine 10 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | 2 | Welcome to Bitorch Engine's documentation! 3 | ========================================== 4 | 5 | Welcome to the documentation of Bitorch Engine (BIE): a cutting-edge computation library for neural networks that enhances PyTorch by integrating specialized layers and functions for Low-Bit quantized neural network operations. This is where you can find all the information you need about how to use BIE. 6 | 7 | Building on the foundational strengths of Bitorch Engine, the technology has been employed in pioneering projects that push the boundaries of neural network training and inference. For instance, 8 | 9 | - `green-bit-llm-trainer `_: In this project, BIE represents a significant leap in the field of Large Language Model (LLM) fine-tuning. Unlike traditional approaches that either quantize a fully trained model or introduce a few additional trainable parameters for `LoRA `_ style fine-tuning, this project innovates by directly fine-tuning the quantized parameters of LLMs. This paradigm shift allows for the full-scale quantization fine-tuning of LLMs, ensuring that the training process tightly integrates with the quantization schema from the outset. 10 | - `green-bit-llm-inference `_ also showcases the BIE's adeptness at supporting inference for models quantized from 4 to 2-bits without any significant loss in accuracy compared to the original 32 or 16-bits models. It stands as a testament to BIE's capability to maintain the delicate balance between model size, computational efficiency, and accuracy, addressing one of the key challenges in deploying sophisticated neural networks in resource-constrained environments. 11 | 12 | All changes are tracked in the `changelog `_. 13 | 14 | 15 | .. toctree:: 16 | :maxdepth: 4 17 | :caption: Contents: 18 | 19 | installation 20 | build_options 21 | documentation 22 | 23 | Indices and tables 24 | ================== 25 | 26 | * :ref:`genindex` 27 | * :ref:`modindex` 28 | * :ref:`search` 29 | 30 | Enjoy exploring our documentation! 31 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/examples/__init__.py -------------------------------------------------------------------------------- /examples/mnist-lightning/mlp.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, hidden_features=1024, num_layers=2, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | assert 2 <= num_layers <= 3, "MLP currently only supports 2 or 3 layers." 10 | self.num_layers = num_layers 11 | self.flatten = nn.Flatten() 12 | 13 | self.first_linear = nn.Linear(28*28, hidden_features) 14 | self.act1 = nn.PReLU() 15 | self.bn1 = nn.BatchNorm1d(num_features=hidden_features) 16 | 17 | self.linear1 = nn.Linear(hidden_features, hidden_features) 18 | self.act2 = nn.PReLU() 19 | self.bn2 = nn.BatchNorm1d(num_features=hidden_features) 20 | 21 | if num_layers > 2: 22 | self.linear2 = nn.Linear(hidden_features, hidden_features) 23 | self.act3 = nn.PReLU() 24 | self.bn3 = nn.BatchNorm1d(num_features=hidden_features) 25 | 26 | self.last_linear = nn.Linear(hidden_features, 10) 27 | 28 | def forward(self, x): 29 | x = self.flatten(x) 30 | 31 | x = self.first_linear(x) 32 | x = self.bn1(x) 33 | x = self.act1(x) 34 | 35 | x = self.linear1(x) 36 | x = self.bn2(x) 37 | x = self.act2(x) 38 | 39 | if self.num_layers > 2: 40 | x = self.linear2(x) 41 | x = self.bn3(x) 42 | x = self.act3(x) 43 | 44 | return self.last_linear(x) 45 | 46 | 47 | class SequentialMLP(nn.Module): 48 | def __init__(self, hidden_features = 2048, *args, **kwargs): 49 | super().__init__(*args, **kwargs) 50 | self.flatten = nn.Flatten() 51 | self.first_linear = nn.Linear(28*28, hidden_features) 52 | self.body = nn.Sequential( 53 | nn.BatchNorm1d(num_features=hidden_features), 54 | nn.ReLU(), 55 | nn.Linear(hidden_features, hidden_features), 56 | nn.BatchNorm1d(num_features=hidden_features), 57 | nn.ReLU(), 58 | nn.Linear(hidden_features, hidden_features), 59 | nn.BatchNorm1d(num_features=hidden_features), 60 | nn.ReLU(), 61 | ) 62 | self.last_linear = nn.Linear(hidden_features, 10) 63 | 64 | def forward(self, x): 65 | x = self.flatten(x) 66 | x = self.first_linear(x) 67 | x = self.body(x) 68 | return self.last_linear(x) 69 | -------------------------------------------------------------------------------- /examples/mnist-lightning/requirements.txt: -------------------------------------------------------------------------------- 1 | bitorch 2 | galore-torch 3 | lightning 4 | -------------------------------------------------------------------------------- /examples/mnist/README.md: -------------------------------------------------------------------------------- 1 | # Example for MNIST 2 | 3 | In this example script we train a simple model for the MNIST dataset using [Bitorch Engine](https://github.com/GreenBitAI/bitorch-engine). 4 | 5 | First the requirements for this example need to be installed: 6 | ```bash 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | Then you can run the following to train an MLP with 3 layers (one of which is a binary layer), 11 | or add `--help` for more arguments: 12 | ```bash 13 | python train_mnist.py --epochs 10 --model q_mlp --log-interval 100 14 | ``` 15 | -------------------------------------------------------------------------------- /examples/mnist/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package contains an example for training an image classification model on the MNIST data set with BITorch 3 | and deploying it with the inference engine. 4 | """ 5 | -------------------------------------------------------------------------------- /examples/mnist/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This submodule contains data preparation code for some of the datasets used with our models, 3 | i.e. MNIST, CIFAR 10 and 100 and ImageNet. 4 | """ 5 | 6 | from typing import List, Type 7 | 8 | from .base import BasicDataset 9 | from .mnist import MNIST 10 | 11 | __all__ = [ 12 | "BasicDataset", 13 | "dataset_from_name", 14 | "dataset_names", 15 | "MNIST", 16 | ] 17 | 18 | 19 | def dataset_from_name(name: str) -> Type[BasicDataset]: 20 | """returns the dataset to which the name belongs to (name has to be the value of the datasets 21 | name-attribute) 22 | 23 | Args: 24 | name (str): name of the dataset 25 | 26 | Raises: 27 | ValueError: raised if no dataset under that name was found 28 | 29 | Returns: 30 | dataset: the dataset 31 | """ 32 | for dataset_class in [MNIST]: 33 | if dataset_class.name == name: 34 | return dataset_class 35 | raise Exception(f"unknown dataset: {name}") 36 | 37 | 38 | def dataset_names() -> List[str]: 39 | """getter for list of dataset names for argparse 40 | 41 | Returns: 42 | List: the dataset names 43 | """ 44 | return [dataset_class.name for dataset_class in [MNIST]] 45 | -------------------------------------------------------------------------------- /examples/mnist/datasets/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | from typing import Optional, Tuple, Any 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision.transforms import transforms 9 | 10 | from .dummy_dataset import DummyDataset 11 | 12 | 13 | class BasicDataset(Dataset): 14 | name = "None" 15 | num_classes = 0 16 | shape = (0, 0, 0, 0) 17 | mean: Any = None 18 | std_dev: Any = None 19 | num_train_samples = 0 20 | num_val_samples = 0 21 | 22 | def __init__(self, train: bool, root_directory: Optional[str] = None, download: bool = False) -> None: 23 | """initializes the dataset. 24 | 25 | Args: 26 | train (bool): whether the train or test dataset is wanted 27 | root_directory (str): path to main dataset storage directory 28 | download (bool): whether train/test should be downloaded if it does not exist 29 | 30 | Returns: 31 | Dataset: the created test/train dataset 32 | """ 33 | super(BasicDataset, self).__init__() 34 | self.is_train = train 35 | self._download = download 36 | self.root_directory = self.get_dataset_root_directory(root_directory) 37 | self.dataset = self.get_dataset(download) 38 | 39 | @classmethod 40 | def get_train_and_test(cls, root_directory: str, download: bool = False) -> Tuple["BasicDataset", "BasicDataset"]: 41 | """creates a pair of train and test dataset. 42 | 43 | Returns: 44 | Tuple: the train and test dataset 45 | """ 46 | return cls(True, root_directory, download), cls(False, root_directory, download) 47 | 48 | @classmethod 49 | def get_dummy_train_and_test_datasets(cls) -> Tuple[DummyDataset, DummyDataset]: 50 | train_set = DummyDataset(cls.shape, cls.num_classes, cls.num_train_samples) # type: ignore 51 | val_set = DummyDataset(cls.shape, cls.num_classes, cls.num_val_samples) # type: ignore 52 | return train_set, val_set 53 | 54 | def get_dataset_root_directory(self, root_directory_argument: Optional[str]) -> Path: 55 | """chooses the dataset root directory based on the passed argument or environment variables. 56 | 57 | Returns: 58 | Tuple: the train and test dataset 59 | """ 60 | if root_directory_argument is not None: 61 | return Path(root_directory_argument) 62 | 63 | environment_variable_name = f"{self.name.upper()}_HOME" 64 | if os.environ.get(environment_variable_name) is not None: 65 | return Path(os.environ.get(environment_variable_name)) # type: ignore 66 | if os.environ.get("BITORCH_DATA_HOME") is not None: 67 | return Path(os.environ.get("BITORCH_DATA_HOME")) / self.name # type: ignore 68 | 69 | environment_variable_hint = ( 70 | f" To change this, set '{environment_variable_name}' or 'BITORCH_DATA_HOME' " 71 | f"(in the latter case, the data resides in the folder '{self.name}' in BITORCH_DATA_HOME)." 72 | f" Some datasets can be downloaded by adding the --download command line argument." 73 | ) 74 | if self._download: 75 | logging.warning("Dataset is being downloaded to the directory './data'." + environment_variable_hint) 76 | return Path("./data") 77 | else: 78 | raise ValueError(f"Dataset {self.name} not found." + environment_variable_hint) 79 | 80 | def get_dataset(self, download: bool) -> Dataset: 81 | """creates the actual dataset 82 | 83 | Args: 84 | download (bool): toggles if train/test shall be downloaded if possible 85 | 86 | Raises: 87 | NotImplementedError: thrown, because this method needs to be overwritten by subclasses 88 | 89 | Returns: 90 | Dataset: the created test/train dataset 91 | """ 92 | raise NotImplementedError() 93 | 94 | def get_transform(self) -> Any: 95 | if self.is_train: 96 | return self.train_transform() 97 | return self.test_transform() 98 | 99 | @classmethod 100 | def test_transform(cls) -> Any: 101 | """get the transform for the test data. 102 | 103 | Returns: 104 | transform: the transform pipeline 105 | """ 106 | return transforms.Compose([transforms.ToTensor(), cls.get_normalize_transform()]) 107 | 108 | @classmethod 109 | def train_transform(cls) -> Any: 110 | """get the transform for the training data. 111 | 112 | Returns: 113 | transform: the transform pipeline 114 | """ 115 | return transforms.Compose([transforms.ToTensor(), cls.get_normalize_transform()]) 116 | 117 | @classmethod 118 | def get_normalize_transform(cls) -> transforms.Normalize: 119 | return transforms.Normalize(cls.mean, cls.std_dev) 120 | 121 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: # type: ignore 122 | """returns the item at the given index of the dataset. 123 | 124 | Args: 125 | index (int): requested index 126 | 127 | Returns: 128 | Tuple[torch.Tensor, torch.Tensor]: data and label at the specified index 129 | """ 130 | return self.dataset[index] 131 | 132 | def __len__(self) -> int: 133 | return len(self.dataset) # type: ignore 134 | 135 | def num_samples(self) -> int: 136 | """returns the (theoretical) dataset size.""" 137 | return self.num_train_samples if self.is_train else self.num_val_samples 138 | -------------------------------------------------------------------------------- /examples/mnist/datasets/dummy_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | from typing import Tuple 4 | 5 | 6 | class DummyDataset(Dataset): 7 | """An iterator that produces repeated dummy data. 8 | Args: 9 | data_sample: a data sample that should be produced at each step. 10 | batch_size: the batch size for storing. 11 | sample_count: number of `data` samples in the dummy dataset. 12 | """ 13 | 14 | def __init__(self, data_shape: torch.Size, num_classes: int, sample_count: int) -> None: 15 | self._data_sample = torch.zeros(data_shape) 16 | self._class_sample = torch.zeros((num_classes,), dtype=torch.int64) 17 | self._sample_count = sample_count 18 | 19 | def __len__(self) -> int: 20 | return self._sample_count 21 | 22 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 23 | return self._data_sample, self._class_sample 24 | -------------------------------------------------------------------------------- /examples/mnist/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision.datasets import mnist 3 | 4 | from .base import BasicDataset 5 | 6 | 7 | class MNIST(BasicDataset): 8 | name = "mnist" 9 | num_classes = 10 10 | shape = (1, 1, 28, 28) 11 | 12 | mean = (0.1307,) 13 | std_dev = (0.3081,) 14 | num_train_samples = 60000 15 | num_val_samples = 10000 16 | 17 | def get_dataset(self, download: bool = True) -> Dataset: 18 | return mnist.MNIST( 19 | root=self.root_directory, 20 | train=self.is_train, 21 | transform=self.get_transform(), 22 | download=download, 23 | ) 24 | -------------------------------------------------------------------------------- /examples/mnist/requirements.txt: -------------------------------------------------------------------------------- 1 | bitorch 2 | -------------------------------------------------------------------------------- /licenses/LICENSE.cutlass.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | SPDX-License-Identifier: BSD-3-Clause 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /licenses/LICENSE.exllamav2.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE.mlx.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2023 Apple Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /licenses/LICENSE.pytorch.txt: -------------------------------------------------------------------------------- 1 | From PyTorch: 2 | 3 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 4 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 5 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 6 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 7 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 8 | Copyright (c) 2011-2013 NYU (Clement Farabet) 9 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 10 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 11 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 12 | 13 | From Caffe2: 14 | 15 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 16 | 17 | All contributions by Facebook: 18 | Copyright (c) 2016 Facebook Inc. 19 | 20 | All contributions by Google: 21 | Copyright (c) 2015 Google Inc. 22 | All rights reserved. 23 | 24 | All contributions by Yangqing Jia: 25 | Copyright (c) 2015 Yangqing Jia 26 | All rights reserved. 27 | 28 | All contributions by Kakao Brain: 29 | Copyright 2019-2020 Kakao Brain 30 | 31 | All contributions by Cruise LLC: 32 | Copyright (c) 2022 Cruise LLC. 33 | All rights reserved. 34 | 35 | All contributions from Caffe: 36 | Copyright(c) 2013, 2014, 2015, the respective contributors 37 | All rights reserved. 38 | 39 | All other contributions: 40 | Copyright(c) 2015, 2016 the respective contributors 41 | All rights reserved. 42 | 43 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 44 | copyright over their contributions to Caffe2. The project versioning records 45 | all such contribution and copyright details. If a contributor wants to further 46 | mark their specific copyright on a particular contribution, they should 47 | indicate their copyright solely in the commit message of the change when it is 48 | committed. 49 | 50 | All rights reserved. 51 | 52 | Redistribution and use in source and binary forms, with or without 53 | modification, are permitted provided that the following conditions are met: 54 | 55 | 1. Redistributions of source code must retain the above copyright 56 | notice, this list of conditions and the following disclaimer. 57 | 58 | 2. Redistributions in binary form must reproduce the above copyright 59 | notice, this list of conditions and the following disclaimer in the 60 | documentation and/or other materials provided with the distribution. 61 | 62 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 63 | and IDIAP Research Institute nor the names of its contributors may be 64 | used to endorse or promote products derived from this software without 65 | specific prior written permission. 66 | 67 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 68 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 69 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 70 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 71 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 72 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 73 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 74 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 75 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 76 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 77 | POSSIBILITY OF SUCH DAMAGE. 78 | -------------------------------------------------------------------------------- /licenses/LICENSE.tcbnn.txt: -------------------------------------------------------------------------------- 1 | TCBNN: 2 | Accelerating Binarized Neural Networks via Bit-Tensor-Cores in Turing GPUs 3 | 4 | 06/30/2020 by Ang Li from High-Performance-Computing Group, 5 | ACMD, PCSD, Pacific Northwest National Laboratory (PNNL), 6 | Richland, WA, 99354, USA. 7 | 8 | 9 | Copyright © 2020, Battelle Memorial Institute 10 | 11 | 1.Battelle Memorial Institute (hereinafter Battelle) hereby grants permission 12 | to any person or entity lawfully obtaining a copy of this software and associated 13 | documentation files (hereinafter “the Software”) to redistribute and use the 14 | Software in source and binary forms, with or without modification. Such person 15 | or entity may use, copy, modify, merge, publish, distribute, sublicense, and/or 16 | sell copies of the Software, and may permit others to do so, subject to the 17 | following conditions: 18 | 19 | - Redistributions of source code must retain the above copyright notice, this list 20 | of conditions and the following disclaimers. 21 | 22 | - Redistributions in binary form must reproduce the above copyright notice, this list 23 | of conditions and the following disclaimer in the documentation and/or other materials 24 | provided with the distribution. 25 | 26 | - Other than as used herein, neither the name Battelle Memorial Institute or Battelle 27 | may be used in any form whatsoever without the express written consent of Battelle. 28 | 29 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY 30 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 31 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT 32 | SHALL BATTELLE OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 33 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 35 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 36 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 37 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 38 | 39 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | numpy 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bitorch 2 | torch 3 | py-cpuinfo>=9.0.0 4 | setuptools~=69.0 5 | -------------------------------------------------------------------------------- /structure.md: -------------------------------------------------------------------------------- 1 | 2 | . 3 | ├── ... 4 | ├── layers # Low-bit layers 5 | │ ├── qconv # quantized Convolotional Layer 6 | │ ├── qembedding # currently only 1-bit embedding layer supported 7 | │ ├── ... 8 | │ └── qlinear 9 | │ ├── binary (1-bit) 10 | │ │ ├── cpp # x86 CPU 11 | │ │ ├── cuda # Nvidia GPU 12 | │ │ └── cutlass # Nvidia GPU 13 | │ └── n-bit (2/4/8-bit) 14 | │ ├── mps # Apple GPU 15 | │ ├── cuda # Nvidia GPU, e.g., weight-only quantized LLMs 16 | │ └── cutlass # Nvidia GPU, e.g., quantization aware training for both activation and weight 17 | └── optim 18 | │ └── DiodeMix # dedicated optimizer for low-bit quantized model 19 | ├── functions 20 | └── ... -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/tests/__init__.py -------------------------------------------------------------------------------- /tests/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/tests/functions/__init__.py -------------------------------------------------------------------------------- /tests/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/tests/layers/__init__.py -------------------------------------------------------------------------------- /tests/layers/test_custom_binary_linear.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Function 5 | from torch.nn import Parameter 6 | 7 | 8 | class LinearFunction(Function): 9 | @staticmethod 10 | # ctx is the first argument to forward 11 | def forward(ctx, input, weight, bias=None): 12 | # The forward pass can use ctx. 13 | ctx.save_for_backward(input, weight, bias) 14 | # TODO: instead of converting to float and use torch's mm we should use binary mm instead 15 | output = input.float().mm(weight.t().float()) 16 | if bias is not None: 17 | output += bias.unsqueeze(0).expand_as(output) 18 | return output 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | input, weight, bias = ctx.saved_tensors 23 | grad_input = grad_weight = grad_bias = None 24 | 25 | if ctx.needs_input_grad[0]: 26 | # TODO: instead of converting to float and use torch's mm we should use our binary mm instead 27 | grad_input = grad_output.mm(weight.float()) 28 | if ctx.needs_input_grad[1]: 29 | # TODO: instead of converting to float and use torch's mm we should use our binary mm instead 30 | grad_weight = grad_output.t().mm(input.float()) 31 | if bias is not None and ctx.needs_input_grad[2]: 32 | grad_bias = grad_output.sum(0) 33 | 34 | print("Weight grads calculated: ", grad_weight) 35 | 36 | # manually need to convert to our required formulation 37 | def binarize_grad(x): 38 | # 1.0 = True 39 | # 0.0 = False 40 | return torch.where(x >= 0.0, 1.0, 0.0) 41 | 42 | # TODO: we could inline this later 43 | grad_input = binarize_grad(grad_input) 44 | grad_weight = binarize_grad(grad_weight) 45 | 46 | return grad_input, grad_weight, grad_bias 47 | 48 | 49 | # Option 2: wrap in a function, to support default args and keyword args. 50 | def linear(input, weight, bias=None): 51 | return LinearFunction.apply(input, weight, bias) 52 | 53 | 54 | class TLinear(nn.Module): 55 | def __init__(self, input_features, output_features, bias=True): 56 | super().__init__() 57 | self.input_features = input_features 58 | self.output_features = output_features 59 | 60 | # nn.Parameter is a special kind of Tensor, that will get 61 | # automatically registered as Module's parameter once it's assigned 62 | # as an attribute. Parameters and buffers need to be registered, or 63 | # they won't appear in .parameters() (doesn't apply to buffers), and 64 | # won't be converted when e.g. .cuda() is called. You can use 65 | # .register_buffer() to register buffers. 66 | # nn.Parameters require gradients by default. 67 | self.weight = nn.Parameter(torch.empty(output_features, input_features)) 68 | if bias: 69 | self.bias = nn.Parameter(torch.empty(output_features)) 70 | else: 71 | # You should always register all possible parameters, but the 72 | # optional ones can be None if you want. 73 | self.register_parameter('bias', None) 74 | 75 | # Not a very smart way to initialize weights 76 | nn.init.uniform_(self.weight, -0.1, 0.1) 77 | if self.bias is not None: 78 | nn.init.uniform_(self.bias, -0.1, 0.1) 79 | 80 | def forward(self, input): 81 | # See the autograd section for explanation of what happens here. 82 | return LinearFunction.apply(input, self.weight, self.bias) 83 | 84 | def extra_repr(self): 85 | # (Optional)Set the extra information about this module. You can test 86 | # it by printing an object of this class. 87 | return 'input_features={}, output_features={}, bias={}'.format( 88 | self.input_features, self.output_features, self.bias is not None 89 | ) 90 | 91 | 92 | def test_q_linear_with_binary_weights(): 93 | torch.manual_seed(42) 94 | # import pydevd_pycharm 95 | # pydevd_pycharm.settrace('localhost', port=11004, stdoutToServer=True, stderrToServer=True) 96 | 97 | batch_size = 10 98 | num_inputs = 32 99 | num_outputs = 64 100 | 101 | # option 1: set dtype 102 | # currently not possible, check uniform bounds fails: 103 | # layer = TstLinear(num_inputs, num_outputs, bias=False, dtype=torch.bool) 104 | 105 | # option 2: manually replace weight: 106 | # currently possible 107 | # but we have to 108 | layer = TLinear(num_inputs, num_outputs, bias=False) 109 | layer.weight = Parameter(torch.rand((num_outputs, num_inputs)) > 0.5, requires_grad=True) 110 | 111 | input = Parameter(torch.rand((batch_size, num_inputs)) > 0.5, requires_grad=True) 112 | 113 | result = layer(input) 114 | print(result) 115 | print("Result shape: ", result.size()) 116 | 117 | print("Grad before backward: ", layer.weight.grad) 118 | mse_loss = nn.MSELoss() 119 | dummy_loss = mse_loss(result, torch.ones_like(result) * 10) 120 | dummy_loss.backward() 121 | print("Grad after backward: ", layer.weight.grad) 122 | # technically it works, but we only get boolean gradients (all true currently) 123 | # could be fixed with custom backward pass? 124 | -------------------------------------------------------------------------------- /tests/layers/test_nbit_conv.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | import time 5 | from bitorch.layers import QConv2d 6 | from tests.layers.util import get_cuda_test_device 7 | from bitorch_engine.layers.qconv.nbit.cutlass import Q4Conv2dCutlass 8 | 9 | """ 10 | Test nbit inference layers 11 | """ 12 | 13 | # lower print threshold 14 | torch.set_printoptions(threshold=100) 15 | 16 | def to_device(data: torch.Tensor, device: torch.device) -> torch.Tensor: 17 | if isinstance(data, (list, tuple)): 18 | return [to_device(x, device) for x in data] 19 | return data.to(device) 20 | 21 | TEST_BATCH_SIZE = [1, 32, 64, 128] 22 | # Input shape: (batch size, num of input channels, h, w) 23 | TEST_INPUT_DATA = [ 24 | ((64, 64, 64), [64, 32], 25 | {"kernel_size": 3, "padding": 2, "stride": 2, "dilation": 1}), 26 | ((64, 56, 56), [64, 256], 27 | {"kernel_size": 1, "padding": 0, "stride": 1, "dilation": 1}), 28 | ((64, 56, 56), [64, 64], 29 | {"kernel_size": 3, "padding": 0, "stride": 1, "dilation": 1}), 30 | ((256, 56, 56), [256, 64], 31 | {"kernel_size": 1, "padding": 0, "stride": 1, "dilation": 1}), 32 | ((256, 56, 56), [256, 128], 33 | {"kernel_size": 1, "padding": 0, "stride": 1, "dilation": 1}), 34 | ((128, 28, 28), [128, 128], 35 | {"kernel_size": 3, "padding": 0, "stride": 1, "dilation": 1}), 36 | ((128, 28, 28), [128, 512], 37 | {"kernel_size": 1, "padding": 0, "stride": 1, "dilation": 1}), 38 | ((256, 14, 14), [256, 256], 39 | {"kernel_size": 3, "padding": 2, "stride": 1, "dilation": 1}), 40 | ((512, 7, 7), [512, 512], 41 | {"kernel_size": 3, "padding": 0, "stride": 1, "dilation": 1}), 42 | ] 43 | 44 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available, skipping CUDA-related test") 45 | @pytest.mark.parametrize("input_shape, args, kwargs", TEST_INPUT_DATA) 46 | @pytest.mark.parametrize("BS", TEST_BATCH_SIZE) 47 | def test_q4_conv_cuda(input_shape, args, kwargs, BS): 48 | print("\n") 49 | print("batch_size: {}, input_shape:{}, input/output nums:{}, other args:{}" 50 | .format(BS, input_shape, args, kwargs)) 51 | input_shape = (BS, input_shape[0], input_shape[1], input_shape[2]) 52 | input = np.random.uniform(-1, 1, input_shape) 53 | input_tensor = torch.tensor(input).float() 54 | 55 | num_runs = 10 56 | 57 | # to gpu 58 | device = get_cuda_test_device() 59 | torch.cuda.set_device(device) 60 | input_tensor_cuda = to_device(input_tensor, device) 61 | 62 | layer = QConv2d(*args, **kwargs) 63 | layer.to(device) 64 | start_time = time.time() 65 | for i in range(num_runs): 66 | result_bitorch = layer(input_tensor_cuda) 67 | time_engine = time.time() - start_time 68 | print("bitorch binary-conv: %.6f s" % (time_engine / num_runs)) 69 | 70 | padding = kwargs["padding"] 71 | kernel_size = kwargs["kernel_size"] 72 | stride = kwargs["stride"] 73 | dilation = kwargs["dilation"] 74 | nbit_conv_layer = Q4Conv2dCutlass(in_channels=int(args[0]), 75 | out_channels=args[1], 76 | kernel_size=kernel_size, 77 | stride=stride, 78 | padding=padding, 79 | dilation=dilation, 80 | device=device) 81 | nbit_conv_layer.to(device) 82 | nbit_conv_layer.prepare_params() 83 | result = nbit_conv_layer(input_tensor_cuda) 84 | # print(result) 85 | 86 | grad_input_data = np.random.uniform(-1, 1, result.shape) 87 | grad_input_tensor = torch.tensor(grad_input_data).float() 88 | grad_input_tensor_cuda = to_device(grad_input_tensor, device) 89 | 90 | # use quantized weight for inference 91 | nbit_conv_layer.generate_quantized_weight(qweight_only=True) 92 | nbit_conv_layer.eval() 93 | start_time = time.time() 94 | for i in range(num_runs): 95 | result_quantized_w = nbit_conv_layer(input_tensor_cuda) 96 | time_engine = time.time() - start_time 97 | print("engine q4-conv forward (CUTLASS): %.6f s" % (time_engine / num_runs)) 98 | # print(result_quantized_w) 99 | assert torch.equal(result, result_quantized_w) 100 | 101 | start_time = time.time() 102 | for i in range(num_runs): 103 | result.backward(grad_input_tensor_cuda, retain_graph=True) 104 | torch.cuda.synchronize() 105 | time_engine = time.time() - start_time 106 | print("bitorch-engine q4-conv backward (CUTLASS): %.6f s" % (time_engine/num_runs)) 107 | -------------------------------------------------------------------------------- /tests/layers/test_nbit_linear_mixbits.py: -------------------------------------------------------------------------------- 1 | import time, math 2 | import pytest 3 | import torch 4 | import json 5 | 6 | from bitorch_engine.layers.qlinear.nbit.cuda import MBWQLinearCuda 7 | from tests.layers.util import get_cuda_test_device, get_packed_info, get_q_groups 8 | 9 | 10 | """ 11 | Test nbit inference layers 12 | """ 13 | 14 | # lower print threshold 15 | torch.set_printoptions(threshold=100) 16 | 17 | 18 | def to_device(data: torch.Tensor, device: torch.device) -> torch.Tensor: 19 | if isinstance(data, (list, tuple)): 20 | return [to_device(x, device) for x in data] 21 | return data.to(device) 22 | 23 | 24 | # ========= testing data ========== # 25 | 26 | test_json_string = '{' \ 27 | '"q_proj": { "group_size": { "4": 32, "2": 32 }, "bits": [ 4, 2 ], "bits_prop": [ 0.75, 0.25 ], "scale_bits": 4 }, ' \ 28 | '"k_proj": { "group_size": { "4": 32, "2": 32 }, "bits": [ 4, 2 ], "bits_prop": [ 0.25, 0.75 ], "scale_bits": 4 } ' \ 29 | '}' 30 | 31 | 32 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available, skipping CUDA-related test") 33 | @pytest.mark.parametrize("num_input_features", [128, 4096]) 34 | @pytest.mark.parametrize("num_hidden_fc", [128, 4096]) 35 | @pytest.mark.parametrize("batch_size", [1, 2]) 36 | def test_mbwq_linear_exl2_cuda(num_input_features, num_hidden_fc, batch_size): 37 | 38 | # support 4-bit und bf16 only! 39 | dtype = torch.half 40 | num_runs = 1 41 | device = get_cuda_test_device() 42 | torch.cuda.set_device(device) 43 | 44 | # creating test data 45 | input_data = torch.normal(0, 1, size=(batch_size, num_input_features), requires_grad=False, dtype=dtype) 46 | # to gpu 47 | input_data_cuda = to_device(input_data, device) 48 | 49 | # Parsing the JSON string into a Python dictionary 50 | gbe_strategy = json.loads(test_json_string) 51 | 52 | for key, value in gbe_strategy.items(): 53 | # read attribute information from predefined json config 54 | groups, rows = get_packed_info(num_input_features, value["bits"], value["bits_prop"], value["group_size"]) 55 | 56 | print("\nM:{}, N:{}, K:{}, bits:{}, group_size:{}, bits_prop:{}, groups:{}, packed_rows:{}." 57 | .format(batch_size, num_hidden_fc, num_input_features, str(value["bits"]), str(value["group_size"]), str(value["bits_prop"]), 58 | groups, rows)) 59 | 60 | # creating int weights 61 | random_int4_tensor = torch.randint(0, 2 ** 16 - 1, size=(rows, num_hidden_fc)) 62 | int_weight_cuda = to_device(random_int4_tensor, device) 63 | 64 | mbwq_linear_layer = MBWQLinearCuda(in_channels=num_input_features, out_channels=num_hidden_fc, 65 | w_bit=4, dtype=dtype, group_size=32, 66 | dq_group_size=1, use_gba_quant=True, 67 | asym=False, dq_mode=2, use_mbw=True, 68 | groups=groups, rows_packed=rows) 69 | 70 | mbwq_linear_layer.set_qweight_data(int_weight_cuda) 71 | # random scales and zeros 72 | scales = torch.randn_like(mbwq_linear_layer.scales).half() 73 | zeros = torch.randn_like(mbwq_linear_layer.zeros).half() 74 | mbwq_linear_layer.set_scales(scales) 75 | mbwq_linear_layer.set_zeros(zeros) 76 | mbwq_linear_layer.q_perm = torch.tensor([i for i in range(num_input_features)], dtype=torch.short) 77 | 78 | # get q_groups 79 | q_groups = get_q_groups(groups, value["bits"], value["group_size"], num_input_features, value["bits_prop"]) 80 | 81 | assert mbwq_linear_layer.q_groups.numel() == len(q_groups) 82 | 83 | mbwq_linear_layer.q_groups = torch.Tensor(q_groups).to(torch.short) 84 | 85 | # to device 86 | mbwq_linear_layer.to(device) 87 | 88 | # will perform qweight layout transformation 89 | mbwq_linear_layer.prepare_params() 90 | 91 | # Testing fp weight reconstruction 92 | reconstructed_fp_weights = MBWQLinearCuda.exl2fp_weight(mbwq_linear_layer.qweight, mbwq_linear_layer.scales, 93 | mbwq_linear_layer.zeros, mbwq_linear_layer.q_perm, 94 | mbwq_linear_layer.q_group_map, mbwq_linear_layer.rows).to(dtype) 95 | 96 | # pytorch result: 97 | result_pt = torch.matmul(input_data_cuda.mul(mbwq_linear_layer.channel_scale), reconstructed_fp_weights) 98 | 99 | # Testing inference output 100 | start_time = time.time() 101 | for i in range(num_runs): 102 | result = mbwq_linear_layer(input_data_cuda) 103 | torch.cuda.synchronize() 104 | time_engine = time.time() - start_time 105 | 106 | print("bitorch-engine mbwq_linear (CUDA) run time: %.6f s" % (time_engine/num_runs)) 107 | 108 | assert torch.all(torch.isclose(result, result_pt, rtol=2, atol=2, equal_nan=False)) -------------------------------------------------------------------------------- /tests/layers/test_nbit_linear_mps.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pytest 4 | import torch 5 | import numpy as np 6 | 7 | from bitorch_engine.layers.qlinear.nbit.mps import MPQLinearMlx 8 | from tests.layers.util import get_mps_test_device 9 | 10 | """ 11 | Test nbit inference layers 12 | """ 13 | 14 | # lower print threshold 15 | torch.set_printoptions(threshold=100) 16 | 17 | 18 | def to_device(data: torch.Tensor, device: torch.device) -> torch.Tensor: 19 | if isinstance(data, (list, tuple)): 20 | return [to_device(x, device) for x in data] 21 | return data.to(device) 22 | 23 | 24 | @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available, skipping MPS-related test") 25 | @pytest.mark.parametrize("w_bit", [2, 4, 8]) 26 | @pytest.mark.parametrize("dtype", [torch.half]) 27 | @pytest.mark.parametrize("num_input_features", [1024, 8192]) 28 | @pytest.mark.parametrize("num_hidden_fc", [1024, 8192]) 29 | @pytest.mark.parametrize("batch_size", [64]) 30 | @pytest.mark.parametrize("group_size", [32, 128]) 31 | @pytest.mark.parametrize("llama_v", [2]) 32 | @pytest.mark.parametrize("model_size", ["1.1b", "7b", "30b", "70b"]) 33 | def test_mpq_linear_mps(w_bit, dtype, num_input_features, num_hidden_fc, batch_size, group_size, llama_v, model_size): 34 | import mlx.core 35 | 36 | if ((w_bit == 1 or w_bit == 8) and group_size == 32) \ 37 | or (model_size in ["30b", "70b"] and w_bit != 2) \ 38 | or (model_size in ["1.1b"] and w_bit != 2): 39 | pytest.skip() 40 | 41 | # === following configuration from low_bit_llama https://github.com/GreenBitAI/low_bit_llama === # 42 | double_groupsize = -1 43 | if group_size == 32 and model_size not in ["1.1b", "1.1B"]: 44 | asym = True 45 | else: 46 | asym = False 47 | 48 | if w_bit == 2: 49 | if asym: 50 | double_groupsize = -1 51 | else: 52 | if group_size == 32: 53 | double_groupsize = 32 54 | else: 55 | if llama_v == 1 and model_size not in ["30b", "30B"]: 56 | double_groupsize = 64 57 | else: 58 | double_groupsize = 32 59 | else: 60 | if model_size in ["3b", "3B"]: 61 | double_groupsize = 64 62 | elif model_size in ["7b", "7B"]: 63 | double_groupsize = 256 64 | 65 | v1 = (llama_v == 1) and model_size in ["7b", "7B"] 66 | 67 | if w_bit not in [2, 4, 8]: 68 | print("w_bit not supported") 69 | pytest.skip() 70 | if group_size not in [32, 64, 128]: 71 | print("group_size not supported") 72 | pytest.skip() 73 | if asym: 74 | print("asym not supported") 75 | pytest.skip() 76 | # =============================================================================================== # 77 | 78 | print("\nM:{}, N:{}, K:{}, bits:{}, dtype:{}, groupsize:{}, llama_v:{}, model_size:{}." 79 | .format(batch_size, num_hidden_fc, num_input_features, w_bit, dtype, group_size, llama_v, model_size)) 80 | 81 | num_runs = 10 82 | # creating test data 83 | input_data = torch.normal(0, 1, size=(batch_size, num_input_features), requires_grad=False, dtype=dtype) 84 | grad_input_data = torch.normal(0, 1, size=(batch_size, num_hidden_fc), requires_grad=False, dtype=dtype) 85 | 86 | # to gpu 87 | device = get_mps_test_device() 88 | input_data_mps = to_device(input_data, device) 89 | grad_input_data_mps = to_device(grad_input_data, device) 90 | 91 | time_engine = 0 92 | for i in range(num_runs): 93 | b_linear = torch.nn.Linear(num_input_features, num_hidden_fc, bias=False, dtype=dtype) 94 | b_linear.to(device) 95 | start_time = time.time() 96 | result = b_linear(input_data_mps) 97 | time_engine += time.time() - start_time 98 | print(f"pytorch linear forward run time (device {device}): {time_engine / num_runs:.6f} s") 99 | 100 | mpq_linear_layer = MPQLinearMlx(in_channels=num_input_features, 101 | out_channels=num_hidden_fc, 102 | w_bit=w_bit, 103 | dtype=dtype, 104 | group_size=group_size, 105 | dq_group_size=double_groupsize, 106 | dq_mode=1 if v1 else 2, 107 | use_gba_quant=True, 108 | asym=asym, 109 | requires_grad=False) 110 | 111 | mpq_linear_layer.to(device) 112 | mpq_linear_layer.prepare_params() 113 | 114 | time_engine = 0 115 | time_mlx = 0 116 | for i in range(num_runs): 117 | random_int_tensor = torch.randint(1, 1000, size=mpq_linear_layer.qweight.shape, 118 | dtype=mpq_linear_layer.qweight.dtype, device=mpq_linear_layer.qweight.device) 119 | mpq_linear_layer.set_qweight_data(random_int_tensor) 120 | start_time = time.time() 121 | result1 = mpq_linear_layer(input_data_mps) 122 | time_engine += time.time() - start_time 123 | 124 | qweight_mlx = mlx.core.array(mpq_linear_layer.qweight.numpy().astype(np.uint32)) 125 | scales_mlx = mlx.core.array(mpq_linear_layer.scales.numpy()) 126 | zeros_mlx = mlx.core.array(mpq_linear_layer.zeros.numpy()) 127 | 128 | start_time = time.time() 129 | input_mlx = mlx.core.array(input_data_mps.cpu().numpy()) 130 | mlx_matmul = mlx.core.quantized_matmul( 131 | input_mlx, 132 | qweight_mlx, 133 | scales=scales_mlx, 134 | biases=zeros_mlx, 135 | transpose=True, 136 | group_size=group_size, 137 | bits=w_bit) 138 | mlx_matmul = torch.from_numpy(np.array(mlx_matmul)).to('mps') 139 | time_mlx += time.time() - start_time 140 | assert mlx_matmul.equal(result1), "mps matmul failed" 141 | print(f"bitorch-engine mpq linear forward (MPS) run time: {time_engine / num_runs:.6f} s") 142 | print(f"Mlx (MPS) run time: {time_mlx / num_runs:.6f} s") -------------------------------------------------------------------------------- /tests/layers/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from typing import Union, List 4 | 5 | import torch 6 | 7 | 8 | def activate_remote_pycharm_debug(port: int = 11004): 9 | import pydevd_pycharm 10 | pydevd_pycharm.settrace('localhost', port=port, stdoutToServer=True, stderrToServer=True) 11 | 12 | 13 | def to_device(data: torch.Tensor, device: torch.device) -> Union[torch.Tensor, List[torch.Tensor]]: 14 | if isinstance(data, (list, tuple)): 15 | return [to_device(x, device) for x in data] 16 | return data.to(device) 17 | 18 | 19 | def get_cuda_test_device_id(): 20 | return int(os.environ.get("BIE_DEVICE", "0")) 21 | 22 | 23 | def get_cuda_test_device(): 24 | return torch.device(f"cuda:{get_cuda_test_device_id()}") 25 | 26 | 27 | def get_mps_test_device(): 28 | return torch.device("mps") 29 | 30 | 31 | def get_packed_info(channels, n_bits, bits_prop, bits_group_size): 32 | groups = 0 33 | rows = 0 34 | bits_channel = [] 35 | for idx in range(len(bits_prop)): 36 | if idx < len(bits_prop) - 1: 37 | minimal_channels = list(bits_group_size.values())[idx] 38 | channel_pre_pack = max(1, int(channels * (bits_prop[idx])) // minimal_channels) * minimal_channels 39 | bits_channel.append(channel_pre_pack) 40 | groups += channel_pre_pack // minimal_channels 41 | rows += channel_pre_pack // 32 * n_bits[idx] 42 | else: 43 | minimal_channels = list(bits_group_size.values())[idx] 44 | channel_pre_pack = channels - sum(bits_channel) 45 | bits_channel.append(channel_pre_pack) 46 | groups += channel_pre_pack // minimal_channels 47 | rows += channel_pre_pack // 32 * n_bits[idx] 48 | 49 | return groups, rows 50 | 51 | 52 | def get_q_groups(groups, n_bits, group_size, channels, bits_prop): 53 | qgroups = [] 54 | bits_column_end_index = [] 55 | 56 | for idx in range(len(bits_prop)): 57 | if idx < len(bits_prop) - 1: 58 | minimal_columns = list(group_size.values())[idx] 59 | columns_index = max(1, int(channels * ( 60 | bits_prop[idx])) // minimal_columns) * minimal_columns # TODO: determine the minimal bits columns 61 | if idx > 0: 62 | columns_index += bits_column_end_index[-1] 63 | bits_column_end_index.append(columns_index) 64 | else: 65 | bits_column_end_index.append(channels) 66 | 67 | for bits_idx, bits in enumerate(n_bits): 68 | if bits_idx == 0: 69 | rows_per_bit = bits_column_end_index[bits_idx] 70 | else: 71 | rows_per_bit = bits_column_end_index[bits_idx] - bits_column_end_index[bits_idx - 1] 72 | 73 | gs = group_size[str(bits)] 74 | groups_per_bit = rows_per_bit // gs 75 | 76 | for group in range(groups_per_bit): 77 | qgroups.append(bits) # record bits per group 78 | qgroups.append(0) 79 | 80 | out_row = 0 81 | rem_rows = channels 82 | for i in range(groups): 83 | bits = qgroups[2 * i] 84 | gs = group_size[str(bits)] 85 | 86 | rows_per_group = min(gs, rem_rows) # rows per group before packing 87 | wpqr = 32 / bits # INT32 elements per group for packing 88 | qrows = math.ceil(rows_per_group / wpqr) # rows per group after packing 89 | qgroups[2 * i + 1] = out_row # record packed rows start idx per group 90 | 91 | out_row += qrows 92 | 93 | return qgroups 94 | 95 | 96 | def pack_rows_4_pytorch(input_tensor, rows, columns): 97 | # Calculate the number of output columns and store 4 columns per 32 bits, so use columns * 4 / 32 98 | out_columns = columns * 4 // 32 99 | 100 | # Initialize the output tensor, the size is [rows, out_columns], the data type is uint32 101 | output_tensor = torch.zeros((rows, out_columns), dtype=torch.int64, device=input_tensor.device) 102 | 103 | # To simulate the behavior of the CUDA kernel, we need to perform the same operation for each output element 104 | # This is implemented using the broadcast and indexing functions of PyTorch 105 | for row in range(rows): 106 | for out_column in range(out_columns): 107 | packed = 0 108 | for i in range(8): 109 | # Simulate operations in the CUDA kernel 110 | x = input_tensor[row, out_column * 8 + i].item() - 1 111 | packed |= (x << (i * 4)) 112 | output_tensor[row, out_column] = packed 113 | # Use bitwise operators to extract the unsigned lower 32 bits 114 | # Note: 0xFFFFFFFF is a 32-bit all-1 mask, which is used to extract the lower 32 bits 115 | output_tensor = (output_tensor & 0xFFFFFFFF).to(torch.int32) 116 | return output_tensor -------------------------------------------------------------------------------- /tests/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/tests/util/__init__.py -------------------------------------------------------------------------------- /tests/util/binary_mse_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def binary_mse_loss(output, target): 4 | print("Output:", output) 5 | print("Target:", target) 6 | loss = torch.mean((output.to(torch.float32) - target.to(torch.float32))**2) 7 | return loss 8 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.2.6 2 | --------------------------------------------------------------------------------