├── .dev-scripts
├── .gitignore
├── README.md
├── basic_tests.sh
├── extract_install_cmd.py
├── publish-docker-internal.sh
├── publish-with-docker.sh
└── publish.sh
├── .gitignore
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── bitorch_engine
├── __init__.py
├── extensions
│ └── __init__.py
├── functions
│ ├── __init__.py
│ └── cuda
│ │ ├── __init__.py
│ │ ├── extension.py
│ │ ├── functions.py
│ │ ├── functions_cuda.cpp
│ │ └── functions_cuda_kernel.cu
├── layers
│ ├── __init__.py
│ ├── qconv
│ │ ├── __init__.py
│ │ ├── binary
│ │ │ ├── __init__.py
│ │ │ ├── cpp
│ │ │ │ ├── __init__.py
│ │ │ │ ├── binary_conv.cpp
│ │ │ │ ├── extension.py
│ │ │ │ └── layer.py
│ │ │ ├── cutlass
│ │ │ │ ├── __init__.py
│ │ │ │ ├── binary_conv2d_cutlass.cpp
│ │ │ │ ├── binary_conv2d_cutlass_kernel.cu
│ │ │ │ ├── extension.py
│ │ │ │ └── layer.py
│ │ │ └── layer.py
│ │ └── nbit
│ │ │ ├── __init__.py
│ │ │ ├── cutlass
│ │ │ ├── __init__.py
│ │ │ ├── extension.py
│ │ │ ├── layer.py
│ │ │ ├── q4_conv_cutlass.cpp
│ │ │ └── q4_conv_cutlass_kernel.cu
│ │ │ └── layer.py
│ ├── qembedding
│ │ ├── __init__.py
│ │ └── binary
│ │ │ ├── __init__.py
│ │ │ └── layer.py
│ ├── qlinear
│ │ ├── __init__.py
│ │ ├── binary
│ │ │ ├── __init__.py
│ │ │ ├── binary_implementation.py
│ │ │ ├── cpp
│ │ │ │ ├── __init__.py
│ │ │ │ ├── binary_linear.cpp
│ │ │ │ ├── extension.py
│ │ │ │ └── layer.py
│ │ │ ├── cuda
│ │ │ │ ├── __init__.py
│ │ │ │ ├── binary_linear_cuda.cpp
│ │ │ │ ├── binary_linear_cuda_kernel.cu
│ │ │ │ ├── bmm.py
│ │ │ │ ├── extension.py
│ │ │ │ └── layer.py
│ │ │ ├── cutlass
│ │ │ │ ├── __init__.py
│ │ │ │ ├── binary_linear_cutlass.cpp
│ │ │ │ ├── binary_linear_cutlass_kernel.cu
│ │ │ │ ├── binary_linear_cutlass_kernel.h
│ │ │ │ ├── extension.py
│ │ │ │ ├── kernel_selection.h
│ │ │ │ └── layer.py
│ │ │ └── layer.py
│ │ ├── layer.py
│ │ ├── nbit
│ │ │ ├── __init__.py
│ │ │ ├── cuda
│ │ │ │ ├── __init__.py
│ │ │ │ ├── exl2
│ │ │ │ │ ├── compat.cuh
│ │ │ │ │ ├── config.h
│ │ │ │ │ ├── kernel_select.cuh
│ │ │ │ │ ├── matrix_view.cuh
│ │ │ │ │ ├── q_gemm_kernel.cuh
│ │ │ │ │ ├── q_gemm_kernel_gptq.cuh
│ │ │ │ │ ├── quant
│ │ │ │ │ │ ├── qdq_2.cuh
│ │ │ │ │ │ ├── qdq_3.cuh
│ │ │ │ │ │ ├── qdq_4.cuh
│ │ │ │ │ │ ├── qdq_5.cuh
│ │ │ │ │ │ ├── qdq_6.cuh
│ │ │ │ │ │ ├── qdq_8.cuh
│ │ │ │ │ │ └── qdq_util.cuh
│ │ │ │ │ └── util.cuh
│ │ │ │ ├── extension.py
│ │ │ │ ├── mbwq_layer.py
│ │ │ │ ├── mbwq_linear_cuda_kernel.cu
│ │ │ │ ├── mpq_layer.py
│ │ │ │ ├── mpq_linear_cuda_kernel.cu
│ │ │ │ ├── q_linear_cuda.cpp
│ │ │ │ └── utils.py
│ │ │ ├── cutlass
│ │ │ │ ├── __init__.py
│ │ │ │ ├── extension.py
│ │ │ │ ├── q4_layer.py
│ │ │ │ ├── q4_linear_cutlass_kernel.cu
│ │ │ │ ├── q8_layer.py
│ │ │ │ ├── q8_linear_cutlass_kernel.cu
│ │ │ │ └── q_linear_cutlass.cpp
│ │ │ ├── layer.py
│ │ │ └── mps
│ │ │ │ ├── __init__.py
│ │ │ │ ├── extension.py
│ │ │ │ ├── mlx_bindings.cpp
│ │ │ │ ├── mpq_layer.py
│ │ │ │ ├── mpq_linear_mlx.cpp
│ │ │ │ └── mpq_linear_mlx.h
│ │ └── qlinear_implementation.py
│ └── qmha
│ │ ├── __init__.py
│ │ └── binary
│ │ ├── __init__.py
│ │ └── layer.py
├── optim
│ ├── __init__.py
│ ├── diode_beta.py
│ └── galore_projector.py
└── utils
│ ├── __init__.py
│ ├── arch_helper.py
│ ├── convert.py
│ ├── cpp_extension.py
│ ├── cuda_extension.py
│ ├── cutlass_path.py
│ ├── mlx_extension.py
│ ├── mlx_path.py
│ ├── model_helper.py
│ ├── quant_operators.py
│ └── safe_import.py
├── docker
├── Dockerfile
├── README.md
└── build_scripts
│ └── install_modified_pytorch.sh
├── docs
├── .gitignore
├── Makefile
├── README.md
├── make_docs.sh
├── requirements.txt
├── scripts
│ └── convert_docs.py
└── source
│ ├── _templates
│ ├── class.rst
│ └── module.rst
│ ├── build_options.rst
│ ├── conf.py
│ ├── documentation.rst
│ ├── index.rst
│ └── installation.rst
├── examples
├── __init__.py
├── mnist-lightning
│ ├── main.py
│ ├── mlp.py
│ └── requirements.txt
└── mnist
│ ├── README.md
│ ├── __init__.py
│ ├── datasets
│ ├── __init__.py
│ ├── base.py
│ ├── dummy_dataset.py
│ └── mnist.py
│ ├── requirements.txt
│ └── train_mnist.py
├── licenses
├── LICENSE.GPTQ-for-LLaMa.txt
├── LICENSE.cutlass.txt
├── LICENSE.exllamav2.txt
├── LICENSE.mlx.txt
├── LICENSE.pytorch.txt
└── LICENSE.tcbnn.txt
├── requirements-dev.txt
├── requirements.txt
├── setup.py
├── structure.md
├── tests
├── __init__.py
├── functions
│ ├── __init__.py
│ └── test_quant_ops.py
├── layers
│ ├── __init__.py
│ ├── test_binary_conv.py
│ ├── test_binary_embedding.py
│ ├── test_binary_linear.py
│ ├── test_custom_binary_linear.py
│ ├── test_nbit_conv.py
│ ├── test_nbit_linear.py
│ ├── test_nbit_linear_mixbits.py
│ ├── test_nbit_linear_mps.py
│ └── util.py
└── util
│ ├── __init__.py
│ └── binary_mse_loss.py
└── version.txt
/.dev-scripts/.gitignore:
--------------------------------------------------------------------------------
1 | test_*.sh
2 |
--------------------------------------------------------------------------------
/.dev-scripts/README.md:
--------------------------------------------------------------------------------
1 | # Development Scripts
2 |
3 | To publish a binary release with docker (for CUDA), run (replace engine version and cuda version):
4 | ```bash
5 | ./.dev-scripts/publish-with-docker.sh v0.2.3 11.8
6 | ```
7 |
8 | ## Fixing RPaths
9 |
10 | See [this github issue](https://github.com/pytorch/builder/issues/468).
11 |
--------------------------------------------------------------------------------
/.dev-scripts/basic_tests.sh:
--------------------------------------------------------------------------------
1 |
2 | # bash code to run after installation to test for correct package installation
3 |
4 | if [ -n "${SKIP_TESTS}" ]; then
5 | echo "Skipping tests. Done."
6 | exit 0
7 | fi
8 |
9 | # try basic importing, should detect errors of .so loading
10 | python -c "from bitorch_engine.layers.qlinear.binary.cpp import BinaryLinearCPP"
11 | python -c "from bitorch_engine.layers.qembedding.binary import BinaryEmbeddingCuda"
12 | python -c "from bitorch_engine.layers.qlinear.nbit.cutlass import Q4LinearCutlass, Q8LinearCutlass, Q4MatMul"
13 | python -c "from bitorch_engine.layers.qlinear.nbit.cuda import MPQLinearCuda, MBWQLinearCuda"
14 | python -c "from bitorch_engine.layers.qlinear.nbit.cuda.utils import pack_fp_weight, unpack_qweight"
15 | echo "Imports successful!"
16 |
17 | set +o errexit
18 | echo "Testing..."
19 | (
20 | rm -rf bitorch_install_tmp_test_dir
21 | mkdir bitorch_install_tmp_test_dir
22 | cd bitorch_install_tmp_test_dir
23 | git clone https://github.com/GreenBitAI/bitorch-engine.git --depth 1 --branch "v0.2.6" bitorch_engine_git
24 | mv bitorch_engine_git/tests .
25 | pip install pytest numpy
26 | pytest tests/layers/test_nbit_linear.py
27 | pytest tests/layers/test_nbit_linear_mixbits.py
28 | pytest tests/functions/test_quant_ops.py
29 | pytest tests/layers/test_binary_linear.py
30 | )
31 | rm -rf bitorch_install_tmp_test_dir
32 |
--------------------------------------------------------------------------------
/.dev-scripts/extract_install_cmd.py:
--------------------------------------------------------------------------------
1 | # take caution: everything is quite hardcoded here
2 | # any changes to the readme could break this code
3 | # run it from root directory: python extract_install_cmd.py path/to/custom/torch-xxx.whl
4 |
5 | import argparse
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument("custom_pytorch_path", help="Path to custom PyTorch wheel")
8 | parser.add_argument("custom_bitorch_engine_path", help="Path to built bitorch engine wheel file")
9 | args = parser.parse_args()
10 |
11 | BLOCK_HEADER_START_BINARY = "### Binary Release"
12 | BLOCK_HEADER_START_FROM_SOURCE = "#### Conda on Linux"
13 | BLOCK_END = "##########"
14 |
15 | with open("README.md") as infile:
16 | content = infile.readlines()
17 |
18 | with open(".dev-scripts/basic_tests.sh") as infile:
19 | test_appendix = infile.readlines()
20 |
21 |
22 | def write_file(filepath, main_content):
23 | with open(filepath, "w") as outfile:
24 | outfile.write(FILE_INTRO)
25 | outfile.writelines(main_content)
26 | outfile.writelines(test_appendix)
27 |
28 |
29 | source_local_install_instructions = []
30 | source_global_install_instructions = []
31 | binary_local_install_instructions = []
32 | binary_global_install_instructions = []
33 |
34 | in_code_block = False
35 | reading_instructions = False
36 | insert_block_pause = False
37 | instruction_type = ""
38 |
39 | FILE_INTRO = """#!/usr/bin/env bash
40 |
41 | trap exit INT
42 | set -o errexit
43 | set -o xtrace
44 |
45 | """
46 | EXTRA_CONDA_INSTRUCTION = """# extra step for bash script (not required in a proper command line):
47 | eval "$(conda shell.bash hook)"
48 | """
49 |
50 |
51 | for line in content:
52 | if line.startswith("```"):
53 | in_code_block = not in_code_block
54 | continue
55 | if line.startswith(BLOCK_HEADER_START_FROM_SOURCE):
56 | reading_instructions = True
57 | instruction_type = "source-global"
58 | BLOCK_END = BLOCK_HEADER_START_FROM_SOURCE.split()[0]
59 | continue
60 | if line.startswith(BLOCK_HEADER_START_BINARY):
61 | reading_instructions = True
62 | instruction_type = "binary-global"
63 | BLOCK_END = BLOCK_HEADER_START_BINARY.split()[0]
64 | continue
65 | if line.startswith(""):
66 | if "source" in instruction_type:
67 | instruction_type = "source-local"
68 | if "binary" in instruction_type:
69 | instruction_type = "binary-local"
70 | continue
71 | if line.startswith("
"):
72 | if "source" in instruction_type:
73 | instruction_type = "source-both"
74 | if "binary" in instruction_type:
75 | instruction_type = "binary-both"
76 | continue
77 | if line.startswith(BLOCK_END):
78 | reading_instructions = False
79 | continue
80 | if not reading_instructions:
81 | continue
82 | if not in_code_block:
83 | insert_block_pause = True
84 | continue
85 |
86 | # deal with comments
87 | if line.startswith("# export CC="):
88 | line = line[2:]
89 | if line.startswith("#"):
90 | continue
91 |
92 | # replace some line contents and add some lines
93 | if "conda activate" in line:
94 | line = EXTRA_CONDA_INSTRUCTION + line
95 | if "export BITORCH_WORKSPACE" in line:
96 | line = line.replace("${HOME}", "$(pwd)")
97 | if line.startswith("pip install torch-"):
98 | line = "pip install {}\n".format(args.custom_pytorch_path)
99 | if line.startswith("pip install bitorch_engine"):
100 | line = "pip install {}\n".format(args.custom_bitorch_engine_path)
101 |
102 | # decide how to write line
103 | line_format = "{line}"
104 | if line.startswith("#"):
105 | line_format = "{line}"
106 | if insert_block_pause:
107 | insert_block_pause = False
108 | line_format = "\n" + line_format
109 |
110 | # write result line(s)
111 | if instruction_type == "source-global" or instruction_type == "source-both":
112 | source_global_install_instructions.append(line_format.format(line=line))
113 | if instruction_type == "source-local" or instruction_type == "source-both":
114 | source_local_install_instructions.append(line_format.format(line=line))
115 | if instruction_type == "binary-global" or instruction_type == "binary-both":
116 | binary_global_install_instructions.append(line_format.format(line=line))
117 | if instruction_type == "binary-local" or instruction_type == "binary-both":
118 | binary_local_install_instructions.append(line_format.format(line=line))
119 |
120 | write_file(".dev-scripts/test_source_local_conda_install.sh", source_local_install_instructions)
121 | write_file(".dev-scripts/test_source_global_conda_install.sh", source_global_install_instructions)
122 | write_file(".dev-scripts/test_binary_local_conda_install.sh", binary_local_install_instructions)
123 | write_file(".dev-scripts/test_binary_global_conda_install.sh", binary_global_install_instructions)
124 |
125 | binary_local_cu118 = [line.replace("cu121", "cu118").replace("cuda-12.1.0", "cuda-11.8.0") for line in binary_local_install_instructions]
126 | write_file(".dev-scripts/test_binary_local_conda_install_cu118.sh", binary_local_cu118)
127 | binary_local_no_cuda = filter(lambda x: "nvidia/label/cuda-12.1.0" not in x, binary_local_install_instructions)
128 | write_file(".dev-scripts/test_binary_local_conda_install_no_cuda.sh", binary_local_no_cuda)
129 |
--------------------------------------------------------------------------------
/.dev-scripts/publish-with-docker.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -o xtrace
4 |
5 | function usage() {
6 | echo "./.dev-scripts/publish-with-docker.sh BIE_VERSION [CUDA_VERSION]"
7 | echo "builds a package and publishes it to (test-)pypi"
8 | echo
9 | echo "BIE_VERSION must be a version string like 'v1.2.3'."
10 | echo "optional: CUDA_VERSION can be either '11.8' (not yet supported) or '12.1'. (default)"
11 | }
12 |
13 | trap exit INT
14 |
15 | if ! ((1 <= $# && $# <= 2)) || [ "${1}" = "-h" ]; then
16 | usage
17 | exit
18 | fi
19 |
20 | export PUBLISH_BIE_VERSION="${1}"
21 | CUDA_VERSION="${2:-12.1}"
22 |
23 | if ! [[ "${PUBLISH_BIE_VERSION}" =~ ^v[0-9].[0-9].[0-9]$ ]]; then
24 | echo "Invalid BIE_VERSION '${PUBLISH_BIE_VERSION}' given."
25 | echo
26 | usage
27 | exit
28 | fi
29 |
30 | cuda_known="false"
31 | build_args=""
32 | cuda_abbrev="unknown"
33 | # TODO: check support for 11.8:
34 | if [ "${CUDA_VERSION}" = "11.8" ]; then
35 | cuda_known="true"
36 | cuda_abbrev="cu118"
37 | torch_requirement="torch==2.3.0"
38 | build_args="${build_args} --build-arg FROM_IMAGE=pytorch/manylinux-builder:cuda11.8-2.3"
39 | build_args="${build_args} --build-arg CUSTOM_TORCH_URL=https://packages.greenbit.ai/whl/cu118/torch/torch-2.3.0-cp310-cp310-linux_x86_64.whl"
40 | build_args="${build_args} --build-arg TORCHVISION_INDEX_URL=https://download.pytorch.org/whl/cu118"
41 | fi
42 | if [ "${CUDA_VERSION}" = "12.1" ]; then
43 | cuda_known="true"
44 | cuda_abbrev="cu121"
45 | torch_requirement="torch==2.3.0"
46 | fi
47 | if [ "${cuda_known}" = "false" ]; then
48 | echo "Unknown CUDA_VERSION '${CUDA_VERSION}' given."
49 | echo
50 | usage
51 | exit
52 | fi
53 |
54 | echo "building bitorch engine ${PUBLISH_BIE_VERSION}"
55 | echo "building for cuda ${CUDA_VERSION}"
56 |
57 | bie_image_tag="bitorch/engine:publish-${cuda_abbrev}-${PUBLISH_BIE_VERSION}"
58 | bie_container_name="bie-${cuda_abbrev}-${PUBLISH_BIE_VERSION}"
59 | output_folder="./dist/${cuda_abbrev}"
60 |
61 | # build/tag docker image
62 | pushd docker
63 | docker build --target no-examples ${build_args} --build-arg GIT_BRANCH="${PUBLISH_BIE_VERSION}" -t "${bie_image_tag}" .
64 | popd
65 |
66 | mkdir -p "${output_folder}"
67 |
68 | docker container create -it \
69 | --rm \
70 | -it \
71 | -v "${output_folder}:/bitorch-engine/dist" \
72 | --name "${bie_container_name}" \
73 | -e PUBLISH_BIE_VERSION \
74 | -e BIE_FORCE_CUDA="true" \
75 | -e BIE_SKIP_BUILD="true" \
76 | -e USER_ID="$(id -u)" \
77 | -e GROUP_ID="$(id -g)" \
78 | -e BIE_TORCH_REQUIREMENT="${torch_requirement}" \
79 | -e BIE_WHEEL_PLATFORM="linux_x86_64" \
80 | -w /bitorch-engine \
81 | "${bie_image_tag}" \
82 | /workspace/publish-docker-internal.sh release
83 |
84 | # make sure correct version is set
85 | echo "${PUBLISH_BIE_VERSION#v}" > version.txt && docker container cp version.txt "${bie_container_name}":/bitorch-engine
86 | docker container cp .dev-scripts/publish-docker-internal.sh "${bie_container_name}":/workspace
87 |
88 | # for previous versions, we need to manually overwrite setup.py:
89 | # TODO: can (hopefully) be removed later on
90 | docker container cp setup.py "${bie_container_name}":/bitorch-engine
91 |
92 | docker start -ai "${bie_container_name}"
93 |
--------------------------------------------------------------------------------
/.dev-scripts/publish.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | function usage() {
4 | echo "./.dev-scripts/publish.sh VERSION"
5 | echo "builds a package and publishes it to (test-)pypi"
6 | echo
7 | echo "VERSION must be either 'pre-release' or 'release'."
8 | }
9 |
10 | if ! [ "$#" = "1" ] || [ "${1}" = "-h" ]; then
11 | usage
12 | exit
13 | fi
14 |
15 | if ! [ "${1}" = "release" ] && ! [ "${1}" = "pre-release" ]; then
16 | usage
17 | exit
18 | fi
19 |
20 | # set SCRIPT_ROOT:
21 | SCRIPT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
22 |
23 | # set SRC_ROOT and make sure that is our working directory
24 | SRC_ROOT="$(readlink -f "${SCRIPT_ROOT}/..")"
25 | cd "${SRC_ROOT}"
26 |
27 | trap exit INT
28 |
29 | function check_yes() {
30 | # asks the given yes or no question, returns true if they answer YES
31 | # usage:
32 | # if check_yes "Do you really want to delete foo?"; then
33 | # rm foo
34 | # fi
35 |
36 | local prompt="${1}"
37 | read -p "${prompt} [y/N] " REPLY
38 | echo ""
39 | if [[ ! "${REPLY}" =~ ^[Yy]$ ]]; then
40 | return 1
41 | fi
42 | return 0
43 | }
44 |
45 | function check_no() {
46 | # asks the given yes or no question, returns false if they answer NO
47 | # usage:
48 | # if check_no "Do you want to exit the script?"; then
49 | # exit 0
50 | # fi
51 |
52 | local prompt="${1}"
53 | read -p "${prompt} [Y/n] " REPLY
54 | echo ""
55 | if [[ "${REPLY}" =~ ^[Nn]$ ]]; then
56 | return 1
57 | fi
58 | return 0
59 | }
60 |
61 | function check_error() {
62 | # shows and then runs a command. if the exit code is not zero, asks the user whether to continue
63 | # usage: check_error mv foo bar
64 |
65 | echo + $@
66 | "$@"
67 | local exit_code=$?
68 | if [ "${exit_code}" -ne 0 ]; then
69 | if ! check_yes "! > An error occurred, continue with the script?"; then
70 | if [ "${1}" = "pre-release" ]; then
71 | git checkout "${version_file}"
72 | fi
73 | exit 1
74 | fi
75 | fi
76 | }
77 |
78 | # main script content
79 |
80 | if [ -z "$(git status --porcelain)" ]; then
81 | echo "Git seems clean."
82 | else
83 | echo "There are uncommitted changes, aborting."
84 | exit 1
85 | fi
86 |
87 | if [ "${1}" = "release" ]; then
88 | version_file="${SRC_ROOT}/version.txt"
89 | version_content="$(cat "${version_file}")"
90 | major_minor_patch="$(cut -d '.' -f 1,2,3 <<< "${version_content}")"
91 | version_str="${major_minor_patch}"
92 | else
93 | version_file="${SRC_ROOT}/version.txt"
94 | version_content="$(cat "${version_file}")"
95 | major_minor_patch="$(cut -d '.' -f 1,2,3 <<< "${version_content}")"
96 | date_str="$(date +"%Y%m%d")"
97 | git_ref="$(git rev-parse --short HEAD)"
98 | version_str="${major_minor_patch}.dev${date_str}+${git_ref}"
99 | fi
100 |
101 | if [ "${1}" = "release" ] && ! [ "${version_content}" = "${version_str}" ]; then
102 | echo "The file version.txt does not seem to contain a release version."
103 | exit 1
104 | else
105 | if ! check_no "Publish version ${version_str} ?"; then
106 | exit 0
107 | fi
108 | fi
109 |
110 | if [ "${1}" = "pre-release" ]; then
111 | echo "${version_str}" > "${version_file}"
112 | fi
113 |
114 | check_error pip uninstall -y -r <(pip freeze)
115 | check_error pip install --upgrade pip
116 | check_error pip install -e ".[dev]" -v
117 |
118 | pip install build twine
119 |
120 | check_error pytest .
121 | check_error python3 -m build --sdist
122 |
123 | if check_yes "Publish to real PyPi?"; then
124 | echo "Do:"
125 | echo ' python3 -m twine dist/*'
126 | else
127 | echo "Do:"
128 | echo ' python3 -m twine upload --repository testpypi dist/*'
129 | fi
130 |
131 | if [ "${1}" = "pre-release" ]; then
132 | git checkout "${version_file}"
133 | fi
134 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 | README.rst
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99 | __pypackages__/
100 |
101 | # Celery stuff
102 | celerybeat-schedule
103 | celerybeat.pid
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
135 | # pytype static type analyzer
136 | .pytype/
137 |
138 | # Cython debug symbols
139 | cython_debug/
140 |
141 | .coverage
142 | .idea/
143 | .vscode/
144 | .mypy_cache
145 |
146 | # downloaded dataset files
147 | train/
148 | test/
149 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 | The format is based on [Keep a Changelog](http://keepachangelog.com/)
5 | and this project adheres to [Semantic Versioning](http://semver.org/).
6 |
7 |
8 | ## [0.2.6] - 2024/06/14
9 |
10 | ### Added
11 |
12 | - Installation instructions for binary releases
13 | - Warning if non-customized PyTorch version is detected which can not calculate gradients for non-complex tensor types
14 |
15 | ### Changed
16 |
17 | - Updated development scripts for binary releases
18 | - Adjusting rpaths in .so files (based on PyTorch's implemented solution)
19 | - Docker base image changed to manywheel builder image
20 |
21 | ## [0.2.5] - 2024/05/24
22 |
23 | ### Added
24 |
25 | - Development scripts for preparing binary releases
26 |
27 | ### Changed
28 |
29 | - Updated build instructions to clarify torchvision installation
30 | - Adapted `setup.py` logic for preparing binary releases
31 |
32 | ### Fixed
33 |
34 | - Broken build process by setting setuptools version
35 |
36 | ## [0.2.4] - 2024/05/23
37 |
38 | ### Added
39 |
40 | - Tuned the hyperparameters of DiodeMix optimizer for sft.
41 | - Added sft-support for the classical gptq-style models.
42 | - Implemented qzeros update in finetuning process.
43 |
44 | ### Updated
45 |
46 | - Extended pack_fp_weight function.
47 | - Enhanced the performance of MPQLinearCUDA layer.
48 |
49 | ### Fixed
50 |
51 | - Fixed various errors in DiodeMix update function.
52 |
53 | ## [0.2.3] - 2024/05/01
54 |
55 | ### Updated
56 |
57 | - Enhanced the performance of the MBWQ linear layer for processing long sequences, addressing previous inefficiencies.
58 |
59 | ## [0.2.2] - 2024/04/29
60 |
61 | ### Updated
62 |
63 | - Building instructions (adding a section for cutlass)
64 | - Checksums for custom torch builds (within docker)
65 |
66 | ### Fixed
67 |
68 | - An error in `pack_fp_weight`
69 |
70 | ## [0.2.1] - 2024/04/27
71 |
72 | ### Fixed
73 |
74 | - Broken links in README.md and index.rst
75 |
76 | ## [0.2.0] - 2024/03/10
77 |
78 | ### Added
79 |
80 | - Quantized layers with different acceleration options
81 | - QConv (binary, quantized) - CPU, Cutlass
82 | - QLinear (binary, quantized, mixed bit-width) - CUDA, Cutlass, MPS
83 | - QEmbedding (binary)
84 | - Optimizer(s) for quantized layers
85 | - Hybrid optimizer `diode_beta` based on Diode v1 (binary) and AdamW (quantized) for memory-efficient training
86 | - Initial support for galore projection
87 | - Examples
88 | - MNIST training script with and without PyTorch Lightning
89 |
90 | ## [0.1.0] - 2023/01/13
91 |
92 | The first release of basic functionality.
93 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Bitorch Engine
2 | Your contributions are welcomed and appreciated. We strive to make the process straightforward and transparent.
3 |
4 | ## Development
5 | 1. Propose significant changes via an Request for Comments (RFC) for discussion.
6 | 2. Add features by submitting a PR with accompanying tests and documentation.
7 | 3. Fix bugs by submitting a PR that includes tests validating the fix and any necessary documentation updates.
8 |
9 | ## Pull Requests
10 | We encourage your pull requests.
11 |
12 | 1. Fork the repository and create your branch from main.
13 | 2. Include tests for any new code.
14 | 3. Update documentation for all related functions.
15 | 4. Ensure all tests are passing.
16 |
17 | ## Reporting Issues
18 | We track bugs with GitHub issues. Provide a clear description and instructions to reproduce the issue.
19 |
20 | ## Your Contributions
21 | Contributions are **licensed under the LICENSE** file found in the root directory of this source tree.
--------------------------------------------------------------------------------
/bitorch_engine/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def initialize():
4 | """
5 | This functions makes all custom layer implementations available in BITorch.
6 | """
7 | from .layers.qlinear import QLinearInf
8 |
9 |
10 | import os
11 | torch_int_gradients_support = None
12 | if os.environ.get("BIE_SKIP_TORCH_CHECK", "false") == "false" and torch_int_gradients_support is None:
13 | import warnings
14 | torch_int_gradients_support = False
15 | try:
16 | import torch
17 | x = torch.nn.Parameter(torch.zeros((1,), dtype=torch.uint8), requires_grad=True)
18 | torch_int_gradients_support = True
19 | except RuntimeError as e:
20 | if "dtype" in str(e).lower() and "only" in str(e).lower():
21 | warnings.warn(
22 | "It seems a regular version of torch is installed.\n"
23 | " Please install the custom torch with enabled gradient calculation for integer tensors.\n"
24 | " Check the instructions at https://github.com/GreenBitAI/bitorch-engine for more information.")
25 | else:
26 | warnings.warn("There may be a problem with the currently installed version of torch:\n" + str(e))
27 | except ModuleNotFoundError as e:
28 | # if torch is not installed, we can not check
29 | pass
30 |
--------------------------------------------------------------------------------
/bitorch_engine/extensions/__init__.py:
--------------------------------------------------------------------------------
1 | EXTENSION_PREFIX = "bitorch_engine.extensions."
2 |
--------------------------------------------------------------------------------
/bitorch_engine/functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/bitorch_engine/functions/__init__.py
--------------------------------------------------------------------------------
/bitorch_engine/functions/cuda/__init__.py:
--------------------------------------------------------------------------------
1 | from .functions import *
--------------------------------------------------------------------------------
/bitorch_engine/functions/cuda/extension.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension
4 |
5 | CUDA_REQUIRED = True
6 |
7 |
8 | def get_ext(path: Path):
9 | """
10 | Generate CUDA extension for specified path.
11 |
12 | Args:
13 | path (Path): Path to the directory containing CUDA extension files.
14 |
15 | Returns:
16 | Extension: CUDA extension for specified path.
17 | """
18 | return get_cuda_extension(
19 | path,
20 | relative_name='functions_cuda',
21 | relative_sources=[
22 | 'functions_cuda.cpp',
23 | 'functions_cuda_kernel.cu',
24 | ]
25 | )
26 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/bitorch_engine/layers/__init__.py
--------------------------------------------------------------------------------
/bitorch_engine/layers/qconv/__init__.py:
--------------------------------------------------------------------------------
1 | from .binary import *
--------------------------------------------------------------------------------
/bitorch_engine/layers/qconv/binary/__init__.py:
--------------------------------------------------------------------------------
1 | from .layer import BinaryConv2dBase, BinaryConvParameter
--------------------------------------------------------------------------------
/bitorch_engine/layers/qconv/binary/cpp/__init__.py:
--------------------------------------------------------------------------------
1 | from .layer import BinaryConv2dCPP
--------------------------------------------------------------------------------
/bitorch_engine/layers/qconv/binary/cpp/extension.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from bitorch_engine.utils.cpp_extension import get_cpp_extension
4 |
5 |
6 | def get_ext(path: Path):
7 | """
8 | Retrieves the C++ extension for binary convolution.
9 |
10 | Args:
11 | path (Path): The path to the directory containing the binary convolution C++ code.
12 |
13 | Returns:
14 | torch.utils.cpp_extension.CppExtension: The C++ extension for binary convolution.
15 | """
16 | return get_cpp_extension(
17 | path,
18 | relative_name='binary_conv_cpp',
19 | relative_sources=['binary_conv.cpp']
20 | )
21 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qconv/binary/cpp/layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 |
4 | from bitorch_engine.utils.safe_import import import_extension
5 |
6 | binary_conv_cpp = import_extension("binary_conv_cpp")
7 |
8 |
9 | from bitorch_engine.utils.quant_operators import get_binary_row
10 | from ..layer import BinaryConv2dBase
11 |
12 | class BinaryConv2dForward(Function):
13 | """
14 | A custom autograd function to perform forward pass of a 2D binary convolution.
15 |
16 | This class implements a static method `forward` to carry out the convolution operation
17 | using binary weights and activations. The operation is performed using a custom C++
18 | backend for efficiency.
19 |
20 | Attributes:
21 | - No class-level attributes.
22 |
23 | Methods:
24 | - forward: Performs the forward pass of the binary convolution.
25 | """
26 | @staticmethod
27 | def forward(ctx, activations: torch.Tensor, weights: torch.Tensor, m: int, n: int, k: int, kernel_size: int,
28 | stride: int, padding: int, dilation: int, output_edge: int) -> torch.Tensor:
29 | """
30 | Forward pass for the 2D binary convolution.
31 |
32 | Utilizes a C++ backend implemented in `binary_conv_cpp.forward` to perform the operation.
33 | This method is statically defined and automatically integrated with PyTorch's autograd mechanism.
34 |
35 | Parameters:
36 | - ctx (torch.autograd.function.BackwardContext): Context object that can be used to stash information
37 | for backward computation. You can cache arbitrary objects for use in the backward pass using
38 | the `save_for_backward` method.
39 | - activations (Tensor): The input feature map or activation tensor.
40 | - weights (Tensor): The binary weights tensor.
41 | - m, n, k (int): Dimensions of the input, specifically:
42 | - m: The number of output channels.
43 | - n: The number of input channels.
44 | - k: The spatial size of the output feature map.
45 | - kernel_size (int or tuple): Size of the conv kernel.
46 | - stride (int or tuple): Stride of the convolution.
47 | - padding (int or tuple): Zero-padding added to both sides of the input.
48 | - dilation (int or tuple): Spacing between kernel elements.
49 | - output_edge (int): The size of the output edge to ensure the output dimension matches expectations.
50 |
51 | Returns:
52 | - Tensor: The output feature map resulting from the binary convolution operation.
53 |
54 | Note:
55 | This method is part of the forward pass and needs to be paired with a corresponding backward
56 | method to enable gradient computation.
57 | """
58 | output = binary_conv_cpp.forward(activations, weights, m, n, k, kernel_size, stride, padding,
59 | dilation, output_edge)
60 | return output
61 |
62 |
63 | class BinaryConv2dCPP(BinaryConv2dBase):
64 | """
65 | This class implements a binary convolutional layer in PyTorch, specifically optimized with C++ extensions.
66 | It inherits from BinaryConv2dBase to leverage common binary convolution functionalities with added
67 | optimizations for efficient computation.
68 |
69 | Attributes:
70 | bits_binary_word (int): Defines the size of the binary word, defaulting to 8 bits.
71 | """
72 | def __init__(self, *args, **kwargs):
73 | """
74 | Initializes the BinaryConv2dCPP layer with the given arguments, which are forwarded to the base class.
75 | Additionally, it sets up the binary word size for quantization.
76 |
77 | Args:
78 | *args: Variable length argument list to be passed to the BinaryConv2dBase class.
79 | **kwargs: Arbitrary keyword arguments to be passed to the BinaryConv2dBase class.
80 | """
81 | super(BinaryConv2dCPP, self).__init__(*args, **kwargs)
82 | self.bits_binary_word = 8
83 |
84 | def prepare_params(self) -> None:
85 | """
86 | Prepares and initializes the model parameters for training.
87 | One can use "prepare_bie_layers" method from project_root.utils.model_helper to call this function.
88 | """
89 | pass
90 |
91 | def generate_quantized_weight(self, qweight_only: bool = False) -> None:
92 | """
93 | Generates and stores quantized weights based on the current weights of the layer, utilizing a binary
94 | quantization method. Quantized weights are stored as a torch.nn.Parameter but are not set to require gradients.
95 |
96 | Args:
97 | qweight_only (bool): If True, the original weights are discarded to save memory. Defaults to False.
98 | """
99 | w_size = self.out_channels * self.in_channels/self.bits_binary_word * self.kernel_size * self.kernel_size
100 | self.qweight = torch.nn.Parameter(
101 | get_binary_row(self.weight.reshape(-1, ),
102 | torch.empty(int(w_size), dtype=torch.uint8),
103 | w_size * self.bits_binary_word,
104 | self.bits_binary_word),
105 | requires_grad=False
106 | )
107 | if qweight_only:
108 | self.weight = None
109 |
110 | def forward(self, x: torch.Tensor) -> torch.Tensor:
111 | """
112 | Defines the forward pass for the binary convolution operation using the quantized weights.
113 |
114 | Args:
115 | x (torch.Tensor): The input tensor for the convolution operation with shape (N, C_in, H, W),
116 | where N is the batch size, C_in is the number of channels, and H, W are the height
117 | and width of the input tensor.
118 |
119 | Returns:
120 | torch.Tensor: The output tensor of the convolution operation with shape determined by the layer's
121 | attributes and the input dimensions.
122 | """
123 | self._check_forward(x)
124 | # pass m, n, k
125 | m = self.out_channels # number of output channel
126 | k = x.size(dim=1) * self.kernel_size * self.kernel_size; # number of input channels * kernel size
127 | # (Image_w – filter_w + 2*pad_w) / stride + 1
128 | output_edge = int((x.size(dim=2) - self.kernel_size + 2 * self.padding) / self.stride + 1)
129 | n = output_edge * output_edge # number of pixels of output images per channel
130 | return BinaryConv2dForward.apply(x, self.opt_weight, m, n, k, self.kernel_size, self.stride, self.padding,
131 | self.dilation, output_edge)
--------------------------------------------------------------------------------
/bitorch_engine/layers/qconv/binary/cutlass/__init__.py:
--------------------------------------------------------------------------------
1 | from .layer import BinaryConv2dCutlass
--------------------------------------------------------------------------------
/bitorch_engine/layers/qconv/binary/cutlass/binary_conv2d_cutlass.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | // CUDA forward declarations
5 |
6 | /**
7 | * Performs a forward pass of binary convolution using the CUTLASS library.
8 | * This function is optimized for binary convolutions, leveraging the efficiency of CUTLASS kernels.
9 | *
10 | * Args:
11 | * input (torch::Tensor): The input tensor with shape [batch_size, in_channels, in_height, in_width].
12 | * weight (torch::Tensor): The filter weights tensor with shape [out_channels, kernel_size, kernel_size, in_channels].
13 | * scale (float): A scaling factor applied to the output tensor.
14 | * is_train (bool): A flag indicating whether the operation is being performed during training.
15 | * This influences the processing of the weight tensor.
16 | * kernel_size (int): The size of the convolution kernel.
17 | * stride (int): The stride of the convolution.
18 | * padding (int): The padding added to the input tensor.
19 | * dilation (int): The dilation factor for the convolution.
20 | * device_id (int): The ID of the CUDA device on which to perform the operation.
21 | *
22 | * Returns:
23 | * torch::Tensor: The output tensor of the convolution, scaled by the 'scale' parameter.
24 | * The output tensor has shape [batch_size, out_edge, out_edge, out_channels],
25 | * where 'out_edge' is computed based on the input dimensions, padding, and stride.
26 | *
27 | * Notes:
28 | * - The function sets the CUDA device to 'device_id' at the beginning.
29 | * - It calculates the output tensor dimensions based on the input size, kernel size, stride, and padding.
30 | * - The weights are optionally preprocessed (viewed and packed) based on the training mode.
31 | * - The input tensor is reshaped and packed for efficient processing.
32 | * - The actual convolution operation is performed by a call to the 'xnor_cutlass::_impl_conv_forward' function,
33 | * which utilizes CUTLASS kernels optimized for binary convolutions.
34 | * - Finally, the output tensor is scaled by the 'scale' parameter before being returned.
35 | */
36 | torch::Tensor binary_conv2d_cutlass_forward(
37 | torch::Tensor input,
38 | torch::Tensor weight,
39 | float scale,
40 | bool is_train,
41 | int kernel_size,
42 | int stride,
43 | int padding,
44 | int dilation);
45 |
46 |
47 | /**
48 | * Performs binary convolution operation with weight packing using CUTLASS.
49 | *
50 | * This function adapts the input data tensor for binary convolution by rearranging its dimensions to match
51 | * the expected format {OHWC} (Output Channels, Height, Width, Input Channels) and then packs the data to optimize
52 | * the convolution operation. It leverages CUTLASS kernels for efficient computation.
53 | *
54 | * Args:
55 | * data (torch::Tensor): The input tensor for the convolution operation. Expected to have dimensions
56 | * {Output Channels, Input Channels, Kernel Height, Kernel Width}.
57 | *
58 | * Returns:
59 | * torch::Tensor: A tensor containing the packed data, ready for efficient binary convolution with CUTLASS.
60 | */
61 | torch::Tensor binary_conv2d_w_pack_cutlass(
62 | torch::Tensor data);
63 |
64 |
65 | // C++ interface
66 |
67 |
68 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
69 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
70 | #define CHECK_INPUT(x) CHECK_CUDA(x);
71 |
72 |
73 | torch::Tensor binary_conv2d_forward(
74 | torch::Tensor input,
75 | torch::Tensor weight,
76 | float scale,
77 | bool is_train,
78 | int kernel_size,
79 | int stride,
80 | int padding,
81 | int dilation
82 | ) {
83 | CHECK_INPUT(input);
84 | CHECK_INPUT(weight);
85 | return binary_conv2d_cutlass_forward(input, weight, scale, is_train, kernel_size, stride, padding, dilation);
86 | }
87 |
88 |
89 | torch::Tensor binary_conv2d_w_pack(
90 | torch::Tensor data
91 | ){
92 | CHECK_INPUT(data);
93 | return binary_conv2d_w_pack_cutlass(data);
94 | }
95 |
96 |
97 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
98 | m.def("forward", &binary_conv2d_forward, "scaled binary conv2d forward (CUTLASS)");
99 | m.def("w_pack", &binary_conv2d_w_pack, "packing binary weight (CUTLASS)");
100 | }
101 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qconv/binary/cutlass/extension.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension
4 |
5 | CUDA_REQUIRED = True
6 | CUTLASS_REQUIRED = True
7 |
8 |
9 | def get_ext(path: Path):
10 | """
11 | Obtains the CUDA extension for Cutlass-based binary convolution.
12 |
13 | Args:
14 | path (Path): The path to the directory containing the necessary source files
15 | for the Cutlass-based binary convolution operation.
16 |
17 | Returns:
18 | Extension: The CUDA extension for the Cutlass-based binary convolution.
19 | """
20 | return get_cuda_extension(
21 | path,
22 | relative_name='binary_conv2d_cutlass',
23 | relative_sources=[
24 | 'binary_conv2d_cutlass.cpp',
25 | 'binary_conv2d_cutlass_kernel.cu',
26 | ]
27 | )
28 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qconv/nbit/__init__.py:
--------------------------------------------------------------------------------
1 | from .layer import nBitConv2dBase, nBitConvParameter
--------------------------------------------------------------------------------
/bitorch_engine/layers/qconv/nbit/cutlass/__init__.py:
--------------------------------------------------------------------------------
1 | from .layer import Q4Conv2dCutlass
--------------------------------------------------------------------------------
/bitorch_engine/layers/qconv/nbit/cutlass/extension.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension
4 |
5 | CUDA_REQUIRED = True
6 | CUTLASS_REQUIRED = True
7 |
8 |
9 | def get_ext(path: Path):
10 | """
11 | Get CUDA extension for a specified path.
12 |
13 | Args:
14 | path (Path): The path to the directory containing CUDA extension files.
15 |
16 | Returns:
17 | Extension: The CUDA extension object.
18 | """
19 | return get_cuda_extension(
20 | path,
21 | relative_name='q4_conv_cutlass',
22 | relative_sources=[
23 | 'q4_conv_cutlass.cpp',
24 | 'q4_conv_cutlass_kernel.cu',
25 | ]
26 | )
27 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qconv/nbit/cutlass/q4_conv_cutlass.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 |
5 | // CUTLASS forward declarations
6 |
7 | /**
8 | * Performs a forward pass of the quantized 4-bit convolution (q4_conv2d) using the CUTLASS library.
9 | * This function takes a 4-bit quantized input and weight tensors, along with convolution parameters
10 | * like scale factors, kernel size, stride, padding, and dilation, to perform the convolution operation
11 | * optimized for CUDA. It's designed to work with NHWC tensor format for efficient computation.
12 | *
13 | * Parameters:
14 | * input - The input tensor in NCHW format that will be converted to NHWC internally.
15 | * weight - The weight tensor, which can be either pre-packed (in inference mode) or will be packed during training.
16 | * scale_a - The scale factor for the input tensor quantization.
17 | * scale_w - The scale factor for the weight tensor quantization.
18 | * is_train - A boolean flag indicating whether the operation is for training. Affects weight processing.
19 | * kernel_size - The size of the convolution kernel.
20 | * stride - The stride of the convolution.
21 | * padding - The padding added to both sides of the input tensor.
22 | * dilation - The spacing between kernel elements.
23 | *
24 | * Returns:
25 | * A tensor containing the result of the quantized convolution operation, scaled by the product of input
26 | * and weight scale factors.
27 | */
28 | std::vector q4_conv2d_cutlass_forward(
29 | torch::Tensor input,
30 | torch::Tensor weight,
31 | float scale_a,
32 | float scale_w,
33 | bool is_train,
34 | int kernel_size,
35 | int stride,
36 | int padding,
37 | int dilation);
38 |
39 |
40 | /**
41 | *
42 | * This function prepares the weight tensor for a quantized 4-bit convolution operation.
43 | * It takes a weight tensor and a scale factor as inputs, restructures the weight tensor for the
44 | * convolution operation, and quantizes it to 4 bits. This preparation is crucial for leveraging
45 | * CUTLASS's efficient low-bit computation capabilities.
46 | *
47 | * Parameters:
48 | * - weight: The original weight tensor of the convolutional layer.
49 | * - scale: The scaling factor used for quantization of the weights to 4-bit precision.
50 | *
51 | * Returns:
52 | * - A tensor representing the packed and quantized weights, ready for use in a 4-bit convolution operation.
53 | */
54 | torch::Tensor q4_conv2d_w_pack_cutlass(
55 | torch::Tensor weight,
56 | float scale);
57 |
58 |
59 | // C++ interface
60 |
61 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
62 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
63 | #define CHECK_INPUT(x) CHECK_CUDA(x);
64 |
65 |
66 | std::vector q4_conv2d_forward(
67 | torch::Tensor input,
68 | torch::Tensor weight,
69 | float scale_a,
70 | float scale_w,
71 | bool is_train,
72 | int kernel_size,
73 | int stride,
74 | int padding,
75 | int dilation) {
76 | CHECK_INPUT(input);
77 | CHECK_INPUT(weight);
78 | return q4_conv2d_cutlass_forward(input, weight, scale_a, scale_w, is_train,
79 | kernel_size, stride, padding, dilation);
80 | }
81 |
82 |
83 | torch::Tensor q4_conv2d_w_pack(
84 | torch::Tensor weight,
85 | float scale
86 | ) {
87 | CHECK_INPUT(weight);
88 | return q4_conv2d_w_pack_cutlass(weight, scale);
89 | }
90 |
91 |
92 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
93 | m.def("forward", &q4_conv2d_forward, "4-bit conv forward (CUTLASS)");
94 | m.def("w_pack", &q4_conv2d_w_pack, "pack q4 weight");
95 | }
96 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qembedding/__init__.py:
--------------------------------------------------------------------------------
1 | from .binary import *
2 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qembedding/binary/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .layer import BinaryEmbeddingBag, BinaryEmbeddingParameter, BinaryEmbeddingCuda
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/__init__.py:
--------------------------------------------------------------------------------
1 | from .layer import QLinearInf
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/binary/__init__.py:
--------------------------------------------------------------------------------
1 | from .layer import BinaryLinearBase, BinaryLinearParameter
2 | import torch.cuda
3 |
4 |
5 | def get_best_binary_implementation():
6 | if torch.cuda.is_available():
7 | from .cuda import BinaryLinearCuda
8 | return BinaryLinearCuda
9 | else:
10 | from .cpp import BinaryLinearCPP
11 | return BinaryLinearCPP
12 |
13 |
14 | BinaryLinear = get_best_binary_implementation()
15 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/binary/binary_implementation.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import Tuple
3 |
4 | from bitorch.layers import QLinearBase
5 | from bitorch.layers.extensions import LayerRecipe
6 | from bitorch.quantizations import Sign, SwishSign
7 |
8 | from bitorch_engine.layers.qlinear.qlinear_implementation import QLinearImplementationMixin
9 |
10 |
11 | class BinaryLinearImplementationMixin(QLinearImplementationMixin, ABC):
12 | """
13 | A mixin class for binary linear layer implementations that extends the quantized linear layer implementation mixin (QLinearImplementationMixin).
14 | This class provides specialized methods to determine if a layer can be cloned based on the quantization functions used for inputs and weights.
15 |
16 | The class supports binary quantization functions such as Sign and SwishSign for both inputs and weights. It leverages the `can_clone` class method
17 | to check if the specified quantization functions are supported for cloning a layer according to a given recipe.
18 |
19 | Attributes:
20 | None specified explicitly, but inherits from QLinearImplementationMixin and ABC.
21 |
22 | Methods:
23 | can_clone: Class method to determine if a layer can be cloned based on its quantization functions for inputs and weights.
24 | """
25 | @classmethod
26 | def can_clone(cls, recipe: LayerRecipe) -> Tuple[bool, str]:
27 | """
28 | Determines if a layer can be cloned based on its quantization functions for inputs and weights.
29 |
30 | This method checks if the layer's input and weight quantization functions are among the supported binary quantization functions.
31 | If either quantization function is not supported, the method returns False along with a message indicating which quantization
32 | function is not supported.
33 |
34 | Args:
35 | recipe (LayerRecipe): An object containing the configuration and parameters for the layer to be cloned.
36 |
37 | Returns:
38 | Tuple[bool, str]: A tuple containing a boolean indicating whether the layer can be cloned and a string message.
39 | If the layer can be cloned, the boolean is True, and the string is empty. If the layer cannot be cloned due to unsupported
40 | quantization functions, the boolean is False, and the string contains a message specifying the unsupported quantization function.
41 | """
42 | supported_quantization_functions = (Sign, SwishSign) # Define supported quantization functions
43 | args = QLinearBase.get_args_as_kwargs(recipe) # Retrieve layer arguments as keyword arguments
44 |
45 | # Check if input quantization function is supported
46 | if args["input_quantization"].__class__ not in supported_quantization_functions:
47 | return False, f"the input quantization {args['input_quantization'].name} is not yet supported."
48 |
49 | # Check if weight quantization function is supported
50 | if args["weight_quantization"].__class__ not in supported_quantization_functions:
51 | return False, f"the weight quantization {args['weight_quantization'].name} is not yet supported."
52 |
53 | # Call superclass method to perform any additional checks
54 | return super().can_clone(recipe)
55 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/binary/cpp/__init__.py:
--------------------------------------------------------------------------------
1 | from .layer import BinaryLinearCPP
2 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/binary/cpp/extension.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from bitorch_engine.utils.cpp_extension import get_cpp_extension
4 |
5 |
6 | def get_ext(path: Path):
7 | """Retrieve C++ extension details for binary linear module.
8 |
9 | Args:
10 | path (Path): Path to the directory containing the extension module.
11 |
12 | Returns:
13 | Any: Extension module details.
14 | """
15 | return get_cpp_extension(
16 | path,
17 | relative_name='binary_linear_cpp',
18 | relative_sources=['binary_linear.cpp']
19 | )
20 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/binary/cpp/layer.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import torch
4 | from bitorch import RuntimeMode
5 | from bitorch.layers import QLinearBase
6 | from bitorch.layers.extensions import LayerRecipe
7 | from bitorch.layers.register import QLinearImplementation
8 | from torch.autograd import Function
9 |
10 | from bitorch_engine.utils.safe_import import import_extension
11 | from ..binary_implementation import BinaryLinearImplementationMixin
12 | from ..layer import BinaryLinearBase
13 | from bitorch_engine.utils.model_helper import flatten_x, unflatten_x
14 |
15 | binary_linear_cpp = import_extension("binary_linear_cpp")
16 |
17 |
18 | class BinaryLinearForward(Function):
19 | """
20 | A custom autograd function for performing forward pass of binary linear layer.
21 | This function uses a custom C++ backend for efficient computation.
22 |
23 | Args:
24 | ctx (torch.autograd.function.FunctionCtx): The context for storing information for backward computation.
25 | input (torch.Tensor): The input tensor.
26 | weights (torch.Tensor): The binary weights tensor.
27 | m (int): The batch size.
28 | n (int): The number of output features.
29 | k (int): The number of input features.
30 |
31 | Returns:
32 | torch.Tensor: The output tensor after applying the binary linear transformation.
33 | """
34 | @staticmethod
35 | def forward(ctx, input: torch.Tensor, weights: torch.Tensor, m: int, n: int, k: int) -> torch.Tensor:
36 | input, shape = flatten_x(input)
37 | output = binary_linear_cpp.forward(input, weights, m, n, k)
38 | output = unflatten_x(output, shape)
39 | return output
40 |
41 |
42 | @QLinearImplementation(RuntimeMode.CPU)
43 | class BinaryLinearCPP(BinaryLinearImplementationMixin, BinaryLinearBase):
44 | """
45 | A class representing the binary linear layer implemented in C++ for CPU runtime mode.
46 | Inherits from BinaryLinearBase and mixes in BinaryLinearImplementationMixin for common functionality.
47 |
48 | This class supports creating a clone of itself from a given LayerRecipe, allowing for easy replication
49 | and modification of layer parameters.
50 | """
51 | @classmethod
52 | def create_clone_from(cls, recipe: LayerRecipe) -> Any:
53 | """
54 | Creates a clone of this layer based on the provided LayerRecipe.
55 |
56 | Args:
57 | recipe (LayerRecipe): The recipe containing the parameters for the clone.
58 |
59 | Returns:
60 | Any: A new instance of this class with parameters derived from the recipe.
61 | """
62 | args = QLinearBase.get_args_as_kwargs(recipe)
63 | input_features, output_features = args["in_features"], args["out_features"]
64 | input_features //= 8
65 | new_layer = cls(input_features, output_features)
66 | new_layer.set_weight_data(recipe.layer.weight.data)
67 | new_layer.generate_quantized_weight(qweight_only=True)
68 | return new_layer
69 |
70 | def __init__(
71 | self,
72 | input_features: int,
73 | out_features: int,
74 | device: torch.device = None,
75 | ) -> None:
76 | """
77 | Initializes the BinaryLinearCPP layer.
78 |
79 | Args:
80 | input_features (int): The number of input features (divided by 8 for binary).
81 | out_features (int): The number of output features.
82 | device (torch.device, optional): The device on which to perform computations.
83 | """
84 | super().__init__(input_features, out_features, device)
85 |
86 | def prepare_params(self) -> None:
87 | """
88 | Prepares and initializes the model parameters for training.
89 | One can use "prepare_bie_layers" method from project_root.utils.model_helper to call this function.
90 | """
91 | pass
92 |
93 | def generate_quantized_weight(self, qweight_only: bool = False) -> None:
94 | """
95 | Generates the quantized weight matrix for this layer and optionally clears the original weight.
96 |
97 | Args:
98 | qweight_only (bool, optional): If True, the original weight matrix is cleared to save memory.
99 | """
100 | # Generate packed weight using custom C++ function
101 | self.qweight = binary_linear_cpp.w_pack(
102 | self.weight, # Original weight
103 | self.output_features, # n
104 | self.input_features, # k
105 | )
106 | if qweight_only:
107 | self.weight = None # Clear the original weight matrix if specified
108 |
109 | def forward(self, x: torch.Tensor) -> torch.Tensor:
110 | """
111 | Defines the forward pass of the binary linear layer.
112 |
113 | Args:
114 | x (torch.Tensor): The input tensor.
115 |
116 | Returns:
117 | torch.Tensor: The output tensor after applying the binary linear transformation.
118 | """
119 | # Check input validity
120 | self._check_forward(x)
121 | # pass m, n, k
122 | m = x.size(dim=0) # batch size
123 | k = x.size(dim=1) # input features
124 | n = self.output_features # output features
125 | return BinaryLinearForward.apply(x, self.opt_weight, m, n, k)
126 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/binary/cuda/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .bmm import BMM
4 |
5 | if torch.cuda.is_available():
6 | from .layer import BinaryLinearCuda
7 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/binary/cuda/binary_linear_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | // CUDA forward declarations
5 |
6 | /**
7 | * Performs a forward pass of the binary linear layer using CUDA.
8 | *
9 | * This function facilitates binary linear operations with support for different data types
10 | * for input and weight tensors, including floating point and quantized types. It leverages
11 | * CUDA for efficient computation, especially suited for deep learning models running on GPU.
12 | *
13 | * @param input The input tensor, which can be of type torch::kFloat32 (float), torch::kBFloat16 (bfloat16),
14 | * or torch::kHalf (half).
15 | * @param weight The weight tensor, which supports torch::kInt8 (int8), torch::kFloat32 (float),
16 | * and torch::kUInt8 (uint8) data types.
17 | * @param bmm_type An integer specifying the type of binary matrix multiplication to perform.
18 | * This parameter allows for customization of the operation based on the model's requirements.
19 | * @param transpose A boolean indicating whether the weight matrix should be transposed during the operation.
20 | *
21 | * @return A tensor containing the result of the binary linear operation.
22 | *
23 | * @note This function dynamically dispatches to specialized template functions based on the data types of
24 | * the input and weight tensors. It supports a combination of float, bfloat16, half, int8, and uint8
25 | * types, ensuring flexibility in handling various neural network architectures.
26 | * If the data type of the input or weight tensor is not supported, the function will terminate the
27 | * program and print an error message.
28 | */
29 | torch::Tensor binary_linear_cuda_forward(
30 | torch::Tensor input,
31 | torch::Tensor weights,
32 | int bmm_type,
33 | bool transpose);
34 |
35 |
36 | /**
37 | * Converts a given weight tensor to its binary representation based on the specified data type.
38 | *
39 | * This function supports weight tensors of different data types (int8, float, bfloat16, and half)
40 | * and converts them to a binary format suitable for certain binary matrix multiplication (BMM) operations.
41 | * The conversion process is dependent on the data type of the input tensor and whether the tensor
42 | * should be transposed as part of the conversion.
43 | *
44 | * @param weight The input weight tensor to be converted to binary format.
45 | * @param bmm_type An integer specifying the type of binary matrix multiplication operation.
46 | * This parameter can influence how the binary conversion is performed.
47 | * @param transpose A boolean indicating whether the weight tensor should be transposed
48 | * as part of the conversion process.
49 | * @return torch::Tensor A tensor containing the binary representation of the input weight tensor.
50 | * The specific format of the binary representation is determined by the
51 | * data type of the input tensor.
52 | *
53 | * @note This function is templated to handle different data types of the input tensor by
54 | * calling the appropriate specialized version of the _get_binary_weight_cuda function.
55 | * If the data type of the input tensor is not supported, the function prints an error message
56 | * and exits the program.
57 | */
58 | torch::Tensor get_binary_weight_cuda(
59 | torch::Tensor weights,
60 | int bmm_type,
61 | bool transpose);
62 |
63 |
64 | /**
65 | * Performs binary matrix multiplication on CUDA using specified data types.
66 | *
67 | * This function dispatches the binary matrix multiplication operation to specialized
68 | * CUDA kernels based on the data type of the input tensors. It supports int8, float32,
69 | * bfloat16, and half (float16) data types. The function checks if the data types of both
70 | * input tensors match and then calls the appropriate templated CUDA kernel function.
71 | *
72 | * @param x A torch::Tensor representing the first matrix in the multiplication.
73 | * @param y A torch::Tensor representing the second matrix in the multiplication.
74 | * @param bmm_type An integer indicating the type of binary matrix multiplication to perform.
75 | *
76 | * @return A torch::Tensor containing the result of the binary matrix multiplication.
77 | *
78 | * @throws std::runtime_error If the input tensors have different data types or if an unsupported
79 | * data type is provided.
80 | */
81 | torch::Tensor binary_mm_cuda(
82 | torch::Tensor x,
83 | torch::Tensor y,
84 | int bmm_type);
85 |
86 | // C++ interface
87 |
88 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
89 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
90 | #define CHECK_INPUT(x) CHECK_CUDA(x);
91 |
92 | torch::Tensor binary_linear_forward(
93 | torch::Tensor input,
94 | torch::Tensor weights,
95 | int bmm_type,
96 | bool transpose) {
97 | CHECK_INPUT(input);
98 | CHECK_INPUT(weights);
99 | return binary_linear_cuda_forward(input, weights, bmm_type, transpose);
100 | }
101 |
102 | torch::Tensor get_binary_weight(
103 | torch::Tensor weights,
104 | int bmm_type,
105 | bool transpose) {
106 | CHECK_INPUT(weights);
107 | return get_binary_weight_cuda(weights, bmm_type, transpose);
108 | }
109 |
110 | torch::Tensor binary_linear_mm(
111 | torch::Tensor x,
112 | torch::Tensor y,
113 | int bmm_type) {
114 | CHECK_INPUT(x);
115 | CHECK_INPUT(y);
116 | return binary_mm_cuda(x, y, bmm_type);
117 | }
118 |
119 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
120 | m.def("forward", &binary_linear_forward, "binary linear forward (CUDA)");
121 | m.def("w_pack", &get_binary_weight, "get linear binary weight (CUDA)");
122 | m.def("mm", &binary_linear_mm, "binary linear mm (CUDA)");
123 | }
124 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/binary/cuda/bmm.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class BMM(Enum):
5 | """
6 | Enumeration for selecting the Bit-Matrix-Multiplication (BMM) kernel to be used during operations.
7 | This allows for the choice of different underlying implementations based on the requirements or
8 | optimizations desired for specific hardware or computational constraints.
9 |
10 | Attributes:
11 | BSTC32: Software-based Tensor Core implementation. This option utilizes a software-level implementation
12 | to simulate tensor core operations, potentially offering more flexibility at the cost of raw performance.
13 | BTC32: Bit-Matrix-Multiplication using NVIDIA Tensor Cores. This leverages hardware tensor cores for
14 | accelerated computation, suitable for NVIDIA GPUs that support tensor core operations, offering
15 | high performance for matrix multiplications.
16 | ADAPTIVE: Automatically selects the best combination of kernel implementations based on the specific dimension
17 | constraints of the inputs and weights. This option aims to optimize performance by considering
18 | the characteristics of the computation and available hardware capabilities.
19 |
20 | The choice of kernel can significantly affect the performance and efficiency of operations that involve
21 | matrix multiplications, especially in deep learning models where such operations are prevalent.
22 | """
23 | BSTC32 = 1 # software based tensor core implementation
24 | BTC32 = 2 # Bit-Matrix-Multiplication using NVIDIA Tensor Cores
25 | ADAPTIVE = 3 # Chooses the best kernel based on input and weight dimensions
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/binary/cuda/extension.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension
4 |
5 | CUDA_REQUIRED = True
6 |
7 |
8 | def get_ext(path: Path):
9 | """
10 | Get the CUDA extension for binary linear operations.
11 |
12 | Args:
13 | path (Path): The path to the CUDA extension directory.
14 |
15 | Returns:
16 | Any: The CUDA extension module.
17 | """
18 | return get_cuda_extension(
19 | path,
20 | relative_name='binary_linear_cuda',
21 | relative_sources=[
22 | 'binary_linear_cuda.cpp',
23 | 'binary_linear_cuda_kernel.cu',
24 | ]
25 | )
26 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/binary/cutlass/__init__.py:
--------------------------------------------------------------------------------
1 | from .layer import BinaryLinearCutlass, BinaryMatMul
2 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/binary/cutlass/extension.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension
4 |
5 | CUDA_REQUIRED = True
6 | CUTLASS_REQUIRED = True
7 |
8 |
9 | def get_ext(path: Path):
10 | """
11 | Get the CUDA extension for binary linear cutlass.
12 |
13 | Args:
14 | path (Path): The path to the CUDA extension.
15 |
16 | Returns:
17 | Extension: The CUDA extension for binary linear cutlass.
18 | """
19 | ext = get_cuda_extension(
20 | path,
21 | relative_name='binary_linear_cutlass',
22 | relative_sources=[
23 | 'binary_linear_cutlass.cpp',
24 | 'binary_linear_cutlass_kernel.cu',
25 | ]
26 | )
27 | ext.include_dirs.extend(['.'])
28 | return ext
29 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/layer.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import torch
4 | from bitorch import RuntimeMode
5 | from bitorch.layers.extensions import LayerRecipe
6 | from bitorch.layers.qlinear import QLinearImplementation, QLinearBase
7 |
8 | from .binary import BinaryLinear
9 | from .binary.layer import BinaryLinearBase
10 | from .nbit import nBitLinearBase
11 | from .qlinear_implementation import QLinearImplementationMixin
12 |
13 |
14 | @QLinearImplementation(RuntimeMode.INFERENCE_AUTO)
15 | class QLinearInf(QLinearImplementationMixin, BinaryLinearBase):
16 | """
17 | QLinearInf is a class for quantized linear layers optimized for inference.
18 | It inherits from QLinearImplementationMixin and BinaryLinearBase to utilize
19 | quantization functionalities and binary linear operations.
20 |
21 | This class specifically handles inference operations with quantized weights,
22 | potentially using different bit widths for activations and weights.
23 | """
24 | @classmethod
25 | def create_clone_from(cls, recipe: LayerRecipe, device: torch.device = None) -> Any:
26 | """
27 | Creates a clone of the layer from a given recipe, adjusting input feature dimensions
28 | and setting up quantization parameters based on the recipe's specifications.
29 |
30 | Args:
31 | recipe (LayerRecipe): A configuration object containing layer specifications.
32 | device (torch.device, optional): The device on which to create the layer. Defaults to None.
33 |
34 | Returns:
35 | Any: An instance of the cloned layer with quantization applied.
36 | """
37 | args = QLinearBase.get_args_as_kwargs(recipe)
38 | input_features, output_features = args["in_features"], args["out_features"]
39 | input_features //= 32
40 | new_layer = cls(
41 | input_features,
42 | output_features,
43 | device=device,
44 | a_bit=args["input_quantization"].bit_width,
45 | w_bit=args["input_quantization"].bit_width,
46 | )
47 | new_layer.set_weight_data(recipe.layer.weight.data.to(device=device))
48 | new_layer.generate_quantized_weight(qweight_only=True)
49 | return new_layer
50 |
51 | def __init__(
52 | self,
53 | input_features: int,
54 | out_features: int,
55 | device=None,
56 | a_bit: int = 1,
57 | w_bit: int = 1,
58 | bias=False,
59 | ) -> None:
60 | """
61 | Initializes the QLinearInf layer with specified input and output feature dimensions,
62 | quantization bit widths, and device. Currently, bias is not supported and must be False.
63 |
64 | Args:
65 | input_features (int): The dimension of input features after bit-packing.
66 | out_features (int): The dimension of output features (hidden states).
67 | device (optional): The device on which to initialize the layer. Defaults to None.
68 | a_bit (int, optional): Bit width for activation quantization. Defaults to 1.
69 | w_bit (int, optional): Bit width for weight quantization. Defaults to 1.
70 | bias (bool, optional): Indicates if bias is used. Currently must be False.
71 |
72 | Raises:
73 | AssertionError: If bias is set to True.
74 | """
75 | super().__init__(input_features, out_features, device)
76 | assert not bias, "currently QLinearInf only supports bias = False"
77 | self.layer = None
78 | if a_bit == 1 and w_bit == 1:
79 | self.layer = BinaryLinear(input_features, out_features, device=device)
80 | else:
81 | self.layer = nBitLinearBase(
82 | input_features, out_features, a_bit, w_bit, device
83 | )
84 |
85 | def prepare_params(self) -> None:
86 | """
87 | Prepares the parameters of the layer for quantization and inference,
88 | calling the corresponding method of the underlying binary or n-bit linear layer.
89 | """
90 | self.layer.prepare_params()
91 |
92 | def generate_quantized_weight(self, qweight_only: bool = False) -> None:
93 | """
94 | Generates and sets the quantized weights for the layer, optionally focusing
95 | only on the quantized weights without affecting the original weights.
96 |
97 | Args:
98 | qweight_only (bool, optional): If True, only quantized weights are generated. Defaults to False.
99 | """
100 | self.layer.generate_quantized_weight(qweight_only=qweight_only)
101 |
102 | def set_weight_data(self, x: torch.Tensor):
103 | """
104 | Sets the weight data for the layer.
105 |
106 | Args:
107 | x (torch.Tensor): The tensor containing the weight data.
108 | """
109 | self.layer.set_weight_data(x)
110 |
111 | def set_quantized_weight_data(self, x: torch.Tensor):
112 | """
113 | Sets the quantized weight data for the layer.
114 |
115 | Args:
116 | x (torch.Tensor): The tensor containing the quantized weight data.
117 | """
118 | self.layer.set_quantized_weight_data(x)
119 |
120 | @property
121 | def weight(self):
122 | """
123 | Property to access the weight tensor of the layer.
124 |
125 | Returns:
126 | torch.Tensor: The weight tensor.
127 | """
128 | return self.layer.weight
129 |
130 | @property
131 | def opt_weight(self):
132 | """
133 | Property to access the optimized weight tensor of the layer, which may
134 | include quantized or otherwise transformed weights for efficient inference.
135 |
136 | Returns:
137 | torch.Tensor: The optimized weight tensor.
138 | """
139 | return self.layer.opt_weight
140 |
141 | def forward(self, x: torch.Tensor) -> torch.Tensor:
142 | """
143 | Forwards the input tensor x through the quantized linear layer, performing
144 | the linear operation with quantized weights.
145 |
146 | Args:
147 | x (torch.Tensor): The input tensor to forward through the layer.
148 |
149 | Returns:
150 | torch.Tensor: The output tensor after passing through the layer.
151 | """
152 | return self.layer(x)
153 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/__init__.py:
--------------------------------------------------------------------------------
1 | from .layer import nBitLinearBase, MPQLinearBase, MPQWeightParameter, nBitLinearParameter
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cuda/__init__.py:
--------------------------------------------------------------------------------
1 | from .mpq_layer import MPQLinearCuda
2 | from .mbwq_layer import MBWQLinearCuda
3 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cuda/exl2/compat.cuh:
--------------------------------------------------------------------------------
1 | #ifndef _compat_cuh
2 | #define _compat_cuh
3 |
4 | // atomicAdd for half types, to support CC < 7.x
5 |
6 | __device__ __forceinline__ void atomicAdd_half(half* address, half val)
7 | {
8 | unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
9 | unsigned int old = *address_as_ui;
10 | unsigned int assumed;
11 |
12 | do
13 | {
14 | assumed = old;
15 | __half_raw hsum;
16 | hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
17 | half tmpres = __hadd(hsum, val);
18 | hsum = __half_raw(tmpres);
19 | old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
20 | old = atomicCAS(address_as_ui, assumed, old);
21 | }
22 | while (assumed != old);
23 | }
24 |
25 | // atomicAdd for half2 types
26 |
27 | __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
28 | {
29 | unsigned int* address_as_ui = (unsigned int*)address;
30 | unsigned int old = *address_as_ui;
31 | unsigned int assumed;
32 | do
33 | {
34 | assumed = old;
35 | half2 old_val = *((half2*)&old);
36 | half2 new_val = __hadd2(old_val, val);
37 | old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
38 | }
39 | while (assumed != old);
40 | }
41 |
42 | //
43 |
44 | #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
45 | #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
46 |
47 | __device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
48 |
49 | #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
50 | __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
51 | #endif
52 |
53 | #endif
54 | #endif
55 |
56 | #endif
57 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cuda/exl2/config.h:
--------------------------------------------------------------------------------
1 | #ifndef _config_h
2 | #define _config_h
3 |
4 | #define MAX_Q_GEMM_ROWS 32
5 | #define MAX_Q_GEMM_ROWS_KERNEL 4
6 | #define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS
7 |
8 | /*
9 | The macros defined by QMODE_*BIT determine whether qweight should be rearranged. This rearrangement should,
10 | to some extent, enhance computational efficiency. Please note that in `quant/qdq_*.cuh`,
11 | two methods of dequantization are implemented for each bit level. When QMODE_*BIT=1,
12 | the rearrangement method `shuffle_*bit_*()` and the corresponding `dequant_*bit_*()` method are implemented.
13 | Therefore, it is important to note that if QMODE_*BIT=1, the qweight tensor needs to be rearranged by calling
14 | the `shuffle_*bit_*()` method.
15 | */
16 | #define QMODE_2BIT 0
17 | #define QMODE_3BIT 0
18 | #define QMODE_4BIT 0
19 | #define QMODE_5BIT 0
20 | #define QMODE_6BIT 0
21 | #define QMODE_8BIT 0
22 |
23 | #endif
24 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cuda/exl2/matrix_view.cuh:
--------------------------------------------------------------------------------
1 | #ifndef _matrix_view_cuh
2 | #define _matrix_view_cuh
3 |
4 | #include
5 | #include
6 |
7 | #include "quant/qdq_util.cuh"
8 |
9 | class MatrixView_half
10 | {
11 | public:
12 | const half* data;
13 | const int height;
14 | const int width;
15 |
16 | __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
17 | : data(data), height(height), width(width)
18 | { }
19 |
20 | __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
21 | __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
22 | __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
23 | __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
24 |
25 | __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
26 | {
27 | half2* ptr = (half2*) item_ptr(row, column);
28 | half2 i01 = ptr[0];
29 | half2 i23 = ptr[1];
30 | items[0] = __low2half(i01);
31 | items[1] = __high2half(i01);
32 | items[2] = __low2half(i23);
33 | items[3] = __high2half(i23);
34 | }
35 | __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
36 | {
37 | half2* ptr = (half2*)item_ptr(row, column);
38 | half2 i01 = ptr[0];
39 | half2 i23 = ptr[1];
40 | items[0] = __half2float(__low2half(i01));
41 | items[1] = __half2float(__high2half(i01));
42 | items[2] = __half2float(__low2half(i23));
43 | items[3] = __half2float(__high2half(i23));
44 | }
45 |
46 | __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
47 | {
48 | half2* ptr = (half2*)item_ptr(row, column);
49 | half2 i01 = ptr[0];
50 | half2 i23 = ptr[1];
51 | items[0] = __half2half2(__low2half(i01));
52 | items[1] = __half2half2(__high2half(i01));
53 | items[2] = __half2half2(__low2half(i23));
54 | items[3] = __half2half2(__high2half(i23));
55 | }
56 | };
57 |
58 | class MatrixView_half_rw
59 | {
60 | public:
61 | half* data;
62 | const int height;
63 | const int width;
64 |
65 | __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
66 | : data(data), height(height), width(width)
67 | { }
68 |
69 | __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
70 | __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
71 | __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
72 | __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
73 | __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
74 |
75 | __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
76 | {
77 | half2 v01 = __halves2half2(v0, v1);
78 | half2 v23 = __halves2half2(v2, v3);
79 | half2* ptr = (half2*) item_ptr(row, column);
80 | ptr[0] = v01;
81 | ptr[1] = v23;
82 | }
83 | };
84 |
85 | class MatrixView_q4_row
86 | {
87 | public:
88 | const uint32_t* data;
89 | const int height;
90 | const int width;
91 |
92 | __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
93 | : data(data), height(height), width(width)
94 | { }
95 |
96 | __device__ __forceinline__ int item(int row, int column) const
97 | {
98 | int shift = (column & 0x07) * 4;
99 | return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
100 | }
101 |
102 | __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
103 | {
104 | int shift = (column & 0x07) * 4;
105 | uint32_t d = data[row * width / 8 + column / 8] >> shift;
106 | items[0] = d & 0x0f;
107 | items[1] = (d >> 4) & 0x0f;
108 | }
109 |
110 | __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
111 | {
112 | int shift = (column & 0x07) * 4;
113 | uint32_t d = data[row * width / 8 + column / 8] >> shift;
114 | items[0] = d & 0x0f;
115 | items[1] = (d >> 4) & 0x0f;
116 | items[2] = (d >> 8) & 0x0f;
117 | items[3] = (d >> 12) & 0x0f;
118 | }
119 | };
120 |
121 | #endif
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cuda/exl2/quant/qdq_2.cuh:
--------------------------------------------------------------------------------
1 | #ifndef _qdq_2_cuh
2 | #define _qdq_2_cuh
3 |
4 | #include "qdq_util.cuh"
5 | #include "../config.h"
6 |
7 | #if QMODE_2BIT == 1
8 |
9 | // Permutation:
10 | //
11 | // ffddbb99 77553311 eeccaa88 66442200
12 |
13 | __forceinline__ __device__ void shuffle_2bit_16
14 | (
15 | uint32_t* q,
16 | int stride
17 | )
18 | {
19 | uint32_t qa = q[0];
20 | uint32_t qb = 0;
21 |
22 | #pragma unroll
23 | for (int i = 0; i < 8; i++)
24 | {
25 | uint32_t qa0 = qa & 0x03;
26 | uint32_t qa1 = (qa & 0x0c) >> 2;
27 | qa >>= 4;
28 | qb |= (qa1 << (i * 2 + 16));
29 | qb |= (qa0 << (i * 2));
30 | }
31 | q[0] = qb;
32 | }
33 |
34 | __forceinline__ __device__ void dequant_2bit_16
35 | (
36 | const uint32_t q_0,
37 | half2 (&dq)[8],
38 | int stride
39 | )
40 | {
41 | const uint32_t c0 = 0x64006400;
42 | const half y4_ = __float2half_rn(1.0f / 4.0f);
43 | const half y16_ = __float2half_rn(1.0f / 16.0f);
44 | const half y64_ = __float2half_rn(1.0f / 64.0f);
45 | const half2 y4 = __halves2half2(y4_, y4_);
46 | const half2 y16 = __halves2half2(y16_, y16_);
47 | const half2 y64 = __halves2half2(y64_, y64_);
48 | const half z1_ = __float2half_rn(-1024.0f);
49 | const half z4_ = __float2half_rn(-1024.0f / 4.0f);
50 | const half z16_ = __float2half_rn(-1024.0f / 16.0f);
51 | const half z64_ = __float2half_rn(-1024.0f / 64.0f);
52 | const half2 z1 = __halves2half2(z1_, z1_);
53 | const half2 z4 = __halves2half2(z4_, z4_);
54 | const half2 z16 = __halves2half2(z16_, z16_);
55 | const half2 z64 = __halves2half2(z64_, z64_);
56 |
57 | uint32_t qa = q_0;
58 | half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
59 | half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
60 | half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
61 | half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
62 | qa >>= 8;
63 | half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
64 | half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
65 | half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
66 | half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
67 |
68 | dq[0] = __hadd2(q0.as_half2, z1);
69 | dq[1] = __hfma2(q1.as_half2, y4, z4);
70 | dq[2] = __hfma2(q2.as_half2, y16, z16);
71 | dq[3] = __hfma2(q3.as_half2, y64, z64);
72 | dq[4] = __hadd2(q4.as_half2, z1);
73 | dq[5] = __hfma2(q5.as_half2, y4, z4);
74 | dq[6] = __hfma2(q6.as_half2, y16, z16);
75 | dq[7] = __hfma2(q7.as_half2, y64, z64);
76 | }
77 |
78 | #else
79 |
80 | __forceinline__ __device__ void shuffle_2bit_16
81 | (
82 | uint32_t* q,
83 | int stride
84 | )
85 | {
86 | }
87 |
88 | __forceinline__ __device__ void dequant_2bit_16
89 | (
90 | const uint32_t q_0,
91 | half2 (&dq)[8],
92 | int stride
93 | )
94 | {
95 | half dqh[16];
96 | for (int i = 0; i < 16; i++)
97 | dqh[i] = __uint2half_rn(exb(q_0, i * 2, 0x03));
98 |
99 | for (int i = 0; i < 8; i++)
100 | dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
101 | }
102 |
103 | #endif
104 |
105 | #endif
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cuda/exl2/quant/qdq_3.cuh:
--------------------------------------------------------------------------------
1 | #ifndef _qdq_3_cuh
2 | #define _qdq_3_cuh
3 |
4 | #include "qdq_util.cuh"
5 | #include "../config.h"
6 |
7 | #if QMODE_3BIT == 1
8 |
9 | // Permutation:
10 | //
11 | // v9997775 55333111 u8886664 44222000 (u, v lsb)
12 | // vjjjhhhf ffdddbbb uiiiggge eecccaaa
13 | // vtttrrrp ppnnnlll usssqqqo oommmkkk
14 |
15 | __forceinline__ __device__ void shuffle_3bit_32
16 | (
17 | uint32_t* q,
18 | int stride
19 | )
20 | {
21 | uint32_t qa = q[0 * stride];
22 | uint32_t qb = q[1 * stride];
23 | uint32_t qc = q[2 * stride];
24 |
25 | // qa: aa999888 77766655 54443332 22111000
26 | // qb: lkkkjjji iihhhggg fffeeedd dcccbbba
27 | // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
28 |
29 | uint32_t qd = qc >> 26;
30 | qc <<= 4;
31 | qc |= qb >> 28;
32 | qb <<= 2;
33 | qb |= qa >> 30;
34 |
35 | // qa: ..999888 77766655 54443332 22111000
36 | // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
37 | // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
38 | // qd: vvvuuu
39 |
40 | uint32_t za = 0;
41 | uint32_t zb = 0;
42 | uint32_t zc = 0;
43 |
44 | for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
45 | for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
46 | for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
47 |
48 | // za: 9997775 55333111 8886664 44222000
49 | // zb: jjjhhhf ffdddbbb iiiggge eecccaaa
50 | // zc: tttrrrp ppnnnlll sssqqqo oommmkkk
51 | // qd: vvvuuu
52 |
53 | za |= ((qd & 0x01) >> 0) << 15;
54 | zb |= ((qd & 0x02) >> 1) << 15;
55 | zc |= ((qd & 0x04) >> 2) << 15;
56 | za |= ((qd & 0x08) >> 3) << 31;
57 | zb |= ((qd & 0x10) >> 4) << 31;
58 | zc |= ((qd & 0x20) >> 5) << 31;
59 |
60 | // za: v9997775 55333111 u8886664 44222000 (u, v lsb)
61 | // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
62 | // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
63 |
64 | q[0 * stride] = za;
65 | q[1 * stride] = zb;
66 | q[2 * stride] = zc;
67 | }
68 |
69 | __forceinline__ __device__ void dequant_3bit_32
70 | (
71 | const uint32_t q_0,
72 | const uint32_t q_1,
73 | const uint32_t q_2,
74 | half2 (&dq)[16],
75 | int stride
76 | )
77 | {
78 | const uint32_t c0 = 0x64006400;
79 | const half y8_ = __float2half_rn(1.0f / 8.0f);
80 | const half y64_ = __float2half_rn(1.0f / 64.0f);
81 | const half2 y8 = __halves2half2(y8_, y8_);
82 | const half2 y64 = __halves2half2(y64_, y64_);
83 | const half z1_ = __float2half_rn(-1024.0f);
84 | const half z8_ = __float2half_rn(-1024.0f / 8.0f);
85 | const half z64_ = __float2half_rn(-1024.0f / 64.0f);
86 | const half2 z1 = __halves2half2(z1_, z1_);
87 | const half2 z8 = __halves2half2(z8_, z8_);
88 | const half2 z64 = __halves2half2(z64_, z64_);
89 |
90 | uint32_t qa = q_0;
91 | uint32_t qb = q_1;
92 | uint32_t qc = q_2;
93 |
94 | half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
95 | half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
96 | qa >>= 6;
97 | half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
98 | half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
99 | half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
100 | qa >>= 9;
101 | qa &= 0x00010001;
102 | half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
103 | half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
104 | qb >>= 6;
105 | half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
106 | half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
107 | half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
108 | qb >>= 8;
109 | qb &= 0x00020002;
110 | half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
111 | half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
112 | qc >>= 6;
113 | half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
114 | half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
115 | half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
116 | qc >>= 7;
117 | qc &= 0x00040004;
118 | half2_uint32 q15((qa | qb | qc) | c0);
119 |
120 | dq[ 0] = __hadd2( q0.as_half2, z1);
121 | dq[ 1] = __hfma2( q1.as_half2, y8, z8);
122 | dq[ 2] = __hadd2( q2.as_half2, z1);
123 | dq[ 3] = __hfma2( q3.as_half2, y8, z8);
124 | dq[ 4] = __hfma2( q4.as_half2, y64, z64);
125 | dq[ 5] = __hadd2( q5.as_half2, z1);
126 | dq[ 6] = __hfma2( q6.as_half2, y8, z8);
127 | dq[ 7] = __hadd2( q7.as_half2, z1);
128 | dq[ 8] = __hfma2( q8.as_half2, y8, z8);
129 | dq[ 9] = __hfma2( q9.as_half2, y64, z64);
130 | dq[10] = __hadd2(q10.as_half2, z1);
131 | dq[11] = __hfma2(q11.as_half2, y8, z8);
132 | dq[12] = __hadd2(q12.as_half2, z1);
133 | dq[13] = __hfma2(q13.as_half2, y8, z8);
134 | dq[14] = __hfma2(q14.as_half2, y64, z64);
135 | dq[15] = __hadd2(q15.as_half2, z1);
136 | }
137 |
138 | #else
139 |
140 | __forceinline__ __device__ void shuffle_3bit_32
141 | (
142 | uint32_t* q,
143 | int stride
144 | )
145 | {
146 | }
147 |
148 | __forceinline__ __device__ void dequant_3bit_32
149 | (
150 | const uint32_t q_0,
151 | const uint32_t q_1,
152 | const uint32_t q_2,
153 | half2 (&dq)[16],
154 | int stride
155 | )
156 | {
157 | half dqh[32];
158 | for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 0);
159 | dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 0);
160 | for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 0);
161 | dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 0);
162 | for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 0);
163 |
164 | for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
165 | }
166 |
167 | #endif
168 |
169 | #endif
170 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cuda/exl2/quant/qdq_6.cuh:
--------------------------------------------------------------------------------
1 | #ifndef _qdq_6_cuh
2 | #define _qdq_6_cuh
3 |
4 | #include "qdq_util.cuh"
5 | #include "../config.h"
6 |
7 | #if QMODE_6BIT == 1
8 |
9 | // Permutation:
10 | //
11 | // dddd3333 33111111 cccc2222 22000000
12 | // ffff7777 77555555 eeee6666 66444444
13 | // ffddbbbb bb999999 eeccaaaa aa888888
14 |
15 | __forceinline__ __device__ void shuffle_6bit_16
16 | (
17 | uint32_t* q,
18 | int stride
19 | )
20 | {
21 | uint32_t qa = q[0 * stride];
22 | uint32_t qb = q[1 * stride];
23 | uint32_t qc = q[2 * stride];
24 |
25 | // qa: 55444444 33333322 22221111 11000000
26 | // qb: aaaa9999 99888888 77777766 66665555
27 | // qc: ffffffee eeeedddd ddcccccc bbbbbbaa
28 |
29 | uint32_t q00 = (qa ) & 0b111111;
30 | uint32_t q01 = (qa >> 6) & 0b111111;
31 | uint32_t q02 = (qa >> 12) & 0b111111;
32 | uint32_t q03 = (qa >> 18) & 0b111111;
33 | uint32_t q04 = (qa >> 24) & 0b111111;
34 | uint32_t q05 = ((qa >> 30) & 0b11) | ((qb & 0b1111) << 2);
35 | uint32_t q06 = (qb >> 4) & 0b111111;
36 | uint32_t q07 = (qb >> 10) & 0b111111;
37 | uint32_t q08 = (qb >> 16) & 0b111111;
38 | uint32_t q09 = (qb >> 22) & 0b111111;
39 | uint32_t q0a = ((qb >> 28) & 0b1111) | ((qc & 0b11) << 4);
40 | uint32_t q0b = (qc >> 2) & 0b111111;
41 | uint32_t q0c = (qc >> 8) & 0b111111;
42 | uint32_t q0d = (qc >> 14) & 0b111111;
43 | uint32_t q0e = (qc >> 20) & 0b111111;
44 | uint32_t q0f = (qc >> 26) & 0b111111;
45 |
46 | qa = q00 | (q01 << 16) | (q02 << 6) | (q03 << 22);
47 | qb = q04 | (q05 << 16) | (q06 << 6) | (q07 << 22);
48 | qc = q08 | (q09 << 16) | (q0a << 6) | (q0b << 22);
49 |
50 | // qa: ....3333 33111111 ....2222 22000000
51 | // qb: ....7777 77555555 ....6666 66444444
52 | // qc: ....bbbb bb999999 ....aaaa aa888888
53 |
54 | qa |= (q0c & 0b001111) << 12;
55 | qc |= (q0c & 0b110000) << 8;
56 | qa |= (q0d & 0b001111) << 28;
57 | qc |= (q0d & 0b110000) << 24;
58 |
59 | // qa: dddd3333 33111111 cccc2222 22000000
60 | // qb: ....7777 77555555 ....6666 66444444
61 | // qc: ..ddbbbb bb999999 ..ccaaaa aa888888
62 |
63 | qb |= (q0e & 0b001111) << 12;
64 | qc |= (q0e & 0b110000) << 10;
65 | qb |= (q0f & 0b001111) << 28;
66 | qc |= (q0f & 0b110000) << 26;
67 |
68 | // qa: dddd3333 33111111 cccc2222 22000000
69 | // qb: ffff7777 77555555 eeee6666 66444444
70 | // qc: ffddbbbb bb999999 eeccaaaa aa888888
71 |
72 | q[0 * stride] = qa;
73 | q[1 * stride] = qb;
74 | q[2 * stride] = qc;
75 | }
76 |
77 | __forceinline__ __device__ void dequant_6bit_16
78 | (
79 | const uint32_t q_0,
80 | const uint32_t q_1,
81 | const uint32_t q_2,
82 | half2 (&dq)[8],
83 | int stride
84 | )
85 | {
86 | const uint32_t c0 = 0x64006400;
87 | const half z1_ = __float2half_rn(-1024.0f);
88 | const half2 z1 = __halves2half2(z1_, z1_);
89 |
90 | uint32_t qa = q_0;
91 | uint32_t qb = q_1;
92 | uint32_t qc = q_2;
93 |
94 | half2_uint32 q0((qa & 0x003f003f) | c0); // half2(q[ 0], q[ 1]) + 1024
95 | qa >>= 6;
96 | half2_uint32 q1((qa & 0x003f003f) | c0); // half2(q[ 2], q[ 3]) + 1024
97 | qa >>= 6;
98 | half2_uint32 q2((qb & 0x003f003f) | c0); // half2(q[ 4], q[ 5]) + 1024
99 | qb >>= 6;
100 | half2_uint32 q3((qb & 0x003f003f) | c0); // half2(q[ 6], q[ 7]) + 1024
101 | qb >>= 6;
102 | half2_uint32 q4((qc & 0x003f003f) | c0); // half2(q[ 8], q[ 9]) + 1024
103 | qc >>= 6;
104 | half2_uint32 q5((qc & 0x003f003f) | c0); // half2(q[10], q[11]) + 1024
105 | qc >>= 2;
106 | half2_uint32 q6((qa & 0x000f000f) | (qc & 0x00300030) | c0); // half2(q[12], q[13]) + 1024
107 | qc >>= 2;
108 | half2_uint32 q7((qb & 0x000f000f) | (qc & 0x00300030) | c0); // half2(q[14], q[15]) + 1024
109 |
110 | dq[0] = __hadd2(q0.as_half2, z1);
111 | dq[1] = __hadd2(q1.as_half2, z1);
112 | dq[2] = __hadd2(q2.as_half2, z1);
113 | dq[3] = __hadd2(q3.as_half2, z1);
114 | dq[4] = __hadd2(q4.as_half2, z1);
115 | dq[5] = __hadd2(q5.as_half2, z1);
116 | dq[6] = __hadd2(q6.as_half2, z1);
117 | dq[7] = __hadd2(q7.as_half2, z1);
118 | }
119 |
120 | #else
121 |
122 | __forceinline__ __device__ void shuffle_6bit_16
123 | (
124 | uint32_t* q,
125 | int stride
126 | )
127 | {
128 | }
129 |
130 | __forceinline__ __device__ void dequant_6bit_16
131 | (
132 | const uint32_t q_0,
133 | const uint32_t q_1,
134 | const uint32_t q_2,
135 | half2 (&dq)[8],
136 | int stride
137 | )
138 | {
139 | half dqh[16];
140 | for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 0);
141 | dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 0);
142 | for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 0);
143 | dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 0);
144 | for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 0);
145 |
146 | for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
147 | }
148 |
149 | #endif
150 |
151 | #endif
152 |
153 |
154 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cuda/exl2/quant/qdq_8.cuh:
--------------------------------------------------------------------------------
1 | #ifndef _qdq_8_cuh
2 | #define _qdq_8_cuh
3 |
4 | #include "qdq_util.cuh"
5 | #include "../config.h"
6 |
7 | #if QMODE_8BIT == 1
8 |
9 | // Not implemented
10 |
11 | #else
12 |
13 | __forceinline__ __device__ void shuffle_8bit_4
14 | (
15 | uint32_t* q,
16 | int stride
17 | )
18 | {
19 | }
20 |
21 | __forceinline__ __device__ void dequant_8bit_8
22 | (
23 | const uint32_t q_0,
24 | const uint32_t q_1,
25 | half2 (&dq)[4],
26 | int stride
27 | )
28 | {
29 | half dqh[8];
30 | for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 0);
31 | for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 0);
32 |
33 | for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
34 | }
35 |
36 | #endif
37 |
38 | #endif
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cuda/exl2/quant/qdq_util.cuh:
--------------------------------------------------------------------------------
1 | #ifndef _qdq_util_cuh
2 | #define _qdq_util_cuh
3 |
4 | union half2_uint32
5 | {
6 | uint32_t as_uint32;
7 | half2 as_half2;
8 | __device__ half2_uint32(uint32_t val) : as_uint32(val) {}
9 | __device__ half2_uint32(half2 val) : as_half2(val) {}
10 | __device__ half2_uint32() : as_uint32(0) {}
11 | };
12 |
13 | union half_uint16
14 | {
15 | uint16_t as_uint16;
16 | half as_half;
17 | __device__ half_uint16(uint16_t val) : as_uint16(val) {}
18 | __device__ half_uint16(half val) : as_half(val) {}
19 | __device__ half_uint16() : as_uint16(0) {}
20 | };
21 |
22 | // Max_scale premultiplied by 1/256
23 |
24 | __forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
25 | {
26 | int qs_i = qs + 1;
27 | half qs_h = __int2half_rn(qs_i * qs_i);
28 | qs_h = __hmul(qs_h, max_scale);
29 | return qs_h;
30 | }
31 |
32 | __forceinline__ __device__ half dq_scale_q_zero(const int qs, const half scale, const half zero)
33 | {
34 | half qs_h = __int2half_rn(qs);
35 | // qs_h = __hfma(qs_h, scale, __hneg(zero));
36 | qs_h = __hadd(qs_h, __hneg(zero));
37 | qs_h = __hmul(qs_h, scale);
38 | return qs_h;
39 | }
40 |
41 | __forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
42 | {
43 | return __hmul(__int2half_rn(q - qzero), scale);
44 | }
45 |
46 | __forceinline__ __device__ half dq_ns(const int q, const int qzero)
47 | {
48 | //return __hsub(__int2half_rn(q), __int2half_rn(qzero));
49 | return __int2half_rn(q - qzero);
50 | }
51 |
52 | __forceinline__ __device__ unsigned int as_unsigned(int i) {
53 | return *reinterpret_cast(&i);
54 | }
55 |
56 | __forceinline__ __device__ unsigned int exb(const uint32_t q, const int shift, const int mask)
57 | {
58 | return as_unsigned((q >> shift) & mask);
59 | }
60 |
61 | __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
62 | {
63 | return (int)(__funnelshift_rc(q0, q1, shift) & mask);
64 | }
65 |
66 | #endif
67 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cuda/exl2/util.cuh:
--------------------------------------------------------------------------------
1 | #ifndef _util_cuh
2 | #define _util_cuh
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
11 |
12 | #define DBGS(__x) printf("%s\n", __x)
13 | #define DBGI(__x) printf("%s: %i\n", #__x, __x)
14 | #define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
15 | #define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
16 | #define DBGX(__x) printf("%s: %x\n", #__x, __x)
17 | #define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
18 | #define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
19 | #define DBGF(__x) printf("%s: %f\n", #__x, __x)
20 | #define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
21 | #define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
22 | #define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
23 | #define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
24 | #define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
25 |
26 | #define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
27 | #define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
28 |
29 | __forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
30 | {
31 | half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
32 | qs_h = __hmul(qs_h, qs_h);
33 | qs_h = __hmul(qs_h, max_scale);
34 | return qs_h;
35 | }
36 |
37 | __forceinline__ __device__ float clamp(float x, float a, float b)
38 | {
39 | return fmaxf(a, fminf(b, x));
40 | }
41 |
42 | #define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
43 | inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)
44 | {
45 | if (code != cudaSuccess)
46 | {
47 | fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line);
48 | if (abort) exit(code);
49 | }
50 | }
51 |
52 | void print_global_mem(const half* ptr, int rows, int columns, int stride);
53 |
54 | #endif
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cuda/extension.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension
4 |
5 | CUDA_REQUIRED = True
6 |
7 | def get_ext(path: Path):
8 | """
9 | Get the CUDA extension for quantized linear operations.
10 |
11 | Args:
12 | path (Path): The path to the CUDA extension directory.
13 |
14 | Returns:
15 | Any: The CUDA extension module.
16 | """
17 | ext = get_cuda_extension(
18 | path,
19 | relative_name='q_linear_cuda',
20 | relative_sources=[
21 | 'q_linear_cuda.cpp',
22 | 'mpq_linear_cuda_kernel.cu',
23 | 'mbwq_linear_cuda_kernel.cu',
24 | ]
25 | )
26 | ext.include_dirs.extend(['exl2'])
27 | return ext
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cutlass/__init__.py:
--------------------------------------------------------------------------------
1 | from .q4_layer import Q4LinearCutlass, Q4MatMul
2 | from .q8_layer import Q8LinearCutlass
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/cutlass/extension.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from bitorch_engine.utils.cuda_extension import get_cuda_extension
4 |
5 | CUDA_REQUIRED = True
6 | CUTLASS_REQUIRED = True
7 |
8 | def get_ext(path: Path):
9 | """
10 | Return CUDA extension for Q Linear Cutlass.
11 |
12 | Args:
13 | path (Path): Path to the directory containing the extension files.
14 |
15 | Returns:
16 | Extension: CUDA extension for Q Linear Cutlass.
17 | """
18 | return get_cuda_extension(
19 | path,
20 | relative_name='q_linear_cutlass',
21 | relative_sources=[
22 | 'q_linear_cutlass.cpp',
23 | 'q4_linear_cutlass_kernel.cu',
24 | 'q8_linear_cutlass_kernel.cu',
25 | ]
26 | )
27 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/mps/__init__.py:
--------------------------------------------------------------------------------
1 | from .mpq_layer import MPQLinearMlx
2 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/mps/extension.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from bitorch_engine.utils.mlx_extension import get_mlx_extension
4 |
5 | MLX_REQUIRED = True
6 |
7 | def get_ext(path: Path):
8 | return get_mlx_extension(
9 | path,
10 | relative_name='mpq_linear_mlx',
11 | relative_sources=[
12 | 'mlx_bindings.cpp',
13 | 'mpq_linear_mlx.cpp',
14 | ]
15 | )
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/mps/mlx_bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "mpq_linear_mlx.h"
3 |
4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
5 | {
6 | m.def("mpq_forward", &mpq_linear_mlx_forward, "Forward call for mlx acclelerated MPS quantized linear layer.");
7 | }
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/mps/mpq_linear_mlx.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | typedef float16_t half_t;
8 |
9 | template
10 | mlx::core::array mlx_from_torch_tensor(torch::Tensor tensor, mlx::core::Dtype valueDtype)
11 | {
12 | return mlx::core::array(
13 | reinterpret_cast(tensor.mutable_data_ptr()), // Pointer to the data.
14 | std::vector({ // Shape of the array as a vector of int32 (mlx requires 32 bit int).
15 | static_cast(tensor.size(0)),
16 | static_cast(tensor.size(1))
17 | }),
18 | valueDtype // Set dtype.
19 | );
20 | }
21 |
22 | torch::Tensor mlx_to_torch_tensor(mlx::core::array arr, torch::Dtype dtype)
23 | {
24 | return torch::from_blob(
25 | arr.data(), // Pointer to the data.
26 | { // Shape of the array as a vector of int64 (torch requires longs).
27 | static_cast(arr.shape()[0]),
28 | static_cast(arr.shape()[1])
29 | },
30 | torch::TensorOptions().dtype(dtype)); // Set dtype.
31 | }
32 |
33 |
34 | /**
35 | * Performs forward pass for mixed precision quantized (MPQ) linear multiplication using a custom MLX library.
36 | * This function processes an input tensor `x` with quantized weights `qweight`, applying scale factors `scales`
37 | * and zero points `zeros` for quantization, within specified group sizes and weight bit precision. It's designed
38 | * for CPU execution, enforcing inputs and computations to reside on the CPU.
39 | *
40 | * @param x The input tensor, expected to be on CPU, containing the features to be processed.
41 | * @param qweight The quantized weights tensor, also on CPU, to be used in the matrix multiplication.
42 | * @param scales The scale factors for the quantized weights, aiding in de-quantization to real values during the computation.
43 | * @param zeros The zero points for the quantized weights, also used in de-quantization process.
44 | * @param group_size The size of groups for performing the quantized matrix multiplication, affecting how inputs are partitioned.
45 | * @param w_bit The precision of the quantized weights in bits, supporting 2, 4, or 8 bits for the computation.
46 | *
47 | * @details
48 | * - The function begins by validating the input tensors to ensure they are CPU tensors, contiguous, and that the weight bits
49 | * are within the supported range.
50 | * - It then converts PyTorch tensors to MLX core arrays, using appropriate data types for the MLX backend computations.
51 | * - Quantized matrix multiplication is performed with the MLX library, leveraging the given scale factors, zero points,
52 | * group size, and weight bit precision.
53 | * - The MLX library uses lazy evaluation for computations; thus, the function explicitly evaluates the output before
54 | * converting it back to a PyTorch tensor.
55 | * - The result is a tensor of the computed output in float16 format, ready for further processing in PyTorch pipelines.
56 | *
57 | * @return A torch::Tensor containing the result of the mixed precision quantized linear multiplication in float16 format.
58 | *
59 | * @note
60 | * - This function is designed to operate exclusively on CPU tensors and will verify the device type of its inputs.
61 | * - It assumes the MLX backend for computation, which requires inputs to be converted to MLX core arrays.
62 | * - The use of float16 and uint32 data types for MLX core arrays is based on the precision and requirements of the inputs and computation.
63 | */
64 | torch::Tensor mpq_linear_mlx_forward(
65 | torch::Tensor x,
66 | torch::Tensor qweight,
67 | torch::Tensor scales,
68 | torch::Tensor zeros,
69 | int group_size,
70 | int w_bit)
71 | {
72 | // Check the input parameters
73 | TORCH_CHECK((w_bit == 2) || (w_bit == 4) || (w_bit == 8), "weights must have {2,4,8} bits.");
74 | TORCH_CHECK(x.device().type() == torch::kCPU, "x must be a CPU tensor. Cannot read from MPS.");
75 | TORCH_CHECK(qweight.device().type() == torch::kCPU, "qweight must be a CPU tensor. Cannot read from MPS.");
76 | TORCH_CHECK(scales.device().type() == torch::kCPU, "scales must be a CPU tensor. Cannot read from MPS.");
77 | TORCH_CHECK(zeros.device().type() == torch::kCPU, "zeros must be a CPU tensor. Cannot read from MPS.");
78 | TORCH_CHECK(x.is_contiguous(), "x must be contiguous.");
79 | TORCH_CHECK(qweight.is_contiguous(), "qweight must be contiguous.");
80 | TORCH_CHECK(scales.is_contiguous(), "scales must be contiguous.");
81 | TORCH_CHECK(zeros.is_contiguous(), "zeros must be contiguous.");
82 |
83 | mlx::core::array x_arr = mlx_from_torch_tensor(x, mlx::core::float16);
84 | mlx::core::array qweight_arr = mlx_from_torch_tensor(qweight, mlx::core::uint32);
85 | mlx::core::array scales_arr = mlx_from_torch_tensor(scales, mlx::core::float16);
86 | mlx::core::array zeros_arr = mlx_from_torch_tensor(zeros, mlx::core::float16);
87 | mlx::core::array output = mlx::core::quantized_matmul(
88 | x_arr,
89 | qweight_arr,
90 | scales_arr,
91 | zeros_arr,
92 | true,
93 | group_size,
94 | w_bit,
95 | mlx::core::default_stream(mlx::core::default_device())
96 | );
97 |
98 | mlx::core::eval(output); // Mlx uses lazy evaluation. Run and wait for the computation to finish.
99 | return mlx_to_torch_tensor(output, torch::kFloat16);
100 | }
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/nbit/mps/mpq_linear_mlx.h:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | torch::Tensor mpq_linear_mlx_forward(
4 | torch::Tensor x,
5 | torch::Tensor qweight,
6 | torch::Tensor scales,
7 | torch::Tensor zeros,
8 | int group_size,
9 | int w_bit);
--------------------------------------------------------------------------------
/bitorch_engine/layers/qlinear/qlinear_implementation.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import Tuple
3 |
4 | from bitorch.layers import CustomImplementationMixin, QLinearBase
5 | from bitorch.layers.extensions import LayerRecipe
6 |
7 |
8 | class QLinearImplementationMixin(CustomImplementationMixin, ABC):
9 | """
10 | A mixin class for QLinear layer implementations that provides common utility functions
11 | and checks specific to quantized linear layers. This mixin extends CustomImplementationMixin
12 | and implements the Abstract Base Class (ABC) to ensure that subclasses provide implementations
13 | for abstract methods defined in parent classes.
14 |
15 | The class provides a method to check if a given layer configuration can be cloned
16 | based on specific constraints related to quantized linear layers.
17 | """
18 | @classmethod
19 | def can_clone(cls, recipe: LayerRecipe) -> Tuple[bool, str]:
20 | """
21 | Determines if a QLinear layer described by the given recipe can be cloned.
22 |
23 | The method checks if the layer configuration meets certain criteria necessary
24 | for cloning a quantized linear layer. Specifically, it checks if the layer
25 | includes bias, and if the number of input features is divisible by 32, as
26 | these are current limitations for cloning such layers.
27 |
28 | Args:
29 | recipe (LayerRecipe): An object containing the configuration parameters
30 | of the layer to be cloned.
31 |
32 | Returns:
33 | Tuple[bool, str]: A tuple containing a boolean and a string. The boolean
34 | indicates whether the layer can be cloned (True if it can,
35 | False otherwise). The string provides a message explaining
36 | why the layer cannot be cloned if the boolean is False.
37 | """
38 | # Extract layer arguments from the recipe
39 | args = QLinearBase.get_args_as_kwargs(recipe)
40 | if args["bias"]:
41 | return False, f"bias is not yet supported."
42 | # Check if the number of input features is divisible by 32
43 | if args["in_features"] % 32 != 0:
44 | return False, f"in_features ({args['in_features']}) is not divisible by 32."
45 | # Layer can be cloned if all checks pass
46 | return True, ""
47 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qmha/__init__.py:
--------------------------------------------------------------------------------
1 | from .binary import *
2 |
--------------------------------------------------------------------------------
/bitorch_engine/layers/qmha/binary/__init__.py:
--------------------------------------------------------------------------------
1 | from .layer import BMHA
2 |
--------------------------------------------------------------------------------
/bitorch_engine/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .diode_beta import DiodeMix
--------------------------------------------------------------------------------
/bitorch_engine/optim/galore_projector.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the original Galore optimizer implementation from `Galore repo `_
3 | with `Apache-2.0 License `_
4 |
5 | @misc{zhao2024galore,
6 | title={GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection},
7 | author={Jiawei Zhao and Zhenyu Zhang and Beidi Chen and Zhangyang Wang and Anima Anandkumar and Yuandong Tian},
8 | year={2024},
9 | eprint={2403.03507},
10 | archivePrefix={arXiv},
11 | primaryClass={cs.LG}
12 | }
13 | """
14 |
15 | import torch
16 |
17 | class GaLoreProjector:
18 | def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'):
19 | self.rank = rank
20 | self.verbose = verbose
21 | self.update_proj_gap = update_proj_gap
22 | self.scale = scale
23 | self.ortho_matrix = None
24 | self.proj_type = proj_type
25 |
26 | def project(self, full_rank_grad, iter):
27 |
28 | if self.proj_type == 'std':
29 | if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
30 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
31 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
32 | low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
33 | else:
34 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
35 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
36 | low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
37 | elif self.proj_type == 'reverse_std':
38 | if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
39 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
40 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
41 | low_rank_grad = torch.matmul(self.ortho_matrix.t(),full_rank_grad)
42 | else:
43 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
44 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
45 | low_rank_grad = torch.matmul(full_rank_grad,self.ortho_matrix.t())
46 | elif self.proj_type == 'right':
47 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
48 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
49 | low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
50 | elif self.proj_type == 'left':
51 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
52 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
53 | low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
54 | elif self.proj_type == 'full':
55 | if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
56 | self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full')
57 | low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
58 |
59 | return low_rank_grad
60 |
61 | def project_back(self, low_rank_grad):
62 |
63 | if self.proj_type == 'std':
64 | if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
65 | full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
66 | else:
67 | full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
68 | elif self.proj_type == 'reverse_std':
69 | if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
70 | full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
71 | else:
72 | full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
73 | elif self.proj_type == 'right':
74 | full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
75 | elif self.proj_type == 'left':
76 | full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
77 | elif self.proj_type == 'full':
78 | full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
79 |
80 |
81 | return full_rank_grad * self.scale
82 |
83 |
84 | # svd decomposition
85 | def get_orthogonal_matrix(self, weights, rank, type):
86 | module_params = weights
87 |
88 | if module_params.data.dtype != torch.float:
89 | float_data = False
90 | original_type = module_params.data.dtype
91 | original_device = module_params.data.device
92 | matrix = module_params.data.float()
93 | else:
94 | float_data = True
95 | matrix = module_params.data
96 |
97 | U, s, Vh = torch.linalg.svd(matrix, full_matrices = False)
98 |
99 | #make the smaller matrix always to be orthogonal matrix
100 | if type=='right':
101 | A = U[:, :rank] @ torch.diag(s[:rank])
102 | B = Vh[:rank, :]
103 |
104 | if not float_data:
105 | B = B.to(original_device).type(original_type)
106 | return B
107 | elif type=='left':
108 | A = U[:, :rank]
109 | B = torch.diag(s[:rank]) @ Vh[:rank, :]
110 | if not float_data:
111 | A = A.to(original_device).type(original_type)
112 | return A
113 | elif type=='full':
114 | A = U[:, :rank]
115 | B = Vh[:rank, :]
116 | if not float_data:
117 | A = A.to(original_device).type(original_type)
118 | B = B.to(original_device).type(original_type)
119 | return [A, B]
120 | else:
121 | raise ValueError('type should be left, right or full')
122 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/bitorch_engine/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/bitorch_engine/utils/__init__.py
--------------------------------------------------------------------------------
/bitorch_engine/utils/arch_helper.py:
--------------------------------------------------------------------------------
1 | import platform
2 | from enum import Enum
3 | import subprocess
4 | import os
5 |
6 | class ARCH_CPU(Enum):
7 | '''
8 | Indicates which CPU architecture using for computation
9 | '''
10 | ARM_A76 = 1
11 | ARM_A55 = 2
12 |
13 |
14 | class linux_arch_ident():
15 | """
16 | A utility class for identifying Linux architecture, specifically aimed at ARM architectures.
17 |
18 | This class provides methods to check if the current system is running on an ARM architecture
19 | and to determine the specific model of the ARM CPU.
20 | """
21 | @staticmethod
22 | def is_arm() -> bool:
23 | """
24 | Determines if the current system's architecture is ARM-based.
25 |
26 | :return: True if the system is ARM-based, False otherwise.
27 | """
28 | return platform.machine().lower().startswith('arm') or platform.machine().lower().startswith('aarch64')
29 |
30 | @staticmethod
31 | def get_arm_model() -> ARCH_CPU:
32 | """
33 | Fetches the model name of the ARM CPU from the system and maps it to a predefined ARCH_CPU enum.
34 |
35 | This method attempts to execute the 'lscpu' command to retrieve CPU information, then parses
36 | the output to identify the model name of the ARM CPU.
37 |
38 | :return: An ARCH_CPU enum value corresponding to the ARM CPU model.
39 | :raises Exception: If there's an error executing 'lscpu', or if the CPU model cannot be recognized.
40 | """
41 | try:
42 | # Execute 'lscpu' command and decode the output to get CPU information
43 | cpuinfo = (subprocess.check_output("lscpu", shell=True).strip()).decode().lower()
44 | # Parse the output to find the model name
45 | s = cpuinfo.index("model name:")
46 | s += len("model name:")
47 | n = cpuinfo[s:].index("\n")
48 | model_n = cpuinfo[s:s+n].strip()
49 | except:
50 | # Raise an exception if 'lscpu' command fails or if parsing fails
51 | raise Exception("Error occurred while running 'lscpu', please check if your OS supports this command.")
52 |
53 | # Map the model name to a specific ARCH_CPU enum value
54 | if model_n.__contains__("cortex-a55"):
55 | return ARCH_CPU.ARM_A55
56 | elif model_n.__contains__("cortex-a76"):
57 | return ARCH_CPU.ARM_A76
58 | # Raise an exception if the model name does not match known values
59 | raise Exception("Invalid architecture name obtained: {}.".format(model_n))
60 |
61 |
62 | def check_cpu_instruction_support(search_term):
63 | """
64 | Checks if the CPU supports a specific instruction set.
65 |
66 | This function utilizes the `cpuinfo` library to fetch detailed CPU information,
67 | then searches for a specific term within the CPU flags to determine if a given
68 | CPU instruction set is supported.
69 |
70 | Args:
71 | search_term (str): The CPU instruction set or feature to search for, e.g., "sse4_2", "avx2".
72 |
73 | Returns:
74 | bool: True if the search term is found within the CPU flags, indicating support for the instruction set.
75 | False otherwise.
76 |
77 | Example:
78 | >>> check_cpu_instruction_support("sse4_2")
79 | True # Indicates that "sse4_2" is supported by the CPU.
80 |
81 | Note:
82 | The function prints the entire CPU information fetched by `cpuinfo.get_cpu_info()` which can be quite verbose.
83 | Consider commenting out the `print` statement for production use.
84 | """
85 | import cpuinfo
86 | print(cpuinfo.get_cpu_info())
87 | # Check if the search term is present in the CPU flags
88 | if search_term in cpuinfo.get_cpu_info()["flags"]:
89 | return True
90 | else:
91 | return False
92 |
--------------------------------------------------------------------------------
/bitorch_engine/utils/cpp_extension.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Dict, Any
3 | from torch.utils.cpp_extension import CppExtension, IS_MACOS
4 | from bitorch_engine.extensions import EXTENSION_PREFIX
5 | from bitorch_engine.utils.arch_helper import linux_arch_ident, ARCH_CPU, check_cpu_instruction_support
6 |
7 |
8 | def get_cpp_extension(root_path: Path, relative_name: str, relative_sources) -> Any:
9 | return CppExtension(
10 | name=EXTENSION_PREFIX + relative_name,
11 | sources=[str(root_path / rel_path) for rel_path in relative_sources],
12 | **get_kwargs()
13 | )
14 |
15 |
16 | def get_kwargs() -> Dict[str, Any]:
17 | """
18 | Generates keyword arguments for compilation settings, tailored for specific system architectures and operating systems.
19 |
20 | This function configures compiler flags and arguments based on the operating system and CPU architecture. It ensures
21 | that the appropriate flags are used for MacOS, ARM architectures, and potentially checks for specific CPU instruction
22 | set support (commented out by default).
23 |
24 | Returns:
25 | A dictionary containing:
26 | - `include_dirs`: A list of directories for the compiler to look for header files.
27 | - `libraries`: A list of libraries to link against. This varies between MacOS (`omp`) and other systems (`gomp`).
28 | - `extra_compile_args`: A list of additional arguments to pass to the compiler. This includes flags for warnings,
29 | OpenMP support, and architecture-specific optimizations.
30 |
31 | Note:
32 | - The function checks if the operating system is MacOS and adjusts the compilation flags accordingly.
33 | - For ARM architectures on Linux (not MacOS), it adds flags for ARMv8.2-A features and sets the CPU model
34 | for further optimizations if the model is detected as ARM A55 or A76.
35 | - The commented code snippet shows how to conditionally add compilation flags based on CPU instruction support,
36 | such as AVX2.
37 | """
38 | extra_compile_args = [
39 | '-Wall',
40 | '-Wno-deprecated-register',
41 | ]
42 | if IS_MACOS:
43 | extra_compile_args.append('-Xpreprocessor')
44 |
45 | extra_compile_args.append('-fopenmp')
46 |
47 | if linux_arch_ident.is_arm() and not IS_MACOS:
48 | extra_compile_args.append('-march=armv8.2-a+fp16+dotprod')
49 | if linux_arch_ident.get_arm_model() is ARCH_CPU.ARM_A55:
50 | extra_compile_args.append('-mcpu=cortex-a55')
51 | if linux_arch_ident.get_arm_model() is ARCH_CPU.ARM_A76:
52 | extra_compile_args.append('-mcpu=cortex-a76')
53 |
54 | ## can use this code to check the cpu support for instructions
55 | # if check_cpu_instruction_support('avx2'):
56 | # extra_compile_args.append('-mavx2')
57 |
58 | return {
59 | "include_dirs": [
60 | "/usr/local/opt/llvm/include",
61 | ],
62 | "libraries": [
63 | "omp" if IS_MACOS else "gomp",
64 | ],
65 | "extra_compile_args": extra_compile_args
66 | }
67 |
--------------------------------------------------------------------------------
/bitorch_engine/utils/cuda_extension.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import warnings
4 | from pathlib import Path
5 | from typing import Dict, Any
6 |
7 | from torch.utils.cpp_extension import CUDAExtension
8 |
9 | from bitorch_engine.extensions import EXTENSION_PREFIX
10 |
11 |
12 | SUPPORTED_CUDA_ARCHS = ["sm_80", "sm_75"]
13 |
14 |
15 | def get_cuda_arch():
16 | cuda_arch = os.environ.get("BIE_CUDA_ARCH", "sm_80")
17 | if cuda_arch not in SUPPORTED_CUDA_ARCHS:
18 | warnings.warn(f"Warning: BIE_CUDA_ARCH={cuda_arch} may not be supported yet.")
19 | return cuda_arch
20 |
21 |
22 | def get_cuda_extension(root_path: Path, relative_name: str, relative_sources) -> Any:
23 | if os.environ.get("BIE_BUILD_SEPARATE_CUDA_ARCH", "false") == "true":
24 | relative_name = f"{relative_name}-{get_cuda_arch()}"
25 | return CUDAExtension(
26 | name=EXTENSION_PREFIX + relative_name,
27 | sources=[str(root_path / rel_path) for rel_path in relative_sources],
28 | **get_kwargs()
29 | )
30 |
31 |
32 | def gcc_version():
33 | """
34 | Determines the version of GCC (GNU Compiler Collection) installed on the system.
35 |
36 | This function executes the 'gcc --version' command using subprocess.run and parses the output
37 | to extract the GCC version number. If GCC is not found or an error occurs during parsing,
38 | it returns a default version number of 0.0.0.
39 |
40 | The function checks if the output contains the string "clang" to identify if clang is masquerading
41 | as gcc, in which case it also returns 0.0.0, assuming GCC is not installed.
42 |
43 | Returns:
44 | tuple: A tuple containing three integers (major, minor, patch) representing the version of GCC.
45 | Returns (0, 0, 0) if GCC is not found or an error occurs.
46 | """
47 | output = subprocess.run(['gcc', '--version'], check=True, capture_output=True, text=True)
48 | if output.returncode > 0 or "clang" in output.stdout:
49 | return 0, 0, 0
50 | first_line = output.stdout.split("\n")[0]
51 | try:
52 | version = first_line.split(" ")[-1]
53 | major, minor, patch = list(map(int, version.split(".")))
54 | return major, minor, patch
55 | except:
56 | return 0, 0, 0
57 |
58 |
59 | def get_kwargs() -> Dict[str, Any]:
60 | """
61 | Generates keyword arguments for compilation based on the GCC version and CUDA architecture.
62 |
63 | This function dynamically constructs a dictionary of extra compilation arguments
64 | for both C++ and CUDA (nvcc) compilers. It includes flags to suppress deprecated
65 | declarations warnings, specify the OpenMP library path, and set the CUDA architecture.
66 | Additionally, it conditionally adjusts the nvcc host compiler to GCC 11 if the detected
67 | GCC version is greater than 11.
68 |
69 | Returns:
70 | Dict[str, Any]: A dictionary containing the 'extra_compile_args' key with nested
71 | dictionaries for 'cxx' and 'nvcc' compilers specifying their
72 | respective extra compilation arguments.
73 | """
74 | # Retrieve the current GCC version to determine compatibility and required flags
75 | major, minor, patch = gcc_version()
76 |
77 | # Initialize the base kwargs dict with common compile arguments for cxx and nvcc
78 | kwargs = {
79 | "extra_compile_args": {
80 | "cxx": [
81 | "-Wno-deprecated-declarations", # Suppress warnings for deprecated declarations
82 | "-L/usr/lib/gcc/x86_64-pc-linux-gnu/10.3.0/libgomp.so", # Specify libgomp library path for OpenMP
83 | "-fopenmp", # Enable OpenMP support
84 | ],
85 | "nvcc": [
86 | "-Xcompiler",
87 | "-fopenmp", # Pass -fopenmp to the host compiler via nvcc
88 | f"-arch={get_cuda_arch()}", # Set CUDA architecture, default to 'sm_80'
89 | "-DARCH_SM_75" if get_cuda_arch() == 'sm_75' else "-DARCH_SM_80", # set architecture macro
90 | ],
91 | },
92 | }
93 | # If the GCC version is greater than 11, adjust the nvcc host compiler settings
94 | if major > 11:
95 | print("Using GCC 11 host compiler for nvcc.")
96 | # Specify GCC 11 as the host compiler for nvcc:
97 | kwargs["extra_compile_args"]["nvcc"].append("-ccbin=/usr/bin/gcc-11")
98 | return kwargs
99 |
--------------------------------------------------------------------------------
/bitorch_engine/utils/cutlass_path.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from typing import Union, Tuple
4 |
5 |
6 | def check_path(p: str) -> Tuple[bool, str]:
7 | """
8 | Checks the provided path for the existence of the 'cutlass.h' file in expected directories.
9 |
10 | This function attempts to locate the 'cutlass.h' header file in three potential locations
11 | relative to the provided path:
12 | 1. Directly within the provided path.
13 | 2. Within a 'cutlass' directory inside the provided path.
14 | 3. Within an 'include/cutlass' directory structure inside the provided path.
15 |
16 | Args:
17 | p (str): The base path as a string where the search for 'cutlass.h' will begin.
18 |
19 | Returns:
20 | Tuple[bool, str]: A tuple containing:
21 | - A boolean indicating whether 'cutlass.h' was found.
22 | - A string representing the path where 'cutlass.h' was found, or an empty string if not found.
23 | """
24 | p = Path(p)
25 | if (p / "cutlass.h").is_file():
26 | return True, str(p)
27 | if (p / "cutlass" / "cutlass.h").is_file():
28 | return True, str(p)
29 | if (p / "include" / "cutlass" / "cutlass.h").is_file():
30 | return True, str(p / "include")
31 | return False, ""
32 |
33 |
34 | def find_cutlass(check_only: bool = True) -> Union[bool, str]:
35 | """
36 | Searches for the Cutlass library in predefined and environment-specified directories.
37 |
38 | This function iterates through a list of potential directories where Cutlass might be located.
39 | It checks each directory to see if Cutlass exists there. The search paths include '/usr/local/include'
40 | and any paths specified in the 'CPATH' environment variable.
41 |
42 | Args:
43 | check_only (bool): Determines the behavior of the function upon finding Cutlass.
44 | If True, the function returns a boolean indicating the presence of Cutlass.
45 | If False, it returns the path where Cutlass is found.
46 |
47 | Returns:
48 | Union[bool, str]: Depending on the value of `check_only`, this function either returns:
49 | - A boolean value indicating whether Cutlass was found (True) or not (False).
50 | - The string path to the directory where Cutlass is located. If not found, returns an empty string.
51 |
52 | Note:
53 | The function utilizes `check_path(p)`, a separate function not shown here, to determine
54 | if Cutlass is present in each directory `p`. It is assumed that `check_path(p)` returns
55 | a tuple (bool, str), where the boolean indicates success, and the string represents the path.
56 | """
57 | success, path = False, ""
58 | search_paths = ["/usr/local/include"] + os.environ.get("CPATH", "").split(":")
59 | for p in search_paths:
60 | if check_only:
61 | print("Searching Cutlass in:", p)
62 | success, path = check_path(p)
63 | if success:
64 | if check_only:
65 | print("Found Cutlass in:", p)
66 | break
67 | return success if check_only else path
68 |
69 |
70 | def is_cutlass_available() -> bool:
71 | return find_cutlass()
72 |
73 |
74 | def get_cutlass_include_path() -> str:
75 | return find_cutlass(check_only=False)
76 |
--------------------------------------------------------------------------------
/bitorch_engine/utils/mlx_extension.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from torch.utils.cpp_extension import CppExtension
4 |
5 | from bitorch_engine.extensions import EXTENSION_PREFIX
6 |
7 | from .mlx_path import get_mlx_include_path, get_mlx_lib_path
8 |
9 |
10 | def get_mlx_extension(root_path: Path, relative_name: str, relative_sources) -> CppExtension:
11 | """
12 | Creates and returns a CppExtension for compiling a C++ extension with MLX library support.
13 |
14 | This function is designed to simplify the configuration of a C++ extension that depends on the MLX library by
15 | automatically setting up include directories, library directories, and other necessary compilation and runtime settings.
16 |
17 | Parameters:
18 | root_path (Path): The root directory path where the C++ source files are located. This path is used to resolve the full paths to the source files specified in `relative_sources`.
19 | relative_name (str): A relative name for the extension. This name is prefixed with a predefined prefix and used as the extension's name.
20 | relative_sources (Iterable[str]): A list or iterable of relative paths to the C++ source files, relative to `root_path`. These source files constitute the extension being compiled.
21 |
22 | Returns:
23 | CppExtension: An instance of CppExtension configured with paths to include directories, library directories, and other settings needed to compile and link the extension with the MLX library.
24 | """
25 | include_path = get_mlx_include_path()
26 | lib_path = get_mlx_lib_path()
27 | return CppExtension(
28 | name=EXTENSION_PREFIX + relative_name,
29 | sources=[str(root_path / rel_path) for rel_path in relative_sources],
30 | include_dirs=[include_path],
31 | library_dirs=[lib_path], # To find mlx during compilation
32 | runtime_library_dirs=[lib_path], # To find mlx during runtime
33 | libraries=['mlx'],
34 | extra_compile_args=['-std=c++17'],
35 | )
--------------------------------------------------------------------------------
/bitorch_engine/utils/mlx_path.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from importlib.util import find_spec
4 | from pathlib import Path
5 | from typing import Union
6 |
7 |
8 | def get_mlx_include_path() -> Union[str, None]:
9 | """
10 | Attempts to find the include path for the mlx library.
11 |
12 | This function searches for the 'mlx.h' header file associated with the mlx library
13 | in various possible locations where the library might be installed. The search follows this order:
14 | 1. Looks within the package's submodule search locations if the mlx package is installed.
15 | 2. Checks the system's default include path under the Python environment's prefix.
16 | 3. Scans paths specified in the 'CPATH' environment variable.
17 |
18 | Returns:
19 | str: The absolute path to the directory containing 'mlx.h' if found.
20 | None: If the 'mlx.h' file cannot be found in any of the searched locations.
21 | """
22 | filename = "mlx.h"
23 |
24 | mlx_spec = find_spec("mlx")
25 | if mlx_spec is not None:
26 | for search_path in mlx_spec.submodule_search_locations:
27 | path = Path(search_path) / "include"
28 | if path.exists() and ((path / filename).exists() or (path / "mlx" / filename).exists()):
29 | return str(path.resolve())
30 |
31 | prefix_path = Path(sys.prefix).resolve()
32 | if (prefix_path / "include" / "mlx" / filename).exists():
33 | return str(prefix_path / "include")
34 |
35 | for path in os.environ.get("CPATH", "").split(":"):
36 | path = Path(path).resolve()
37 | if (path / "include" / filename).exists():
38 | return str(path / "include")
39 | elif (path / filename).exists():
40 | return str(path)
41 |
42 | return None
43 |
44 | def get_mlx_lib_path() -> Union[str, None]:
45 | """
46 | Attempts to find the library path for 'libmlx.dylib'.
47 |
48 | This function searches for the 'libmlx.dylib' file in various locations to determine
49 | the library path for mlx (a hypothetical library). The search follows this order:
50 |
51 | 1. Within the 'mlx' package's installation directory, if the package is installed.
52 | It looks for a 'lib' directory under the package's submodule search locations.
53 | 2. In the 'lib' directory under the Python environment's prefix directory. This is
54 | typically where libraries are installed for the current Python environment.
55 | 3. In directories specified in the 'LIBRARY_PATH' environment variable. This is a
56 | colon-separated list of directories where libraries are searched for on Unix-like
57 | systems.
58 |
59 | The function returns the path as a string if 'libmlx.dylib' is found, or None if the
60 | library cannot be found in any of the searched locations.
61 |
62 | Returns:
63 | str or None: The path to 'libmlx.dylib' if found, otherwise None.
64 | """
65 | filename = "libmlx.dylib"
66 |
67 | mlx_spec = find_spec("mlx")
68 | if mlx_spec is not None:
69 | for search_path in mlx_spec.submodule_search_locations:
70 | path = Path(search_path) / "lib"
71 | if path.exists() and ((path / filename).exists() or (path / "mlx" / filename).exists()):
72 | return str(path.resolve())
73 |
74 | prefix_path = Path(sys.prefix).resolve()
75 | if (prefix_path / "lib" / filename).exists():
76 | return str(prefix_path / "lib")
77 |
78 | for path in os.environ.get("LIBRARY_PATH", "").split(":"):
79 | path = Path(path)
80 | if (path / "lib" / filename).exists():
81 | return str(path / "lib")
82 | elif (path / filename).exists():
83 | return str(path)
84 |
85 | return None
86 |
87 | def is_mlx_available() -> bool:
88 | """
89 | Checks if the MLX library is available for use.
90 |
91 | This function determines the availability of the MLX library by verifying both the include path and the library path of MLX.
92 | It does this by calling two functions: `get_mlx_include_path` and `get_mlx_lib_path`.
93 | For the MLX library to be considered available, both of these functions must return a value that is not `None`.
94 |
95 | Returns:
96 | bool: True if both the MLX include path and the MLX library path are available (i.e., not None),
97 | indicating that the MLX library is available for use. Otherwise, returns False.
98 | """
99 | return get_mlx_include_path() is not None and get_mlx_lib_path() is not None
--------------------------------------------------------------------------------
/bitorch_engine/utils/safe_import.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from typing import Any
3 |
4 | from bitorch_engine.extensions import EXTENSION_PREFIX
5 |
6 | MESSAGE = """The extension '{}' could not be imported. It is either not yet implemented or was not build correctly.
7 | This message is expected during the build process. If it appears later on, try installing the package again."""
8 |
9 |
10 | class ExtensionModulePlaceholder:
11 | """
12 | A placeholder class for extension modules.
13 |
14 | This class serves as a placeholder for dynamically loaded extension modules. It is designed to
15 | intercept attribute access and modifications, raising a runtime error if any operation other than
16 | setting the initial name is attempted. This behavior ensures that any misuse of the placeholder,
17 | such as accessing or modifying attributes that are not supposed to exist in the placeholder state,
18 | is promptly identified during development.
19 |
20 | Attributes:
21 | _name (str): The name of the extension module this placeholder represents.
22 |
23 | Methods:
24 | __init__: Initializes the placeholder with the name of the extension module.
25 | __getattr__: Intercepts attribute access attempts, raising a RuntimeError.
26 | __setattr__: Intercepts attribute modification attempts, allowing only the _name attribute to be set.
27 | """
28 | def __init__(self, name: str) -> None:
29 | """
30 | Initializes the ExtensionModulePlaceholder with a specified name.
31 |
32 | Args:
33 | name (str): The name of the extension module this placeholder represents.
34 | """
35 | self._name = name
36 |
37 | def __getattr__(self, item: str) -> Any:
38 | """
39 | Handles attribute access attempts for the placeholder.
40 |
41 | This method raises a RuntimeError to indicate that the attempted access is invalid, as the
42 | placeholder should not be used for accessing attributes.
43 |
44 | Args:
45 | item (str): The name of the attribute being accessed.
46 |
47 | Returns:
48 | Any: This method does not return but raises RuntimeError instead.
49 |
50 | Raises:
51 | RuntimeError: Indicates an invalid attribute access attempt.
52 | """
53 | raise RuntimeError(MESSAGE.format(self._name))
54 |
55 | def __setattr__(self, key: Any, value: Any) -> None:
56 | """
57 | Handles attribute modification attempts for the placeholder.
58 |
59 | This method allows setting the _name attribute only. Any other attempt to modify attributes
60 | raises a RuntimeError to indicate that the operation is invalid.
61 |
62 | Args:
63 | key (Any): The name of the attribute to be modified.
64 | value (Any): The new value for the attribute.
65 |
66 | Raises:
67 | RuntimeError: Indicates an invalid attribute modification attempt, except for the _name attribute.
68 | """
69 | if key == "_name":
70 | self.__dict__["_name"] = value
71 | return
72 | raise RuntimeError(MESSAGE.format(self._name))
73 |
74 |
75 | def import_extension(module_name: str, not_yet_implemented: bool = False) -> Any:
76 | """
77 | Dynamically imports a Python extension module by name, providing a safe mechanism to handle cases
78 | where the module is not yet built or not yet implemented. This function is particularly useful for
79 | conditionally importing modules that provide optional functionality or are platform-specific.
80 |
81 | If the module is marked as not yet implemented (not_yet_implemented=True), or if the module cannot be
82 | found during import, the function returns a placeholder object instead of raising an ImportError. This
83 | allows the application to continue running and gracefully handle the absence of the module.
84 |
85 | Args:
86 | module_name (str): The name of the module to be imported. The actual module name will be prefixed
87 | with a predefined prefix defined in `EXTENSION_PREFIX` to form the full module name.
88 | not_yet_implemented (bool, optional): A flag indicating whether the module is known to be not yet
89 | implemented. If True, the function immediately returns a placeholder
90 | without attempting to import the module. Defaults to False.
91 |
92 | Returns:
93 | Any: An imported module if successful, or an instance of `ExtensionModulePlaceholder` if the module
94 | is not implemented or cannot be found.
95 |
96 | Example:
97 | binary_linear_cuda = import_extension("binary_linear_cuda")
98 | This example attempts to import a module named "binary_linear_cuda" (prefixed appropriately),
99 | returning the module if found, or a placeholder if not found or not implemented.
100 |
101 | Note:
102 | This function prints a warning message to the console if the module cannot be found, informing the
103 | user of the issue without interrupting the execution of the program.
104 | """
105 | if not_yet_implemented:
106 | return ExtensionModulePlaceholder(module_name)
107 |
108 | try:
109 | return importlib.import_module(EXTENSION_PREFIX + module_name)
110 | except ModuleNotFoundError:
111 | print("Warning:", MESSAGE.format(module_name))
112 | return ExtensionModulePlaceholder(module_name)
113 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG FROM_IMAGE=pytorch/manylinux-builder:cuda12.1-2.3
2 | FROM ${FROM_IMAGE} as builder-base
3 | RUN mkdir "/build_scripts"
4 | RUN mkdir "/workspace"
5 |
6 | FROM builder-base as pytorch-base
7 | #COPY "build_scripts/prepare_builder_image.sh" "/build_scripts/"
8 | #RUN bash "/build_scripts/prepare_builder_image.sh" "${FROM_IMAGE}" && \
9 | # rm "/build_scripts/prepare_builder_image.sh"
10 |
11 | FROM pytorch-base as cutlass-install
12 | ARG CUTLASS_VERSION="2.8.0"
13 | ARG CUTLASS_HOME="/opt/cutlass"
14 | RUN git clone --depth 1 --branch "v${CUTLASS_VERSION}" "https://github.com/NVIDIA/cutlass.git" --recursive "${CUTLASS_HOME}/source" && \
15 | mkdir "${CUTLASS_HOME}/build" && \
16 | cd "${CUTLASS_HOME}/build" && \
17 | cmake ../source \
18 | -DCMAKE_INSTALL_PREFIX="${CUTLASS_HOME}/install" \
19 | -DCUTLASS_ENABLE_HEADERS_ONLY=ON \
20 | -DCUTLASS_ENABLE_TOOLS=ON \
21 | -DCUTLASS_ENABLE_LIBRARY=OFF \
22 | -DCUTLASS_ENABLE_PROFILER=OFF \
23 | -DCUTLASS_NVCC_ARCHS='75;80;86' && \
24 | cmake --install . && \
25 | rm -rf "${CUTLASS_HOME}/build" "${CUTLASS_HOME}/source"
26 |
27 | FROM cutlass-install as build-ready
28 | ARG PYTHON_HOME="/opt/python/cp310-cp310"
29 | ENV PATH="${PYTHON_HOME}/bin:${PATH}"
30 | ARG CUSTOM_TORCH_URL="https://packages.greenbit.ai/whl/cu121/torch/torch-2.3.0-cp310-cp310-linux_x86_64.whl"
31 | ARG TORCHVISION_VERSION="0.18.0"
32 | ARG TORCHVISION_INDEX_URL="https://download.pytorch.org/whl/cu121"
33 | RUN pip install "${CUSTOM_TORCH_URL}" && \
34 | pip install "torchvision==${TORCHVISION_VERSION}" --index-url "${TORCHVISION_INDEX_URL}" && \
35 | pip cache purge && \
36 | rm -rf /build_scripts
37 |
38 | # clone instead of mounting makes the code in the image independent from local changes
39 | # to mount your code before building, use the target above, and check the "For Development" section in docs/README.md
40 | FROM build-ready as no-examples
41 | ARG GIT_URL="https://github.com/GreenBitAI/bitorch-engine.git"
42 | ARG GIT_BRANCH="main"
43 | ARG BUILD_TARGET="."
44 | RUN git clone \
45 | --depth 1 \
46 | --branch "${GIT_BRANCH}" \
47 | "${GIT_URL}" \
48 | /bitorch-engine && \
49 | cd /bitorch-engine && \
50 | BIE_FORCE_CUDA="true" CPATH="${CUTLASS_HOME}/install/include" pip install -e ${BUILD_TARGET} -v && \
51 | rm -rf build/ bitorch_engine.egg-info/
52 | WORKDIR "/workspace"
53 |
54 | FROM no-examples as example-ready
55 | RUN pip install -r /bitorch-engine/examples/mnist-lightning/requirements.txt
56 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | # Project Setup with Docker
2 |
3 | ## Build Docker Image
4 |
5 | Use the following commands to build a [Docker](https://www.docker.com/) image with a Bitorch Engine installation.
6 | This is currently only targeted and tested for CUDA 11.8 or 12.1 and _Torch 2.2.x_.
7 |
8 | ```bash
9 | # cd docker
10 | # you should be in this `docker` directory
11 | docker build -t bitorch/engine .
12 | # if you do not want to include installation of example requirements, use this instead:
13 | docker build --target no-examples -t bitorch/engine .
14 | ```
15 |
16 | After building, the docker image should contain:
17 | - The selected torch package (limited to those that we modified to support gradients for non-floating-point tensors)
18 | - A ready-built bitorch engine, and its requirements
19 | - Everything is installed in a conda environment with Python (currently 3.10)
20 |
21 | ## Build Options
22 |
23 | Depending on your setup, you may want to adjust some options through build arguments:
24 | - CUDA version, e.g. for CUDA 11.8 add
25 | - `--build-arg FROM_IMAGE="pytorch/manylinux-builder:cuda11.8-2.3"`
26 | - `--build-arg CUSTOM_TORCH_URL="https://packages.greenbit.ai/whl/cu118/torch/torch-2.3.0-cp310-cp310-linux_x86_64.whl"`
27 | - `--build-arg TORCHVISION_INDEX_URL="https://download.pytorch.org/whl/cu118"`
28 | - repository URL, e.g. add `--build-arg GIT_URL="https://accesstoken:tokenpassword@github.com/MyFork/bitorch-engine.git"`
29 | - Bitorch Engine branch or tag, e.g. add `--build-arg GIT_BRANCH="v1.2.3"`
30 | - installing requirements for development, e.g. `--build-arg BUILD_TARGET=".[dev]"`
31 | - if there is a problem, set the environment variable `BUILDKIT_PROGRESS=plain` to see all output
32 |
33 | Here is an example:
34 | ```bash
35 | BUILDKIT_PROGRESS=plain docker build -t bitorch/engine --build-arg BUILD_TARGET=".[dev]" --build-arg GIT_BRANCH="mybranch" .
36 | ```
37 |
38 | ## Run Docker Container
39 |
40 | After building the image you can run a container based on it with:
41 | ```bash
42 | docker run -it --rm --gpus all bitorch/engine:latest
43 | ```
44 |
45 | ## For Development
46 |
47 | A docker image without the code cloned, e.g. for mounting a local copy of the code, can be made easily with the target `build-ready`:
48 | ```bash
49 | # cd docker
50 | # you should be in this `docker` directory
51 | docker build -t bitorch/engine:build-ready --target build-ready .
52 | docker run -it --rm --gpus all --volume "$(pwd)/..":/bitorch-engine bitorch/engine:build-ready
53 | # in docker container:
54 | cd /bitorch-engine
55 | pip install -e ".[dev]" -v
56 | ```
57 | However, this means the build results will not be persisted in the image, so you probably want to mount the same directory every time.
58 |
--------------------------------------------------------------------------------
/docker/build_scripts/install_modified_pytorch.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | from_image="${1}"
4 | action="${2:-install}"
5 |
6 | function usage() {
7 | echo "./install_modified_pytorch.sh DOCKER_IMAGE ACTION"
8 | echo "verify or install a modified torch version suitable for the chosen docker image"
9 | echo
10 | echo "ACTION can be 'install' or 'verify' and is optional (default: install)"
11 | }
12 |
13 | gdrive_id="unknown"
14 | file="custom_torch.whl"
15 |
16 | ## list of known docker images and the corresponding google drive id to download modified torch packages
17 | ## adding them here individually is tedious, but we need to build them manually and ensure compatibility anyway
18 |
19 | if [ "${from_image}" == "pytorch/pytorch:2.2.0-cuda11.8-cudnn8-devel" ]; then
20 | gdrive_id="1PoVor85-RF3s0KpOP19mFV5hNUnHERa1"
21 | file="torch-2.2.2-cp310-cp310-linux_x86_64.whl"
22 | checksum="6646519e5e7b4af8f99b79eb9be3e6460b0d05c4695bbf86de02568f37ff3fea"
23 | fi
24 | if [ "${from_image}" == "pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel" ]; then
25 | gdrive_id="1LjFNImboq8QeFSompMS2gPjBRYtP2Dsz"
26 | file="torch-2.2.2-cp310-cp310-linux_x86_64.whl"
27 | checksum="bcc0ba7f121ee2f42ed0a59f01d4e3d70f82a8981be0be25c5e0fe0635a54b2d"
28 | fi
29 | #if [ "${from_image}" == "pytorch/pytorch:X.X.X-cudaXX.X-cudnn8-devel" ]; then
30 | # gdrive_id="xxx"
31 | # file="torch-X.X.X-cp310-cp310-linux_x86_64.whl"
32 | # checksum="xxx"
33 | #fi
34 |
35 | function check_error() {
36 | # shows and then runs a command. if the exit code is not zero, aborts the script
37 | # usage: check_error mv foo bar
38 |
39 | echo + $@
40 | "$@"
41 | local exit_code=$?
42 | if [ "${exit_code}" -ne 0 ]; then
43 | echo "! > An error occured, aborting."
44 | exit 1
45 | fi
46 | }
47 |
48 | if [ "${gdrive_id}" == "unknown" ]; then
49 | echo "Unknown image '${from_image}', could not choose modified torch accordingly."
50 | echo "Please add your base image to install_modified_pytorch.sh or request official support for your image via Github."
51 | echo
52 | usage
53 | exit 1
54 | fi
55 |
56 | check_error pip install gdown
57 | check_error gdown "${gdrive_id}" -O "${file}"
58 | check_error pip uninstall -y gdown
59 |
60 | if [ -n "${checksum}" ]; then
61 | check_error sha256sum --check --status <<< "${checksum} ${file}"
62 | fi
63 |
64 | if [ "${action}" == "verify" ]; then
65 | exit 0
66 | fi
67 |
68 | check_error pip install "${file}"
69 | check_error rm "${file}"
70 | check_error pip cache purge
71 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | source/_autosummary/
2 | build/
3 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Bitorch Engine Documentation
2 |
3 | ## Requirements
4 |
5 | First you should install BITorch Engine as normal.
6 | Then, install the additional requirements with:
7 | ```bash
8 | # in the docs folder
9 | pip install -r requirements.txt
10 | ```
11 |
12 | ## Magic Build Script
13 |
14 | The script `make_docs.sh` will try to automagically build the documentation for you:
15 | ```bash
16 | ./docs/make_docs.sh
17 | ```
18 | If there is a problem, see the script and manual steps below.
19 |
20 | ## Manual Build
21 |
22 | The docs for `bitorch_engine` are generated using the [sphinx](https://www.sphinx-doc.org/en/master/>) package.
23 | To build the docs, `cd` into the repository root and execute.
24 |
25 | ```bash
26 | sphinx-build -b html docs/source/ docs/build/ -a
27 | ```
28 |
29 | The generated `HTML` files will be put into `docs/build`.
30 |
31 | ## Synchronize from Readmes
32 |
33 | To synchronize information from the Readme, we can use pandoc to convert the markdown file to RST:
34 | ```bash
35 | pip install pandoc
36 | pandoc --from=markdown --to=rst --output=README.rst README.md
37 | ```
38 |
39 | Afterward, we need to fix a few issues, such as incorrect URLs, collapsible sections, etc.
40 | and then move those sections to the appropriate place in the documentation.
41 | You can try automatically doing so by running `python docs/scripts/convert_docs.py README.rst` before building with sphinx.
42 |
--------------------------------------------------------------------------------
/docs/make_docs.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | if ! [ -d "docs" ]; then
4 | cd ..
5 | fi
6 | if ! [ -d "docs" ]; then
7 | echo "Could not locate docs directory. Please run the script from the root or the docs directory."
8 | exit 1
9 | fi
10 |
11 | function package_required() {
12 | pip freeze | grep "${1}" &> /dev/null
13 | if ! [ $? == "0" ]; then
14 | echo "Package '${1}' not found. Please install it, e.g. with: $ pip install ${1}"
15 | exit 1
16 | fi
17 | }
18 | # check new packages are installed
19 | package_required pandoc
20 | package_required sphinx_design
21 |
22 | pandoc --from=markdown --to=rst --output=README.rst README.md
23 | python docs/scripts/convert_docs.py README.rst
24 | sphinx-build -b html docs/source/ docs/build/ -a
25 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | autodocsumm
2 | nbsphinx
3 | pandoc
4 | sphinx_rtd_theme
5 | sphinx_gallery
6 | sphinx_toolbox
7 | sphinx_design
8 |
--------------------------------------------------------------------------------
/docs/scripts/convert_docs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | class Converter:
5 | def __init__(self, verbose=False):
6 | self.active = False
7 | self.verbose = verbose
8 |
9 | def process_line(self, lines, line_num):
10 | self.check_start_stop(lines, line_num)
11 | if self.active:
12 | self.internal_process_line(lines, line_num)
13 |
14 | def check_start_stop(self, lines, line_num):
15 | raise NotImplementedError("Should be implemented in subclass")
16 |
17 | def internal_process_line(self, lines, line_num):
18 | raise NotImplementedError("Should be implemented in subclass")
19 |
20 | def replace_line(self, lines, line_num, new_line):
21 | if self.verbose:
22 | print(lines[line_num].rstrip())
23 | print("vvv")
24 | print(new_line.rstrip())
25 | lines[line_num] = new_line
26 |
27 |
28 | class Replacer(Converter):
29 | def __init__(self, replacements, **kwargs):
30 | super().__init__(**kwargs)
31 | self.replacements = replacements
32 | self.keys = [pair[0] for pair in replacements]
33 |
34 | def check_start_stop(self, lines, line_num):
35 | if any(key in lines[line_num] for key in self.keys):
36 | self.active = True
37 | return
38 | self.active = False
39 |
40 | def internal_process_line(self, lines, line_num):
41 | for pair in self.replacements:
42 | new_line = lines[line_num].replace(*pair)
43 | self.replace_line(lines, line_num, new_line)
44 |
45 |
46 | class SectionExtractor(Converter):
47 | heading_levels = ["=", "-", "~", "^"]
48 |
49 | def __init__(self, heading_contains, modify_by=-1, **kwargs):
50 | super().__init__(**kwargs)
51 | self.heading_contains = heading_contains
52 | self.modify_by = modify_by
53 | self.level_stop = False
54 | self.start_line = -1
55 | self.end_line = -1
56 |
57 | @staticmethod
58 | def determine_level(line):
59 | line = line.rstrip()
60 | if len(line) == 0:
61 | return False
62 | first_symbol = line[0]
63 | if not first_symbol in SectionExtractor.heading_levels:
64 | return False
65 | if all(c == first_symbol for c in line):
66 | return SectionExtractor.heading_levels.index(first_symbol)
67 | return False
68 |
69 | def check_start_stop(self, lines, line_num):
70 | level = self.determine_level(lines[line_num])
71 | if level is not False and level == self.level_stop and self.active:
72 | if self.verbose:
73 | print("Inactive:", lines[line_num])
74 | self.active = False
75 | self.end_line = line_num - 1
76 | if self.heading_contains in lines[line_num]:
77 | if self.verbose:
78 | print("Active:", lines[line_num])
79 | self.active = True
80 | self.start_line = line_num
81 |
82 | def internal_process_line(self, lines, line_num):
83 | level = self.determine_level(lines[line_num])
84 | if level is False:
85 | return
86 | if self.level_stop is False:
87 | self.level_stop = level
88 | new_line = lines[line_num].replace(
89 | self.heading_levels[level],
90 | self.heading_levels[level + self.modify_by],
91 | )
92 | self.replace_line(lines, line_num, new_line)
93 |
94 | def write_to_file(self, lines, output_file, mode="w"):
95 | with open(output_file, mode) as f:
96 | f.writelines(lines[self.start_line:self.end_line])
97 |
98 |
99 | class SectionCollapseFixer(Converter):
100 | heading_levels = ["=", "-", "~", "^"]
101 |
102 | def __init__(self, **kwargs):
103 | super().__init__(**kwargs)
104 |
105 | def check_start_stop(self, lines, line_num):
106 | line = lines[line_num]
107 | if ".. raw:: html" in line and "" in lines[line_num + 2]:
108 | if self.verbose:
109 | print("Active:", line)
110 | self.active = True
111 | if " " in line:
112 | if self.verbose:
113 | print("Inactive:", line)
114 | self.active = False
115 | self.replace_line(lines, line_num, "")
116 |
117 | def internal_process_line(self, lines, line_num):
118 | stripped_line = lines[line_num].strip()
119 | if stripped_line in ["", ".. raw:: html", " ", ""]:
120 | self.replace_line(lines, line_num, "")
121 | return
122 | if stripped_line == "":
123 | summary_line = line_num + 1
124 | while lines[summary_line].strip() == "":
125 | summary_line += 1
126 | self.replace_line(lines, line_num, ".. dropdown:: " + lines[summary_line])
127 | for i in range(line_num + 1, summary_line + 1):
128 | self.replace_line(lines, i, "")
129 | return
130 | self.replace_line(lines, line_num, " " + lines[line_num])
131 |
132 | def main(args):
133 | # set build_options_separate=False to integrate build options into installation part:
134 | build_options_separate = True
135 | verbose = False
136 |
137 | install_section = SectionExtractor("Installation", verbose=verbose)
138 | build_section = SectionExtractor("Build options", modify_by=-1 if build_options_separate else 0, verbose=verbose)
139 | section_collapse_fixer = SectionCollapseFixer(verbose=verbose)
140 | replacer = Replacer([
141 | [
142 | r"`docker readme `__",
143 | r"`docker readme `__",
144 | ],
145 | ], verbose=verbose)
146 |
147 | content = None
148 | with open(args.input, "r") as f:
149 | content = f.readlines()
150 | if content is None:
151 | return
152 |
153 | i = 0
154 | while i < len(content):
155 | for converter in [install_section, build_section, section_collapse_fixer, replacer]:
156 | converter.process_line(content, i)
157 | i += 1
158 |
159 | install_section.write_to_file(content, "docs/source/installation.rst")
160 | if build_options_separate:
161 | build_section.write_to_file(content, "docs/source/build_options.rst")
162 | else:
163 | build_section.write_to_file(content, "docs/source/installation_test.rst", mode="a")
164 |
165 |
166 | if __name__ == "__main__":
167 | parser = argparse.ArgumentParser()
168 | parser.add_argument("input", help="Input file")
169 | main(parser.parse_args())
170 |
--------------------------------------------------------------------------------
/docs/source/_templates/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 | :members:
7 | :no-inherited-members:
8 | :special-members: __call__, __add__, __mul__, forward, __init__
9 |
10 | {% block methods %}
11 | {% if methods %}
12 | .. rubric:: {{ _('Methods') }}
13 |
14 | .. autosummary::
15 | :nosignatures:
16 | {% for item in methods %}
17 |
18 | {%- if item.startswith('__init__') %}
19 | {%- if item not in inherited_members %}
20 | ~{{ name }}.{{ item }}
21 | {%- endif -%}
22 | {%- endif -%}
23 | {%- if not item.startswith('_') %}
24 | {%- if item not in inherited_members %}
25 | ~{{ name }}.{{ item }}
26 | {%- endif -%}
27 | {%- endif -%}
28 | {%- endfor %}
29 | {% endif %}
30 | {% endblock %}
31 |
32 | {% block attributes %}
33 | {% if attributes %}
34 | .. rubric:: {{ _('Attributes') }}
35 |
36 | .. autosummary::
37 | {% for item in attributes %}
38 | {%- if item not in inherited_members %}
39 | ~{{ name }}.{{ item }}
40 | {%- endif -%}
41 | {%- endfor %}
42 | {% endif %}
43 | {% endblock %}
44 |
--------------------------------------------------------------------------------
/docs/source/_templates/module.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. automodule:: {{ fullname }}
4 |
5 | {% block attributes %}
6 | {% if attributes %}
7 | .. rubric:: Module attributes
8 |
9 | .. autosummary::
10 | :toctree:
11 | {% for item in attributes %}
12 | {{ item }}
13 | {%- endfor %}
14 | {% endif %}
15 | {% endblock %}
16 |
17 | {% block functions %}
18 | {% if functions and fullname != 'bitorch_engine' %}
19 | .. rubric:: {{ _('Functions') }}
20 |
21 | .. autosummary::
22 | :toctree:
23 | :nosignatures:
24 | {% for item in functions %}
25 | {{ item }}
26 | {%- endfor %}
27 | {% endif %}
28 | {% endblock %}
29 |
30 | {% block classes %}
31 | {% if classes %}
32 | .. rubric:: {{ _('Classes') }}
33 |
34 | .. autosummary::
35 | :template: class.rst
36 | :toctree:
37 | {% for item in classes %}
38 | {{ item }}
39 | {%- endfor %}
40 | {% endif %}
41 | {% endblock %}
42 |
43 | {% block exceptions %}
44 | {% if exceptions %}
45 | .. rubric:: {{ _('Exceptions') }}
46 |
47 | .. autosummary::
48 | :toctree:
49 | {% for item in exceptions %}
50 | {{ item }}
51 | {%- endfor %}
52 | {% endif %}
53 | {% endblock %}
54 |
55 | {% block modules %}
56 | {% if modules %}
57 | .. rubric:: {{ _('Modules') }}
58 |
59 | .. autosummary::
60 | :toctree:
61 | :template: module.rst
62 | :recursive:
63 | {% for item in modules %}
64 | {% if not item.endswith('.extensions') %}{{ item }}{%- endif %}
65 | {%- endfor %}
66 | {% endif %}
67 | {% endblock %}
68 |
--------------------------------------------------------------------------------
/docs/source/build_options.rst:
--------------------------------------------------------------------------------
1 | Build options
2 | =============
3 |
4 | Building Specific Extensions
5 | ----------------------------
6 |
7 | While developing, a specific cpp/cuda extension can be (re-)build, by
8 | using the environment variable ``BIE_BUILD_ONLY``, like so:
9 |
10 | .. code:: bash
11 |
12 | BIE_BUILD_ONLY="bitorch_engine/layers/qlinear/binary/cpp" pip install -e . -v
13 |
14 | It needs to a relative path to one extension directory.
15 |
16 | Building for a Specific CUDA Architecture
17 | -----------------------------------------
18 |
19 | To build for a different CUDA Arch, use the environment variable
20 | ``BIE_CUDA_ARCH`` (e.g. use ‘sm_75’, ‘sm_80’, ‘sm_86’):
21 |
22 | .. code:: bash
23 |
24 | BIE_CUDA_ARCH="sm_86" pip install -e . -v
25 |
26 | Force Building CUDA Modules
27 | ---------------------------
28 |
29 | If you have CUDA development libraries installed, but
30 | ``torch.cuda.is_available()`` is False, e.g. in HPC or docker
31 | environments, you can still build the extensions that depend on CUDA, by
32 | setting ``BIE_FORCE_CUDA="true"``:
33 |
34 | .. code:: bash
35 |
36 | BIE_FORCE_CUDA="true" pip install -e . -v
37 |
38 | Skip Library File Building
39 | --------------------------
40 |
41 | If you just want to avoid rebuilding any files, you can set
42 | ``BIE_SKIP_BUILD``:
43 |
44 | .. code:: bash
45 |
46 | BIE_SKIP_BUILD="true" python3 -m build --no-isolation --wheel
47 |
48 | This would create a wheel and package ``.so`` files without trying to
49 | rebuild them.
50 |
51 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | import os
7 | import sys
8 | from pathlib import Path
9 |
10 | sys.path.insert(0, os.path.abspath('../../'))
11 |
12 | # -- Project information -----------------------------------------------------
13 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
14 |
15 | project = 'Bitorch Engine'
16 | copyright = '2024, Haojin Yang, Joseph Bethge, Nianhui Guo, Maximilian Schulze, Hong Guo, Paul Mattes'
17 | author = 'Haojin Yang, Joseph Bethge, Nianhui Guo, Maximilian Schulze, Hong Guo, Paul Mattes'
18 |
19 | root_path = Path(__file__).resolve().parent.parent.parent
20 | release = "unknown"
21 | version_file = root_path / "version.txt"
22 | if version_file.exists():
23 | with open(root_path / "version.txt") as handle:
24 | version_content = handle.read().strip()
25 | if version_content:
26 | release = version_content
27 |
28 |
29 | # -- General configuration ---------------------------------------------------
30 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
31 |
32 | extensions = [
33 | "sphinx.ext.autodoc",
34 | "sphinx.ext.autosummary",
35 | "sphinx.ext.doctest",
36 | "sphinx.ext.todo",
37 | "sphinx.ext.mathjax",
38 | "sphinx.ext.ifconfig",
39 | "sphinx.ext.viewcode",
40 | "sphinx.ext.githubpages",
41 | "sphinx.ext.napoleon",
42 | "autodocsumm",
43 | "sphinx.ext.intersphinx",
44 | "sphinx_toolbox.code",
45 | "sphinx_toolbox.collapse",
46 | "sphinx_design",
47 | ]
48 |
49 | # List of patterns, relative to source directory, that match files and
50 | # directories to ignore when looking for source files.
51 | # This pattern also affects html_static_path and html_extra_path.
52 | exclude_patterns = ['.so', '_build', 'Thumbs.db', '.DS_Store']
53 |
54 | # Add any paths that contain templates here, relative to this directory.
55 | templates_path = ['_templates']
56 | autosummary_generate = True
57 |
58 | # -- Options for HTML output -------------------------------------------------
59 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
60 |
61 | import sphinx_rtd_theme
62 | # html_theme = 'alabaster'
63 | html_theme = 'sphinx_rtd_theme'
64 | # html_static_path = ['_static']
65 |
--------------------------------------------------------------------------------
/docs/source/documentation.rst:
--------------------------------------------------------------------------------
1 | Full Documentation
2 | ==================
3 |
4 | .. autosummary::
5 | :toctree: _autosummary
6 | :template: module.rst
7 | :recursive:
8 |
9 | bitorch_engine
10 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 |
2 | Welcome to Bitorch Engine's documentation!
3 | ==========================================
4 |
5 | Welcome to the documentation of Bitorch Engine (BIE): a cutting-edge computation library for neural networks that enhances PyTorch by integrating specialized layers and functions for Low-Bit quantized neural network operations. This is where you can find all the information you need about how to use BIE.
6 |
7 | Building on the foundational strengths of Bitorch Engine, the technology has been employed in pioneering projects that push the boundaries of neural network training and inference. For instance,
8 |
9 | - `green-bit-llm-trainer `_: In this project, BIE represents a significant leap in the field of Large Language Model (LLM) fine-tuning. Unlike traditional approaches that either quantize a fully trained model or introduce a few additional trainable parameters for `LoRA `_ style fine-tuning, this project innovates by directly fine-tuning the quantized parameters of LLMs. This paradigm shift allows for the full-scale quantization fine-tuning of LLMs, ensuring that the training process tightly integrates with the quantization schema from the outset.
10 | - `green-bit-llm-inference `_ also showcases the BIE's adeptness at supporting inference for models quantized from 4 to 2-bits without any significant loss in accuracy compared to the original 32 or 16-bits models. It stands as a testament to BIE's capability to maintain the delicate balance between model size, computational efficiency, and accuracy, addressing one of the key challenges in deploying sophisticated neural networks in resource-constrained environments.
11 |
12 | All changes are tracked in the `changelog `_.
13 |
14 |
15 | .. toctree::
16 | :maxdepth: 4
17 | :caption: Contents:
18 |
19 | installation
20 | build_options
21 | documentation
22 |
23 | Indices and tables
24 | ==================
25 |
26 | * :ref:`genindex`
27 | * :ref:`modindex`
28 | * :ref:`search`
29 |
30 | Enjoy exploring our documentation!
31 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/examples/__init__.py
--------------------------------------------------------------------------------
/examples/mnist-lightning/mlp.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class MLP(nn.Module):
7 | def __init__(self, hidden_features=1024, num_layers=2, *args, **kwargs):
8 | super().__init__(*args, **kwargs)
9 | assert 2 <= num_layers <= 3, "MLP currently only supports 2 or 3 layers."
10 | self.num_layers = num_layers
11 | self.flatten = nn.Flatten()
12 |
13 | self.first_linear = nn.Linear(28*28, hidden_features)
14 | self.act1 = nn.PReLU()
15 | self.bn1 = nn.BatchNorm1d(num_features=hidden_features)
16 |
17 | self.linear1 = nn.Linear(hidden_features, hidden_features)
18 | self.act2 = nn.PReLU()
19 | self.bn2 = nn.BatchNorm1d(num_features=hidden_features)
20 |
21 | if num_layers > 2:
22 | self.linear2 = nn.Linear(hidden_features, hidden_features)
23 | self.act3 = nn.PReLU()
24 | self.bn3 = nn.BatchNorm1d(num_features=hidden_features)
25 |
26 | self.last_linear = nn.Linear(hidden_features, 10)
27 |
28 | def forward(self, x):
29 | x = self.flatten(x)
30 |
31 | x = self.first_linear(x)
32 | x = self.bn1(x)
33 | x = self.act1(x)
34 |
35 | x = self.linear1(x)
36 | x = self.bn2(x)
37 | x = self.act2(x)
38 |
39 | if self.num_layers > 2:
40 | x = self.linear2(x)
41 | x = self.bn3(x)
42 | x = self.act3(x)
43 |
44 | return self.last_linear(x)
45 |
46 |
47 | class SequentialMLP(nn.Module):
48 | def __init__(self, hidden_features = 2048, *args, **kwargs):
49 | super().__init__(*args, **kwargs)
50 | self.flatten = nn.Flatten()
51 | self.first_linear = nn.Linear(28*28, hidden_features)
52 | self.body = nn.Sequential(
53 | nn.BatchNorm1d(num_features=hidden_features),
54 | nn.ReLU(),
55 | nn.Linear(hidden_features, hidden_features),
56 | nn.BatchNorm1d(num_features=hidden_features),
57 | nn.ReLU(),
58 | nn.Linear(hidden_features, hidden_features),
59 | nn.BatchNorm1d(num_features=hidden_features),
60 | nn.ReLU(),
61 | )
62 | self.last_linear = nn.Linear(hidden_features, 10)
63 |
64 | def forward(self, x):
65 | x = self.flatten(x)
66 | x = self.first_linear(x)
67 | x = self.body(x)
68 | return self.last_linear(x)
69 |
--------------------------------------------------------------------------------
/examples/mnist-lightning/requirements.txt:
--------------------------------------------------------------------------------
1 | bitorch
2 | galore-torch
3 | lightning
4 |
--------------------------------------------------------------------------------
/examples/mnist/README.md:
--------------------------------------------------------------------------------
1 | # Example for MNIST
2 |
3 | In this example script we train a simple model for the MNIST dataset using [Bitorch Engine](https://github.com/GreenBitAI/bitorch-engine).
4 |
5 | First the requirements for this example need to be installed:
6 | ```bash
7 | pip install -r requirements.txt
8 | ```
9 |
10 | Then you can run the following to train an MLP with 3 layers (one of which is a binary layer),
11 | or add `--help` for more arguments:
12 | ```bash
13 | python train_mnist.py --epochs 10 --model q_mlp --log-interval 100
14 | ```
15 |
--------------------------------------------------------------------------------
/examples/mnist/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This package contains an example for training an image classification model on the MNIST data set with BITorch
3 | and deploying it with the inference engine.
4 | """
5 |
--------------------------------------------------------------------------------
/examples/mnist/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This submodule contains data preparation code for some of the datasets used with our models,
3 | i.e. MNIST, CIFAR 10 and 100 and ImageNet.
4 | """
5 |
6 | from typing import List, Type
7 |
8 | from .base import BasicDataset
9 | from .mnist import MNIST
10 |
11 | __all__ = [
12 | "BasicDataset",
13 | "dataset_from_name",
14 | "dataset_names",
15 | "MNIST",
16 | ]
17 |
18 |
19 | def dataset_from_name(name: str) -> Type[BasicDataset]:
20 | """returns the dataset to which the name belongs to (name has to be the value of the datasets
21 | name-attribute)
22 |
23 | Args:
24 | name (str): name of the dataset
25 |
26 | Raises:
27 | ValueError: raised if no dataset under that name was found
28 |
29 | Returns:
30 | dataset: the dataset
31 | """
32 | for dataset_class in [MNIST]:
33 | if dataset_class.name == name:
34 | return dataset_class
35 | raise Exception(f"unknown dataset: {name}")
36 |
37 |
38 | def dataset_names() -> List[str]:
39 | """getter for list of dataset names for argparse
40 |
41 | Returns:
42 | List: the dataset names
43 | """
44 | return [dataset_class.name for dataset_class in [MNIST]]
45 |
--------------------------------------------------------------------------------
/examples/mnist/datasets/base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path
4 | from typing import Optional, Tuple, Any
5 |
6 | import torch
7 | from torch.utils.data import Dataset
8 | from torchvision.transforms import transforms
9 |
10 | from .dummy_dataset import DummyDataset
11 |
12 |
13 | class BasicDataset(Dataset):
14 | name = "None"
15 | num_classes = 0
16 | shape = (0, 0, 0, 0)
17 | mean: Any = None
18 | std_dev: Any = None
19 | num_train_samples = 0
20 | num_val_samples = 0
21 |
22 | def __init__(self, train: bool, root_directory: Optional[str] = None, download: bool = False) -> None:
23 | """initializes the dataset.
24 |
25 | Args:
26 | train (bool): whether the train or test dataset is wanted
27 | root_directory (str): path to main dataset storage directory
28 | download (bool): whether train/test should be downloaded if it does not exist
29 |
30 | Returns:
31 | Dataset: the created test/train dataset
32 | """
33 | super(BasicDataset, self).__init__()
34 | self.is_train = train
35 | self._download = download
36 | self.root_directory = self.get_dataset_root_directory(root_directory)
37 | self.dataset = self.get_dataset(download)
38 |
39 | @classmethod
40 | def get_train_and_test(cls, root_directory: str, download: bool = False) -> Tuple["BasicDataset", "BasicDataset"]:
41 | """creates a pair of train and test dataset.
42 |
43 | Returns:
44 | Tuple: the train and test dataset
45 | """
46 | return cls(True, root_directory, download), cls(False, root_directory, download)
47 |
48 | @classmethod
49 | def get_dummy_train_and_test_datasets(cls) -> Tuple[DummyDataset, DummyDataset]:
50 | train_set = DummyDataset(cls.shape, cls.num_classes, cls.num_train_samples) # type: ignore
51 | val_set = DummyDataset(cls.shape, cls.num_classes, cls.num_val_samples) # type: ignore
52 | return train_set, val_set
53 |
54 | def get_dataset_root_directory(self, root_directory_argument: Optional[str]) -> Path:
55 | """chooses the dataset root directory based on the passed argument or environment variables.
56 |
57 | Returns:
58 | Tuple: the train and test dataset
59 | """
60 | if root_directory_argument is not None:
61 | return Path(root_directory_argument)
62 |
63 | environment_variable_name = f"{self.name.upper()}_HOME"
64 | if os.environ.get(environment_variable_name) is not None:
65 | return Path(os.environ.get(environment_variable_name)) # type: ignore
66 | if os.environ.get("BITORCH_DATA_HOME") is not None:
67 | return Path(os.environ.get("BITORCH_DATA_HOME")) / self.name # type: ignore
68 |
69 | environment_variable_hint = (
70 | f" To change this, set '{environment_variable_name}' or 'BITORCH_DATA_HOME' "
71 | f"(in the latter case, the data resides in the folder '{self.name}' in BITORCH_DATA_HOME)."
72 | f" Some datasets can be downloaded by adding the --download command line argument."
73 | )
74 | if self._download:
75 | logging.warning("Dataset is being downloaded to the directory './data'." + environment_variable_hint)
76 | return Path("./data")
77 | else:
78 | raise ValueError(f"Dataset {self.name} not found." + environment_variable_hint)
79 |
80 | def get_dataset(self, download: bool) -> Dataset:
81 | """creates the actual dataset
82 |
83 | Args:
84 | download (bool): toggles if train/test shall be downloaded if possible
85 |
86 | Raises:
87 | NotImplementedError: thrown, because this method needs to be overwritten by subclasses
88 |
89 | Returns:
90 | Dataset: the created test/train dataset
91 | """
92 | raise NotImplementedError()
93 |
94 | def get_transform(self) -> Any:
95 | if self.is_train:
96 | return self.train_transform()
97 | return self.test_transform()
98 |
99 | @classmethod
100 | def test_transform(cls) -> Any:
101 | """get the transform for the test data.
102 |
103 | Returns:
104 | transform: the transform pipeline
105 | """
106 | return transforms.Compose([transforms.ToTensor(), cls.get_normalize_transform()])
107 |
108 | @classmethod
109 | def train_transform(cls) -> Any:
110 | """get the transform for the training data.
111 |
112 | Returns:
113 | transform: the transform pipeline
114 | """
115 | return transforms.Compose([transforms.ToTensor(), cls.get_normalize_transform()])
116 |
117 | @classmethod
118 | def get_normalize_transform(cls) -> transforms.Normalize:
119 | return transforms.Normalize(cls.mean, cls.std_dev)
120 |
121 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: # type: ignore
122 | """returns the item at the given index of the dataset.
123 |
124 | Args:
125 | index (int): requested index
126 |
127 | Returns:
128 | Tuple[torch.Tensor, torch.Tensor]: data and label at the specified index
129 | """
130 | return self.dataset[index]
131 |
132 | def __len__(self) -> int:
133 | return len(self.dataset) # type: ignore
134 |
135 | def num_samples(self) -> int:
136 | """returns the (theoretical) dataset size."""
137 | return self.num_train_samples if self.is_train else self.num_val_samples
138 |
--------------------------------------------------------------------------------
/examples/mnist/datasets/dummy_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import torch
3 | from typing import Tuple
4 |
5 |
6 | class DummyDataset(Dataset):
7 | """An iterator that produces repeated dummy data.
8 | Args:
9 | data_sample: a data sample that should be produced at each step.
10 | batch_size: the batch size for storing.
11 | sample_count: number of `data` samples in the dummy dataset.
12 | """
13 |
14 | def __init__(self, data_shape: torch.Size, num_classes: int, sample_count: int) -> None:
15 | self._data_sample = torch.zeros(data_shape)
16 | self._class_sample = torch.zeros((num_classes,), dtype=torch.int64)
17 | self._sample_count = sample_count
18 |
19 | def __len__(self) -> int:
20 | return self._sample_count
21 |
22 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
23 | return self._data_sample, self._class_sample
24 |
--------------------------------------------------------------------------------
/examples/mnist/datasets/mnist.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from torchvision.datasets import mnist
3 |
4 | from .base import BasicDataset
5 |
6 |
7 | class MNIST(BasicDataset):
8 | name = "mnist"
9 | num_classes = 10
10 | shape = (1, 1, 28, 28)
11 |
12 | mean = (0.1307,)
13 | std_dev = (0.3081,)
14 | num_train_samples = 60000
15 | num_val_samples = 10000
16 |
17 | def get_dataset(self, download: bool = True) -> Dataset:
18 | return mnist.MNIST(
19 | root=self.root_directory,
20 | train=self.is_train,
21 | transform=self.get_transform(),
22 | download=download,
23 | )
24 |
--------------------------------------------------------------------------------
/examples/mnist/requirements.txt:
--------------------------------------------------------------------------------
1 | bitorch
2 |
--------------------------------------------------------------------------------
/licenses/LICENSE.cutlass.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | SPDX-License-Identifier: BSD-3-Clause
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | 2. Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | 3. Neither the name of the copyright holder nor the names of its
15 | contributors may be used to endorse or promote products derived from
16 | this software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/licenses/LICENSE.exllamav2.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
--------------------------------------------------------------------------------
/licenses/LICENSE.mlx.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright © 2023 Apple Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/licenses/LICENSE.pytorch.txt:
--------------------------------------------------------------------------------
1 | From PyTorch:
2 |
3 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
4 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
5 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
6 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
7 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
8 | Copyright (c) 2011-2013 NYU (Clement Farabet)
9 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
10 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
11 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
12 |
13 | From Caffe2:
14 |
15 | Copyright (c) 2016-present, Facebook Inc. All rights reserved.
16 |
17 | All contributions by Facebook:
18 | Copyright (c) 2016 Facebook Inc.
19 |
20 | All contributions by Google:
21 | Copyright (c) 2015 Google Inc.
22 | All rights reserved.
23 |
24 | All contributions by Yangqing Jia:
25 | Copyright (c) 2015 Yangqing Jia
26 | All rights reserved.
27 |
28 | All contributions by Kakao Brain:
29 | Copyright 2019-2020 Kakao Brain
30 |
31 | All contributions by Cruise LLC:
32 | Copyright (c) 2022 Cruise LLC.
33 | All rights reserved.
34 |
35 | All contributions from Caffe:
36 | Copyright(c) 2013, 2014, 2015, the respective contributors
37 | All rights reserved.
38 |
39 | All other contributions:
40 | Copyright(c) 2015, 2016 the respective contributors
41 | All rights reserved.
42 |
43 | Caffe2 uses a copyright model similar to Caffe: each contributor holds
44 | copyright over their contributions to Caffe2. The project versioning records
45 | all such contribution and copyright details. If a contributor wants to further
46 | mark their specific copyright on a particular contribution, they should
47 | indicate their copyright solely in the commit message of the change when it is
48 | committed.
49 |
50 | All rights reserved.
51 |
52 | Redistribution and use in source and binary forms, with or without
53 | modification, are permitted provided that the following conditions are met:
54 |
55 | 1. Redistributions of source code must retain the above copyright
56 | notice, this list of conditions and the following disclaimer.
57 |
58 | 2. Redistributions in binary form must reproduce the above copyright
59 | notice, this list of conditions and the following disclaimer in the
60 | documentation and/or other materials provided with the distribution.
61 |
62 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
63 | and IDIAP Research Institute nor the names of its contributors may be
64 | used to endorse or promote products derived from this software without
65 | specific prior written permission.
66 |
67 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
68 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
69 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
70 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
71 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
72 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
73 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
74 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
75 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
76 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
77 | POSSIBILITY OF SUCH DAMAGE.
78 |
--------------------------------------------------------------------------------
/licenses/LICENSE.tcbnn.txt:
--------------------------------------------------------------------------------
1 | TCBNN:
2 | Accelerating Binarized Neural Networks via Bit-Tensor-Cores in Turing GPUs
3 |
4 | 06/30/2020 by Ang Li from High-Performance-Computing Group,
5 | ACMD, PCSD, Pacific Northwest National Laboratory (PNNL),
6 | Richland, WA, 99354, USA.
7 |
8 |
9 | Copyright © 2020, Battelle Memorial Institute
10 |
11 | 1.Battelle Memorial Institute (hereinafter Battelle) hereby grants permission
12 | to any person or entity lawfully obtaining a copy of this software and associated
13 | documentation files (hereinafter “the Software”) to redistribute and use the
14 | Software in source and binary forms, with or without modification. Such person
15 | or entity may use, copy, modify, merge, publish, distribute, sublicense, and/or
16 | sell copies of the Software, and may permit others to do so, subject to the
17 | following conditions:
18 |
19 | - Redistributions of source code must retain the above copyright notice, this list
20 | of conditions and the following disclaimers.
21 |
22 | - Redistributions in binary form must reproduce the above copyright notice, this list
23 | of conditions and the following disclaimer in the documentation and/or other materials
24 | provided with the distribution.
25 |
26 | - Other than as used herein, neither the name Battelle Memorial Institute or Battelle
27 | may be used in any form whatsoever without the express written consent of Battelle.
28 |
29 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
30 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
31 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
32 | SHALL BATTELLE OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
33 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
35 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
36 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
37 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
38 |
39 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | pytest
2 | numpy
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | bitorch
2 | torch
3 | py-cpuinfo>=9.0.0
4 | setuptools~=69.0
5 |
--------------------------------------------------------------------------------
/structure.md:
--------------------------------------------------------------------------------
1 |
2 | .
3 | ├── ...
4 | ├── layers # Low-bit layers
5 | │ ├── qconv # quantized Convolotional Layer
6 | │ ├── qembedding # currently only 1-bit embedding layer supported
7 | │ ├── ...
8 | │ └── qlinear
9 | │ ├── binary (1-bit)
10 | │ │ ├── cpp # x86 CPU
11 | │ │ ├── cuda # Nvidia GPU
12 | │ │ └── cutlass # Nvidia GPU
13 | │ └── n-bit (2/4/8-bit)
14 | │ ├── mps # Apple GPU
15 | │ ├── cuda # Nvidia GPU, e.g., weight-only quantized LLMs
16 | │ └── cutlass # Nvidia GPU, e.g., quantization aware training for both activation and weight
17 | └── optim
18 | │ └── DiodeMix # dedicated optimizer for low-bit quantized model
19 | ├── functions
20 | └── ...
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/tests/__init__.py
--------------------------------------------------------------------------------
/tests/functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/tests/functions/__init__.py
--------------------------------------------------------------------------------
/tests/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/tests/layers/__init__.py
--------------------------------------------------------------------------------
/tests/layers/test_custom_binary_linear.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch import nn
4 | from torch.autograd import Function
5 | from torch.nn import Parameter
6 |
7 |
8 | class LinearFunction(Function):
9 | @staticmethod
10 | # ctx is the first argument to forward
11 | def forward(ctx, input, weight, bias=None):
12 | # The forward pass can use ctx.
13 | ctx.save_for_backward(input, weight, bias)
14 | # TODO: instead of converting to float and use torch's mm we should use binary mm instead
15 | output = input.float().mm(weight.t().float())
16 | if bias is not None:
17 | output += bias.unsqueeze(0).expand_as(output)
18 | return output
19 |
20 | @staticmethod
21 | def backward(ctx, grad_output):
22 | input, weight, bias = ctx.saved_tensors
23 | grad_input = grad_weight = grad_bias = None
24 |
25 | if ctx.needs_input_grad[0]:
26 | # TODO: instead of converting to float and use torch's mm we should use our binary mm instead
27 | grad_input = grad_output.mm(weight.float())
28 | if ctx.needs_input_grad[1]:
29 | # TODO: instead of converting to float and use torch's mm we should use our binary mm instead
30 | grad_weight = grad_output.t().mm(input.float())
31 | if bias is not None and ctx.needs_input_grad[2]:
32 | grad_bias = grad_output.sum(0)
33 |
34 | print("Weight grads calculated: ", grad_weight)
35 |
36 | # manually need to convert to our required formulation
37 | def binarize_grad(x):
38 | # 1.0 = True
39 | # 0.0 = False
40 | return torch.where(x >= 0.0, 1.0, 0.0)
41 |
42 | # TODO: we could inline this later
43 | grad_input = binarize_grad(grad_input)
44 | grad_weight = binarize_grad(grad_weight)
45 |
46 | return grad_input, grad_weight, grad_bias
47 |
48 |
49 | # Option 2: wrap in a function, to support default args and keyword args.
50 | def linear(input, weight, bias=None):
51 | return LinearFunction.apply(input, weight, bias)
52 |
53 |
54 | class TLinear(nn.Module):
55 | def __init__(self, input_features, output_features, bias=True):
56 | super().__init__()
57 | self.input_features = input_features
58 | self.output_features = output_features
59 |
60 | # nn.Parameter is a special kind of Tensor, that will get
61 | # automatically registered as Module's parameter once it's assigned
62 | # as an attribute. Parameters and buffers need to be registered, or
63 | # they won't appear in .parameters() (doesn't apply to buffers), and
64 | # won't be converted when e.g. .cuda() is called. You can use
65 | # .register_buffer() to register buffers.
66 | # nn.Parameters require gradients by default.
67 | self.weight = nn.Parameter(torch.empty(output_features, input_features))
68 | if bias:
69 | self.bias = nn.Parameter(torch.empty(output_features))
70 | else:
71 | # You should always register all possible parameters, but the
72 | # optional ones can be None if you want.
73 | self.register_parameter('bias', None)
74 |
75 | # Not a very smart way to initialize weights
76 | nn.init.uniform_(self.weight, -0.1, 0.1)
77 | if self.bias is not None:
78 | nn.init.uniform_(self.bias, -0.1, 0.1)
79 |
80 | def forward(self, input):
81 | # See the autograd section for explanation of what happens here.
82 | return LinearFunction.apply(input, self.weight, self.bias)
83 |
84 | def extra_repr(self):
85 | # (Optional)Set the extra information about this module. You can test
86 | # it by printing an object of this class.
87 | return 'input_features={}, output_features={}, bias={}'.format(
88 | self.input_features, self.output_features, self.bias is not None
89 | )
90 |
91 |
92 | def test_q_linear_with_binary_weights():
93 | torch.manual_seed(42)
94 | # import pydevd_pycharm
95 | # pydevd_pycharm.settrace('localhost', port=11004, stdoutToServer=True, stderrToServer=True)
96 |
97 | batch_size = 10
98 | num_inputs = 32
99 | num_outputs = 64
100 |
101 | # option 1: set dtype
102 | # currently not possible, check uniform bounds fails:
103 | # layer = TstLinear(num_inputs, num_outputs, bias=False, dtype=torch.bool)
104 |
105 | # option 2: manually replace weight:
106 | # currently possible
107 | # but we have to
108 | layer = TLinear(num_inputs, num_outputs, bias=False)
109 | layer.weight = Parameter(torch.rand((num_outputs, num_inputs)) > 0.5, requires_grad=True)
110 |
111 | input = Parameter(torch.rand((batch_size, num_inputs)) > 0.5, requires_grad=True)
112 |
113 | result = layer(input)
114 | print(result)
115 | print("Result shape: ", result.size())
116 |
117 | print("Grad before backward: ", layer.weight.grad)
118 | mse_loss = nn.MSELoss()
119 | dummy_loss = mse_loss(result, torch.ones_like(result) * 10)
120 | dummy_loss.backward()
121 | print("Grad after backward: ", layer.weight.grad)
122 | # technically it works, but we only get boolean gradients (all true currently)
123 | # could be fixed with custom backward pass?
124 |
--------------------------------------------------------------------------------
/tests/layers/test_nbit_conv.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import numpy as np
4 | import time
5 | from bitorch.layers import QConv2d
6 | from tests.layers.util import get_cuda_test_device
7 | from bitorch_engine.layers.qconv.nbit.cutlass import Q4Conv2dCutlass
8 |
9 | """
10 | Test nbit inference layers
11 | """
12 |
13 | # lower print threshold
14 | torch.set_printoptions(threshold=100)
15 |
16 | def to_device(data: torch.Tensor, device: torch.device) -> torch.Tensor:
17 | if isinstance(data, (list, tuple)):
18 | return [to_device(x, device) for x in data]
19 | return data.to(device)
20 |
21 | TEST_BATCH_SIZE = [1, 32, 64, 128]
22 | # Input shape: (batch size, num of input channels, h, w)
23 | TEST_INPUT_DATA = [
24 | ((64, 64, 64), [64, 32],
25 | {"kernel_size": 3, "padding": 2, "stride": 2, "dilation": 1}),
26 | ((64, 56, 56), [64, 256],
27 | {"kernel_size": 1, "padding": 0, "stride": 1, "dilation": 1}),
28 | ((64, 56, 56), [64, 64],
29 | {"kernel_size": 3, "padding": 0, "stride": 1, "dilation": 1}),
30 | ((256, 56, 56), [256, 64],
31 | {"kernel_size": 1, "padding": 0, "stride": 1, "dilation": 1}),
32 | ((256, 56, 56), [256, 128],
33 | {"kernel_size": 1, "padding": 0, "stride": 1, "dilation": 1}),
34 | ((128, 28, 28), [128, 128],
35 | {"kernel_size": 3, "padding": 0, "stride": 1, "dilation": 1}),
36 | ((128, 28, 28), [128, 512],
37 | {"kernel_size": 1, "padding": 0, "stride": 1, "dilation": 1}),
38 | ((256, 14, 14), [256, 256],
39 | {"kernel_size": 3, "padding": 2, "stride": 1, "dilation": 1}),
40 | ((512, 7, 7), [512, 512],
41 | {"kernel_size": 3, "padding": 0, "stride": 1, "dilation": 1}),
42 | ]
43 |
44 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available, skipping CUDA-related test")
45 | @pytest.mark.parametrize("input_shape, args, kwargs", TEST_INPUT_DATA)
46 | @pytest.mark.parametrize("BS", TEST_BATCH_SIZE)
47 | def test_q4_conv_cuda(input_shape, args, kwargs, BS):
48 | print("\n")
49 | print("batch_size: {}, input_shape:{}, input/output nums:{}, other args:{}"
50 | .format(BS, input_shape, args, kwargs))
51 | input_shape = (BS, input_shape[0], input_shape[1], input_shape[2])
52 | input = np.random.uniform(-1, 1, input_shape)
53 | input_tensor = torch.tensor(input).float()
54 |
55 | num_runs = 10
56 |
57 | # to gpu
58 | device = get_cuda_test_device()
59 | torch.cuda.set_device(device)
60 | input_tensor_cuda = to_device(input_tensor, device)
61 |
62 | layer = QConv2d(*args, **kwargs)
63 | layer.to(device)
64 | start_time = time.time()
65 | for i in range(num_runs):
66 | result_bitorch = layer(input_tensor_cuda)
67 | time_engine = time.time() - start_time
68 | print("bitorch binary-conv: %.6f s" % (time_engine / num_runs))
69 |
70 | padding = kwargs["padding"]
71 | kernel_size = kwargs["kernel_size"]
72 | stride = kwargs["stride"]
73 | dilation = kwargs["dilation"]
74 | nbit_conv_layer = Q4Conv2dCutlass(in_channels=int(args[0]),
75 | out_channels=args[1],
76 | kernel_size=kernel_size,
77 | stride=stride,
78 | padding=padding,
79 | dilation=dilation,
80 | device=device)
81 | nbit_conv_layer.to(device)
82 | nbit_conv_layer.prepare_params()
83 | result = nbit_conv_layer(input_tensor_cuda)
84 | # print(result)
85 |
86 | grad_input_data = np.random.uniform(-1, 1, result.shape)
87 | grad_input_tensor = torch.tensor(grad_input_data).float()
88 | grad_input_tensor_cuda = to_device(grad_input_tensor, device)
89 |
90 | # use quantized weight for inference
91 | nbit_conv_layer.generate_quantized_weight(qweight_only=True)
92 | nbit_conv_layer.eval()
93 | start_time = time.time()
94 | for i in range(num_runs):
95 | result_quantized_w = nbit_conv_layer(input_tensor_cuda)
96 | time_engine = time.time() - start_time
97 | print("engine q4-conv forward (CUTLASS): %.6f s" % (time_engine / num_runs))
98 | # print(result_quantized_w)
99 | assert torch.equal(result, result_quantized_w)
100 |
101 | start_time = time.time()
102 | for i in range(num_runs):
103 | result.backward(grad_input_tensor_cuda, retain_graph=True)
104 | torch.cuda.synchronize()
105 | time_engine = time.time() - start_time
106 | print("bitorch-engine q4-conv backward (CUTLASS): %.6f s" % (time_engine/num_runs))
107 |
--------------------------------------------------------------------------------
/tests/layers/test_nbit_linear_mixbits.py:
--------------------------------------------------------------------------------
1 | import time, math
2 | import pytest
3 | import torch
4 | import json
5 |
6 | from bitorch_engine.layers.qlinear.nbit.cuda import MBWQLinearCuda
7 | from tests.layers.util import get_cuda_test_device, get_packed_info, get_q_groups
8 |
9 |
10 | """
11 | Test nbit inference layers
12 | """
13 |
14 | # lower print threshold
15 | torch.set_printoptions(threshold=100)
16 |
17 |
18 | def to_device(data: torch.Tensor, device: torch.device) -> torch.Tensor:
19 | if isinstance(data, (list, tuple)):
20 | return [to_device(x, device) for x in data]
21 | return data.to(device)
22 |
23 |
24 | # ========= testing data ========== #
25 |
26 | test_json_string = '{' \
27 | '"q_proj": { "group_size": { "4": 32, "2": 32 }, "bits": [ 4, 2 ], "bits_prop": [ 0.75, 0.25 ], "scale_bits": 4 }, ' \
28 | '"k_proj": { "group_size": { "4": 32, "2": 32 }, "bits": [ 4, 2 ], "bits_prop": [ 0.25, 0.75 ], "scale_bits": 4 } ' \
29 | '}'
30 |
31 |
32 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available, skipping CUDA-related test")
33 | @pytest.mark.parametrize("num_input_features", [128, 4096])
34 | @pytest.mark.parametrize("num_hidden_fc", [128, 4096])
35 | @pytest.mark.parametrize("batch_size", [1, 2])
36 | def test_mbwq_linear_exl2_cuda(num_input_features, num_hidden_fc, batch_size):
37 |
38 | # support 4-bit und bf16 only!
39 | dtype = torch.half
40 | num_runs = 1
41 | device = get_cuda_test_device()
42 | torch.cuda.set_device(device)
43 |
44 | # creating test data
45 | input_data = torch.normal(0, 1, size=(batch_size, num_input_features), requires_grad=False, dtype=dtype)
46 | # to gpu
47 | input_data_cuda = to_device(input_data, device)
48 |
49 | # Parsing the JSON string into a Python dictionary
50 | gbe_strategy = json.loads(test_json_string)
51 |
52 | for key, value in gbe_strategy.items():
53 | # read attribute information from predefined json config
54 | groups, rows = get_packed_info(num_input_features, value["bits"], value["bits_prop"], value["group_size"])
55 |
56 | print("\nM:{}, N:{}, K:{}, bits:{}, group_size:{}, bits_prop:{}, groups:{}, packed_rows:{}."
57 | .format(batch_size, num_hidden_fc, num_input_features, str(value["bits"]), str(value["group_size"]), str(value["bits_prop"]),
58 | groups, rows))
59 |
60 | # creating int weights
61 | random_int4_tensor = torch.randint(0, 2 ** 16 - 1, size=(rows, num_hidden_fc))
62 | int_weight_cuda = to_device(random_int4_tensor, device)
63 |
64 | mbwq_linear_layer = MBWQLinearCuda(in_channels=num_input_features, out_channels=num_hidden_fc,
65 | w_bit=4, dtype=dtype, group_size=32,
66 | dq_group_size=1, use_gba_quant=True,
67 | asym=False, dq_mode=2, use_mbw=True,
68 | groups=groups, rows_packed=rows)
69 |
70 | mbwq_linear_layer.set_qweight_data(int_weight_cuda)
71 | # random scales and zeros
72 | scales = torch.randn_like(mbwq_linear_layer.scales).half()
73 | zeros = torch.randn_like(mbwq_linear_layer.zeros).half()
74 | mbwq_linear_layer.set_scales(scales)
75 | mbwq_linear_layer.set_zeros(zeros)
76 | mbwq_linear_layer.q_perm = torch.tensor([i for i in range(num_input_features)], dtype=torch.short)
77 |
78 | # get q_groups
79 | q_groups = get_q_groups(groups, value["bits"], value["group_size"], num_input_features, value["bits_prop"])
80 |
81 | assert mbwq_linear_layer.q_groups.numel() == len(q_groups)
82 |
83 | mbwq_linear_layer.q_groups = torch.Tensor(q_groups).to(torch.short)
84 |
85 | # to device
86 | mbwq_linear_layer.to(device)
87 |
88 | # will perform qweight layout transformation
89 | mbwq_linear_layer.prepare_params()
90 |
91 | # Testing fp weight reconstruction
92 | reconstructed_fp_weights = MBWQLinearCuda.exl2fp_weight(mbwq_linear_layer.qweight, mbwq_linear_layer.scales,
93 | mbwq_linear_layer.zeros, mbwq_linear_layer.q_perm,
94 | mbwq_linear_layer.q_group_map, mbwq_linear_layer.rows).to(dtype)
95 |
96 | # pytorch result:
97 | result_pt = torch.matmul(input_data_cuda.mul(mbwq_linear_layer.channel_scale), reconstructed_fp_weights)
98 |
99 | # Testing inference output
100 | start_time = time.time()
101 | for i in range(num_runs):
102 | result = mbwq_linear_layer(input_data_cuda)
103 | torch.cuda.synchronize()
104 | time_engine = time.time() - start_time
105 |
106 | print("bitorch-engine mbwq_linear (CUDA) run time: %.6f s" % (time_engine/num_runs))
107 |
108 | assert torch.all(torch.isclose(result, result_pt, rtol=2, atol=2, equal_nan=False))
--------------------------------------------------------------------------------
/tests/layers/test_nbit_linear_mps.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import pytest
4 | import torch
5 | import numpy as np
6 |
7 | from bitorch_engine.layers.qlinear.nbit.mps import MPQLinearMlx
8 | from tests.layers.util import get_mps_test_device
9 |
10 | """
11 | Test nbit inference layers
12 | """
13 |
14 | # lower print threshold
15 | torch.set_printoptions(threshold=100)
16 |
17 |
18 | def to_device(data: torch.Tensor, device: torch.device) -> torch.Tensor:
19 | if isinstance(data, (list, tuple)):
20 | return [to_device(x, device) for x in data]
21 | return data.to(device)
22 |
23 |
24 | @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available, skipping MPS-related test")
25 | @pytest.mark.parametrize("w_bit", [2, 4, 8])
26 | @pytest.mark.parametrize("dtype", [torch.half])
27 | @pytest.mark.parametrize("num_input_features", [1024, 8192])
28 | @pytest.mark.parametrize("num_hidden_fc", [1024, 8192])
29 | @pytest.mark.parametrize("batch_size", [64])
30 | @pytest.mark.parametrize("group_size", [32, 128])
31 | @pytest.mark.parametrize("llama_v", [2])
32 | @pytest.mark.parametrize("model_size", ["1.1b", "7b", "30b", "70b"])
33 | def test_mpq_linear_mps(w_bit, dtype, num_input_features, num_hidden_fc, batch_size, group_size, llama_v, model_size):
34 | import mlx.core
35 |
36 | if ((w_bit == 1 or w_bit == 8) and group_size == 32) \
37 | or (model_size in ["30b", "70b"] and w_bit != 2) \
38 | or (model_size in ["1.1b"] and w_bit != 2):
39 | pytest.skip()
40 |
41 | # === following configuration from low_bit_llama https://github.com/GreenBitAI/low_bit_llama === #
42 | double_groupsize = -1
43 | if group_size == 32 and model_size not in ["1.1b", "1.1B"]:
44 | asym = True
45 | else:
46 | asym = False
47 |
48 | if w_bit == 2:
49 | if asym:
50 | double_groupsize = -1
51 | else:
52 | if group_size == 32:
53 | double_groupsize = 32
54 | else:
55 | if llama_v == 1 and model_size not in ["30b", "30B"]:
56 | double_groupsize = 64
57 | else:
58 | double_groupsize = 32
59 | else:
60 | if model_size in ["3b", "3B"]:
61 | double_groupsize = 64
62 | elif model_size in ["7b", "7B"]:
63 | double_groupsize = 256
64 |
65 | v1 = (llama_v == 1) and model_size in ["7b", "7B"]
66 |
67 | if w_bit not in [2, 4, 8]:
68 | print("w_bit not supported")
69 | pytest.skip()
70 | if group_size not in [32, 64, 128]:
71 | print("group_size not supported")
72 | pytest.skip()
73 | if asym:
74 | print("asym not supported")
75 | pytest.skip()
76 | # =============================================================================================== #
77 |
78 | print("\nM:{}, N:{}, K:{}, bits:{}, dtype:{}, groupsize:{}, llama_v:{}, model_size:{}."
79 | .format(batch_size, num_hidden_fc, num_input_features, w_bit, dtype, group_size, llama_v, model_size))
80 |
81 | num_runs = 10
82 | # creating test data
83 | input_data = torch.normal(0, 1, size=(batch_size, num_input_features), requires_grad=False, dtype=dtype)
84 | grad_input_data = torch.normal(0, 1, size=(batch_size, num_hidden_fc), requires_grad=False, dtype=dtype)
85 |
86 | # to gpu
87 | device = get_mps_test_device()
88 | input_data_mps = to_device(input_data, device)
89 | grad_input_data_mps = to_device(grad_input_data, device)
90 |
91 | time_engine = 0
92 | for i in range(num_runs):
93 | b_linear = torch.nn.Linear(num_input_features, num_hidden_fc, bias=False, dtype=dtype)
94 | b_linear.to(device)
95 | start_time = time.time()
96 | result = b_linear(input_data_mps)
97 | time_engine += time.time() - start_time
98 | print(f"pytorch linear forward run time (device {device}): {time_engine / num_runs:.6f} s")
99 |
100 | mpq_linear_layer = MPQLinearMlx(in_channels=num_input_features,
101 | out_channels=num_hidden_fc,
102 | w_bit=w_bit,
103 | dtype=dtype,
104 | group_size=group_size,
105 | dq_group_size=double_groupsize,
106 | dq_mode=1 if v1 else 2,
107 | use_gba_quant=True,
108 | asym=asym,
109 | requires_grad=False)
110 |
111 | mpq_linear_layer.to(device)
112 | mpq_linear_layer.prepare_params()
113 |
114 | time_engine = 0
115 | time_mlx = 0
116 | for i in range(num_runs):
117 | random_int_tensor = torch.randint(1, 1000, size=mpq_linear_layer.qweight.shape,
118 | dtype=mpq_linear_layer.qweight.dtype, device=mpq_linear_layer.qweight.device)
119 | mpq_linear_layer.set_qweight_data(random_int_tensor)
120 | start_time = time.time()
121 | result1 = mpq_linear_layer(input_data_mps)
122 | time_engine += time.time() - start_time
123 |
124 | qweight_mlx = mlx.core.array(mpq_linear_layer.qweight.numpy().astype(np.uint32))
125 | scales_mlx = mlx.core.array(mpq_linear_layer.scales.numpy())
126 | zeros_mlx = mlx.core.array(mpq_linear_layer.zeros.numpy())
127 |
128 | start_time = time.time()
129 | input_mlx = mlx.core.array(input_data_mps.cpu().numpy())
130 | mlx_matmul = mlx.core.quantized_matmul(
131 | input_mlx,
132 | qweight_mlx,
133 | scales=scales_mlx,
134 | biases=zeros_mlx,
135 | transpose=True,
136 | group_size=group_size,
137 | bits=w_bit)
138 | mlx_matmul = torch.from_numpy(np.array(mlx_matmul)).to('mps')
139 | time_mlx += time.time() - start_time
140 | assert mlx_matmul.equal(result1), "mps matmul failed"
141 | print(f"bitorch-engine mpq linear forward (MPS) run time: {time_engine / num_runs:.6f} s")
142 | print(f"Mlx (MPS) run time: {time_mlx / num_runs:.6f} s")
--------------------------------------------------------------------------------
/tests/layers/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from typing import Union, List
4 |
5 | import torch
6 |
7 |
8 | def activate_remote_pycharm_debug(port: int = 11004):
9 | import pydevd_pycharm
10 | pydevd_pycharm.settrace('localhost', port=port, stdoutToServer=True, stderrToServer=True)
11 |
12 |
13 | def to_device(data: torch.Tensor, device: torch.device) -> Union[torch.Tensor, List[torch.Tensor]]:
14 | if isinstance(data, (list, tuple)):
15 | return [to_device(x, device) for x in data]
16 | return data.to(device)
17 |
18 |
19 | def get_cuda_test_device_id():
20 | return int(os.environ.get("BIE_DEVICE", "0"))
21 |
22 |
23 | def get_cuda_test_device():
24 | return torch.device(f"cuda:{get_cuda_test_device_id()}")
25 |
26 |
27 | def get_mps_test_device():
28 | return torch.device("mps")
29 |
30 |
31 | def get_packed_info(channels, n_bits, bits_prop, bits_group_size):
32 | groups = 0
33 | rows = 0
34 | bits_channel = []
35 | for idx in range(len(bits_prop)):
36 | if idx < len(bits_prop) - 1:
37 | minimal_channels = list(bits_group_size.values())[idx]
38 | channel_pre_pack = max(1, int(channels * (bits_prop[idx])) // minimal_channels) * minimal_channels
39 | bits_channel.append(channel_pre_pack)
40 | groups += channel_pre_pack // minimal_channels
41 | rows += channel_pre_pack // 32 * n_bits[idx]
42 | else:
43 | minimal_channels = list(bits_group_size.values())[idx]
44 | channel_pre_pack = channels - sum(bits_channel)
45 | bits_channel.append(channel_pre_pack)
46 | groups += channel_pre_pack // minimal_channels
47 | rows += channel_pre_pack // 32 * n_bits[idx]
48 |
49 | return groups, rows
50 |
51 |
52 | def get_q_groups(groups, n_bits, group_size, channels, bits_prop):
53 | qgroups = []
54 | bits_column_end_index = []
55 |
56 | for idx in range(len(bits_prop)):
57 | if idx < len(bits_prop) - 1:
58 | minimal_columns = list(group_size.values())[idx]
59 | columns_index = max(1, int(channels * (
60 | bits_prop[idx])) // minimal_columns) * minimal_columns # TODO: determine the minimal bits columns
61 | if idx > 0:
62 | columns_index += bits_column_end_index[-1]
63 | bits_column_end_index.append(columns_index)
64 | else:
65 | bits_column_end_index.append(channels)
66 |
67 | for bits_idx, bits in enumerate(n_bits):
68 | if bits_idx == 0:
69 | rows_per_bit = bits_column_end_index[bits_idx]
70 | else:
71 | rows_per_bit = bits_column_end_index[bits_idx] - bits_column_end_index[bits_idx - 1]
72 |
73 | gs = group_size[str(bits)]
74 | groups_per_bit = rows_per_bit // gs
75 |
76 | for group in range(groups_per_bit):
77 | qgroups.append(bits) # record bits per group
78 | qgroups.append(0)
79 |
80 | out_row = 0
81 | rem_rows = channels
82 | for i in range(groups):
83 | bits = qgroups[2 * i]
84 | gs = group_size[str(bits)]
85 |
86 | rows_per_group = min(gs, rem_rows) # rows per group before packing
87 | wpqr = 32 / bits # INT32 elements per group for packing
88 | qrows = math.ceil(rows_per_group / wpqr) # rows per group after packing
89 | qgroups[2 * i + 1] = out_row # record packed rows start idx per group
90 |
91 | out_row += qrows
92 |
93 | return qgroups
94 |
95 |
96 | def pack_rows_4_pytorch(input_tensor, rows, columns):
97 | # Calculate the number of output columns and store 4 columns per 32 bits, so use columns * 4 / 32
98 | out_columns = columns * 4 // 32
99 |
100 | # Initialize the output tensor, the size is [rows, out_columns], the data type is uint32
101 | output_tensor = torch.zeros((rows, out_columns), dtype=torch.int64, device=input_tensor.device)
102 |
103 | # To simulate the behavior of the CUDA kernel, we need to perform the same operation for each output element
104 | # This is implemented using the broadcast and indexing functions of PyTorch
105 | for row in range(rows):
106 | for out_column in range(out_columns):
107 | packed = 0
108 | for i in range(8):
109 | # Simulate operations in the CUDA kernel
110 | x = input_tensor[row, out_column * 8 + i].item() - 1
111 | packed |= (x << (i * 4))
112 | output_tensor[row, out_column] = packed
113 | # Use bitwise operators to extract the unsigned lower 32 bits
114 | # Note: 0xFFFFFFFF is a 32-bit all-1 mask, which is used to extract the lower 32 bits
115 | output_tensor = (output_tensor & 0xFFFFFFFF).to(torch.int32)
116 | return output_tensor
--------------------------------------------------------------------------------
/tests/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GreenBitAI/bitorch-engine/0009b1f8178e86df6433383879c0005a49c0cc92/tests/util/__init__.py
--------------------------------------------------------------------------------
/tests/util/binary_mse_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def binary_mse_loss(output, target):
4 | print("Output:", output)
5 | print("Target:", target)
6 | loss = torch.mean((output.to(torch.float32) - target.to(torch.float32))**2)
7 | return loss
8 |
--------------------------------------------------------------------------------
/version.txt:
--------------------------------------------------------------------------------
1 | 0.2.6
2 |
--------------------------------------------------------------------------------