The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitignore
├── .gitmodules
├── CMakeLists.txt
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── assets
    ├── header_model_release.png
    ├── intel_performance.jpg
    ├── m2_performance.jpg
    ├── tl1.png
    └── tl2.png
├── docs
    └── codegen.md
├── gpu
    ├── README.md
    ├── bitnet_kernels
    │   ├── bitnet_kernels.cu
    │   ├── bitnet_kernels.h
    │   ├── compile.sh
    │   └── setup.py
    ├── convert_checkpoint.py
    ├── convert_safetensors.py
    ├── generate.py
    ├── model.py
    ├── pack_weight.py
    ├── requirements.txt
    ├── sample_utils.py
    ├── stats.py
    ├── test.py
    ├── tokenizer.model
    └── tokenizer.py
├── include
    └── ggml-bitnet.h
├── media
    ├── benchmark.png
    └── demo.mp4
├── preset_kernels
    ├── Llama3-8B-1.58-100B-tokens
    │   ├── bitnet-lut-kernels-tl1.h
    │   ├── bitnet-lut-kernels-tl2.h
    │   ├── kernel_config_tl1.ini
    │   └── kernel_config_tl2.ini
    ├── bitnet_b1_58-3B
    │   ├── bitnet-lut-kernels-tl1.h
    │   ├── bitnet-lut-kernels-tl2.h
    │   ├── kernel_config_tl1.ini
    │   └── kernel_config_tl2.ini
    └── bitnet_b1_58-large
    │   ├── bitnet-lut-kernels-tl1.h
    │   ├── bitnet-lut-kernels-tl2.h
    │   ├── kernel_config_tl1.ini
    │   └── kernel_config_tl2.ini
├── requirements.txt
├── run_inference.py
├── run_inference_server.py
├── setup_env.py
├── src
    ├── CMakeLists.txt
    ├── ggml-bitnet-lut.cpp
    └── ggml-bitnet-mad.cpp
└── utils
    ├── codegen_tl1.py
    ├── codegen_tl2.py
    ├── convert-helper-bitnet.py
    ├── convert-hf-to-gguf-bitnet.py
    ├── convert-ms-to-gguf-bitnet.py
    ├── convert.py
    ├── e2e_benchmark.py
    ├── generate-dummy-bitnet-model.py
    ├── kernel_tuning.py
    └── preprocess-huggingface-bitnet.py


/.gitignore:
--------------------------------------------------------------------------------
 1 | # Extensions
 2 | 
 3 | *.a
 4 | *.bat
 5 | *.bin
 6 | *.dll
 7 | *.dot
 8 | *.etag
 9 | *.exe
10 | *.gcda
11 | *.gcno
12 | *.gcov
13 | *.gguf
14 | *.gguf.json
15 | *.lastModified
16 | *.log
17 | *.metallib
18 | *.o
19 | *.so
20 | *.tmp
21 | 
22 | # IDE / OS
23 | 
24 | .cache/
25 | .ccls-cache/
26 | .direnv/
27 | .DS_Store
28 | .envrc
29 | .idea/
30 | .swiftpm
31 | .vs/
32 | .vscode/
33 | nppBackup
34 | 
35 | # Models
36 | models/*
37 | gpu/checkpoints/*
38 | 
39 | # Python
40 | 
41 | /.venv
42 | __pycache__/
43 | */poetry.lock
44 | poetry.toml
45 | 
46 | build/
47 | logs/


--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "3rdparty/llama.cpp"]
2 | 	path = 3rdparty/llama.cpp
3 | 	url = https://github.com/Eddie-Wang1120/llama.cpp.git
4 | 	branch = merge-dev
5 | 


--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
 1 | cmake_minimum_required(VERSION 3.14)  # for add_link_options and implicit target directories.
 2 | project("bitnet.cpp" C CXX)
 3 | include(CheckIncludeFileCXX)
 4 | 
 5 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
 6 | 
 7 | if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
 8 |     set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
 9 |     set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
10 | endif()
11 | 
12 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
13 | 
14 | # option list
15 | option(BITNET_ARM_TL1    "bitnet.cpp: use tl1 on arm platform"    OFF)
16 | option(BITNET_X86_TL2    "bitnet.cpp: use tl2 on x86 platform"    OFF)
17 | 
18 | 
19 | set(CMAKE_CXX_STANDARD_REQUIRED true)
20 | set(CMAKE_C_STANDARD 11)
21 | set(CMAKE_C_STANDARD_REQUIRED true)
22 | set(THREADS_PREFER_PTHREAD_FLAG ON)
23 | 
24 | # override ggml options
25 | set(GGML_BITNET_ARM_TL1    ${BITNET_ARM_TL1})
26 | set(GGML_BITNET_X86_TL2    ${BITNET_X86_TL2})
27 | 
28 | if (GGML_BITNET_ARM_TL1)
29 |     add_compile_definitions(GGML_BITNET_ARM_TL1)
30 | endif()
31 | if (GGML_BITNET_X86_TL2)
32 |     add_compile_definitions(GGML_BITNET_X86_TL2)
33 | endif()
34 | 
35 | if (CMAKE_C_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
36 |     add_compile_options(-fpermissive)
37 | endif()
38 | 
39 | find_package(Threads REQUIRED)
40 | 
41 | add_subdirectory(src)
42 | set(LLAMA_BUILD_SERVER ON CACHE BOOL "Build llama.cpp server" FORCE)
43 | add_subdirectory(3rdparty/llama.cpp)
44 | 
45 | # install
46 | 
47 | include(GNUInstallDirs)
48 | include(CMakePackageConfigHelpers)
49 | 
50 | set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR}
51 |     CACHE PATH "Location of header files")
52 | set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR}
53 |     CACHE PATH "Location of library files")
54 | set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR}
55 |     CACHE PATH "Location of binary files")
56 | set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER})
57 | set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT})
58 | set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER})
59 | 
60 | get_target_property(GGML_DIRECTORY ggml SOURCE_DIR)
61 | get_directory_property(GGML_DIR_DEFINES DIRECTORY ${GGML_DIRECTORY} COMPILE_DEFINITIONS)
62 | get_target_property(GGML_TARGET_DEFINES ggml COMPILE_DEFINITIONS)
63 | set(GGML_TRANSIENT_DEFINES ${GGML_TARGET_DEFINES} ${GGML_DIR_DEFINES})
64 | get_target_property(GGML_LINK_LIBRARIES ggml LINK_LIBRARIES)
65 | 
66 | get_directory_property(LLAMA_TRANSIENT_DEFINES COMPILE_DEFINITIONS)
67 | 
68 | write_basic_package_version_file(
69 |         ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake
70 |     VERSION ${LLAMA_INSTALL_VERSION}
71 |     COMPATIBILITY SameMajorVersion)
72 | 
73 | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake
74 |               ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake
75 |         DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama)
76 | 
77 | set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/llama.h)
78 | install(TARGETS llama LIBRARY PUBLIC_HEADER)
79 | 


--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
 1 | # Microsoft Open Source Code of Conduct
 2 | 
 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
 4 | 
 5 | Resources:
 6 | 
 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 | 


--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 |     MIT License
 2 | 
 3 |     Copyright (c) Microsoft Corporation.
 4 | 
 5 |     Permission is hereby granted, free of charge, to any person obtaining a copy
 6 |     of this software and associated documentation files (the "Software"), to deal
 7 |     in the Software without restriction, including without limitation the rights
 8 |     to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 |     copies of the Software, and to permit persons to whom the Software is
10 |     furnished to do so, subject to the following conditions:
11 | 
12 |     The above copyright notice and this permission notice shall be included in all
13 |     copies or substantial portions of the Software.
14 | 
15 |     THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 |     IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 |     FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 |     AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 |     LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 |     OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 |     SOFTWARE
22 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # bitnet.cpp
  2 | [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
  3 | ![version](https://img.shields.io/badge/version-1.0-blue)
  4 | 
  5 | [<img src="./assets/header_model_release.png" alt="BitNet Model on Hugging Face" width="800"/>](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
  6 | 
  7 | Try it out via this [demo](https://bitnet-demo.azurewebsites.net/), or build and run it on your own [CPU](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) or [GPU](https://github.com/microsoft/BitNet/blob/main/gpu/README.md).
  8 | 
  9 | bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU and GPU (NPU support will coming next).
 10 | 
 11 | The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details.
 12 | 
 13 | <img src="./assets/m2_performance.jpg" alt="m2_performance" width="800"/>
 14 | <img src="./assets/intel_performance.jpg" alt="m2_performance" width="800"/>
 15 | 
 16 | >The tested models are dummy setups used in a research context to demonstrate the inference performance of bitnet.cpp.
 17 | 
 18 | ## Demo
 19 | 
 20 | A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2:
 21 | 
 22 | https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1
 23 | 
 24 | ## What's New:
 25 | - 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md) ![NEW](https://img.shields.io/badge/NEW-red)
 26 | - 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
 27 | - 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880)
 28 | - 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965)
 29 | - 10/21/2024 [1-bit AI Infra: Part 1.1, Fast and Lossless BitNet b1.58 Inference on CPUs](https://arxiv.org/abs/2410.16144)
 30 | - 10/17/2024 bitnet.cpp 1.0 released.
 31 | - 03/21/2024 [The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf)
 32 | - 02/27/2024 [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764)
 33 | - 10/17/2023 [BitNet: Scaling 1-bit Transformers for Large Language Models](https://arxiv.org/abs/2310.11453)
 34 | 
 35 | ## Acknowledgements
 36 | 
 37 | This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp) framework. We would like to thank all the authors for their contributions to the open-source community. Also, bitnet.cpp's kernels are built on top of the Lookup Table methodologies pioneered in [T-MAC](https://github.com/microsoft/T-MAC/). For inference of general low-bit LLMs beyond ternary models, we recommend using T-MAC.
 38 | ## Official Models
 39 | <table>
 40 |     </tr>
 41 |     <tr>
 42 |         <th rowspan="2">Model</th>
 43 |         <th rowspan="2">Parameters</th>
 44 |         <th rowspan="2">CPU</th>
 45 |         <th colspan="3">Kernel</th>
 46 |     </tr>
 47 |     <tr>
 48 |         <th>I2_S</th>
 49 |         <th>TL1</th>
 50 |         <th>TL2</th>
 51 |     </tr>
 52 |     <tr>
 53 |         <td rowspan="2"><a href="https://huggingface.co/microsoft/BitNet-b1.58-2B-4T">BitNet-b1.58-2B-4T</a></td>
 54 |         <td rowspan="2">2.4B</td>
 55 |         <td>x86</td>
 56 |         <td>&#9989;</td>
 57 |         <td>&#10060;</td>
 58 |         <td>&#9989;</td>
 59 |     </tr>
 60 |     <tr>
 61 |         <td>ARM</td>
 62 |         <td>&#9989;</td>
 63 |         <td>&#9989;</td>
 64 |         <td>&#10060;</td>
 65 |     </tr>
 66 | </table>
 67 | 
 68 | ## Supported Models
 69 | ❗️**We use existing 1-bit LLMs available on [Hugging Face](https://huggingface.co/) to demonstrate the inference capabilities of bitnet.cpp. We hope the release of bitnet.cpp will inspire the development of 1-bit LLMs in large-scale settings in terms of model size and training tokens.**
 70 | 
 71 | <table>
 72 |     </tr>
 73 |     <tr>
 74 |         <th rowspan="2">Model</th>
 75 |         <th rowspan="2">Parameters</th>
 76 |         <th rowspan="2">CPU</th>
 77 |         <th colspan="3">Kernel</th>
 78 |     </tr>
 79 |     <tr>
 80 |         <th>I2_S</th>
 81 |         <th>TL1</th>
 82 |         <th>TL2</th>
 83 |     </tr>
 84 |     <tr>
 85 |         <td rowspan="2"><a href="https://huggingface.co/1bitLLM/bitnet_b1_58-large">bitnet_b1_58-large</a></td>
 86 |         <td rowspan="2">0.7B</td>
 87 |         <td>x86</td>
 88 |         <td>&#9989;</td>
 89 |         <td>&#10060;</td>
 90 |         <td>&#9989;</td>
 91 |     </tr>
 92 |     <tr>
 93 |         <td>ARM</td>
 94 |         <td>&#9989;</td>
 95 |         <td>&#9989;</td>
 96 |         <td>&#10060;</td>
 97 |     </tr>
 98 |     <tr>
 99 |         <td rowspan="2"><a href="https://huggingface.co/1bitLLM/bitnet_b1_58-3B">bitnet_b1_58-3B</a></td>
100 |         <td rowspan="2">3.3B</td>
101 |         <td>x86</td>
102 |         <td>&#10060;</td>
103 |         <td>&#10060;</td>
104 |         <td>&#9989;</td>
105 |     </tr>
106 |     <tr>
107 |         <td>ARM</td>
108 |         <td>&#10060;</td>
109 |         <td>&#9989;</td>
110 |         <td>&#10060;</td>
111 |     </tr>
112 |     <tr>
113 |         <td rowspan="2"><a href="https://huggingface.co/HF1BitLLM/Llama3-8B-1.58-100B-tokens">Llama3-8B-1.58-100B-tokens</a></td>
114 |         <td rowspan="2">8.0B</td>
115 |         <td>x86</td>
116 |         <td>&#9989;</td>
117 |         <td>&#10060;</td>
118 |         <td>&#9989;</td>
119 |     </tr>
120 |     <tr>
121 |         <td>ARM</td>
122 |         <td>&#9989;</td>
123 |         <td>&#9989;</td>
124 |         <td>&#10060;</td>
125 |     </tr>
126 |     <tr>
127 |         <td rowspan="2"><a href="https://huggingface.co/collections/tiiuae/falcon3-67605ae03578be86e4e87026">Falcon3 Family</a></td>
128 |         <td rowspan="2">1B-10B</td>
129 |         <td>x86</td>
130 |         <td>&#9989;</td>
131 |         <td>&#10060;</td>
132 |         <td>&#9989;</td>
133 |     </tr>
134 |     <tr>
135 |         <td>ARM</td>
136 |         <td>&#9989;</td>
137 |         <td>&#9989;</td>
138 |         <td>&#10060;</td>
139 |     </tr>
140 |     <tr>
141 |         <td rowspan="2"><a href="https://huggingface.co/collections/tiiuae/falcon-edge-series-6804fd13344d6d8a8fa71130">Falcon-E Family</a></td>
142 |         <td rowspan="2">1B-3B</td>
143 |         <td>x86</td>
144 |         <td>&#9989;</td>
145 |         <td>&#10060;</td>
146 |         <td>&#9989;</td>
147 |     </tr>
148 |     <tr>
149 |         <td>ARM</td>
150 |         <td>&#9989;</td>
151 |         <td>&#9989;</td>
152 |         <td>&#10060;</td>
153 |     </tr>
154 | </table>
155 | 
156 | 
157 | 
158 | ## Installation
159 | 
160 | ### Requirements
161 | - python>=3.9
162 | - cmake>=3.22
163 | - clang>=18
164 |     - For Windows users, install [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/). In the installer, toggle on at least the following options(this also automatically installs the required additional tools like CMake):
165 |         -  Desktop-development with C++
166 |         -  C++-CMake Tools for Windows
167 |         -  Git for Windows
168 |         -  C++-Clang Compiler for Windows
169 |         -  MS-Build Support for LLVM-Toolset (clang)
170 |     - For Debian/Ubuntu users, you can download with [Automatic installation script](https://apt.llvm.org/)
171 | 
172 |         `bash -c "$(wget -O - https://apt.llvm.org/llvm.sh)"`
173 | - conda (highly recommend)
174 | 
175 | ### Build from source
176 | 
177 | > [!IMPORTANT]
178 | > If you are using Windows, please remember to always use a Developer Command Prompt / PowerShell for VS2022 for the following commands. Please refer to the FAQs below if you see any issues.
179 | 
180 | 1. Clone the repo
181 | ```bash
182 | git clone --recursive https://github.com/microsoft/BitNet.git
183 | cd BitNet
184 | ```
185 | 2. Install the dependencies
186 | ```bash
187 | # (Recommended) Create a new conda environment
188 | conda create -n bitnet-cpp python=3.9
189 | conda activate bitnet-cpp
190 | 
191 | pip install -r requirements.txt
192 | ```
193 | 3. Build the project
194 | ```bash
195 | # Manually download the model and run with local path
196 | huggingface-cli download microsoft/BitNet-b1.58-2B-4T-gguf --local-dir models/BitNet-b1.58-2B-4T
197 | python setup_env.py -md models/BitNet-b1.58-2B-4T -q i2_s
198 | 
199 | ```
200 | <pre>
201 | usage: setup_env.py [-h] [--hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}] [--model-dir MODEL_DIR] [--log-dir LOG_DIR] [--quant-type {i2_s,tl1}] [--quant-embd]
202 |                     [--use-pretuned]
203 | 
204 | Setup the environment for running inference
205 | 
206 | optional arguments:
207 |   -h, --help            show this help message and exit
208 |   --hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}, -hr {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}
209 |                         Model used for inference
210 |   --model-dir MODEL_DIR, -md MODEL_DIR
211 |                         Directory to save/load the model
212 |   --log-dir LOG_DIR, -ld LOG_DIR
213 |                         Directory to save the logging info
214 |   --quant-type {i2_s,tl1}, -q {i2_s,tl1}
215 |                         Quantization type
216 |   --quant-embd          Quantize the embeddings to f16
217 |   --use-pretuned, -p    Use the pretuned kernel parameters
218 | </pre>
219 | ## Usage
220 | ### Basic usage
221 | ```bash
222 | # Run inference with the quantized model
223 | python run_inference.py -m models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf -p "You are a helpful assistant" -cnv
224 | ```
225 | <pre>
226 | usage: run_inference.py [-h] [-m MODEL] [-n N_PREDICT] -p PROMPT [-t THREADS] [-c CTX_SIZE] [-temp TEMPERATURE] [-cnv]
227 | 
228 | Run inference
229 | 
230 | optional arguments:
231 |   -h, --help            show this help message and exit
232 |   -m MODEL, --model MODEL
233 |                         Path to model file
234 |   -n N_PREDICT, --n-predict N_PREDICT
235 |                         Number of tokens to predict when generating text
236 |   -p PROMPT, --prompt PROMPT
237 |                         Prompt to generate text from
238 |   -t THREADS, --threads THREADS
239 |                         Number of threads to use
240 |   -c CTX_SIZE, --ctx-size CTX_SIZE
241 |                         Size of the prompt context
242 |   -temp TEMPERATURE, --temperature TEMPERATURE
243 |                         Temperature, a hyperparameter that controls the randomness of the generated text
244 |   -cnv, --conversation  Whether to enable chat mode or not (for instruct models.)
245 |                         (When this option is turned on, the prompt specified by -p will be used as the system prompt.)
246 | </pre>
247 | 
248 | ### Benchmark
249 | We provide scripts to run the inference benchmark providing a model.
250 | 
251 | ```  
252 | usage: e2e_benchmark.py -m MODEL [-n N_TOKEN] [-p N_PROMPT] [-t THREADS]  
253 |    
254 | Setup the environment for running the inference  
255 |    
256 | required arguments:  
257 |   -m MODEL, --model MODEL  
258 |                         Path to the model file. 
259 |    
260 | optional arguments:  
261 |   -h, --help  
262 |                         Show this help message and exit. 
263 |   -n N_TOKEN, --n-token N_TOKEN  
264 |                         Number of generated tokens. 
265 |   -p N_PROMPT, --n-prompt N_PROMPT  
266 |                         Prompt to generate text from. 
267 |   -t THREADS, --threads THREADS  
268 |                         Number of threads to use. 
269 | ```  
270 |    
271 | Here's a brief explanation of each argument:  
272 |    
273 | - `-m`, `--model`: The path to the model file. This is a required argument that must be provided when running the script.  
274 | - `-n`, `--n-token`: The number of tokens to generate during the inference. It is an optional argument with a default value of 128.  
275 | - `-p`, `--n-prompt`: The number of prompt tokens to use for generating text. This is an optional argument with a default value of 512.  
276 | - `-t`, `--threads`: The number of threads to use for running the inference. It is an optional argument with a default value of 2.  
277 | - `-h`, `--help`: Show the help message and exit. Use this argument to display usage information.  
278 |    
279 | For example:  
280 |    
281 | ```sh  
282 | python utils/e2e_benchmark.py -m /path/to/model -n 200 -p 256 -t 4  
283 | ```  
284 |    
285 | This command would run the inference benchmark using the model located at `/path/to/model`, generating 200 tokens from a 256 token prompt, utilizing 4 threads.  
286 | 
287 | For the model layout that do not supported by any public model, we provide scripts to generate a dummy model with the given model layout, and run the benchmark on your machine:
288 | 
289 | ```bash
290 | python utils/generate-dummy-bitnet-model.py models/bitnet_b1_58-large --outfile models/dummy-bitnet-125m.tl1.gguf --outtype tl1 --model-size 125M
291 | 
292 | # Run benchmark with the generated model, use -m to specify the model path, -p to specify the prompt processed, -n to specify the number of token to generate
293 | python utils/e2e_benchmark.py -m models/dummy-bitnet-125m.tl1.gguf -p 512 -n 128
294 | ```
295 | 
296 | ### Convert from `.safetensors` Checkpoints
297 | 
298 | ```sh
299 | # Prepare the .safetensors model file
300 | huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./models/bitnet-b1.58-2B-4T-bf16
301 | 
302 | # Convert to gguf model
303 | python ./utils/convert-helper-bitnet.py ./models/bitnet-b1.58-2B-4T-bf16
304 | ```
305 | 
306 | ### FAQ (Frequently Asked Questions)📌 
307 | 
308 | #### Q1: The build dies with errors building llama.cpp due to issues with std::chrono in log.cpp?
309 | 
310 | **A:**
311 | This is an issue introduced in recent version of llama.cpp. Please refer to this [commit](https://github.com/tinglou/llama.cpp/commit/4e3db1e3d78cc1bcd22bcb3af54bd2a4628dd323) in the [discussion](https://github.com/abetlen/llama-cpp-python/issues/1942) to fix this issue.
312 | 
313 | #### Q2: How to build with clang in conda environment on windows?
314 | 
315 | **A:** 
316 | Before building the project, verify your clang installation and access to Visual Studio tools by running:
317 | ```
318 | clang -v
319 | ```
320 | 
321 | This command checks that you are using the correct version of clang and that the Visual Studio tools are available. If you see an error message such as:
322 | ```
323 | 'clang' is not recognized as an internal or external command, operable program or batch file.
324 | ```
325 | 
326 | It indicates that your command line window is not properly initialized for Visual Studio tools.
327 | 
328 | • If you are using Command Prompt, run:
329 | ```
330 | "C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64
331 | ```
332 | 
333 | • If you are using Windows PowerShell, run the following commands:
334 | ```
335 | Import-Module "C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\Tools\Microsoft.VisualStudio.DevShell.dll" Enter-VsDevShell 3f0e31ad -SkipAutomaticLocation -DevCmdArguments "-arch=x64 -host_arch=x64"
336 | ```
337 | 
338 | These steps will initialize your environment and allow you to use the correct Visual Studio tools.
339 | 


--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
 1 | <!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->
 2 | 
 3 | ## Security
 4 | 
 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
 6 | 
 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
 8 | 
 9 | ## Reporting Security Issues
10 | 
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 | 
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14 | 
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com).  If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16 | 
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 
18 | 
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 | 
21 |   * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 |   * Full paths of source file(s) related to the manifestation of the issue
23 |   * The location of the affected source code (tag/branch/commit or direct URL)
24 |   * Any special configuration required to reproduce the issue
25 |   * Step-by-step instructions to reproduce the issue
26 |   * Proof-of-concept or exploit code (if possible)
27 |   * Impact of the issue, including how an attacker might exploit the issue
28 | 
29 | This information will help us triage your report more quickly.
30 | 
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32 | 
33 | ## Preferred Languages
34 | 
35 | We prefer all communications to be in English.
36 | 
37 | ## Policy
38 | 
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40 | 
41 | <!-- END MICROSOFT SECURITY.MD BLOCK -->
42 | 


--------------------------------------------------------------------------------
/assets/header_model_release.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/404980eecae38affa4871c3e419eae3f44536a95/assets/header_model_release.png


--------------------------------------------------------------------------------
/assets/intel_performance.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/404980eecae38affa4871c3e419eae3f44536a95/assets/intel_performance.jpg


--------------------------------------------------------------------------------
/assets/m2_performance.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/404980eecae38affa4871c3e419eae3f44536a95/assets/m2_performance.jpg


--------------------------------------------------------------------------------
/assets/tl1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/404980eecae38affa4871c3e419eae3f44536a95/assets/tl1.png


--------------------------------------------------------------------------------
/assets/tl2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/404980eecae38affa4871c3e419eae3f44536a95/assets/tl2.png


--------------------------------------------------------------------------------
/docs/codegen.md:
--------------------------------------------------------------------------------
 1 | Codegen for TL1 and TL2
 2 | ------------------------
 3 | 
 4 | codegen_tl1.py and codegen_tl2.py are using params to generate kernel codes in different devices to achieve fastest performance for TL1 and TL2.
 5 | 
 6 | We cutting weight into multiple compute blocks to best utilize hardware capabilities.
 7 | 
 8 | ### Example
 9 | bitnet_b1_58-large:
10 | 
11 | - Make sure Matmul kernels shapes \
12 | For example, bitnet_b1_58-large Matmul kernel shapes are:\
13 | [1536, 4096]\
14 | [1536, 1536]\
15 | [4096, 1536]
16 | 
17 | - Make sure each BM, BK, bm for each kernel to meet the requirements below
18 | - Generate codes\
19 | For example, for bitnet_b1_58-large, we can gencode like:
20 | 
21 | ```bash
22 | # For TL1
23 | python utils/codegen_tl1.py --model bitnet_b1_58-large --BM 256,128,256 --BK 128,64,128 --bm 32,64,32
24 | 
25 | # For TL2
26 | python utils/codegen_tl2.py --model bitnet_b1_58-large --BM 256,128,256 --BK 96,192,96 --bm 32,32,32
27 | ```
28 | 
29 | ### TL1:
30 | ![TL1](../assets/tl1.png)
31 | 
32 | For TL1, we cut weight into M / BM weights, each weight shape is (BM, K). Then we cut weight into K / BK weights, each weight shape is (BM, BK). As for (BM, BK) weight, we cut it the same way into (bm, compute_num / bm) compute blocks, and finish computing in it.
33 | 
34 | Thus, we need to make sure 
35 | - M % BM == 0
36 | - K % BK == 0
37 | - BM % bm == 0
38 | - bm choose in [32, 64]
39 | 
40 | ### TL2:
41 | ![TL2](../assets/tl2.png)
42 | 
43 | For TL2, things got a little more complicated. Due to TL2 needs BK % 6 == 0, we need to split K into threeK and twoK, in which compute in TL2 for (M, threeK), compute in TL1 for (M, two_K).
44 | 
45 | Thus, we needs to make sure
46 | - M % BM == 0
47 | - K % BK % 32 == 0
48 | - BM % bm == 0
49 | - bm choose in \[32\]


--------------------------------------------------------------------------------
/gpu/README.md:
--------------------------------------------------------------------------------
  1 | # BitNet Inference Kernel
  2 | 
  3 | This repository provides a highly efficient GEMV kernel implementation for the BitNet model, optimized for W2A8 inference — 2-bit weights and 8-bit activations. It is tailored for use with the [BitNet-b1.58-2B-4T](https://arxiv.org/abs/2504.12285) model.
  4 | 
  5 | ## Features
  6 | 
  7 | - Support for W2A8 (2-bit weight × 8-bit activation) GEMV computation  
  8 | - Custom CUDA kernels with low-latency execution  
  9 | - Optimizations for memory access, decoding, and compute throughput  
 10 | 
 11 | ## Usage
 12 | 
 13 | Installation and kernel performance tests:
 14 | 
 15 | ```bash
 16 | # (Recommended) Create a new conda environment
 17 | conda create --name bitnet-gpu "python<3.13"
 18 | conda activate bitnet-gpu
 19 | 
 20 | # Install dependencies
 21 | pip install -r requirements.txt
 22 | 
 23 | # Build the kernel
 24 | cd bitnet_kernels
 25 | bash compile.sh
 26 | cd ..
 27 | 
 28 | # Run performance tests
 29 | python test.py
 30 | ```
 31 | 
 32 | End-to-end inference:
 33 | 
 34 | ```bash
 35 | # Download and convert the BitNet-b1.58-2B model
 36 | mkdir checkpoints
 37 | huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./checkpoints/bitnet-b1.58-2B-4T-bf16
 38 | python ./convert_safetensors.py --safetensors_file ./checkpoints/bitnet-b1.58-2B-4T-bf16/model.safetensors --output checkpoints/model_state.pt --model_name 2B
 39 | python ./convert_checkpoint.py --input ./checkpoints/model_state.pt
 40 | rm ./checkpoints/model_state.pt
 41 | 
 42 | # Inference
 43 | python3 ./generate.py ./checkpoints/ --interactive --chat_format
 44 | ```
 45 | 
 46 | ## Optimizations
 47 | 
 48 | ### Weight Permutation
 49 | 
 50 | The weight matrix is divided into 16×32 blocks to optimize memory access patterns.  
 51 | 
 52 | Within each block, values are stored contiguously in memory and permuted to facilitate efficient access and processing.  
 53 | 
 54 | See `convert_checkpoint.py` for details.
 55 | 
 56 | ### Fast Decoding
 57 | 
 58 | Every 16 two-bit values are packed into a single 32-bit integer using the following interleaving pattern:  
 59 | ```
 60 | [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
 61 | ```
 62 | 
 63 | This layout is designed to accelerate decoding by enabling efficient extraction of 4 values at a time into `int8`.
 64 | 
 65 | ### `dp4a` Instruction
 66 | 
 67 | We use the `dp4a` instruction to accelerate low-precision dot product operations.  
 68 | 
 69 | This instruction performs a dot product between two 4-element vectors (each stored in a 32-bit word as 8-bit integers) and accumulates the result into a 32-bit integer.  
 70 | 
 71 | It significantly improves GEMV throughput when processing quantized weights and activations.
 72 | 
 73 | 
 74 | ## Performance
 75 | 
 76 | ### Kernel Benchmarks
 77 | 
 78 | Tested on NVIDIA A100 40GB GPU, our custom W2A8 kernel shows significant speedups over standard BF16 implementations:
 79 | 
 80 | | Shape (N×K)         | W2A8 Latency (us) | BF16 Latency (us) | Speedup Ratio        |
 81 | |---------------------|-------------------|-------------------|----------------------|
 82 | | 2560 × 2560         | 13.32             | 18.32             |   1.38               |
 83 | | 3840 × 2560         | 14.90             | 18.87             |   1.27               |
 84 | | 13824 × 2560        | 18.75             | 59.51             |   3.17               |
 85 | | 2560 × 6912         | 14.49             | 37.78             |   2.61               |
 86 | | 3200 × 3200         | 14.61             | 19.08             |   1.31               |
 87 | | 4800 × 3200         | 13.09             | 21.84             |   1.67               |
 88 | | 3200 × 10240        | 19.64             | 60.79             |   3.10               |
 89 | | 20480 × 3200        | 30.99             | 112.39            |   3.63               |
 90 | 
 91 | ### End-to-End Generation Latency
 92 | 
 93 | Compared to a similarly-sized BF16 model (Gemma-2-2B using vLLM), BitNet-b1.58-2B with our kernel achieves consistent speedups across workloads:
 94 | 
 95 | | Input Length | Output Length | BF16 Latency (ms) | W2A8 Latency (ms) | Speedup Ratio |
 96 | | --- | --- | --- | --- | --- |
 97 | | 64 | 16 | 187.64 | 57.40 | 3.27 |
 98 | | 64 | 32 | 353.50 | 112.22 | 3.15 |
 99 | | 64 | 64 | 683.23 | 221.08 | 3.09 |
100 | | 256 | 16 | 183.14 | 61.24 | 2.99 |
101 | | 256 | 32 | 353.14 | 115.47 | 3.06 |
102 | | 256 | 64 | 684.24 | 224.16 | 3.05 |
103 | | 512 | 16 | 208.99 | 68.06 | 3.07 |
104 | | 512 | 32 | 354.33 | 122.72 | 2.89 |
105 | | 512 | 64 | 709.65 | 231.82 | 3.06 |
106 | 
107 | *Note: Comparison uses equivalent-sized models (2B parameters) on NVIDIA A100 40GB GPU.*


--------------------------------------------------------------------------------
/gpu/bitnet_kernels/bitnet_kernels.cu:
--------------------------------------------------------------------------------
 1 | #include "bitnet_kernels.h"
 2 | 
 3 | extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){
 4 |     if (M == 1 && N == 3840 && K == 2560){
 5 |         ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<<dim3(240, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
 6 |     }
 7 |     else if (M == 1 && N == 2560 && K == 2560){
 8 |         ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
 9 |     }
10 |     else if (M == 1 && N == 13824 && K == 2560){
11 |         ladder_int8xint2_kernel<1, 13824, 2560, 2, 8, 16><<<dim3(864, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
12 |     }
13 |     else if (M == 1 && N == 2560 && K == 6912){
14 |         ladder_int8xint2_kernel<1, 2560, 6912, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
15 |     }
16 |     else if(M == 1 && N == 4800 && K == 3200){
17 |         ladder_int8xint2_kernel<1, 4800, 3200, 6, 8, 16><<<dim3(300, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
18 |     }
19 |     else if(M == 1 && N == 3200 && K == 3200){
20 |         ladder_int8xint2_kernel<1, 3200, 3200, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
21 |     }
22 |     else if(M == 1 && N == 20480 && K == 3200){
23 |         ladder_int8xint2_kernel<1, 20480, 3200, 2, 8, 16><<<dim3(1280, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
24 |     }
25 |     else if(M == 1 && N == 3200 && K == 10240){
26 |         ladder_int8xint2_kernel<1, 3200, 10240, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
27 |     }    
28 |     else if(M == 1 && N == 5120 && K == 27648){
29 |         ladder_int8xint2_kernel<1, 5120, 27648, 1, 8, 16><<<dim3(320, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
30 |     }
31 |     else if(M == 1 && N == 55296 && K == 5120){
32 |         ladder_int8xint2_kernel<1, 55296, 5120, 1, 8, 16><<<dim3(3456, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
33 |     }
34 |     else{
35 |         std::cout << "required ladder gemm kernel: M " << M << ", N " << N << ", K " << K << std::endl;
36 |     }
37 | }


--------------------------------------------------------------------------------
/gpu/bitnet_kernels/bitnet_kernels.h:
--------------------------------------------------------------------------------
 1 | #include <cuda_runtime.h>
 2 | #include <math_constants.h>
 3 | #include <math.h>
 4 | #include <mma.h>
 5 | #include <iostream>
 6 | #include <cuda.h>
 7 | #include <cuda_fp16.h>
 8 | #include <cuda_bf16.h>
 9 | 
10 | 
11 | #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || (__CUDACC_VER_MAJOR__ > 11))
12 | #define TVM_ENABLE_L2_PREFETCH 1
13 | #else
14 | #define TVM_ENABLE_L2_PREFETCH 0
15 | #endif
16 | 
17 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
18 | #define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1
19 | #else
20 | #define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0
21 | #endif
22 | 
23 | template <typename T1, typename T2>
24 | __device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16)
25 | {
26 |   // convert 8 int2b_t to 8 int8b_t -> 2 int32
27 |   uint *i8s = reinterpret_cast<uint *>(_i8s);
28 | 
29 |   // i2s = {e0, e4, e8, e12, e1, e5, e9, e13, e2, e6, e10, e14, e3, e7, e11, e15}
30 |   uint const i2s = *_i2s;
31 | 
32 |   static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;     // 0b11101010
33 |   static constexpr uint BOTTOM_MASK = 0x03030303;          // 0xf -> 0b11 select 0,3
34 |   static constexpr uint I4s_TO_I8s_MAGIC_NUM = 0x00000000; 
35 | 
36 | #pragma unroll
37 |   for (int i = 0; i < (N / 4); i++)
38 |   {
39 |     asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
40 |                  : "=r"(i8s[i])
41 |                  : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(I4s_TO_I8s_MAGIC_NUM), "n"(immLut));
42 |     i8s[i] = __vsubss4(i8s[i], 0x02020202);
43 |   }
44 | }
45 | 
46 | template <int M, int N, int K, int ws_num, int K_block_size, int N_block_size>
47 | __global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restrict__ A, int8_t* __restrict__ B, __nv_bfloat16* __restrict__ dtype_transform, __nv_bfloat16* __restrict__ s, __nv_bfloat16* __restrict__ ws) {
48 |   constexpr int K_per_loop = 16;
49 |   constexpr int wmma_K = 32;
50 |   constexpr int wmma_N = 16;
51 |   int in_thread_C_local[1];
52 |   signed char A_local[K_per_loop];
53 |   int B_reshape_local[1];
54 |   signed char B_decode_local[K_per_loop];
55 |   int red_buf0[1];
56 |   in_thread_C_local[0] = 0;
57 |   #pragma unroll
58 |   for (int k_0 = 0; k_0 < K/(K_per_loop * K_block_size); ++k_0) {
59 |     *(int4*)(A_local + 0) = *(int4*)(A + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop)));
60 |     B_reshape_local[0] = *(int*)(B + 
61 |       (((int)blockIdx.x) * N_block_size * K / 4) + 
62 |       (k_0 * K_block_size * K_per_loop * wmma_N / 4) +
63 |       ((((int)threadIdx.x) >> 1) * wmma_K * wmma_N / 4) +
64 |       ((((int)threadIdx.y) >> 3) * (wmma_K * wmma_N / 2) / 4) + 
65 |       ((((int)threadIdx.x) & 1) * (wmma_K * wmma_N / 4) / 4) + 
66 |       ((((int)threadIdx.y) & 7) * (wmma_K / 2) / 4)
67 |       );
68 |     decode_i2s_to_i8s(B_reshape_local, B_decode_local, 16);
69 |     #pragma unroll
70 |     for (int k_2_0 = 0; k_2_0 < 4; ++k_2_0) {
71 |       in_thread_C_local[0] = __dp4a(*(int *)&A_local[((k_2_0 * 4))],*(int *)&B_decode_local[((k_2_0 * 4))], in_thread_C_local[0]);
72 |     }
73 |   }
74 |   red_buf0[0] = in_thread_C_local[0];
75 |   #pragma unroll
76 |   for (int offset = K_block_size/2; offset > 0; offset /= 2) {
77 |     red_buf0[0] += __shfl_down_sync(__activemask(), red_buf0[0], offset, K_block_size);
78 |   }
79 |   int out_idx = ((((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y));
80 |   int ws_idx = out_idx / (N / ws_num);
81 |   if (threadIdx.x == 0)
82 |     dtype_transform[out_idx] = (__nv_bfloat16)(((float)red_buf0[0])/(float)s[0]*(float)ws[ws_idx]);
83 | }


--------------------------------------------------------------------------------
/gpu/bitnet_kernels/compile.sh:
--------------------------------------------------------------------------------
1 | nvcc -std=c++17 -Xcudafe --diag_suppress=177 --compiler-options -fPIC -lineinfo --shared bitnet_kernels.cu -lcuda -gencode=arch=compute_80,code=compute_80 -o libbitnet.so
2 | 
3 | 
4 | 


--------------------------------------------------------------------------------
/gpu/bitnet_kernels/setup.py:
--------------------------------------------------------------------------------
 1 | from setuptools import setup
 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
 3 | 
 4 | setup(
 5 |     name='bitlinear_cpp',
 6 |     ext_modules=[
 7 |         CUDAExtension('bitlinear_cuda', [
 8 |             'bitnet_kernels.cu',
 9 |         ])
10 |     ],
11 |     cmdclass={
12 |         'build_ext': BuildExtension
13 |     })


--------------------------------------------------------------------------------
/gpu/convert_checkpoint.py:
--------------------------------------------------------------------------------
  1 | import json
  2 | import os
  3 | import re
  4 | import sys
  5 | from pathlib import Path
  6 | from typing import Optional
  7 | from dataclasses import dataclass
  8 | import torch
  9 | from einops import rearrange
 10 | from safetensors.torch import save_file
 11 | import model
 12 | from pack_weight import convert_weight_int8_to_int2
 13 | 
 14 | @torch.inference_mode()
 15 | def convert_ts_checkpoint(
 16 |     *,
 17 |     input_path: str = "",
 18 | ) -> None:
 19 | 
 20 |     config = model.ModelArgs()
 21 |     print(f"Model config {config.__dict__}")
 22 | 
 23 |     def quant_weight_int8(weight):
 24 |         s = 1.0 / weight.abs().mean().clamp_(min=1e-5)
 25 |         new_weight = (weight * s).round().clamp(-1, 1).to(torch.int8)
 26 |         new_scale = (1.0 / s).to(torch.bfloat16)
 27 |         return new_weight, new_scale.reshape(1)
 28 | 
 29 |     def quant_weight_fp16(weight):
 30 |         s = 1.0 / weight.abs().mean().clamp_(min=1e-5)
 31 |         new_weight = (weight * s).round().clamp(-1, 1) / s
 32 |         return new_weight
 33 | 
 34 |     def convert_int8_to_int2(weight):
 35 |         return convert_weight_int8_to_int2(weight)
 36 | 
 37 |     merged_result = torch.load(input_path, map_location="cpu", mmap=True)
 38 |     int2_result = {}
 39 |     fp16_result = {}
 40 |     zero = torch.zeros(1).to(torch.bfloat16)
 41 |     for key, value in merged_result.items():
 42 |         if 'wqkv' in key:
 43 |             wq = value[:config.dim]
 44 |             wk = value[config.dim:config.dim // config.n_heads * config.n_kv_heads + config.dim]
 45 |             wv = value[config.dim // config.n_heads * config.n_kv_heads + config.dim:]
 46 |             wq_weight, wa_scale = quant_weight_int8(wq)
 47 |             wk_weight, wb_scale = quant_weight_int8(wk)
 48 |             wv_weight, wc_scale = quant_weight_int8(wv)
 49 |             wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0)
 50 |             wqkv_scale = torch.cat([wa_scale, wb_scale, wc_scale, zero], dim=0)
 51 |             int2_result[key] = convert_int8_to_int2(wqkv_weight)
 52 |             int2_result[key.replace('weight', 'weight_scale')] = wqkv_scale
 53 | 
 54 |             wq_weight = quant_weight_fp16(wq)
 55 |             wk_weight = quant_weight_fp16(wk)
 56 |             wv_weight = quant_weight_fp16(wv)
 57 |             wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0)
 58 |             fp16_result[key] = wqkv_weight
 59 |         elif 'w13' in key:
 60 |             w1 = value[:config.ffn_dim]
 61 |             w3 = value[config.ffn_dim:]
 62 |             w1_weight, w1_scale = quant_weight_int8(w1)
 63 |             w3_weight, w3_scale = quant_weight_int8(w3)
 64 |             w13_weight = torch.cat([w1_weight, w3_weight], dim=0)
 65 |             w13_scale = torch.cat([w1_scale, w3_scale, zero, zero], dim=0)
 66 |             int2_result[key] = convert_int8_to_int2(w13_weight)
 67 |             int2_result[key.replace('weight', 'weight_scale')] = w13_scale
 68 | 
 69 |             w1_weight = quant_weight_fp16(w1)
 70 |             w3_weight = quant_weight_fp16(w3)
 71 |             w13_weight = torch.cat([w1_weight, w3_weight], dim=0)
 72 |             fp16_result[key] = w13_weight
 73 |         elif 'w2' in key or 'wo' in key:
 74 |             weight, scale = quant_weight_int8(value)
 75 |             scale = torch.cat([scale, zero, zero, zero], dim=0)
 76 |             int2_result[key] = convert_int8_to_int2(weight)
 77 |             int2_result[key.replace('weight', 'weight_scale')] = scale
 78 | 
 79 |             weight = quant_weight_fp16(value)
 80 |             fp16_result[key] = weight
 81 |         else:
 82 |             int2_result[key] = value.clone()
 83 |             fp16_result[key] = value.clone()
 84 | 
 85 |     output_dir = os.path.dirname(input_path)
 86 |     print(f"Saving checkpoint to {output_dir}/model_state_int2.pt")
 87 |     torch.save(int2_result, f"{output_dir}/model_state_int2.pt")
 88 | 
 89 |     print(f"Saving checkpoint to {output_dir}/model_state_fp16.pt")
 90 |     torch.save(fp16_result, f"{output_dir}/model_state_fp16.pt")
 91 | 
 92 | if __name__ == '__main__':
 93 |     import argparse
 94 |     parser = argparse.ArgumentParser(description='Convert TorchScale checkpoint.')
 95 |     parser.add_argument('--input', type=str)
 96 | 
 97 |     args = parser.parse_args()
 98 |     convert_ts_checkpoint(
 99 |         input_path=args.input,
100 |     )
101 | 


--------------------------------------------------------------------------------
/gpu/convert_safetensors.py:
--------------------------------------------------------------------------------
  1 | import re
  2 | import torch
  3 | from pathlib import Path
  4 | from safetensors.torch import load_file
  5 | from einops import rearrange
  6 | from dataclasses import dataclass
  7 | from typing import Optional
  8 | 
  9 | transformer_configs = {
 10 |     "2B": dict(n_layer=30, n_head=20, dim=2560, vocab_size=128256, n_local_heads=5, intermediate_size=6912),
 11 | }
 12 | 
 13 | @dataclass
 14 | class ModelArgs:
 15 |     block_size: int = 4096
 16 |     vocab_size: int = 32000
 17 |     n_layer: int = 32
 18 |     n_head: int = 32
 19 |     dim: int = 4096
 20 |     intermediate_size: int = None
 21 |     n_local_heads: int = -1
 22 |     head_dim: int = 64
 23 |     rope_base: float = 10000
 24 |     norm_eps: float = 1e-5
 25 | 
 26 |     def __post_init__(self):
 27 |         if self.n_local_heads == -1:
 28 |             self.n_local_heads = self.n_head
 29 |         if self.intermediate_size is None:
 30 |             hidden_dim = 4 * self.dim
 31 |             n_hidden = int(2 * hidden_dim / 3)
 32 |             self.intermediate_size = n_hidden + (256 - n_hidden % 256) if n_hidden % 256 else n_hidden
 33 |         self.head_dim = self.dim // self.n_head
 34 | 
 35 |     @classmethod
 36 |     def from_name(cls, name: str):
 37 |         if name in transformer_configs:
 38 |             return cls(**transformer_configs[name])
 39 |         config = [k for k in transformer_configs if k in name.upper() or k in name]
 40 |         assert len(config) == 1, f"Unknown model name: {name}"
 41 |         return cls(**transformer_configs[config[0]])
 42 | 
 43 | def invert_convert_q(w: torch.Tensor, config: ModelArgs) -> torch.Tensor:
 44 |     return rearrange(w, '(h l d) i -> (h d l) i', h=config.n_head, l=2)
 45 | 
 46 | def invert_convert_k(w: torch.Tensor, config: ModelArgs) -> torch.Tensor:
 47 |     return rearrange(w, '(h l d) i -> (h d l) i', h=config.n_local_heads, l=2)
 48 | 
 49 | def convert_back(
 50 |     safetensors_path: str,
 51 |     output_file: str,
 52 |     model_name: Optional[str] = None,
 53 | ):
 54 |     st_dict = load_file(safetensors_path)
 55 | 
 56 |     cfg = ModelArgs.from_name(model_name)
 57 |     print(f"Using model configurations: {cfg}")
 58 | 
 59 |     recovered: dict = {}
 60 | 
 61 |     for layer in range(cfg.n_layer):
 62 |         base = f"model.layers.{layer}."
 63 | 
 64 |         wq = st_dict[f"{base}self_attn.q_proj.weight"]
 65 |         wk = st_dict[f"{base}self_attn.k_proj.weight"]
 66 |         wv = st_dict[f"{base}self_attn.v_proj.weight"]
 67 | 
 68 |         wq = invert_convert_q(wq, cfg)
 69 |         wk = invert_convert_k(wk, cfg)
 70 | 
 71 |         wqkv = torch.cat([wq, wk, wv], dim=0)
 72 |         recovered[f"layers.{layer}.attention.wqkv.weight"] = wqkv
 73 | 
 74 |         recovered[f"layers.{layer}.attention.wo.weight"] = st_dict[f"{base}self_attn.o_proj.weight"]
 75 | 
 76 |         recovered[f"layers.{layer}.attention_norm.weight"] = st_dict[f"{base}input_layernorm.weight"]
 77 |         recovered[f"layers.{layer}.ffn_norm.weight"] = st_dict[f"{base}post_attention_layernorm.weight"]
 78 |         recovered[f"layers.{layer}.attention.attn_sub_norm.weight"] = st_dict[f"{base}self_attn.attn_sub_norm.weight"]
 79 |         recovered[f"layers.{layer}.feed_forward.ffn_sub_norm.weight"] = st_dict[f"{base}mlp.ffn_sub_norm.weight"]
 80 | 
 81 |         gate = st_dict[f"{base}mlp.gate_proj.weight"]
 82 |         up   = st_dict[f"{base}mlp.up_proj.weight"]
 83 |         w13  = torch.cat([gate, up], dim=0)
 84 |         recovered[f"layers.{layer}.feed_forward.w13.weight"] = w13
 85 | 
 86 |         recovered[f"layers.{layer}.feed_forward.w2.weight"] = st_dict[f"{base}mlp.down_proj.weight"]
 87 | 
 88 |     recovered["tok_embeddings.weight"] = st_dict["model.embed_tokens.weight"]
 89 |     recovered["output.weight"]         = st_dict["model.embed_tokens.weight"]
 90 |     recovered["norm.weight"]           = st_dict["model.norm.weight"]
 91 | 
 92 |     print(f"Saving to {output_file}")
 93 |     torch.save(recovered, output_file)
 94 | 
 95 | if __name__ == "__main__":
 96 |     import argparse
 97 |     parser = argparse.ArgumentParser(description="Convert Safetensors back to Torch .pth checkpoint")
 98 |     parser.add_argument(
 99 |         "--safetensors_file", type=str, required=True,
100 |         help="Path to input .safetensors file"
101 |     )
102 |     parser.add_argument(
103 |         "--output", type=str, default="./checkpoints/model_state.pt",
104 |         help="Path to output .pt file"
105 |     )
106 |     parser.add_argument(
107 |         "--model_name", type=str, default="2B",
108 |         help="Model configuration name to use (e.g. 2B)"
109 |     )
110 |     args = parser.parse_args()
111 | 
112 |     convert_back(
113 |         safetensors_path=args.safetensors_file,
114 |         output_file=args.output,
115 |         model_name=args.model_name,
116 |     )


--------------------------------------------------------------------------------
/gpu/generate.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  2 | #
  3 | # This source code is licensed under the BSD license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | import json
  7 | import os
  8 | import readline  # type: ignore # noqa
  9 | import sys
 10 | import time
 11 | from dataclasses import dataclass
 12 | from pathlib import Path
 13 | from typing import Iterable, Optional, Tuple, Union
 14 | 
 15 | import fire
 16 | import model as fast
 17 | import torch
 18 | from stats import Stats
 19 | from tokenizer import Tokenizer, ChatFormat
 20 | import sample_utils
 21 | from xformers.ops.fmha.attn_bias import (
 22 |     BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
 23 | )
 24 | 
 25 | 
 26 | @dataclass
 27 | class GenArgs:
 28 |     gen_length: int = 32
 29 |     gen_bsz: int = 1
 30 |     prompt_length: int = 64
 31 | 
 32 |     use_sampling: bool = False
 33 |     temperature: float = 0.8
 34 |     top_p: float = 0.9
 35 | 
 36 | 
 37 | class FastGen:
 38 |     GRAPH_WARMUPS: int = 1
 39 |     tokenizer: Tokenizer
 40 | 
 41 |     @staticmethod
 42 |     def build(
 43 |         ckpt_dir: str,
 44 |         gen_args: GenArgs,
 45 |         device: Union[torch.device, str],
 46 |         tokenizer_path: Optional[str] = None,
 47 |         num_layers: int = 13,
 48 |         use_full_vocab: bool = False,
 49 |     ) -> "FastGen":
 50 |         """
 51 |         Load a Llama or Code Llama checkpoint and return a new
 52 |         generator for this model.
 53 |         """
 54 |         start_time = time.time()
 55 | 
 56 |         model_args_prefill = fast.ModelArgs(use_kernel=False)
 57 |         model_args_decode = fast.ModelArgs(use_kernel=True)
 58 |         tokenizer = Tokenizer("./tokenizer.model")
 59 | 
 60 |         torch.set_default_device(device)
 61 |         torch.set_default_dtype(torch.bfloat16)
 62 | 
 63 |         prefill_model = fast.Transformer(model_args_prefill)
 64 |         decode_model = fast.Transformer(model_args_decode)
 65 | 
 66 |         fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt")
 67 |         fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu")
 68 |         int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt")
 69 |         int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu")
 70 |         prefill_model.load_state_dict(fp16_checkpoint, strict=True)
 71 |         decode_model.load_state_dict(int2_checkpoint, strict=True)
 72 | 
 73 |         torch.cuda.synchronize()
 74 |         print(f"loaded model in {time.time() - start_time:.2f} seconds")
 75 |         start_time = time.time()
 76 | 
 77 |         return FastGen(gen_args, model_args_prefill, prefill_model, decode_model, tokenizer)
 78 | 
 79 |     def __init__(
 80 |         self,
 81 |         args: GenArgs,
 82 |         model_args: fast.ModelArgs,
 83 |         prefill_model: fast.Transformer,
 84 |         decode_model: fast.Transformer,
 85 |         tokenizer: Tokenizer,
 86 |     ):
 87 |         self.gen_args = args
 88 |         self.max_seq_length = args.prompt_length + args.gen_length
 89 |         self.model_args = model_args
 90 |         # self.model = model
 91 |         self.prefill_model = prefill_model
 92 |         self.decode_model = decode_model
 93 |         self.tokenizer = tokenizer
 94 |         self._prefill_cuda_graph, self._prefill_compile_model, self._prefill_inputs, self._prefill_logits = None, None, None, None
 95 |         self._generate_cuda_graph, self._generate_compile_model, self._generate_inputs, self._generate_logits = None, None, None, None
 96 |         self._cache = None
 97 |         start_time = time.time()
 98 |         self._prefill_compile_model = self.compile_prefill()
 99 |         self._generate_compile_model = self.compile_generate()
100 |         print(f"compiled model in {time.time() - start_time:.2f} seconds")
101 | 
102 |     def compile_prefill(self):
103 | 
104 |         if self._cache is None:
105 |             self._cache = fast.make_cache(
106 |                 args=self.model_args,
107 |                 length=self.gen_args.gen_bsz * self.max_seq_length,
108 |             )
109 | 
110 |         seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)]
111 | 
112 |         bias = AttnBias.from_seqlens(
113 |             q_seqlen=seq_lens,
114 |             kv_seqlen=seq_lens,
115 |             kv_padding=self.max_seq_length,
116 |         )
117 |         bias.q_seqinfo.to("cuda")
118 |         bias.k_seqinfo.to("cuda")
119 | 
120 |         tokens = torch.IntTensor([1] * self.gen_args.gen_bsz * self.gen_args.prompt_length).cuda()
121 |         self._prefill_inputs = (tokens, bias)
122 | 
123 |         s = torch.cuda.Stream()
124 |         s.wait_stream(torch.cuda.current_stream())
125 |         
126 |         with torch.cuda.stream(s):
127 |             _ = self.prefill_model.forward_with_attn_bias(
128 |                 token_values=self._prefill_inputs[0],
129 |                 attn_bias=self._prefill_inputs[1],
130 |                 cache=self._cache,
131 |             )
132 |         torch.cuda.current_stream().wait_stream(s)
133 | 
134 |         self._prefill_cuda_graph = torch.cuda.CUDAGraph()
135 |         recording_kwargs = {}
136 |         if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
137 |             # In PyTorch 2.1+ and nightlies from late Aug 2023,
138 |             # we can do this to maybe avoid watchdog-related crashes
139 |             recording_kwargs["capture_error_mode"] = "thread_local"
140 |         with torch.cuda.graph(self._prefill_cuda_graph, **recording_kwargs):
141 |             self._prefill_logits = self.prefill_model.forward_with_attn_bias(
142 |                 token_values=self._prefill_inputs[0],
143 |                 attn_bias=self._prefill_inputs[1],
144 |                 cache=self._cache,
145 |             )
146 | 
147 |         def replay(tokens, seq_lens=None):
148 |             self._prefill_inputs[0].copy_(tokens)
149 |             if seq_lens is not None:
150 |                 self._prefill_inputs[1].k_seqinfo.seqlen.copy_(seq_lens)
151 | 
152 |             self._prefill_cuda_graph.replay()
153 |             torch.cuda.synchronize()
154 | 
155 |             return self._prefill_logits
156 | 
157 |         return replay
158 | 
159 |     def compile_generate(self):
160 | 
161 |         if self._cache is None:
162 |             self._cache = fast.make_cache(
163 |                 args=self.model_args,
164 |                 length=self.gen_args.gen_bsz * self.max_seq_length,
165 |             )
166 | 
167 |         seq_lens = [1 for _ in range(self.gen_args.gen_bsz)]
168 |         kv_seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)]
169 | 
170 |         bias = AttnBias.from_seqlens(
171 |             q_seqlen=seq_lens,
172 |             kv_seqlen=kv_seq_lens,
173 |             kv_padding=self.max_seq_length,
174 |         )
175 |         bias.q_seqinfo.to("cuda")
176 |         bias.k_seqinfo.to("cuda")
177 | 
178 |         tokens = torch.IntTensor([1] * self.gen_args.gen_bsz).cuda()
179 |         self._generate_inputs = (tokens, bias)
180 | 
181 |         s = torch.cuda.Stream()
182 |         s.wait_stream(torch.cuda.current_stream())
183 |         
184 |         with torch.cuda.stream(s):
185 |             _ = self.decode_model.forward_with_attn_bias(
186 |                 token_values=self._generate_inputs[0],
187 |                 attn_bias=self._generate_inputs[1],
188 |                 cache=self._cache,
189 |             )
190 |         torch.cuda.current_stream().wait_stream(s)
191 | 
192 |         self._generate_cuda_graph = torch.cuda.CUDAGraph()
193 |         recording_kwargs = {}
194 |         if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
195 |             # In PyTorch 2.1+ and nightlies from late Aug 2023,
196 |             # we can do this to maybe avoid watchdog-related crashes
197 |             recording_kwargs["capture_error_mode"] = "thread_local"
198 |         with torch.cuda.graph(self._generate_cuda_graph, **recording_kwargs):
199 |             self._generate_logits = self.decode_model.forward_with_attn_bias(
200 |                 token_values=self._generate_inputs[0],
201 |                 attn_bias=self._generate_inputs[1],
202 |                 cache=self._cache,
203 |             )
204 | 
205 |         def replay(tokens, seq_lens):
206 |             self._generate_inputs[0].copy_(tokens)
207 |             self._generate_inputs[1].k_seqinfo.seqlen.copy_(seq_lens)
208 | 
209 |             self._generate_cuda_graph.replay()
210 | 
211 |             return self._generate_logits
212 | 
213 |         return replay
214 | 
215 | 
216 |     @torch.inference_mode()
217 |     def generate_all(
218 |         self, prompts: list[list[int]], use_cuda_graphs: bool, use_sampling: bool
219 |     ) -> Tuple[Stats, list[list[int]]]:
220 |         bs = len(prompts)
221 |         prompt_lens = [len(p) for p in prompts]
222 |         padded_prompt_lens = [self.gen_args.prompt_length] * bs
223 |         max_prompt_length = max(prompt_lens)
224 |         gen_length = self.gen_args.gen_length
225 |         max_seq_length = max_prompt_length + gen_length
226 |         print(max_prompt_length, gen_length)
227 | 
228 |         bias = AttnBias.from_seqlens(
229 |             q_seqlen=padded_prompt_lens,
230 |             kv_seqlen=prompt_lens,
231 |             kv_padding=max_seq_length,
232 |         )
233 |         bias.q_seqinfo.to("cuda")
234 |         bias.k_seqinfo.to("cuda")
235 | 
236 |         # Input tensors to the cuda graph
237 |         kv_seqlen = bias.k_seqinfo.seqlen
238 |         prompts = [prompt + [1] * (self.gen_args.prompt_length - len(prompt)) for prompt in prompts]
239 |         tokens = torch.IntTensor(sum(prompts, [])).cuda()
240 |         out_tokens = torch.zeros((max_seq_length, bs), dtype=torch.int)
241 | 
242 |         stats = Stats()
243 |         torch.cuda.synchronize()
244 |         stats.phase("prefill" if use_cuda_graphs else "total")
245 |         # stats.phase("total")
246 | 
247 |         output = self._prefill_compile_model(tokens, None)
248 | 
249 |         logits = output[kv_seqlen - 1, :]
250 |         logits = logits.view(bs, self.model_args.vocab_size)
251 | 
252 |         if use_sampling:
253 |             temp = 0.7
254 |             top_p = 0.95
255 |             probs = torch.softmax(logits / temp, dim=-1)
256 |             next_token = sample_utils.top_p(probs, top_p)
257 |         else:
258 |             next_token = torch.argmax(logits, dim=-1)        
259 | 
260 |         next_token = next_token.reshape(bs)
261 |         out_tokens[0, :] = next_token
262 | 
263 |         torch.cuda.synchronize()
264 |         stats.phase("decode" if use_cuda_graphs else "total")
265 | 
266 |         eos_id = self.tokenizer.eot_id
267 |         for niter in range(1, gen_length):
268 |             kv_seqlen.add_(kv_seqlen < max_seq_length)
269 |             output = self._generate_compile_model(next_token, kv_seqlen)
270 | 
271 |             logits = output.view(bs, self.model_args.vocab_size)
272 | 
273 |             if use_sampling:
274 |                 temp = 0.7
275 |                 top_p = 0.95
276 |                 probs = torch.softmax(logits / temp, dim=-1)
277 |                 next_token = sample_utils.top_p(probs, top_p)
278 |             else:
279 |                 next_token = torch.argmax(logits, dim=-1)
280 | 
281 |             next_token = next_token.reshape(bs)
282 |             out_tokens[niter, :] = next_token
283 | 
284 |             if next_token.eq(eos_id).any():
285 |                 break
286 | 
287 |         torch.cuda.synchronize()
288 |         stats.end_phase(tokens=niter * bs)
289 | 
290 |         def trim_answer(prompt_len, tokens):
291 |             # print(prompt, tokens)
292 |             """Trim the answer to end it on an eos token."""
293 |             tokens = tokens[: max_seq_length - prompt_len]
294 |             eos_id = self.tokenizer.eot_id
295 |             if eos_id in tokens:
296 |                 return tokens[: tokens.index(eos_id) + 1]
297 |             else:
298 |                 return tokens
299 | 
300 |         answers = [
301 |             trim_answer(prompt_len, answer)
302 |             for prompt_len, answer in zip(prompt_lens, out_tokens.t().tolist())
303 |         ]
304 |         return stats, answers
305 | 
306 | 
307 | def get_prompts(interactive: bool) -> Iterable[list[str]]:
308 |     if interactive:
309 |         while True:
310 |             try:
311 |                 prompts = input("enter prompt: ").split("\n")
312 |             except EOFError:
313 |                 print("exiting")
314 |                 sys.exit(0)
315 |             yield prompts
316 |     else:
317 |         yield [
318 |             "Hello, my name is",
319 |         ]
320 | 
321 | 
322 | def main(ckpt_dir: str, interactive: bool = False, chat_format: bool = False, sampling: bool = False):
323 | 
324 |     local_rank = 0
325 |     device = f"cuda:{local_rank}"
326 |     torch.cuda.set_device(local_rank)
327 | 
328 |     g = FastGen.build(ckpt_dir, GenArgs(), device)
329 | 
330 |     if chat_format:
331 |         g.tokenizer = ChatFormat(g.tokenizer)
332 | 
333 |     for prompts in get_prompts(interactive):
334 |         # prompts = [f"{prompt}\n" for prompt in prompts]
335 |         if chat_format:
336 |             # prompts = [f'<|begin_of_text|>User: {prompt}<|eot_id|>Assistant: ' for prompt in prompts]
337 |             tokens = [g.tokenizer.encode_dialog_prompt(dialog=[{"role": "user", "content": prompt}], completion=True) for prompt in prompts]
338 |         else:
339 |             tokens = [g.tokenizer.encode(x, bos=False, eos=False) for x in prompts]
340 | 
341 |         print(tokens)
342 |         stats, out_tokens = g.generate_all(
343 |             tokens, use_cuda_graphs="NO_CUDA_GRAPHS" not in os.environ, use_sampling=sampling,
344 |         )
345 | 
346 |         for i, prompt in enumerate(prompts):
347 |             print(f"> {prompt}")
348 |             answer = g.tokenizer.decode(out_tokens[i])
349 |             print(answer)
350 |             print("---------------")
351 | 
352 |         for phase_stats in stats.phases:
353 |             print(phase_stats.show())
354 | 
355 |         print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
356 | 
357 | 
358 | if __name__ == "__main__":
359 |     fire.Fire(main)


--------------------------------------------------------------------------------
/gpu/model.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  2 | #
  3 | # This source code is licensed under the BSD license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | from dataclasses import dataclass
  7 | from typing import Optional, Tuple, Union
  8 | 
  9 | import torch
 10 | from torch import nn
 11 | from torch.nn import functional as F
 12 | 
 13 | from xformers.ops import RMSNorm, fmha, rope_padded
 14 | from xformers.ops.fmha.attn_bias import (
 15 |     BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
 16 | )
 17 | 
 18 | import ctypes
 19 | bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so')
 20 | 
 21 | def bitnet_int8xint2_linear(input0, input1, s, ws):
 22 |     out_shape = list(input0.shape)
 23 |     out_shape[-1] = input1.shape[0]
 24 | 
 25 |     stream = torch.cuda.current_stream()
 26 | 
 27 |     M = input0.shape[0]
 28 |     if len(out_shape) == 3: 
 29 |         M *= input0.shape[1]
 30 |     N = input1.shape[0]
 31 |     K = input1.shape[1] * 4
 32 | 
 33 |     ret = torch.zeros(*out_shape, dtype=torch.bfloat16, device=input0.device)
 34 | 
 35 |     bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)])
 36 | 
 37 |     return ret
 38 | 
 39 | @dataclass
 40 | class ModelArgs:
 41 |     dim: int = 2560
 42 |     n_layers: int = 30
 43 |     n_heads: int = 20
 44 |     n_kv_heads: int = 5
 45 |     vocab_size: int = 128256
 46 |     ffn_dim: int = 6912
 47 |     norm_eps: float = 1e-5
 48 |     rope_theta: float = 500000.0
 49 |     use_kernel: bool = False
 50 | 
 51 | 
 52 | LayerCache = Tuple[torch.Tensor, torch.Tensor]
 53 | 
 54 | class BitLinearKernel(nn.Module):
 55 |     in_features: int
 56 |     out_features: int
 57 |     weight: torch.Tensor
 58 |     weight_scale: torch.Tensor
 59 | 
 60 |     def __init__(self, in_features: int, out_features: int, bias: bool = False):
 61 |         super().__init__()
 62 |         self.in_features = in_features
 63 |         self.out_features = out_features
 64 | 
 65 |         self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features//4, dtype=torch.int8), requires_grad=False)
 66 |         self.weight_scale = torch.nn.Parameter(torch.zeros(4, dtype=torch.bfloat16), requires_grad=False)
 67 | 
 68 |     @torch.compile
 69 |     def quant_input(self, input):
 70 |         s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
 71 |         return (input * s).round().clamp(-128, 127).to(torch.int8), s
 72 | 
 73 |     def forward(self, input):
 74 |         input, s = self.quant_input(input)
 75 |         return bitnet_int8xint2_linear(input, self.weight, s, self.weight_scale)
 76 | 
 77 | class BitLinear(nn.Linear):
 78 |     @torch.compile
 79 |     def quant_input(self, input):
 80 |         s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
 81 |         return (input * s).round().clamp(-128, 127) / s
 82 | 
 83 |     def forward(self, input):
 84 |         input = self.quant_input(input)
 85 |         return F.linear(input, self.weight)
 86 | 
 87 | class Attention(nn.Module):
 88 |     def __init__(
 89 |         self,
 90 |         dim: int,
 91 |         head_dim: int,
 92 |         n_heads: int,
 93 |         n_kv_heads: int,
 94 |         rope_theta: float,
 95 |         norm_eps: float,
 96 |         use_kernel: bool,
 97 |     ):
 98 |         super().__init__()
 99 | 
100 |         self.head_dim = head_dim
101 |         self.rope_theta = rope_theta
102 | 
103 |         self.n_local_heads = n_heads
104 |         self.n_local_kv_heads = n_kv_heads
105 | 
106 |         Linear = BitLinearKernel if use_kernel else BitLinear
107 | 
108 |         self.wqkv = Linear(
109 |             dim,
110 |             (self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim,
111 |             bias=False,
112 |         )
113 |         self.wo = Linear(
114 |             self.n_local_heads * head_dim,
115 |             dim,
116 |             bias=False,
117 |         )
118 | 
119 |         self.attn_sub_norm = RMSNorm(dim, norm_eps)
120 | 
121 |     def forward(
122 |         self,
123 |         x: torch.Tensor,
124 |         cache: LayerCache,
125 |         attn_bias: AttnBias,
126 |     ) -> torch.Tensor:
127 | 
128 |         xqkv = self.wqkv(x)
129 |         xq = xqkv[:, : (self.n_local_heads * self.head_dim)]
130 |         xkv = xqkv[:, (self.n_local_heads * self.head_dim) :]
131 |         xk, xv = xkv.chunk(2, 1)
132 | 
133 |         output_shape = xq.shape
134 |         heads_per_group = self.n_local_heads // self.n_local_kv_heads
135 |         xq = xq.view(
136 |             1, xq.shape[0], self.n_local_kv_heads, heads_per_group, self.head_dim
137 |         )
138 |         xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, 1, self.head_dim)
139 |         # xq = rearrange(xq, 'b (g h l d) -> 1 b h g (d l)', g=heads_per_group, h=self.n_local_kv_heads, d=self.head_dim // 2, l=2)
140 |         # xk = rearrange(xk, 'b (g l d) -> 1 b g 1 (d l)', g=self.n_local_kv_heads, d=self.head_dim // 2)
141 |         xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, 1, self.head_dim)
142 |         cache_k, cache_v = cache
143 | 
144 |         xq = rope_padded(
145 |             xq=xq,
146 |             xk=xk,
147 |             xv=xv,
148 |             cache_k=cache_k,
149 |             cache_v=cache_v,
150 |             attn_bias=attn_bias,
151 |             theta=self.rope_theta,
152 |         )
153 | 
154 |         output = fmha.memory_efficient_attention_forward(
155 |             xq, cache_k, cache_v, attn_bias, op = fmha.flash.FwOp
156 |         )
157 | 
158 |         output = output.reshape(output_shape)
159 |         output = self.attn_sub_norm(output)
160 |         output = self.wo(output)
161 | 
162 |         return output
163 | 
164 | @torch.compile
165 | def squared_relu(x: torch.Tensor) -> torch.Tensor:
166 |     return F.relu(x) ** 2
167 | 
168 | class FeedForward(nn.Module):
169 |     def __init__(
170 |         self,
171 |         dim: int,
172 |         hidden_dim: int,
173 |         norm_eps: float,
174 |         use_kernel: bool,
175 |     ):
176 |         super().__init__()
177 | 
178 |         Linear = BitLinearKernel if use_kernel else BitLinear
179 | 
180 |         self.w13 = Linear(
181 |             dim,
182 |             2 * hidden_dim,
183 |             bias=False,
184 |         )
185 |         self.w2 = Linear(
186 |             hidden_dim,
187 |             dim,
188 |             bias=False,
189 |         )
190 |         self.ffn_sub_norm = RMSNorm(hidden_dim, norm_eps)
191 | 
192 |     def forward(self, x: torch.Tensor) -> torch.Tensor:
193 |         x13 = self.w13(x)
194 |         x1, x3 = x13.chunk(2, -1)
195 |         inner = self.ffn_sub_norm(squared_relu(x1) * x3)
196 |         output = self.w2(inner)
197 |         return output
198 | 
199 | 
200 | class TransformerBlock(nn.Module):
201 |     def __init__(self, args: ModelArgs):
202 |         super().__init__()
203 | 
204 |         assert args.dim % args.n_heads == 0
205 |         head_dim = args.dim // args.n_heads
206 |         if args.n_kv_heads is not None:
207 |             n_kv_heads = args.n_kv_heads
208 |         else:
209 |             n_kv_heads = args.n_heads
210 | 
211 |         assert args.n_heads % n_kv_heads == 0
212 | 
213 |         self.attention = Attention(
214 |             dim=args.dim,
215 |             head_dim=head_dim,
216 |             n_heads=args.n_heads,
217 |             n_kv_heads=n_kv_heads,
218 |             rope_theta=args.rope_theta,
219 |             norm_eps=args.norm_eps,
220 |             use_kernel=args.use_kernel,
221 |         )
222 |         self.feed_forward = FeedForward(
223 |             dim=args.dim,
224 |             hidden_dim=args.ffn_dim,
225 |             norm_eps=args.norm_eps,
226 |             use_kernel=args.use_kernel,
227 |         )
228 |         self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
229 |         self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
230 | 
231 |     def forward(
232 |         self,
233 |         x: torch.Tensor,
234 |         cache: LayerCache,
235 |         attn_bias: AttnBias,
236 |     ) -> torch.Tensor:
237 |         h = x + self.attention.forward(
238 |             self.attention_norm(x),
239 |             cache,
240 |             attn_bias,
241 |         )
242 |         out = h + self.feed_forward(self.ffn_norm(h))
243 |         return out
244 | 
245 | 
246 | class Transformer(nn.Module):
247 |     def __init__(self, args: ModelArgs):
248 |         super().__init__()
249 |         assert args.vocab_size > 0
250 | 
251 |         self.tok_embeddings = nn.Embedding(
252 |             num_embeddings=args.vocab_size,
253 |             embedding_dim=args.dim,
254 |         )
255 | 
256 |         self.layers = nn.ModuleList()
257 |         for _ in range(args.n_layers):
258 |             self.layers.append(TransformerBlock(args))
259 | 
260 |         self.norm = RMSNorm(args.dim, eps=args.norm_eps)
261 | 
262 |         self.output = nn.Linear(
263 |             args.dim,
264 |             args.vocab_size,
265 |             bias=False,
266 |         )
267 | 
268 |     @torch.no_grad()
269 |     def forward_with_attn_bias(
270 |         self,
271 |         token_values: torch.Tensor,
272 |         attn_bias: AttnBias,
273 |         cache: list[LayerCache],
274 |     ) -> torch.Tensor:
275 |         h = self.tok_embeddings(token_values)
276 | 
277 |         for i, layer in enumerate(self.layers):
278 |             h = layer(h, cache[i], attn_bias)
279 | 
280 |         logits = self.output(self.norm(h))
281 |         return logits.float()
282 | 
283 |     def forward(
284 |         self,
285 |         token_values: torch.Tensor,
286 |         token_lengths: torch.Tensor,
287 |         start_pos: torch.Tensor,
288 |         cache: list[LayerCache],
289 |         kv_padding: int,
290 |     ) -> torch.Tensor:
291 |         attn_bias = AttnBias.from_seqlens(
292 |             q_seqlen=token_lengths.tolist(),
293 |             kv_seqlen=(start_pos + token_lengths).tolist(),
294 |             kv_padding=kv_padding,
295 |         )
296 |         return self.forward_with_attn_bias(token_values, attn_bias, cache)
297 | 
298 | 
299 | def make_cache(
300 |     args: ModelArgs,
301 |     length: int,
302 |     device: Optional[Union[str, torch.device]] = None,
303 |     n_layers: Optional[int] = None,
304 |     dtype: Optional[torch.dtype] = None,
305 | ) -> list[LayerCache]:
306 |     """
307 |     Allocate a cache to be used with the Transformer module.
308 | 
309 |     Args:
310 |         args (ModelArgs): the model configuration.
311 |         length (int): per layer cache size.
312 |             It is usually budgeted as ``max_batch * max_seq``
313 |         device (torch.device, optional): the device on which
314 |             the cache should be allocated.
315 |         n_layers (int, optional): the number of layers to
316 |             allocate a cache for (defaults to the model
317 |             settings).
318 |         dtype (torch.dtype, optional): the dtype to use for
319 |             cache entries (defaults to the default dtype).
320 | 
321 |     Returns:
322 |         The cache object to pass to ``Tranformer.forward``.
323 |     """
324 | 
325 |     head_dim = args.dim // args.n_heads
326 |     n_kv_heads = args.n_kv_heads
327 |     if n_kv_heads is None:
328 |         n_kv_heads = args.n_heads
329 |     n_local_kv_heads = n_kv_heads
330 | 
331 |     if n_layers is None:
332 |         n_layers = args.n_layers
333 | 
334 |     shape = (1, length, n_local_kv_heads, 1, head_dim)
335 |     heads_per_group = args.n_heads // n_kv_heads
336 |     expansion = (-1, -1, -1, heads_per_group, -1)
337 |     return [
338 |         (
339 |             torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
340 |             torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
341 |         )
342 |         for _ in range(n_layers)
343 |     ]
344 | 
345 | 
346 | def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]:
347 |     """
348 |     Take a prefix view of a larger cache.
349 | 
350 |     The original cache object remains of identical size and valid
351 |     after the shrinked alias has been used. This function is useful
352 |     when a cache was allocated for a larger batch size than what is
353 |     necessary.
354 | 
355 |     Args:
356 |         cache: the cache to take a view in.
357 |         length (int): the desired length
358 | 
359 |     Returns:
360 |         A view in the input cache object.
361 |     """
362 | 
363 |     if len(cache) > 0:
364 |         assert cache[0][0].shape[1] >= length
365 | 
366 |     return [(ck[:, :length], cv[:, :length]) for ck, cv in cache]


--------------------------------------------------------------------------------
/gpu/pack_weight.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import numpy as np
 3 | 
 4 | 
 5 | def B_global_16x32_to_shared_load_16x32_layout(i, j):
 6 |     """
 7 |          stride * 8 * (tx // HALF_WARP_expr)
 8 |                 + (tx % 8) * stride
 9 |                 + 16 * ((tx % HALF_WARP_expr) // 8)
10 |     """
11 |     thread_id = i * 2 + j // 16
12 |     row = (thread_id // 16) * 8 + (thread_id % 8)
13 |     col = (j % 16) + 16 * ((thread_id % 16) // 8)
14 |     return row, col
15 | 
16 | 
17 | def permutate_weight_fastest(weight):
18 |     wmma_n = 16
19 |     wmma_k = 32
20 |     N = weight.shape[0]
21 |     K = weight.shape[1]
22 |     
23 |     # Create a lookup table for the permutation
24 |     mapping = np.zeros((wmma_n, wmma_k, 2), dtype=int)
25 |     for ii in range(wmma_n):
26 |         for jj in range(wmma_k):
27 |             mapping[ii, jj] = B_global_16x32_to_shared_load_16x32_layout(ii, jj)
28 |     
29 |     # Reshape weight for the final format
30 |     permutated_weight = np.zeros((N // wmma_n, K // wmma_k, wmma_n, wmma_k), dtype="int8")
31 |     
32 |     # Use advanced indexing for the entire operation
33 |     i_indices = np.arange(N // wmma_n)[:, np.newaxis, np.newaxis, np.newaxis]
34 |     j_indices = np.arange(K // wmma_k)[np.newaxis, :, np.newaxis, np.newaxis]
35 |     
36 |     # Create the source indices
37 |     src_i = i_indices * wmma_n + mapping[:, :, 0]
38 |     src_j = j_indices * wmma_k + mapping[:, :, 1]
39 |     
40 |     # Extract and reshape in one go
41 |     permutated_weight = weight[src_i, src_j]
42 |     
43 |     return permutated_weight
44 | 
45 | 
46 | def compress_int2_to_int8(int2_weight):
47 |     int8_weight = np.zeros(
48 |         (*int2_weight.shape[:-1], int2_weight.shape[-1] // 4), dtype=np.int8
49 |     )
50 |     for j in range(int2_weight.shape[-1] // 4):
51 |         for k in range(4):
52 |             int8_weight[:, :, :, j] |= int2_weight[:, :, :, j * 4 + k] << (k * 2)
53 |     return int8_weight
54 | 
55 | 
56 | def interleave_weight_int8(qweight, nbits=2):\
57 |     # reinterpret the data type of qweight to int32
58 |     # shift = [ 0,  8, 16, 24,  2, 10, 18, 26,  4, 12, 20, 28,  6, 14, 22, 30]
59 |     # index: [ 0,  4,  8, 12,  1,  5,  9, 13,  2,  6, 10, 14,  3,  7, 11, 15]
60 |     qweight = qweight.view(np.int32)
61 |     new_qweight = np.zeros_like(qweight)
62 |     bits_stride = 8
63 |     mask = (1 << nbits) - 1  # for 4bit the val is 0x0000000f
64 |     num_groups = 32 // bits_stride # 4
65 |     elems_per_group = bits_stride // nbits  # 4
66 |     for i in range(num_groups):
67 |         for j in range(elems_per_group):
68 |             offset = i * elems_per_group + j
69 |             shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
70 | 
71 |             new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
72 |     return new_qweight.view(np.int8)
73 | 
74 | 
75 | 
76 | def convert_weight_int8_to_int2(weight):
77 |     N = weight.shape[0]
78 |     K = weight.shape[1]
79 | 
80 |     weight = weight+2
81 |     
82 |     weight = weight.cpu().numpy()
83 | 
84 |     # print(weight)
85 |     # print(torch.max(weight), torch.min(weight))
86 | 
87 |     # permutated_weight_slow = permutate_weight(weight)
88 |     permutated_weight = permutate_weight_fastest(weight)
89 |     # assert np.all(permutated_weight_slow == permutated_weight)
90 |     # print("Permutation is correct")
91 |     compressed_weight = compress_int2_to_int8(permutated_weight)
92 |     interleaved_weight = interleave_weight_int8(compressed_weight, 2)
93 | 
94 |     ret = torch.from_numpy(interleaved_weight)
95 | 
96 |     ret = torch.reshape(ret, (N, K // 4))
97 | 
98 |     return ret
99 | 


--------------------------------------------------------------------------------
/gpu/requirements.txt:
--------------------------------------------------------------------------------
1 | fire
2 | sentencepiece
3 | torch>=2.2.0
4 | xformers>=0.0.22
5 | tiktoken
6 | blobfile
7 | flask
8 | einops
9 | transformers


--------------------------------------------------------------------------------
/gpu/sample_utils.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 2 | #
 3 | # This source code is licensed under the BSD license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import torch
 7 | 
 8 | @torch.compile
 9 | def top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
10 |     """
11 |     Perform top-p (nucleus) sampling on a probability distribution.
12 | 
13 |     Args:
14 |         probs (torch.Tensor): probability distribution tensor.
15 |         p (float): probability threshold for top-p sampling.
16 | 
17 |     Returns:
18 |         torch.Tensor: sampled token indices.
19 | 
20 |     Note:
21 |         Top-p sampling selects the smallest set of tokens whose cumulative
22 |         probability mass exceeds the threshold p. The distribution is
23 |         renormalized based on the selected tokens.
24 |     """
25 |     probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
26 |     probs_sum = torch.cumsum(probs_sort, dim=-1)
27 |     mask = probs_sum - probs_sort > p
28 |     probs_sort[mask] = 0.0
29 |     next_token = torch.multinomial(probs_sort, num_samples=1)
30 |     next_token = torch.gather(probs_idx, -1, next_token)
31 |     return next_token


--------------------------------------------------------------------------------
/gpu/stats.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 2 | #
 3 | # This source code is licensed under the BSD license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import time
 7 | from dataclasses import dataclass
 8 | from typing import Optional
 9 | 
10 | 
11 | @dataclass
12 | class PhaseStats:
13 |     name: str
14 |     tokens: int
15 |     time: float
16 | 
17 |     def show(self) -> str:
18 |         tps = self.tokens / self.time
19 |         return (
20 |             f"[{self.name}] "
21 |             f"generated tokens: {self.tokens}"
22 |             f" - total time: {self.time:.3f}s"
23 |             f" - {tps:.1f} tokens per second"
24 |         )
25 | 
26 | 
27 | class Stats:
28 |     """
29 |     Generation stats, split by phases.
30 |     """
31 | 
32 |     def __init__(self):
33 |         self.phases = []
34 |         self.current = None
35 | 
36 |     def end_phase(self, tokens: int, now: Optional[float] = None):
37 |         """Terminate the current phase."""
38 |         if self.current is None:
39 |             return
40 |         if now is None:
41 |             now = time.time()
42 |         cname, ctokens, ctime = self.current
43 |         stats = PhaseStats(
44 |             name=cname,
45 |             tokens=tokens - ctokens,
46 |             time=now - ctime,
47 |         )
48 |         self.phases.append(stats)
49 | 
50 |     def phase(self, name: str, tokens: int = 0):
51 |         """
52 |         Start a new phase, and terminate the current one,
53 |         if one is ongoing.
54 |         """
55 |         now = time.time()
56 |         self.end_phase(tokens, now)
57 |         self.current = (name, tokens, now)


--------------------------------------------------------------------------------
/gpu/test.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | from torch.utils import benchmark
  3 | from torch import nn
  4 | 
  5 | from pack_weight import convert_weight_int8_to_int2
  6 | from torch.profiler import profile, record_function, ProfilerActivity
  7 | import ctypes
  8 | import numpy as np
  9 | # set all seed
 10 | torch.manual_seed(42)
 11 | np.random.seed(42)
 12 | 
 13 | bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so')
 14 | 
 15 | def bitnet_int8xint2_linear(input0, input1, s, ws, ret):
 16 |     out_shape = list(input0.shape)
 17 |     out_shape[-1] = input1.shape[0]
 18 | 
 19 |     stream = torch.cuda.current_stream()
 20 | 
 21 |     M = input0.shape[0]
 22 |     if len(out_shape) == 3: 
 23 |         M *= input0.shape[1]
 24 |     N = input1.shape[0]
 25 |     K = input1.shape[1] * 4
 26 | 
 27 |     bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)])
 28 | 
 29 |     return ret
 30 | 
 31 | if __name__ == '__main__':
 32 |     test_list = [
 33 |         (2560,  2560), 
 34 |         (3840,  2560), 
 35 |         (13824, 2560),
 36 |         (2560,  6912) ,
 37 |         (3200, 3200), 
 38 |         (4800, 3200), 
 39 |         (3200, 10240),
 40 |         (20480, 3200),
 41 |     ]
 42 |     for N,K in test_list:
 43 |         weight = torch.randint(-1, 2, (N, K), dtype=torch.int8, device='cuda')
 44 |         weight_scale = torch.ones(1, dtype=torch.bfloat16, device='cuda')
 45 |         weight_compressed = convert_weight_int8_to_int2(weight).to('cuda')
 46 | 
 47 |         for i in range(1):
 48 |             input0 = torch.randint(-128,127,(1, K),dtype=torch.int8, device='cuda')
 49 |             input0_bf16 = input0.to(torch.bfloat16)
 50 |             input_np = input0.cpu().to(torch.int32).numpy()
 51 |             weight_np = weight.cpu().to(torch.int32).T.numpy()
 52 |             out_np = np.matmul(input_np,weight_np)
 53 |             out_np = torch.tensor(out_np).cuda().to(torch.bfloat16)
 54 | 
 55 |             s = torch.ones(1, dtype=torch.bfloat16, device='cuda')
 56 |             ws = torch.ones(6, dtype=torch.bfloat16, device='cuda')
 57 | 
 58 |             ret = torch.empty((1,N), dtype=torch.bfloat16, device=input0.device)
 59 |             out = bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)
 60 | 
 61 |             print(f'custom == np {torch.all(out==out_np)}')
 62 | 
 63 |         input0 = torch.randint(-128,127,(1, K),dtype=torch.int8, device='cuda')
 64 |         input0_fp16 = input0.to(torch.float16)
 65 |         input0_bf16 = input0.to(torch.bfloat16)
 66 |         weight_fp16 = weight.to(torch.float16).T
 67 |         weight_bf16 = weight.to(torch.bfloat16).T
 68 |         ret = torch.empty((1,N), dtype=torch.bfloat16, device=input0.device)
 69 |         s = torch.ones(1, dtype=torch.bfloat16, device='cuda')
 70 |         ws = torch.ones(6, dtype=torch.bfloat16, device='cuda')
 71 |         t0 = benchmark.Timer(
 72 |             stmt="bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)",
 73 |             setup="from __main__ import input0, weight_compressed, s, ws, ret, bitnet_int8xint2_linear",
 74 |             num_threads=1,
 75 |         )
 76 | 
 77 |         t1 = benchmark.Timer(
 78 |             stmt="torch.matmul(input0_bf16,weight_bf16)",
 79 |             setup="from __main__ import input0_bf16, weight_bf16",
 80 |             num_threads=1,
 81 |         )
 82 | 
 83 |         time0 = t0.timeit(50)
 84 |         time1 = t1.timeit(50)
 85 | 
 86 |         print(f'Shape{N,K}, W2A8: {time0.mean * 1e6:.2f}us, torch BF16: {time1.mean * 1e6:.2f}us')
 87 |         # activities = [ ProfilerActivity.CUDA, 
 88 |         #             #   ProfilerActivity.CPU
 89 |         #               ]
 90 |         # sort_by_keyword = 'cuda' + "_time_total"
 91 |         # with profile(activities=activities, record_shapes=True) as prof:
 92 |         #     with record_function("model_inference1"):
 93 |         #         for _ in range(10):
 94 |         #             bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)
 95 |         #             torch.matmul(input0_fp16,weight_fp16)
 96 |         #             torch.matmul(input0_bf16,weight_bf16)
 97 | 
 98 |         # print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=15))
 99 |         
100 | 


--------------------------------------------------------------------------------
/gpu/tokenizer.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | from logging import getLogger
  3 | from pathlib import Path
  4 | from typing import (
  5 |     AbstractSet,
  6 |     cast,
  7 |     Collection,
  8 |     Dict,
  9 |     Iterator,
 10 |     List,
 11 |     Literal,
 12 |     Sequence,
 13 |     TypedDict,
 14 |     Union,
 15 | )
 16 | 
 17 | import tiktoken
 18 | from tiktoken.load import load_tiktoken_bpe
 19 | 
 20 | 
 21 | logger = getLogger(__name__)
 22 | 
 23 | Role = Literal["system", "user", "assistant"]
 24 | 
 25 | 
 26 | class Message(TypedDict):
 27 |     role: Role
 28 |     content: str
 29 | 
 30 | 
 31 | Dialog = Sequence[Message]
 32 | 
 33 | 
 34 | class Tokenizer:
 35 |     """
 36 |     Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
 37 |     """
 38 | 
 39 |     special_tokens: Dict[str, int]
 40 | 
 41 |     num_reserved_special_tokens = 256
 42 | 
 43 |     pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"  # noqa: E501
 44 | 
 45 |     def __init__(self, model_path: str):
 46 |         """
 47 |         Initializes the Tokenizer with a Tiktoken model.
 48 | 
 49 |         Args:
 50 |             model_path (str): The path to the Tiktoken model file.
 51 |         """
 52 |         assert os.path.isfile(model_path), model_path
 53 | 
 54 |         mergeable_ranks = load_tiktoken_bpe(model_path)
 55 |         num_base_tokens = len(mergeable_ranks)
 56 |         special_tokens = [
 57 |             "<|begin_of_text|>",
 58 |             "<|end_of_text|>",
 59 |             "<|reserved_special_token_0|>",
 60 |             "<|reserved_special_token_1|>",
 61 |             "<|reserved_special_token_2|>",
 62 |             "<|reserved_special_token_3|>",
 63 |             "<|start_header_id|>",
 64 |             "<|end_header_id|>",
 65 |             "<|reserved_special_token_4|>",
 66 |             "<|eot_id|>",  # end of turn
 67 |         ] + [
 68 |             f"<|reserved_special_token_{i}|>"
 69 |             for i in range(5, self.num_reserved_special_tokens - 5)
 70 |         ]
 71 |         self.special_tokens = {
 72 |             token: num_base_tokens + i for i, token in enumerate(special_tokens)
 73 |         }
 74 |         self.model = tiktoken.Encoding(
 75 |             name=Path(model_path).name,
 76 |             pat_str=self.pat_str,
 77 |             mergeable_ranks=mergeable_ranks,
 78 |             special_tokens=self.special_tokens,
 79 |         )
 80 |         logger.info(f"Reloaded tiktoken model from {model_path}")
 81 | 
 82 |         self.n_words: int = self.model.n_vocab
 83 |         # BOS / EOS token IDs
 84 |         self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
 85 |         self.eos_id: int = self.special_tokens["<|end_of_text|>"]
 86 |         self.pad_id: int = self.n_words - 1
 87 |         self.stop_tokens = {
 88 |             self.special_tokens["<|end_of_text|>"],
 89 |             self.special_tokens["<|eot_id|>"],
 90 |         }
 91 |         logger.info(
 92 |             f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
 93 |         )
 94 | 
 95 |     def encode(
 96 |         self,
 97 |         s: str,
 98 |         *,
 99 |         bos: bool,
100 |         eos: bool,
101 |         allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
102 |         disallowed_special: Union[Literal["all"], Collection[str]] = (),
103 |     ) -> List[int]:
104 |         """
105 |         Encodes a string into a list of token IDs.
106 | 
107 |         Args:
108 |             s (str): The input string to be encoded.
109 |             bos (bool): Whether to prepend the beginning-of-sequence token.
110 |             eos (bool): Whether to append the end-of-sequence token.
111 |             allowed_tokens ("all"|set[str]): allowed special tokens in string
112 |             disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
113 | 
114 |         Returns:
115 |             list[int]: A list of token IDs.
116 | 
117 |         By default, setting disallowed_special=() encodes a string by ignoring
118 |         special tokens. Specifically:
119 |         - Setting `disallowed_special` to () will cause all text corresponding
120 |           to special tokens to be encoded as natural text (insteading of raising
121 |           an error).
122 |         - Setting `allowed_special` to "all" will treat all text corresponding
123 |           to special tokens to be encoded as special tokens.
124 |         """
125 |         assert type(s) is str
126 | 
127 |         # The tiktoken tokenizer can handle <=400k chars without
128 |         # pyo3_runtime.PanicException.
129 |         TIKTOKEN_MAX_ENCODE_CHARS = 400_000
130 | 
131 |         # https://github.com/openai/tiktoken/issues/195
132 |         # Here we iterate over subsequences and split if we exceed the limit
133 |         # of max consecutive non-whitespace or whitespace characters.
134 |         MAX_NO_WHITESPACES_CHARS = 25_000
135 | 
136 |         substrs = (
137 |             substr
138 |             for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
139 |             for substr in self._split_whitespaces_or_nonwhitespaces(
140 |                 s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
141 |             )
142 |         )
143 |         t: List[int] = []
144 |         for substr in substrs:
145 |             t.extend(
146 |                 self.model.encode(
147 |                     substr,
148 |                     allowed_special=allowed_special,
149 |                     disallowed_special=disallowed_special,
150 |                 )
151 |             )
152 |         if bos:
153 |             t.insert(0, self.bos_id)
154 |         if eos:
155 |             t.append(self.eos_id)
156 |         return t
157 | 
158 |     def decode(self, t: Sequence[int]) -> str:
159 |         """
160 |         Decodes a list of token IDs into a string.
161 | 
162 |         Args:
163 |             t (List[int]): The list of token IDs to be decoded.
164 | 
165 |         Returns:
166 |             str: The decoded string.
167 |         """
168 |         # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
169 |         return self.model.decode(cast(List[int], t))
170 | 
171 |     @staticmethod
172 |     def _split_whitespaces_or_nonwhitespaces(
173 |         s: str, max_consecutive_slice_len: int
174 |     ) -> Iterator[str]:
175 |         """
176 |         Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
177 |         consecutive whitespaces or consecutive non-whitespaces.
178 |         """
179 |         current_slice_len = 0
180 |         current_slice_is_space = s[0].isspace() if len(s) > 0 else False
181 |         slice_start = 0
182 | 
183 |         for i in range(len(s)):
184 |             is_now_space = s[i].isspace()
185 | 
186 |             if current_slice_is_space ^ is_now_space:
187 |                 current_slice_len = 1
188 |                 current_slice_is_space = is_now_space
189 |             else:
190 |                 current_slice_len += 1
191 |                 if current_slice_len > max_consecutive_slice_len:
192 |                     yield s[slice_start:i]
193 |                     slice_start = i
194 |                     current_slice_len = 1
195 |         yield s[slice_start:]
196 | 
197 | class ChatFormat:
198 |     def __init__(self, tokenizer: Tokenizer):
199 |         self.tokenizer = tokenizer
200 |         self.eot_id = tokenizer.special_tokens["<|eot_id|>"]
201 |     
202 |     def decode(self, tokens: List[int]) -> str:
203 |         # Decode the tokens to a string.
204 |         decoded_str = self.tokenizer.decode(tokens)
205 |         # Remove the special tokens from the decoded string.
206 |         decoded_str = decoded_str.replace("<|eot_id|>", "")
207 |         return decoded_str
208 | 
209 |     def encode_header(self, message: Message) -> List[int]:
210 |         tokens = []
211 |         if message["role"] == "system":
212 |             tokens.extend(self.tokenizer.encode("System: ", bos=False, eos=False))
213 |         elif message["role"] == "user":
214 |             tokens.extend(self.tokenizer.encode("User: ", bos=False, eos=False))
215 |         elif message["role"] == "assistant":
216 |             tokens.extend(self.tokenizer.encode("Assistant: ", bos=False, eos=False))
217 |         else:
218 |             raise NotImplementedError(f"Role {message['role']} not implemented.")
219 |         # tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
220 |         # tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
221 |         # tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
222 |         # tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
223 |         return tokens
224 | 
225 |     def encode_message(self, message: Message, return_target=False) -> List[int]:
226 |         tokens, targets = [], []
227 |         headers = self.encode_header(message)
228 |         contents = self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
229 |         contents.append(self.tokenizer.special_tokens["<|eot_id|>"])
230 |         tokens = headers + contents
231 | 
232 |         if message["role"] == "assistant":
233 |             targets = [-1] * len(headers) + contents
234 |         else:
235 |             targets = [-1] * len(tokens)
236 | 
237 |         if return_target:
238 |             return tokens, targets
239 | 
240 |         return tokens, None
241 | 
242 |     def encode_dialog_prompt(self, dialog: Dialog, completion=False, return_target=False) -> List[int]:
243 |         tokens = [self.tokenizer.special_tokens["<|begin_of_text|>"]]
244 |         targets = [-1]
245 |         for message in dialog:
246 |             _tokens, _targets = self.encode_message(message, return_target=return_target)
247 |             tokens.extend(_tokens)
248 |             if _targets is not None:
249 |                 targets.extend(_targets)
250 |         # Add the start of an assistant message for the model to complete.
251 |         if completion:
252 |             tokens.extend(self.encode_header({"role": "assistant", "content": ""}))
253 |         
254 |         if return_target:
255 |             return tokens, targets
256 | 
257 |         return tokens


--------------------------------------------------------------------------------
/include/ggml-bitnet.h:
--------------------------------------------------------------------------------
 1 | #pragma once
 2 | 
 3 | #include "ggml.h"
 4 | #include "ggml-backend.h"
 5 | 
 6 | #ifdef __ARM_NEON
 7 | #include <arm_neon.h>
 8 | typedef float32_t bitnet_float_type;
 9 | #else
10 | typedef float bitnet_float_type;
11 | #endif
12 | 
13 | #ifdef  __cplusplus
14 | extern "C" {
15 | #endif
16 | 
17 | struct bitnet_tensor_extra {
18 |     int lut_scales_size;
19 |     int BK;
20 |     int n_tile_num;
21 |     uint8_t * qweights;
22 |     bitnet_float_type * scales;
23 | };
24 | 
25 | GGML_API void ggml_bitnet_init(void);
26 | GGML_API void ggml_bitnet_free(void);
27 | // src0->type == Q4_0/IQ2_XXS/IQ3_XXS
28 | // bitnet.cpp currently only supports BitNet quantization or GPTQ-like quantization (only scales, without zeros)
29 | // If use i-quantization gguf models, the results will be wrong
30 | // TODO: add customized block types Q2_0/Q3_0
31 | GGML_API bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
32 | GGML_API size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
33 | GGML_API void ggml_bitnet_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits);
34 | GGML_API void ggml_bitnet_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits);
35 | GGML_API void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor);
36 | GGML_API int ggml_bitnet_get_type_bits(enum ggml_type type);
37 | GGML_API void ggml_bitnet_set_n_threads(int n_threads);
38 | #if defined(GGML_BITNET_ARM_TL1)
39 | GGML_API void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C);
40 | GGML_API void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT);
41 | #endif
42 | #if defined(GGML_BITNET_X86_TL2)
43 | GGML_API void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C);
44 | GGML_API void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT);
45 | #endif
46 | 
47 | #ifdef  __cplusplus
48 | }
49 | #endif
50 | 


--------------------------------------------------------------------------------
/media/benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/404980eecae38affa4871c3e419eae3f44536a95/media/benchmark.png


--------------------------------------------------------------------------------
/media/demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/404980eecae38affa4871c3e419eae3f44536a95/media/demo.mp4


--------------------------------------------------------------------------------
/preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl1.ini:
--------------------------------------------------------------------------------
 1 | [Kernels_0]
 2 | m = 14336
 3 | k = 4096
 4 | bm = 256
 5 | bk = 128
 6 | bmm = 64
 7 | 
 8 | [Kernels_1]
 9 | m = 4096
10 | k = 14336
11 | bm = 256
12 | bk = 128
13 | bmm = 32
14 | 
15 | [Kernels_2]
16 | m = 1024
17 | k = 4096
18 | bm = 128
19 | bk = 64
20 | bmm = 64
21 | 
22 | [Kernels_3]
23 | m = 4096
24 | k = 4096
25 | bm = 128
26 | bk = 64
27 | bmm = 32
28 | 
29 | 


--------------------------------------------------------------------------------
/preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl2.ini:
--------------------------------------------------------------------------------
 1 | [Kernels_0]
 2 | m = 14336
 3 | k = 4096
 4 | bm = 256
 5 | bk = 96
 6 | bmm = 32
 7 | 
 8 | [Kernels_1]
 9 | m = 4096
10 | k = 14336
11 | bm = 128
12 | bk = 96
13 | bmm = 32
14 | 
15 | [Kernels_2]
16 | m = 1024
17 | k = 4096
18 | bm = 256
19 | bk = 96
20 | bmm = 32
21 | 
22 | [Kernels_3]
23 | m = 4096
24 | k = 4096
25 | bm = 128
26 | bk = 96
27 | bmm = 32
28 | 
29 | 


--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-3B/bitnet-lut-kernels-tl1.h:
--------------------------------------------------------------------------------
  1 | #if defined(GGML_BITNET_ARM_TL1)
  2 | #include "ggml-bitnet.h"
  3 | #define GGML_BITNET_MAX_NODES 8192
  4 | static bool initialized = false;
  5 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;
  6 | static size_t bitnet_tensor_extras_index = 0;
  7 | static void * aligned_malloc(size_t size) {{
  8 | #if defined(_WIN32)
  9 |     return _aligned_malloc(size, 64);
 10 | #else
 11 |     void * ptr = nullptr;
 12 |     posix_memalign(&ptr, 64, size);
 13 |     return ptr;
 14 | #endif
 15 | }}
 16 | static void aligned_free(void * ptr) {{
 17 | #if defined(_WIN32)
 18 |     _aligned_free(ptr);
 19 | #else
 20 |     free(ptr);
 21 | #endif
 22 | }}
 23 | 
 24 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{
 25 |     bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
 26 |     bitnet_float_type* b = (bitnet_float_type*)b_;
 27 | #ifdef __ARM_NEON
 28 |     float32x4_t temp_max = vdupq_n_f32(0);
 29 |     for (int i=0; i < k / 4; i++) {{
 30 |       float32x4_t vec_bs = vld1q_f32(b + 4 * i);
 31 |       float32x4_t abssum = vabsq_f32(vec_bs);
 32 |       temp_max = vmaxq_f32(abssum, temp_max);
 33 |     }}
 34 |     float32_t scales = 127 / vmaxvq_f32(temp_max);
 35 |     *lut_scales = scales;
 36 | #elif defined __AVX2__
 37 |     __m256 max_vec = _mm256_set1_ps(0.f);
 38 |     const __m256 vec_sign = _mm256_set1_ps(-0.0f);
 39 |     // #pragma unroll
 40 |     for (int i = 0; i < k / 8; i++) {{
 41 |         __m256 vec_b = _mm256_loadu_ps(b + i * 8);
 42 |         __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);
 43 |         max_vec = _mm256_max_ps(vec_babs, max_vec);
 44 |     }}
 45 |     __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));
 46 |     max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));
 47 |     max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));
 48 |     float scales = 127 / _mm_cvtss_f32(max1);
 49 |     *lut_scales = scales;
 50 | #endif
 51 | }}
 52 | 
 53 | void partial_max_reset(void* lut_scales_) {{
 54 |     bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
 55 |     *lut_scales = 0.0;
 56 | }}
 57 | 
 58 | #ifdef __ARM_NEON
 59 | inline void Transpose_8_8(
 60 |     int16x8_t *v0,
 61 |     int16x8_t *v1,
 62 |     int16x8_t *v2,
 63 |     int16x8_t *v3,
 64 |     int16x8_t *v4,
 65 |     int16x8_t *v5,
 66 |     int16x8_t *v6,
 67 |     int16x8_t *v7)
 68 | {{
 69 |     int16x8x2_t q04 = vzipq_s16(*v0, *v4);
 70 |     int16x8x2_t q15 = vzipq_s16(*v1, *v5);
 71 |     int16x8x2_t q26 = vzipq_s16(*v2, *v6);
 72 |     int16x8x2_t q37 = vzipq_s16(*v3, *v7);
 73 | 
 74 |     int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);
 75 |     int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);
 76 |     int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);
 77 |     int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);
 78 | 
 79 |     int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);
 80 |     int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);
 81 |     int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);
 82 |     int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);
 83 | 
 84 |     *v0 = q_fin_0.val[0];
 85 |     *v1 = q_fin_0.val[1];
 86 |     *v2 = q_fin_1.val[0];
 87 |     *v3 = q_fin_1.val[1];
 88 |     *v4 = q_fin_2.val[0];
 89 |     *v5 = q_fin_2.val[1];
 90 |     *v6 = q_fin_3.val[0];
 91 |     *v7 = q_fin_3.val[1];
 92 | }}
 93 | #endif
 94 | 
 95 | template<int act_k>
 96 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{
 97 | #ifdef __ARM_NEON
 98 |     int16x8_t vec_lut[16];
 99 |     float32_t scales = *lut_scales;
100 |         uint8_t tbl_mask[16];
101 |         tbl_mask[0] = 0;
102 |         tbl_mask[1] = 2;
103 |         tbl_mask[2] = 4;
104 |         tbl_mask[3] = 6;
105 |         tbl_mask[4] = 8;
106 |         tbl_mask[5] = 10;
107 |         tbl_mask[6] = 12;
108 |         tbl_mask[7] = 14;
109 |         tbl_mask[8] = 1;
110 |         tbl_mask[9] = 3;
111 |         tbl_mask[10] = 5;
112 |         tbl_mask[11] = 7;
113 |         tbl_mask[12] = 9;
114 |         tbl_mask[13] = 11;
115 |         tbl_mask[14] = 13;
116 |         tbl_mask[15] = 15;
117 |         uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);
118 | #pragma unroll
119 |     for (int k = 0; k < act_k / 16; ++k) {{
120 |         float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);
121 |         float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);
122 |         float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);
123 |         float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);
124 |         float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);
125 |         float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);
126 |         int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);
127 |         int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);
128 |         int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);
129 |         int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);
130 |         int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);
131 |         int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);
132 |         int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);
133 |         int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);
134 |         int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);
135 |         int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);
136 |         vec_lut[0] = vdupq_n_s16(0);
137 |         vec_lut[0] = vec_lut[0] - vec_bs_0;
138 |         vec_lut[0] = vec_lut[0] - vec_bs_1;
139 |         vec_lut[1] = vdupq_n_s16(0);
140 |         vec_lut[1] = vec_lut[1] - vec_bs_0;
141 |         vec_lut[2] = vdupq_n_s16(0);
142 |         vec_lut[2] = vec_lut[2] - vec_bs_0;
143 |         vec_lut[2] = vec_lut[2] + vec_bs_1;
144 |         vec_lut[3] = vdupq_n_s16(0);
145 |         vec_lut[3] = vec_lut[3] - vec_bs_1;
146 |         vec_lut[4] = vdupq_n_s16(0);
147 |         vec_lut[5] = vec_bs_1;
148 |         vec_lut[6] = vec_bs_0;
149 |         vec_lut[6] = vec_lut[6] - vec_bs_1;
150 |         vec_lut[7] = vec_bs_0;
151 |         vec_lut[8] = vec_bs_0;
152 |         vec_lut[8] = vec_lut[8] + vec_bs_1;
153 |         Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),
154 |                       &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));
155 |         Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),
156 |                       &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));
157 | #pragma unroll
158 |         for (int idx = 0; idx < 8; idx++) {{
159 |             int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);
160 |             int8x8_t q0_low = vget_low_s8(q0_s);
161 |             int8x8_t q0_high = vget_high_s8(q0_s);
162 |             int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);
163 |             int8x8_t q1_low = vget_low_s8(q1_s);
164 |             int8x8_t q1_high = vget_high_s8(q1_s);
165 |             vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);
166 |             vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);
167 |             vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);
168 |             vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);
169 |         }}
170 |     }}
171 | #endif
172 | }}
173 | 
174 | static bool is_type_supported(enum ggml_type type) {{
175 |     if (type == GGML_TYPE_Q4_0 ||
176 |         type == GGML_TYPE_TL1) {{
177 |         return true;
178 |     }} else {{
179 |         return false;
180 |     }}
181 | }}
182 | #include <arm_neon.h>
183 | 
184 | #define BM3200_8640 160
185 | #define BBK3200_8640 64
186 | inline void tbl_impl_3200_8640(int32_t* c, int8_t* lut, uint8_t* a) {
187 | #ifdef __ARM_NEON
188 |     const int KK = BBK3200_8640 / 2;
189 |     const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
190 |     const int8x16_t vec_zero = vdupq_n_s16(0x0000);
191 |     int8x16_t vec_lut[2 * KK];
192 |     int16x8_t vec_c[4];
193 | #pragma unroll
194 |     for (int k = 0; k < 2 * KK; k++) {
195 |         vec_lut[k] = vld1q_s8(lut + k * 16);
196 |     }
197 | 
198 | #pragma unroll
199 |     for (int i = 0; i < BM3200_8640; i += 32) {
200 |         #pragma unroll
201 |         for (int i=0; i<4; i++) {
202 |             vec_c[i] = vandq_s16(vec_c[i], vec_zero);
203 |         }
204 | 
205 | #pragma unroll
206 |         for (int k = 0; k < KK / 4; k++) {
207 |             
208 |             uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
209 |             uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
210 |             uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
211 |             int8x16_t  vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
212 |             int8x16_t  vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
213 |             int8x16_t  vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
214 |             int8x16_t  vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
215 |             int8x16x2_t  vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
216 |             int8x16x2_t  vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
217 |             vec_c[0] += vec_v_left_0.val[0];
218 |             vec_c[0] += vec_v_right_0.val[0];
219 |             vec_c[1] += vec_v_left_0.val[1];
220 |             vec_c[1] += vec_v_right_0.val[1];
221 |         
222 |             uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
223 |             uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
224 |             uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
225 |             int8x16_t  vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
226 |             int8x16_t  vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
227 |             int8x16_t  vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
228 |             int8x16_t  vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
229 |             int8x16x2_t  vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
230 |             int8x16x2_t  vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
231 |             vec_c[0] += vec_v_left_1.val[0];
232 |             vec_c[0] += vec_v_right_1.val[0];
233 |             vec_c[1] += vec_v_left_1.val[1];
234 |             vec_c[1] += vec_v_right_1.val[1];
235 |         
236 |             uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
237 |             uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
238 |             uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
239 |             int8x16_t  vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
240 |             int8x16_t  vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
241 |             int8x16_t  vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
242 |             int8x16_t  vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
243 |             int8x16x2_t  vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
244 |             int8x16x2_t  vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
245 |             vec_c[2] += vec_v_left_2.val[0];
246 |             vec_c[2] += vec_v_right_2.val[0];
247 |             vec_c[3] += vec_v_left_2.val[1];
248 |             vec_c[3] += vec_v_right_2.val[1];
249 |         
250 |             uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
251 |             uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
252 |             uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
253 |             int8x16_t  vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
254 |             int8x16_t  vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
255 |             int8x16_t  vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
256 |             int8x16_t  vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
257 |             int8x16x2_t  vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
258 |             int8x16x2_t  vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
259 |             vec_c[2] += vec_v_left_3.val[0];
260 |             vec_c[2] += vec_v_right_3.val[0];
261 |             vec_c[3] += vec_v_left_3.val[1];
262 |             vec_c[3] += vec_v_right_3.val[1];
263 |         
264 |        }
265 | 
266 |         int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
267 |         int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
268 |         vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
269 |         vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
270 |         int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
271 |         int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
272 |         vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
273 |         vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
274 |         int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
275 |         int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
276 |         vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
277 |         vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
278 |         int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
279 |         int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
280 |         vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
281 |         vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
282 | 
283 |     }
284 | #endif
285 | }
286 | 
287 | int32_t qgemm_lut_3200_8640(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
288 |     alignas(32) uint32_t CBits[BM3200_8640];
289 |     memset(&(CBits[0]), 0, BM3200_8640 * sizeof(int32_t));
290 | #pragma unroll
291 |     for (int32_t k_outer = 0; k_outer < 8640 / BBK3200_8640; ++k_outer) {
292 |         tbl_impl_3200_8640((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_8640 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_8640 / 2 / 2 * BM3200_8640)])));
293 |     }
294 | #pragma unroll
295 |     for (int i = 0; i < BM3200_8640; i++) {
296 |         ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
297 |     }
298 |   return 0;
299 | };
300 | #include <arm_neon.h>
301 | 
302 | #define BM3200_3200 320
303 | #define BBK3200_3200 128
304 | inline void tbl_impl_3200_3200(int32_t* c, int8_t* lut, uint8_t* a) {
305 | #ifdef __ARM_NEON
306 |     const int KK = BBK3200_3200 / 2;
307 |     const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
308 |     const int8x16_t vec_zero = vdupq_n_s16(0x0000);
309 |     int8x16_t vec_lut[2 * KK];
310 |     int16x8_t vec_c[8];
311 | #pragma unroll
312 |     for (int k = 0; k < 2 * KK; k++) {
313 |         vec_lut[k] = vld1q_s8(lut + k * 16);
314 |     }
315 | 
316 | #pragma unroll
317 |     for (int i = 0; i < BM3200_3200; i += 64) {
318 |         #pragma unroll
319 |         for (int i=0; i<8; i++) {
320 |             vec_c[i] = vandq_s16(vec_c[i], vec_zero);
321 |         }
322 | 
323 | #pragma unroll
324 |         for (int k = 0; k < KK / 2; k++) {
325 |             
326 |             uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
327 |             uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
328 |             uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
329 |             int8x16_t  vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top);
330 |             int8x16_t  vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top);
331 |             int8x16_t  vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot);
332 |             int8x16_t  vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot);
333 |             int8x16x2_t  vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
334 |             int8x16x2_t  vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
335 |             vec_c[0] += vec_v_left_0.val[0];
336 |             vec_c[0] += vec_v_right_0.val[0];
337 |             vec_c[1] += vec_v_left_0.val[1];
338 |             vec_c[1] += vec_v_right_0.val[1];
339 |         
340 |             uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
341 |             uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
342 |             uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
343 |             int8x16_t  vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top);
344 |             int8x16_t  vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top);
345 |             int8x16_t  vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot);
346 |             int8x16_t  vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot);
347 |             int8x16x2_t  vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
348 |             int8x16x2_t  vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
349 |             vec_c[2] += vec_v_left_1.val[0];
350 |             vec_c[2] += vec_v_right_1.val[0];
351 |             vec_c[3] += vec_v_left_1.val[1];
352 |             vec_c[3] += vec_v_right_1.val[1];
353 |         
354 |             uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
355 |             uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
356 |             uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
357 |             int8x16_t  vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top);
358 |             int8x16_t  vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top);
359 |             int8x16_t  vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot);
360 |             int8x16_t  vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot);
361 |             int8x16x2_t  vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
362 |             int8x16x2_t  vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
363 |             vec_c[4] += vec_v_left_2.val[0];
364 |             vec_c[4] += vec_v_right_2.val[0];
365 |             vec_c[5] += vec_v_left_2.val[1];
366 |             vec_c[5] += vec_v_right_2.val[1];
367 |         
368 |             uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
369 |             uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
370 |             uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
371 |             int8x16_t  vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top);
372 |             int8x16_t  vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top);
373 |             int8x16_t  vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot);
374 |             int8x16_t  vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot);
375 |             int8x16x2_t  vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
376 |             int8x16x2_t  vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
377 |             vec_c[6] += vec_v_left_3.val[0];
378 |             vec_c[6] += vec_v_right_3.val[0];
379 |             vec_c[7] += vec_v_left_3.val[1];
380 |             vec_c[7] += vec_v_right_3.val[1];
381 |         
382 |        }
383 | 
384 |         int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
385 |         int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
386 |         vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
387 |         vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
388 |         int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
389 |         int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
390 |         vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
391 |         vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
392 |         int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
393 |         int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
394 |         vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
395 |         vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
396 |         int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
397 |         int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
398 |         vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
399 |         vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
400 |         int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4]));
401 |         int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]);
402 |         vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4);
403 |         vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4);
404 |         int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5]));
405 |         int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]);
406 |         vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5);
407 |         vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5);
408 |         int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6]));
409 |         int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]);
410 |         vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6);
411 |         vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6);
412 |         int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7]));
413 |         int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]);
414 |         vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7);
415 |         vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7);
416 | 
417 |     }
418 | #endif
419 | }
420 | 
421 | int32_t qgemm_lut_3200_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
422 |     alignas(32) uint32_t CBits[BM3200_3200];
423 |     memset(&(CBits[0]), 0, BM3200_3200 * sizeof(int32_t));
424 | #pragma unroll
425 |     for (int32_t k_outer = 0; k_outer < 3200 / BBK3200_3200; ++k_outer) {
426 |         tbl_impl_3200_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_3200 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_3200 / 2 / 2 * BM3200_3200)])));
427 |     }
428 | #pragma unroll
429 |     for (int i = 0; i < BM3200_3200; i++) {
430 |         ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
431 |     }
432 |   return 0;
433 | };
434 | #include <arm_neon.h>
435 | 
436 | #define BM8640_3200 320
437 | #define BBK8640_3200 64
438 | inline void tbl_impl_8640_3200(int32_t* c, int8_t* lut, uint8_t* a) {
439 | #ifdef __ARM_NEON
440 |     const int KK = BBK8640_3200 / 2;
441 |     const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
442 |     const int8x16_t vec_zero = vdupq_n_s16(0x0000);
443 |     int8x16_t vec_lut[2 * KK];
444 |     int16x8_t vec_c[4];
445 | #pragma unroll
446 |     for (int k = 0; k < 2 * KK; k++) {
447 |         vec_lut[k] = vld1q_s8(lut + k * 16);
448 |     }
449 | 
450 | #pragma unroll
451 |     for (int i = 0; i < BM8640_3200; i += 32) {
452 |         #pragma unroll
453 |         for (int i=0; i<4; i++) {
454 |             vec_c[i] = vandq_s16(vec_c[i], vec_zero);
455 |         }
456 | 
457 | #pragma unroll
458 |         for (int k = 0; k < KK / 4; k++) {
459 |             
460 |             uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
461 |             uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
462 |             uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
463 |             int8x16_t  vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
464 |             int8x16_t  vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
465 |             int8x16_t  vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
466 |             int8x16_t  vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
467 |             int8x16x2_t  vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
468 |             int8x16x2_t  vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
469 |             vec_c[0] += vec_v_left_0.val[0];
470 |             vec_c[0] += vec_v_right_0.val[0];
471 |             vec_c[1] += vec_v_left_0.val[1];
472 |             vec_c[1] += vec_v_right_0.val[1];
473 |         
474 |             uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
475 |             uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
476 |             uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
477 |             int8x16_t  vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
478 |             int8x16_t  vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
479 |             int8x16_t  vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
480 |             int8x16_t  vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
481 |             int8x16x2_t  vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
482 |             int8x16x2_t  vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
483 |             vec_c[0] += vec_v_left_1.val[0];
484 |             vec_c[0] += vec_v_right_1.val[0];
485 |             vec_c[1] += vec_v_left_1.val[1];
486 |             vec_c[1] += vec_v_right_1.val[1];
487 |         
488 |             uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
489 |             uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
490 |             uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
491 |             int8x16_t  vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
492 |             int8x16_t  vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
493 |             int8x16_t  vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
494 |             int8x16_t  vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
495 |             int8x16x2_t  vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
496 |             int8x16x2_t  vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
497 |             vec_c[2] += vec_v_left_2.val[0];
498 |             vec_c[2] += vec_v_right_2.val[0];
499 |             vec_c[3] += vec_v_left_2.val[1];
500 |             vec_c[3] += vec_v_right_2.val[1];
501 |         
502 |             uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
503 |             uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
504 |             uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
505 |             int8x16_t  vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
506 |             int8x16_t  vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
507 |             int8x16_t  vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
508 |             int8x16_t  vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
509 |             int8x16x2_t  vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
510 |             int8x16x2_t  vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
511 |             vec_c[2] += vec_v_left_3.val[0];
512 |             vec_c[2] += vec_v_right_3.val[0];
513 |             vec_c[3] += vec_v_left_3.val[1];
514 |             vec_c[3] += vec_v_right_3.val[1];
515 |         
516 |        }
517 | 
518 |         int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
519 |         int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
520 |         vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
521 |         vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
522 |         int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
523 |         int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
524 |         vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
525 |         vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
526 |         int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
527 |         int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
528 |         vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
529 |         vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
530 |         int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
531 |         int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
532 |         vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
533 |         vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
534 | 
535 |     }
536 | #endif
537 | }
538 | 
539 | int32_t qgemm_lut_8640_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
540 |     alignas(32) uint32_t CBits[BM8640_3200];
541 |     memset(&(CBits[0]), 0, BM8640_3200 * sizeof(int32_t));
542 | #pragma unroll
543 |     for (int32_t k_outer = 0; k_outer < 3200 / BBK8640_3200; ++k_outer) {
544 |         tbl_impl_8640_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK8640_3200 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK8640_3200 / 2 / 2 * BM8640_3200)])));
545 |     }
546 | #pragma unroll
547 |     for (int i = 0; i < BM8640_3200; i++) {
548 |         ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
549 |     }
550 |   return 0;
551 | };
552 | 
553 | template<int K>
554 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{
555 |   partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));
556 |   per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));
557 |   
558 |   lut_ctor<K>((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));
559 | }}
560 | void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {
561 |     if (m == 3200 && k == 8640) {
562 |         preprocessor_k<8640>(B, LUT_Scales, QLUT);
563 |     }
564 |     else if (m == 3200 && k == 3200) {
565 |         preprocessor_k<3200>(B, LUT_Scales, QLUT);
566 |     }
567 |     else if (m == 8640 && k == 3200) {
568 |         preprocessor_k<3200>(B, LUT_Scales, QLUT);
569 |     }
570 | }
571 | void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
572 |     if (m == 3200 && k == 8640) {
573 |         qgemm_lut_3200_8640(A, LUT, Scales, LUT_Scales, C);
574 |     }
575 |     else if (m == 3200 && k == 3200) {
576 |         qgemm_lut_3200_3200(A, LUT, Scales, LUT_Scales, C);
577 |     }
578 |     else if (m == 8640 && k == 3200) {
579 |         qgemm_lut_8640_3200(A, LUT, Scales, LUT_Scales, C);
580 |     }
581 | }
582 | 
583 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {
584 |     if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
585 |         return;
586 |     }
587 | 
588 |     int k = tensor->ne[0];
589 |     int m = tensor->ne[1];
590 |     const int lut_scales_size = 1;
591 |     const int scales_size = 1;
592 |     int bk = 0;
593 |     int bm = 0;
594 | 
595 |     if (m == 3200 && k == 8640) {
596 |         bm = BM3200_8640;
597 |         bk = BBK3200_8640;
598 |     }
599 | else if (m == 3200 && k == 3200) {
600 |         bm = BM3200_3200;
601 |         bk = BBK3200_3200;
602 |     }
603 | else if (m == 8640 && k == 3200) {
604 |         bm = BM8640_3200;
605 |         bk = BBK8640_3200;
606 |     }
607 | 
608 |     const int n_tile_num = m / bm;
609 |     const int BK = bk;
610 |     uint8_t * qweights;
611 |     bitnet_float_type * scales;
612 | 
613 |     scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));
614 |     qweights = (uint8_t *) tensor->data;
615 |     float * i2_scales = (float * )(qweights + k * m / 4);
616 |     scales[0] = (bitnet_float_type) i2_scales[0];
617 | 
618 |     tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;
619 |     bitnet_tensor_extras[bitnet_tensor_extras_index++] = {
620 |         /* .lut_scales_size = */ lut_scales_size,
621 |         /* .scales_size     = */ scales_size,
622 |         /* .n_tile_num      = */ n_tile_num,
623 |         /* .qweights        = */ qweights,
624 |         /* .scales          = */ scales
625 |     };
626 | }
627 | #endif


--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-3B/kernel_config_tl1.ini:
--------------------------------------------------------------------------------
 1 | [Kernels_0]
 2 | m = 3200
 3 | k = 8640
 4 | bm = 160
 5 | bk = 64
 6 | bmm = 32
 7 | 
 8 | [Kernels_1]
 9 | m = 3200
10 | k = 3200
11 | bm = 320
12 | bk = 128
13 | bmm = 64
14 | 
15 | [Kernels_2]
16 | m = 8640
17 | k = 3200
18 | bm = 320
19 | bk = 64
20 | bmm = 32
21 | 
22 | 


--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-3B/kernel_config_tl2.ini:
--------------------------------------------------------------------------------
 1 | [Kernels_0]
 2 | m = 3200
 3 | k = 8640
 4 | bm = 160
 5 | bk = 96
 6 | bmm = 32
 7 | 
 8 | [Kernels_1]
 9 | m = 3200
10 | k = 3200
11 | bm = 320
12 | bk = 96
13 | bmm = 32
14 | 
15 | [Kernels_2]
16 | m = 8640
17 | k = 3200
18 | bm = 320
19 | bk = 96
20 | bmm = 32
21 | 
22 | 


--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-large/kernel_config_tl1.ini:
--------------------------------------------------------------------------------
 1 | [Kernels_0]
 2 | m = 1536
 3 | k = 4096
 4 | bm = 256
 5 | bk = 128
 6 | bmm = 32
 7 | 
 8 | [Kernels_1]
 9 | m = 1536
10 | k = 1536
11 | bm = 128
12 | bk = 64
13 | bmm = 64
14 | 
15 | [Kernels_2]
16 | m = 4096
17 | k = 1536
18 | bm = 256
19 | bk = 128
20 | bmm = 32
21 | 
22 | 


--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-large/kernel_config_tl2.ini:
--------------------------------------------------------------------------------
 1 | [Kernels_0]
 2 | m = 1536
 3 | k = 4096
 4 | bm = 256
 5 | bk = 96
 6 | bmm = 32
 7 | 
 8 | [Kernels_1]
 9 | m = 1536
10 | k = 1536
11 | bm = 128
12 | bk = 192
13 | bmm = 32
14 | 
15 | [Kernels_2]
16 | m = 4096
17 | k = 1536
18 | bm = 256
19 | bk = 96
20 | bmm = 64
21 | 
22 | 


--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
 1 | # These requirements include all dependencies for all top-level python scripts
 2 | # for llama.cpp. Avoid adding packages here directly.
 3 | #
 4 | # Package versions must stay compatible across all top-level python scripts.
 5 | #
 6 | 
 7 | -r 3rdparty/llama.cpp/requirements/requirements-convert_legacy_llama.txt
 8 | -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt
 9 | -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt
10 | -r 3rdparty/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt
11 | -r 3rdparty/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt


--------------------------------------------------------------------------------
/run_inference.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import sys
 3 | import signal
 4 | import platform
 5 | import argparse
 6 | import subprocess
 7 | 
 8 | def run_command(command, shell=False):
 9 |     """Run a system command and ensure it succeeds."""
10 |     try:
11 |         subprocess.run(command, shell=shell, check=True)
12 |     except subprocess.CalledProcessError as e:
13 |         print(f"Error occurred while running command: {e}")
14 |         sys.exit(1)
15 | 
16 | def run_inference():
17 |     build_dir = "build"
18 |     if platform.system() == "Windows":
19 |         main_path = os.path.join(build_dir, "bin", "Release", "llama-cli.exe")
20 |         if not os.path.exists(main_path):
21 |             main_path = os.path.join(build_dir, "bin", "llama-cli")
22 |     else:
23 |         main_path = os.path.join(build_dir, "bin", "llama-cli")
24 |     command = [
25 |         f'{main_path}',
26 |         '-m', args.model,
27 |         '-n', str(args.n_predict),
28 |         '-t', str(args.threads),
29 |         '-p', args.prompt,
30 |         '-ngl', '0',
31 |         '-c', str(args.ctx_size),
32 |         '--temp', str(args.temperature),
33 |         "-b", "1",
34 |     ]
35 |     if args.conversation:
36 |         command.append("-cnv")
37 |     run_command(command)
38 | 
39 | def signal_handler(sig, frame):
40 |     print("Ctrl+C pressed, exiting...")
41 |     sys.exit(0)
42 | 
43 | if __name__ == "__main__":
44 |     signal.signal(signal.SIGINT, signal_handler)
45 |     # Usage: python run_inference.py -p "Microsoft Corporation is an American multinational corporation and technology company headquartered in Redmond, Washington."
46 |     parser = argparse.ArgumentParser(description='Run inference')
47 |     parser.add_argument("-m", "--model", type=str, help="Path to model file", required=False, default="models/bitnet_b1_58-3B/ggml-model-i2_s.gguf")
48 |     parser.add_argument("-n", "--n-predict", type=int, help="Number of tokens to predict when generating text", required=False, default=128)
49 |     parser.add_argument("-p", "--prompt", type=str, help="Prompt to generate text from", required=True)
50 |     parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2)
51 |     parser.add_argument("-c", "--ctx-size", type=int, help="Size of the prompt context", required=False, default=2048)
52 |     parser.add_argument("-temp", "--temperature", type=float, help="Temperature, a hyperparameter that controls the randomness of the generated text", required=False, default=0.8)
53 |     parser.add_argument("-cnv", "--conversation", action='store_true', help="Whether to enable chat mode or not (for instruct models.)")
54 | 
55 |     args = parser.parse_args()
56 |     run_inference()


--------------------------------------------------------------------------------
/run_inference_server.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import sys
 3 | import signal
 4 | import platform
 5 | import argparse
 6 | import subprocess
 7 | 
 8 | def run_command(command, shell=False):
 9 |     """Run a system command and ensure it succeeds."""
10 |     try:
11 |         subprocess.run(command, shell=shell, check=True)
12 |     except subprocess.CalledProcessError as e:
13 |         print(f"Error occurred while running command: {e}")
14 |         sys.exit(1)
15 | 
16 | def run_server():
17 |     build_dir = "build"
18 |     if platform.system() == "Windows":
19 |         server_path = os.path.join(build_dir, "bin", "Release", "llama-server.exe")
20 |         if not os.path.exists(server_path):
21 |             server_path = os.path.join(build_dir, "bin", "llama-server")
22 |     else:
23 |         server_path = os.path.join(build_dir, "bin", "llama-server")
24 |     
25 |     command = [
26 |         f'{server_path}',
27 |         '-m', args.model,
28 |         '-c', str(args.ctx_size),
29 |         '-t', str(args.threads),
30 |         '-n', str(args.n_predict),
31 |         '-ngl', '0',
32 |         '--temp', str(args.temperature),
33 |         '--host', args.host,
34 |         '--port', str(args.port),
35 |         '-cb'  # Enable continuous batching
36 |     ]
37 |     
38 |     if args.prompt:
39 |         command.extend(['-p', args.prompt])
40 |     
41 |     # Note: -cnv flag is removed as it's not supported by the server
42 |     
43 |     print(f"Starting server on {args.host}:{args.port}")
44 |     run_command(command)
45 | 
46 | def signal_handler(sig, frame):
47 |     print("Ctrl+C pressed, shutting down server...")
48 |     sys.exit(0)
49 | 
50 | if __name__ == "__main__":
51 |     signal.signal(signal.SIGINT, signal_handler)
52 |     
53 |     parser = argparse.ArgumentParser(description='Run llama.cpp server')
54 |     parser.add_argument("-m", "--model", type=str, help="Path to model file", required=False, default="models/bitnet_b1_58-3B/ggml-model-i2_s.gguf")
55 |     parser.add_argument("-p", "--prompt", type=str, help="System prompt for the model", required=False)
56 |     parser.add_argument("-n", "--n-predict", type=int, help="Number of tokens to predict", required=False, default=4096)
57 |     parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2)
58 |     parser.add_argument("-c", "--ctx-size", type=int, help="Size of the context window", required=False, default=2048)
59 |     parser.add_argument("--temperature", type=float, help="Temperature for sampling", required=False, default=0.8)
60 |     parser.add_argument("--host", type=str, help="IP address to listen on", required=False, default="127.0.0.1")
61 |     parser.add_argument("--port", type=int, help="Port to listen on", required=False, default=8080)
62 |     
63 |     args = parser.parse_args()
64 |     run_server()
65 | 


--------------------------------------------------------------------------------
/setup_env.py:
--------------------------------------------------------------------------------
  1 | import subprocess
  2 | import signal
  3 | import sys
  4 | import os
  5 | import platform
  6 | import argparse
  7 | import logging
  8 | import shutil
  9 | from pathlib import Path
 10 | 
 11 | logger = logging.getLogger("setup_env")
 12 | 
 13 | SUPPORTED_HF_MODELS = {
 14 |     "1bitLLM/bitnet_b1_58-large": {
 15 |         "model_name": "bitnet_b1_58-large",
 16 |     },
 17 |     "1bitLLM/bitnet_b1_58-3B": {
 18 |         "model_name": "bitnet_b1_58-3B",
 19 |     },
 20 |     "HF1BitLLM/Llama3-8B-1.58-100B-tokens": {
 21 |         "model_name": "Llama3-8B-1.58-100B-tokens",
 22 |     },
 23 |     "tiiuae/Falcon3-7B-Instruct-1.58bit": {
 24 |         "model_name": "Falcon3-7B-Instruct-1.58bit",
 25 |     },
 26 |     "tiiuae/Falcon3-7B-1.58bit": {
 27 |         "model_name": "Falcon3-7B-1.58bit",
 28 |     },
 29 |     "tiiuae/Falcon3-10B-Instruct-1.58bit": {
 30 |         "model_name": "Falcon3-10B-Instruct-1.58bit",
 31 |     },
 32 |     "tiiuae/Falcon3-10B-1.58bit": {
 33 |         "model_name": "Falcon3-10B-1.58bit",
 34 |     },
 35 |     "tiiuae/Falcon3-3B-Instruct-1.58bit": {
 36 |         "model_name": "Falcon3-3B-Instruct-1.58bit",
 37 |     },
 38 |     "tiiuae/Falcon3-3B-1.58bit": {
 39 |         "model_name": "Falcon3-3B-1.58bit",
 40 |     },
 41 |     "tiiuae/Falcon3-1B-Instruct-1.58bit": {
 42 |         "model_name": "Falcon3-1B-Instruct-1.58bit",
 43 |     },
 44 |     "microsoft/BitNet-b1.58-2B-4T": {
 45 |         "model_name": "BitNet-b1.58-2B-4T",
 46 |     },
 47 |     "tiiuae/Falcon-E-3B-Instruct": {
 48 |         "model_name": "Falcon-E-3B-Instruct",
 49 |     },
 50 |     "tiiuae/Falcon-E-1B-Instruct": {
 51 |         "model_name": "Falcon-E-1B-Instruct",
 52 |     },
 53 |     "tiiuae/Falcon-E-3B-Base": {
 54 |         "model_name": "Falcon-E-3B-Base",
 55 |     },
 56 |     "tiiuae/Falcon-E-1B-Base": {
 57 |         "model_name": "Falcon-E-1B-Base",
 58 |     },
 59 | }
 60 | 
 61 | SUPPORTED_QUANT_TYPES = {
 62 |     "arm64": ["i2_s", "tl1"],
 63 |     "x86_64": ["i2_s", "tl2"]
 64 | }
 65 | 
 66 | COMPILER_EXTRA_ARGS = {
 67 |     "arm64": ["-DBITNET_ARM_TL1=ON"],
 68 |     "x86_64": ["-DBITNET_X86_TL2=ON"]
 69 | }
 70 | 
 71 | OS_EXTRA_ARGS = {
 72 |     "Windows":["-T", "ClangCL"],
 73 | }
 74 | 
 75 | ARCH_ALIAS = {
 76 |     "AMD64": "x86_64",
 77 |     "x86": "x86_64",
 78 |     "x86_64": "x86_64",
 79 |     "aarch64": "arm64",
 80 |     "arm64": "arm64",
 81 |     "ARM64": "arm64",
 82 | }
 83 | 
 84 | def system_info():
 85 |     return platform.system(), ARCH_ALIAS[platform.machine()]
 86 | 
 87 | def get_model_name():
 88 |     if args.hf_repo:
 89 |         return SUPPORTED_HF_MODELS[args.hf_repo]["model_name"]
 90 |     return os.path.basename(os.path.normpath(args.model_dir))
 91 | 
 92 | def run_command(command, shell=False, log_step=None):
 93 |     """Run a system command and ensure it succeeds."""
 94 |     if log_step:
 95 |         log_file = os.path.join(args.log_dir, log_step + ".log")
 96 |         with open(log_file, "w") as f:
 97 |             try:
 98 |                 subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f)
 99 |             except subprocess.CalledProcessError as e:
100 |                 logging.error(f"Error occurred while running command: {e}, check details in {log_file}")
101 |                 sys.exit(1)
102 |     else:
103 |         try:
104 |             subprocess.run(command, shell=shell, check=True)
105 |         except subprocess.CalledProcessError as e:
106 |             logging.error(f"Error occurred while running command: {e}")
107 |         sys.exit(1)
108 | 
109 | def prepare_model():
110 |     _, arch = system_info()
111 |     hf_url = args.hf_repo
112 |     model_dir = args.model_dir
113 |     quant_type = args.quant_type
114 |     quant_embd = args.quant_embd
115 |     if hf_url is not None:
116 |         # download the model
117 |         model_dir = os.path.join(model_dir, SUPPORTED_HF_MODELS[hf_url]["model_name"])
118 |         Path(model_dir).mkdir(parents=True, exist_ok=True)
119 |         logging.info(f"Downloading model {hf_url} from HuggingFace to {model_dir}...")
120 |         run_command(["huggingface-cli", "download", hf_url, "--local-dir", model_dir], log_step="download_model")
121 |     elif not os.path.exists(model_dir):
122 |         logging.error(f"Model directory {model_dir} does not exist.")
123 |         sys.exit(1)
124 |     else:
125 |         logging.info(f"Loading model from directory {model_dir}.")
126 |     gguf_path = os.path.join(model_dir, "ggml-model-" + quant_type + ".gguf")
127 |     if not os.path.exists(gguf_path) or os.path.getsize(gguf_path) == 0:
128 |         logging.info(f"Converting HF model to GGUF format...")
129 |         if quant_type.startswith("tl"):
130 |             run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", quant_type, "--quant-embd"], log_step="convert_to_tl")
131 |         else: # i2s
132 |             # convert to f32
133 |             run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "f32"], log_step="convert_to_f32_gguf")
134 |             f32_model = os.path.join(model_dir, "ggml-model-f32.gguf")
135 |             i2s_model = os.path.join(model_dir, "ggml-model-i2_s.gguf")
136 |             # quantize to i2s
137 |             if platform.system() != "Windows":
138 |                 if quant_embd:
139 |                     run_command(["./build/bin/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s")
140 |                 else:
141 |                     run_command(["./build/bin/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s")
142 |             else:
143 |                 if quant_embd:
144 |                     run_command(["./build/bin/Release/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s")
145 |                 else:
146 |                     run_command(["./build/bin/Release/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s")
147 | 
148 |         logging.info(f"GGUF model saved at {gguf_path}")
149 |     else:
150 |         logging.info(f"GGUF model already exists at {gguf_path}")
151 | 
152 | def setup_gguf():
153 |     # Install the pip package
154 |     run_command([sys.executable, "-m", "pip", "install", "3rdparty/llama.cpp/gguf-py"], log_step="install_gguf")
155 | 
156 | def gen_code():
157 |     _, arch = system_info()
158 |     
159 |     llama3_f3_models = set([model['model_name'] for model in SUPPORTED_HF_MODELS.values() if model['model_name'].startswith("Falcon") or model['model_name'].startswith("Llama")])
160 | 
161 |     if arch == "arm64":
162 |         if args.use_pretuned:
163 |             pretuned_kernels = os.path.join("preset_kernels", get_model_name())
164 |             if not os.path.exists(pretuned_kernels):
165 |                 logging.error(f"Pretuned kernels not found for model {args.hf_repo}")
166 |                 sys.exit(1)
167 |             if args.quant_type == "tl1":
168 |                 shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl1.h"), "include/bitnet-lut-kernels.h")
169 |                 shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl1.ini"), "include/kernel_config.ini")
170 |             elif args.quant_type == "tl2":
171 |                 shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
172 |                 shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl2.ini"), "include/kernel_config.ini")
173 |         if get_model_name() == "bitnet_b1_58-large":
174 |             run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen")
175 |         elif get_model_name() in llama3_f3_models:
176 |             run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen")
177 |         elif get_model_name() == "bitnet_b1_58-3B":
178 |             run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
179 |         elif get_model_name() == "BitNet-b1.58-2B-4T":
180 |             run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
181 |         else:
182 |             raise NotImplementedError()
183 |     else:
184 |         if args.use_pretuned:
185 |             # cp preset_kernels/model_name/bitnet-lut-kernels_tl1.h to include/bitnet-lut-kernels.h
186 |             pretuned_kernels = os.path.join("preset_kernels", get_model_name())
187 |             if not os.path.exists(pretuned_kernels):
188 |                 logging.error(f"Pretuned kernels not found for model {args.hf_repo}")
189 |                 sys.exit(1)
190 |             shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
191 |         if get_model_name() == "bitnet_b1_58-large":
192 |             run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen")
193 |         elif get_model_name() in llama3_f3_models:
194 |             run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen")
195 |         elif get_model_name() == "bitnet_b1_58-3B":
196 |             run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
197 |         elif get_model_name() == "BitNet-b1.58-2B-4T":
198 |             run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")    
199 |         else:
200 |             raise NotImplementedError()
201 | 
202 | 
203 | def compile():
204 |     # Check if cmake is installed
205 |     cmake_exists = subprocess.run(["cmake", "--version"], capture_output=True)
206 |     if cmake_exists.returncode != 0:
207 |         logging.error("Cmake is not available. Please install CMake and try again.")
208 |         sys.exit(1)
209 |     _, arch = system_info()
210 |     if arch not in COMPILER_EXTRA_ARGS.keys():
211 |         logging.error(f"Arch {arch} is not supported yet")
212 |         exit(0)
213 |     logging.info("Compiling the code using CMake.")
214 |     run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), []), "-DCMAKE_C_COMPILER=clang", "-DCMAKE_CXX_COMPILER=clang++"], log_step="generate_build_files")
215 |     # run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"])
216 |     run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile")
217 | 
218 | def main():
219 |     setup_gguf()
220 |     gen_code()
221 |     compile()
222 |     prepare_model()
223 |     
224 | def parse_args():
225 |     _, arch = system_info()
226 |     parser = argparse.ArgumentParser(description='Setup the environment for running the inference')
227 |     parser.add_argument("--hf-repo", "-hr", type=str, help="Model used for inference", choices=SUPPORTED_HF_MODELS.keys())
228 |     parser.add_argument("--model-dir", "-md", type=str, help="Directory to save/load the model", default="models")
229 |     parser.add_argument("--log-dir", "-ld", type=str, help="Directory to save the logging info", default="logs")
230 |     parser.add_argument("--quant-type", "-q", type=str, help="Quantization type", choices=SUPPORTED_QUANT_TYPES[arch], default="i2_s")
231 |     parser.add_argument("--quant-embd", action="store_true", help="Quantize the embeddings to f16")
232 |     parser.add_argument("--use-pretuned", "-p", action="store_true", help="Use the pretuned kernel parameters")
233 |     return parser.parse_args()
234 | 
235 | def signal_handler(sig, frame):
236 |     logging.info("Ctrl+C pressed, exiting...")
237 |     sys.exit(0)
238 | 
239 | if __name__ == "__main__":
240 |     signal.signal(signal.SIGINT, signal_handler)
241 |     args = parse_args()
242 |     Path(args.log_dir).mkdir(parents=True, exist_ok=True)
243 |     logging.basicConfig(level=logging.INFO)
244 |     main()
245 | 


--------------------------------------------------------------------------------
/src/CMakeLists.txt:
--------------------------------------------------------------------------------
 1 | set(GGML_HEADERS_BITNET ../include/ggml-bitnet.h)
 2 | set(GGML_SOURCES_BITNET ggml-bitnet-mad.cpp)
 3 | set(GGML_SOURCES_BITNET ggml-bitnet-lut.cpp)
 4 | 
 5 | include_directories(3rdparty/llama.cpp/ggml/include)
 6 | 
 7 | if (NOT (CMAKE_C_COMPILER_ID MATCHES "Clang" OR CMAKE_C_COMPILER_ID STREQUAL "GNU") OR
 8 |     NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU"))
 9 |     message(FATAL_ERROR "Clang or GCC is required for Bitnet.cpp compilation")
10 | endif()
11 | 


--------------------------------------------------------------------------------
/src/ggml-bitnet-lut.cpp:
--------------------------------------------------------------------------------
  1 | #include <vector>
  2 | #include <type_traits>
  3 | 
  4 | #include <string.h>
  5 | #include <stdio.h>
  6 | #include <stdlib.h>
  7 | 
  8 | #include "ggml-bitnet.h"
  9 | #include "ggml-quants.h"
 10 | #include "bitnet-lut-kernels.h"
 11 | 
 12 | #if defined(GGML_BITNET_ARM_TL1)
 13 | 
 14 | void ggml_bitnet_init(void) {
 15 |     // LOG(INFO) << "ggml_bitnet_init";
 16 | 
 17 |     if (initialized) {
 18 |         return;
 19 |     }
 20 |     initialized = true;
 21 | 
 22 |     // if (wrapper == nullptr) {
 23 |     //     wrapper = new BITNET::BITNETGeMMWrapper<bitnet_bitnet_float_type>();
 24 |     // }
 25 |     if (bitnet_tensor_extras == nullptr) {
 26 |         bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES];
 27 |     }
 28 |     bitnet_tensor_extras_index = 0;
 29 | }
 30 | 
 31 | void ggml_bitnet_free(void) {
 32 |     // LOG(INFO) << "ggml_bitnet_free";
 33 | 
 34 |     if (!initialized) {
 35 |         return;
 36 |     }
 37 |     initialized = false;
 38 | 
 39 |     // delete wrapper;
 40 |     // wrapper = nullptr;
 41 |     for (size_t i = 0; i < bitnet_tensor_extras_index; i++) {
 42 |         // aligned_free(bitnet_tensor_extras[i].qweights);
 43 |         // aligned_free(bitnet_tensor_extras[i].scales);
 44 |     }
 45 |     delete[] bitnet_tensor_extras;
 46 |     bitnet_tensor_extras = nullptr;
 47 | }
 48 | 
 49 | static bool do_permutate(enum ggml_type type) {
 50 |     if (type == GGML_TYPE_TL1) {
 51 |         // Add additional args to decide if permuted I2 or naive I2
 52 |         return false;
 53 |     } else {
 54 |         return true;
 55 |     }
 56 | }
 57 | 
 58 | bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
 59 |     if ((is_type_supported(src0->type)) &&
 60 |         src1->type == GGML_TYPE_F32 &&
 61 |         dst->type == GGML_TYPE_F32 &&
 62 |         src0->backend == GGML_BACKEND_TYPE_CPU) {
 63 |         if (src1->ne[1] <= 1) {
 64 |             return true;
 65 |         }
 66 |     }
 67 |     return false;
 68 | }
 69 | 
 70 | size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
 71 |     const size_t ne01 = src0->ne[1];
 72 |     const size_t ne10 = src1->ne[0];
 73 |     const size_t ne11 = src1->ne[1];
 74 |     const int bits = ggml_bitnet_get_type_bits(src0->type);
 75 |     
 76 |     size_t wsize = ne10 * ne11 * 15 * sizeof(int8_t) + 1 * ne11 * 2 * sizeof(bitnet_float_type);
 77 |     if (sizeof(bitnet_float_type) == 2) {
 78 |         // Need fp32 to fp16 conversion
 79 |         wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type);
 80 |     }
 81 |     wsize = ((wsize - 1) / 64 + 1) * 64;
 82 |     return wsize;
 83 | }
 84 | 
 85 | int ggml_bitnet_get_type_bits(enum ggml_type type) {
 86 |     switch (type) {
 87 |         case GGML_TYPE_TL1:
 88 |             return 2;
 89 |         case GGML_TYPE_Q4_0:
 90 |             return 4;
 91 |         default:
 92 |             return 0;
 93 |     }
 94 | }
 95 | 
 96 | #endif
 97 | #if defined(GGML_BITNET_X86_TL2)
 98 | void ggml_bitnet_init(void) {
 99 |     // LOG(INFO) << "ggml_bitnet_init";
100 | 
101 |     if (initialized) {
102 |         return;
103 |     }
104 |     initialized = true;
105 | 
106 |     // if (wrapper == nullptr) {
107 |     //     wrapper = new BITNET::BITNETGeMMWrapper<bitnet_bitnet_float_type>();
108 |     // }
109 |     if (bitnet_tensor_extras == nullptr) {
110 |         bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES];
111 |     }
112 |     bitnet_tensor_extras_index = 0;
113 | }
114 | 
115 | void ggml_bitnet_free(void) {
116 |     // LOG(INFO) << "ggml_bitnet_free";
117 | 
118 |     if (!initialized) {
119 |         return;
120 |     }
121 |     initialized = false;
122 | 
123 |     // delete wrapper;
124 |     // wrapper = nullptr;
125 |     for (size_t i = 0; i < bitnet_tensor_extras_index; i++) {
126 |         // aligned_free(bitnet_tensor_extras[i].qweights);
127 |         // aligned_free(bitnet_tensor_extras[i].scales);
128 |     }
129 |     delete[] bitnet_tensor_extras;
130 |     bitnet_tensor_extras = nullptr;
131 | }
132 | 
133 | bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
134 |     if ((is_type_supported(src0->type)) &&
135 |         src1->type == GGML_TYPE_F32 &&
136 |         dst->type == GGML_TYPE_F32 &&
137 |         src0->backend == GGML_BACKEND_TYPE_CPU) {
138 |         return true;
139 |     }
140 |     return false;
141 | }
142 | 
143 | size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
144 |     const size_t ne01 = src0->ne[1];
145 |     const size_t ne10 = src1->ne[0];
146 |     const size_t ne11 = src1->ne[1];
147 |     
148 |     size_t wsize = ne10 * ne11 * 11 * sizeof(int8_t) + 2 * ne11 * 2 * sizeof(bitnet_float_type);
149 |     if (sizeof(bitnet_float_type) == 2) {
150 |         // Need fp32 to fp16 conversion
151 |         wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type);
152 |     }
153 |     wsize = ((wsize - 1) / 64 + 1) * 64;
154 |     return wsize;
155 | }
156 | 
157 | int ggml_bitnet_get_type_bits(enum ggml_type type) {
158 |     switch (type) {
159 |         case GGML_TYPE_TL2:
160 |             return 2;
161 |         case GGML_TYPE_Q4_0:
162 |             return 4;
163 |         default:
164 |             return 0;
165 |     }
166 | }
167 | #endif


--------------------------------------------------------------------------------
/src/ggml-bitnet-mad.cpp:
--------------------------------------------------------------------------------
  1 | #include <vector>
  2 | #include <type_traits>
  3 | 
  4 | #include "ggml-bitnet.h"
  5 | #include "ggml-quants.h"
  6 | #include <cmath>
  7 | #include <cstring>
  8 | 
  9 | #define QK_I2_S 128
 10 | #define QK_I2 128
 11 | 
 12 | #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 13 | #include <immintrin.h>
 14 | // horizontally add 8 int32_t
 15 | static inline int hsum_i32_8(const __m256i a) {
 16 |     const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
 17 |     const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
 18 |     const __m128i sum64 = _mm_add_epi32(hi64, sum128);
 19 |     const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
 20 |     return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
 21 | }
 22 | #elif defined(__loongarch_asx)
 23 | // horizontally add 8 int32_t
 24 | static inline int hsum_i32_8(const __m256i a) {
 25 | 
 26 |     __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
 27 |     __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
 28 | 
 29 |     __m128i  tmp1_128 = lasx_extracti128_lo(tmp1);
 30 |     __m128i  tmp2_128 = lasx_extracti128_lo(tmp2);
 31 | 
 32 |     __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
 33 | 
 34 |     __m128i ev = __lsx_vpickev_w(sum128, sum128);
 35 |     __m128i od = __lsx_vpickod_w(sum128, sum128);
 36 |     __m128i sum64 = __lsx_vadd_w(ev, od);
 37 | 
 38 |     int sum64_1, sum64_2;
 39 |     sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
 40 |     sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
 41 | 
 42 |     return  sum64_1 + sum64_2;
 43 | }
 44 | #endif
 45 | 
 46 | size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
 47 |     // 2 bits per weight
 48 | 
 49 |     size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row);
 50 | 
 51 |     int n = nrow * n_per_row;
 52 | 
 53 |     // f32 -> q8
 54 |     double max = 0;
 55 |     for (int i = 0; i < n; ++i) {
 56 |         max = fmax(max, (double)fabs((double)src[i]));
 57 |     }
 58 |     double i2_scale = max;
 59 | 
 60 |     uint8_t* q8 = (uint8_t*)malloc(n * sizeof(uint8_t));
 61 |     for (int i=0; i<n; i++) {
 62 |         if (fabs((double)(src[i])) < 1e-6) {
 63 |             q8[i] = 1;
 64 |             continue;
 65 |         }
 66 |         q8[i] = (double)src[i] * i2_scale > 0 ? 2 : 0;
 67 |     }
 68 | 
 69 |     memset(dst, 0, n * sizeof(uint8_t) / 4);
 70 | 
 71 |     // q8 -> 0, 1, 2
 72 |     //       |  |  |
 73 |     //      -1, 0, 1
 74 | 
 75 |     uint8_t* i2_weight = (uint8_t*)dst;
 76 |     for (int i = 0; i < n / QK_I2; i++) {
 77 |         for (int j = 0; j < QK_I2; j++) {
 78 |             int group_idx = j / 32;
 79 |             int group_pos = j % 32;
 80 |             uint8_t temp = (q8[i * QK_I2 + j] << (6 - 2 * group_idx));
 81 |             i2_weight[i * 32 + group_pos] |= temp;            
 82 |         }
 83 |     }
 84 | 
 85 |     float* scale_ptr = (float*)((char*)i2_weight + n / 4);
 86 |     scale_ptr[0] = i2_scale;
 87 | 
 88 |     free(q8);
 89 | 
 90 |     // 32B for alignment
 91 |     return nrow * row_size / 4 + 32;
 92 | }
 93 | 
 94 | void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
 95 |     const uint8_t *    x = (uint8_t *)vx;
 96 |     const int8_t  *    y = (int8_t *)vy;
 97 | 
 98 |     const int nb = n / QK_I2_S;
 99 |     const int group32_num = nb / 32;
100 |     const int la_num = nb % 32;
101 |     const int groupla_num = nb % 32 != 0 ? 1 : 0;
102 | 
103 | #if defined(__AVX2__)
104 | 
105 |     __m256i mask = _mm256_set1_epi8(0x03);
106 |     __m256i accu = _mm256_setzero_si256();
107 | 
108 |     for (int i=0; i < group32_num; i++){
109 |         __m256i accu32 = _mm256_setzero_si256();
110 |         for (int j=0; j < 32; j++) {
111 |         // 128 index
112 |         __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + i * 32 * 32 + j * 32));
113 |         __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
114 |         __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
115 |         __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
116 | 
117 |         // each 32 index
118 |         xq8_3 = _mm256_and_si256(xq8_3, mask);
119 |         xq8_2 = _mm256_and_si256(xq8_2, mask);
120 |         xq8_1 = _mm256_and_si256(xq8_1, mask);
121 |         xq8_0 = _mm256_and_si256(xq8_0, mask);
122 | 
123 |         // each 32 index
124 |         __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 0));
125 |         __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 32));
126 |         __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 64));
127 |         __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 96));
128 | 
129 |         // 128 index accumulation add
130 |         // split into 32 accumulation block
131 |         // each block each 128 index accumulated 4index
132 |         // each index maximum 256
133 |         // each block maximum 4 * 256
134 |         // each block accumulation maximum 127 * 256
135 |         // each 32 group index (128 index in one group) needs cast to int32
136 |         xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
137 |         xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
138 |         xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
139 |         xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
140 | 
141 |         accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1));
142 |         accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3));
143 |         }
144 |         accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, _mm256_set1_epi16(1)), accu);
145 |     }
146 | 
147 |     for (int i = 0; i < groupla_num; i++){
148 |         __m256i accula = _mm256_setzero_si256();
149 |         for (int j = 0; j < la_num; j++) {
150 |         // 128 index
151 |         __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + group32_num * 32 * 32 + j * 32));
152 |         __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
153 |         __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
154 |         __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
155 | 
156 |         // each 32 index
157 |         xq8_3 = _mm256_and_si256(xq8_3, mask);
158 |         xq8_2 = _mm256_and_si256(xq8_2, mask);
159 |         xq8_1 = _mm256_and_si256(xq8_1, mask);
160 |         xq8_0 = _mm256_and_si256(xq8_0, mask);
161 | 
162 |         // each 32 index
163 |         __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 0));
164 |         __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 32));
165 |         __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 64));
166 |         __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 96));
167 | 
168 |         // 128 index accumulation add
169 |         // split into 32 accumulation block
170 |         // each block each 128 index accumulated 4index
171 |         // each index maximum 256
172 |         // each block maximum 4 * 256
173 |         // each block accumulation maximum 127 * 256
174 |         // each 32 group index (128 index in one group) needs cast to int32
175 |         xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
176 |         xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
177 |         xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
178 |         xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
179 | 
180 |         accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1));
181 |         accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3));
182 |         }
183 |         accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, _mm256_set1_epi16(1)));
184 |     }
185 |     int sumi = hsum_i32_8(accu);
186 |     *s = (float)sumi;
187 | 
188 | #elif defined(__ARM_NEON)
189 | 
190 |     int32x4_t accu_0 = vdupq_n_s32(0);
191 |     int32x4_t accu_1 = vdupq_n_s32(0);
192 |     int32x4_t accu_2 = vdupq_n_s32(0);
193 |     int32x4_t accu_3 = vdupq_n_s32(0);
194 |     const uint8x16_t mask = vdupq_n_u8(3);
195 | 
196 |     for (int i=0; i < group32_num; i++) {
197 | 
198 | #if defined(__ARM_FEATURE_DOTPROD)
199 | 
200 | #else
201 |         int16x8_t accu32_0 = vdupq_n_s16(0);
202 |         int16x8_t accu32_1 = vdupq_n_s16(0);
203 |         int16x8_t accu32_2 = vdupq_n_s16(0);
204 |         int16x8_t accu32_3 = vdupq_n_s16(0);
205 | #endif
206 | 
207 |         for (int j=0; j < 32; j++) {
208 |             uint8x16_t xq8_6 = vld1q_u8(x + i * 32 * 32 + j * 32);
209 |             uint8x16_t xq8_7 = vld1q_u8(x + i * 32 * 32 + j * 32 + 16);
210 |             uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2);
211 |             uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2);
212 |             uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4);
213 |             uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4);
214 |             uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6);
215 |             uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6);
216 | 
217 |             int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
218 |             int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
219 |             int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
220 |             int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
221 |             int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask));
222 |             int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask));
223 |             int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask));
224 |             int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask));
225 | 
226 |             const int8x16_t yq8_0 = vld1q_s8(y + i * 128 * 32 + j * 128 + 0);
227 |             const int8x16_t yq8_1 = vld1q_s8(y + i * 128 * 32 + j * 128 + 16);
228 |             const int8x16_t yq8_2 = vld1q_s8(y + i * 128 * 32 + j * 128 + 32);
229 |             const int8x16_t yq8_3 = vld1q_s8(y + i * 128 * 32 + j * 128 + 48);
230 |             const int8x16_t yq8_4 = vld1q_s8(y + i * 128 * 32 + j * 128 + 64);
231 |             const int8x16_t yq8_5 = vld1q_s8(y + i * 128 * 32 + j * 128 + 80);
232 |             const int8x16_t yq8_6 = vld1q_s8(y + i * 128 * 32 + j * 128 + 96);
233 |             const int8x16_t yq8_7 = vld1q_s8(y + i * 128 * 32 + j * 128 + 112);
234 | 
235 | #if defined(__ARM_FEATURE_DOTPROD)
236 |             accu_0 = vdotq_s32(accu_0, q8_0, yq8_0);
237 |             accu_1 = vdotq_s32(accu_1, q8_1, yq8_1);
238 |             accu_2 = vdotq_s32(accu_2, q8_2, yq8_2);
239 |             accu_3 = vdotq_s32(accu_3, q8_3, yq8_3);
240 |             accu_0 = vdotq_s32(accu_0, q8_4, yq8_4);
241 |             accu_1 = vdotq_s32(accu_1, q8_5, yq8_5);
242 |             accu_2 = vdotq_s32(accu_2, q8_6, yq8_6);
243 |             accu_3 = vdotq_s32(accu_3, q8_7, yq8_7);
244 | #else
245 |             accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_0), vget_low_s8(yq8_0));
246 |             accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_0), vget_high_s8(yq8_0));
247 |             accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_1), vget_low_s8(yq8_1));
248 |             accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_1), vget_high_s8(yq8_1));
249 |             accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_2), vget_low_s8(yq8_2));
250 |             accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_2), vget_high_s8(yq8_2));
251 |             accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_3), vget_low_s8(yq8_3));
252 |             accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_3), vget_high_s8(yq8_3));
253 |             accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_4), vget_low_s8(yq8_4));
254 |             accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_4), vget_high_s8(yq8_4));
255 |             accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_5), vget_low_s8(yq8_5));
256 |             accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_5), vget_high_s8(yq8_5));
257 |             accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_6), vget_low_s8(yq8_6));
258 |             accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_6), vget_high_s8(yq8_6));
259 |             accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_7), vget_low_s8(yq8_7));
260 |             accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_7), vget_high_s8(yq8_7));
261 | #endif
262 |         }
263 | 
264 | #if defined(__ARM_FEATURE_DOTPROD)
265 | 
266 | #else
267 |         accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accu32_0)));
268 |         accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accu32_0));
269 |         accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accu32_1)));
270 |         accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accu32_1));
271 |         accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accu32_2)));
272 |         accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accu32_2));
273 |         accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accu32_3)));
274 |         accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accu32_3));
275 | #endif
276 |     }
277 | 
278 |     for (int i = 0; i < groupla_num; i++){
279 | #if defined(__ARM_FEATURE_DOTPROD)
280 | 
281 | #else
282 |         int16x8_t accula_0 = vdupq_n_s16(0);
283 |         int16x8_t accula_1 = vdupq_n_s16(0);
284 |         int16x8_t accula_2 = vdupq_n_s16(0);
285 |         int16x8_t accula_3 = vdupq_n_s16(0);
286 | #endif
287 |         for (int j = 0; j < la_num; j++) {
288 |             uint8x16_t xq8_6 = vld1q_u8(x + group32_num * 32 * 32 + j * 32);
289 |             uint8x16_t xq8_7 = vld1q_u8(x + group32_num * 32 * 32 + j * 32 + 16);
290 |             uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2);
291 |             uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2);
292 |             uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4);
293 |             uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4);
294 |             uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6);
295 |             uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6);
296 | 
297 |             int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
298 |             int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
299 |             int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
300 |             int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
301 |             int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask));
302 |             int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask));
303 |             int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask));
304 |             int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask));
305 | 
306 |             const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 0);
307 |             const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 16);
308 |             const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 32);
309 |             const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 48);
310 |             const int8x16_t yq8_4 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 64);
311 |             const int8x16_t yq8_5 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 80);
312 |             const int8x16_t yq8_6 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 96);
313 |             const int8x16_t yq8_7 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 112);
314 | 
315 | #if defined(__ARM_FEATURE_DOTPROD)
316 |             accu_0 = vdotq_s32(accu_0, q8_0, yq8_0);
317 |             accu_1 = vdotq_s32(accu_1, q8_1, yq8_1);
318 |             accu_2 = vdotq_s32(accu_2, q8_2, yq8_2);
319 |             accu_3 = vdotq_s32(accu_3, q8_3, yq8_3);
320 |             accu_0 = vdotq_s32(accu_0, q8_4, yq8_4);
321 |             accu_1 = vdotq_s32(accu_1, q8_5, yq8_5);
322 |             accu_2 = vdotq_s32(accu_2, q8_6, yq8_6);
323 |             accu_3 = vdotq_s32(accu_3, q8_7, yq8_7);
324 | #else
325 |             accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_0), vget_low_s8(yq8_0));
326 |             accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_0), vget_high_s8(yq8_0));
327 |             accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_1), vget_low_s8(yq8_1));
328 |             accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_1), vget_high_s8(yq8_1));
329 |             accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_2), vget_low_s8(yq8_2));
330 |             accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_2), vget_high_s8(yq8_2));
331 |             accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_3), vget_low_s8(yq8_3));
332 |             accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_3), vget_high_s8(yq8_3));
333 |             accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_4), vget_low_s8(yq8_4));
334 |             accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_4), vget_high_s8(yq8_4));
335 |             accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_5), vget_low_s8(yq8_5));
336 |             accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_5), vget_high_s8(yq8_5));
337 |             accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_6), vget_low_s8(yq8_6));
338 |             accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_6), vget_high_s8(yq8_6));
339 |             accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_7), vget_low_s8(yq8_7));
340 |             accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_7), vget_high_s8(yq8_7));
341 | #endif
342 |         }
343 | #if defined(__ARM_FEATURE_DOTPROD)
344 | 
345 | #else
346 |         accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accula_0)));
347 |         accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accula_0));
348 |         accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accula_1)));
349 |         accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accula_1));
350 |         accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accula_2)));
351 |         accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accula_2));
352 |         accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accula_3)));
353 |         accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accula_3));
354 | #endif
355 |     }
356 |     accu_0 = vaddq_s32(accu_0, accu_1);
357 |     accu_2 = vaddq_s32(accu_2, accu_3);
358 |     accu_0 = vaddq_s32(accu_0, accu_2);
359 |     int sumi = vaddlvq_s32(accu_0);
360 |     *s = (float)sumi;
361 | 
362 | #endif
363 | }


--------------------------------------------------------------------------------
/utils/codegen_tl1.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import os
  3 | from configparser import ConfigParser
  4 | 
  5 | def gen_ctor_code():
  6 |     kernel_code = "\n\
  7 | #include \"ggml-bitnet.h\"\n\
  8 | #define GGML_BITNET_MAX_NODES 8192\n\
  9 | static bool initialized = false;\n\
 10 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\
 11 | static size_t bitnet_tensor_extras_index = 0;\n\
 12 | static void * aligned_malloc(size_t size) {{\n\
 13 | #if defined(_WIN32)\n\
 14 |     return _aligned_malloc(size, 64);\n\
 15 | #else\n\
 16 |     void * ptr = nullptr;\n\
 17 |     posix_memalign(&ptr, 64, size);\n\
 18 |     return ptr;\n\
 19 | #endif\n\
 20 | }}\n\
 21 | static void aligned_free(void * ptr) {{\n\
 22 | #if defined(_WIN32)\n\
 23 |     _aligned_free(ptr);\n\
 24 | #else\n\
 25 |     free(ptr);\n\
 26 | #endif\n\
 27 | }}\n\
 28 | \n\
 29 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{\n\
 30 |     bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
 31 |     bitnet_float_type* b = (bitnet_float_type*)b_;\n\
 32 | #ifdef __ARM_NEON\n\
 33 |     float32x4_t temp_max = vdupq_n_f32(0);\n\
 34 |     for (int i=0; i < k / 4; i++) {{\n\
 35 |       float32x4_t vec_bs = vld1q_f32(b + 4 * i);\n\
 36 |       float32x4_t abssum = vabsq_f32(vec_bs);\n\
 37 |       temp_max = vmaxq_f32(abssum, temp_max);\n\
 38 |     }}\n\
 39 |     float32_t scales = 127 / vmaxvq_f32(temp_max);\n\
 40 |     *lut_scales = scales;\n\
 41 | #elif defined __AVX2__\n\
 42 |     __m256 max_vec = _mm256_set1_ps(0.f);\n\
 43 |     const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\
 44 |     // #pragma unroll\n\
 45 |     for (int i = 0; i < k / 8; i++) {{\n\
 46 |         __m256 vec_b = _mm256_loadu_ps(b + i * 8);\n\
 47 |         __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);\n\
 48 |         max_vec = _mm256_max_ps(vec_babs, max_vec);\n\
 49 |     }}\n\
 50 |     __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));\n\
 51 |     max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));\n\
 52 |     max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));\n\
 53 |     float scales = 127 / _mm_cvtss_f32(max1);\n\
 54 |     *lut_scales = scales;\n\
 55 | #endif\n\
 56 | }}\n\
 57 | \n\
 58 | void partial_max_reset(void* lut_scales_) {{\n\
 59 |     bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
 60 |     *lut_scales = 0.0;\n\
 61 | }}\n\
 62 | \n\
 63 | #ifdef __ARM_NEON\n\
 64 | inline void Transpose_8_8(\n\
 65 |     int16x8_t *v0,\n\
 66 |     int16x8_t *v1,\n\
 67 |     int16x8_t *v2,\n\
 68 |     int16x8_t *v3,\n\
 69 |     int16x8_t *v4,\n\
 70 |     int16x8_t *v5,\n\
 71 |     int16x8_t *v6,\n\
 72 |     int16x8_t *v7)\n\
 73 | {{\n\
 74 |     int16x8x2_t q04 = vzipq_s16(*v0, *v4);\n\
 75 |     int16x8x2_t q15 = vzipq_s16(*v1, *v5);\n\
 76 |     int16x8x2_t q26 = vzipq_s16(*v2, *v6);\n\
 77 |     int16x8x2_t q37 = vzipq_s16(*v3, *v7);\n\
 78 | \n\
 79 |     int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);\n\
 80 |     int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);\n\
 81 |     int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);\n\
 82 |     int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);\n\
 83 | \n\
 84 |     int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);\n\
 85 |     int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);\n\
 86 |     int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);\n\
 87 |     int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);\n\
 88 | \n\
 89 |     *v0 = q_fin_0.val[0];\n\
 90 |     *v1 = q_fin_0.val[1];\n\
 91 |     *v2 = q_fin_1.val[0];\n\
 92 |     *v3 = q_fin_1.val[1];\n\
 93 |     *v4 = q_fin_2.val[0];\n\
 94 |     *v5 = q_fin_2.val[1];\n\
 95 |     *v6 = q_fin_3.val[0];\n\
 96 |     *v7 = q_fin_3.val[1];\n\
 97 | }}\n\
 98 | #endif\n\
 99 | \n\
100 | template<int act_k>\n\
101 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{\n\
102 | #ifdef __ARM_NEON\n\
103 |     int16x8_t vec_lut[16];\n\
104 |     float32_t scales = *lut_scales;\n\
105 |         uint8_t tbl_mask[16];\n\
106 |         tbl_mask[0] = 0;\n\
107 |         tbl_mask[1] = 2;\n\
108 |         tbl_mask[2] = 4;\n\
109 |         tbl_mask[3] = 6;\n\
110 |         tbl_mask[4] = 8;\n\
111 |         tbl_mask[5] = 10;\n\
112 |         tbl_mask[6] = 12;\n\
113 |         tbl_mask[7] = 14;\n\
114 |         tbl_mask[8] = 1;\n\
115 |         tbl_mask[9] = 3;\n\
116 |         tbl_mask[10] = 5;\n\
117 |         tbl_mask[11] = 7;\n\
118 |         tbl_mask[12] = 9;\n\
119 |         tbl_mask[13] = 11;\n\
120 |         tbl_mask[14] = 13;\n\
121 |         tbl_mask[15] = 15;\n\
122 |         uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);\n\
123 | #pragma unroll\n\
124 |     for (int k = 0; k < act_k / 16; ++k) {{\n\
125 |         float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);\n\
126 |         float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);\n\
127 |         float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);\n\
128 |         float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);\n\
129 |         float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);\n\
130 |         float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);\n\
131 |         int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);\n\
132 |         int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);\n\
133 |         int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);\n\
134 |         int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);\n\
135 |         int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);\n\
136 |         int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);\n\
137 |         int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);\n\
138 |         int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);\n\
139 |         int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);\n\
140 |         int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);\n\
141 |         vec_lut[0] = vdupq_n_s16(0);\n\
142 |         vec_lut[0] = vec_lut[0] - vec_bs_0;\n\
143 |         vec_lut[0] = vec_lut[0] - vec_bs_1;\n\
144 |         vec_lut[1] = vdupq_n_s16(0);\n\
145 |         vec_lut[1] = vec_lut[1] - vec_bs_0;\n\
146 |         vec_lut[2] = vdupq_n_s16(0);\n\
147 |         vec_lut[2] = vec_lut[2] - vec_bs_0;\n\
148 |         vec_lut[2] = vec_lut[2] + vec_bs_1;\n\
149 |         vec_lut[3] = vdupq_n_s16(0);\n\
150 |         vec_lut[3] = vec_lut[3] - vec_bs_1;\n\
151 |         vec_lut[4] = vdupq_n_s16(0);\n\
152 |         vec_lut[5] = vec_bs_1;\n\
153 |         vec_lut[6] = vec_bs_0;\n\
154 |         vec_lut[6] = vec_lut[6] - vec_bs_1;\n\
155 |         vec_lut[7] = vec_bs_0;\n\
156 |         vec_lut[8] = vec_bs_0;\n\
157 |         vec_lut[8] = vec_lut[8] + vec_bs_1;\n\
158 |         Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),\n\
159 |                       &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));\n\
160 |         Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),\n\
161 |                       &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));\n\
162 | #pragma unroll\n\
163 |         for (int idx = 0; idx < 8; idx++) {{\n\
164 |             int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);\n\
165 |             int8x8_t q0_low = vget_low_s8(q0_s);\n\
166 |             int8x8_t q0_high = vget_high_s8(q0_s);\n\
167 |             int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);\n\
168 |             int8x8_t q1_low = vget_low_s8(q1_s);\n\
169 |             int8x8_t q1_high = vget_high_s8(q1_s);\n\
170 |             vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);\n\
171 |             vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);\n\
172 |             vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);\n\
173 |             vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);\n\
174 |         }}\n\
175 |     }}\n\
176 | #endif\n\
177 | }}\n\
178 | \n\
179 | static bool is_type_supported(enum ggml_type type) {{\n\
180 |     if (type == GGML_TYPE_Q4_0 ||\n\
181 |         type == GGML_TYPE_TL1) {{\n\
182 |         return true;\n\
183 |     }} else {{\n\
184 |         return false;\n\
185 |     }}\n\
186 | }}\n\
187 | "
188 |     return kernel_code
189 | 
190 | def gen_body_core_code(bm, by):
191 |     length = 4
192 |     all_code = ""
193 |     for i in range(length):
194 |         core_code = "\n\
195 |             uint8x16_t vec_a_{0} = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + {0} * 16);\n\
196 |             uint8x16_t vec_a{0}_top = vshrq_n_u8(vec_a_{0}, 4);\n\
197 |             uint8x16_t vec_a{0}_bot = vandq_u8(vec_a_{0}, vec_mask);\n\
198 |             int8x16_t  vec_v_{0}_left_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {2}], vec_a{0}_top);\n\
199 |             int8x16_t  vec_v_{0}_left_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {3}], vec_a{0}_top);\n\
200 |             int8x16_t  vec_v_{0}_right_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {4}], vec_a{0}_bot);\n\
201 |             int8x16_t  vec_v_{0}_right_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {5}], vec_a{0}_bot);\n\
202 |             int8x16x2_t  vec_v_left_{0} = vzipq_s8(vec_v_{0}_left_tmp1, vec_v_{0}_left_tmp0);\n\
203 |             int8x16x2_t  vec_v_right_{0} = vzipq_s8(vec_v_{0}_right_tmp1, vec_v_{0}_right_tmp0);\n\
204 |             vec_c[{6}] += vec_v_left_{0}.val[0];\n\
205 |             vec_c[{6}] += vec_v_right_{0}.val[0];\n\
206 |             vec_c[{7}] += vec_v_left_{0}.val[1];\n\
207 |             vec_c[{7}] += vec_v_right_{0}.val[1];\n\
208 |         ".format(i, 2 * by // 2, (4 * i) % (2 * by // 2), (4 * i + 1) % (2 * by // 2), (4 * i + 2) % (2 * by // 2), (4 * i + 3) % (2 * by // 2), (i * 2) // (by // 2) * 2 + 0, (i * 2) // (by // 2) * 2 + 1)
209 |         
210 |         all_code = "".join([all_code, core_code])
211 | 
212 |     all_code = "".join([all_code, "\n       }\n\n"])
213 | 
214 |     for i in range(bm // 8):
215 |         core_code = "\
216 |         int32x4_t vec_v_bot_low_low_{0} = vmovl_s16(vget_low_s16(vec_c[{0}]));\n\
217 |         int32x4_t vec_v_bot_low_high_{0} = vmovl_high_s16(vec_c[{0}]);\n\
218 |         vst1q_s32(c + i + {1}, vld1q_s32(c + i + {1}) + vec_v_bot_low_low_{0});\n\
219 |         vst1q_s32(c + i + {2}, vld1q_s32(c + i + {2}) + vec_v_bot_low_high_{0});\n".format(i, i * 8, i * 8 + 4)
220 |         all_code = "".join([all_code, core_code])
221 | 
222 |     return all_code
223 | 
224 | def gen_tbl_impl(pre, BM, BK, bm, k):
225 | 
226 |     kernel_code = "\
227 | #include <arm_neon.h>\n\
228 | \n\
229 | #define BM{0} {1}\n\
230 | #define BBK{0} {2}\n\
231 | inline void tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\
232 | #ifdef __ARM_NEON\n\
233 |     const int KK = BBK{0} / 2;\n\
234 |     const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\
235 |     const int8x16_t vec_zero = vdupq_n_s16(0x0000);\n\
236 |     int8x16_t vec_lut[2 * KK];\n\
237 | ".format(pre, BM, BK)
238 |     
239 |     kernel_code = "".join([kernel_code, "    int16x8_t vec_c[{}];".format(bm // 8)])
240 | 
241 |     kernel_code = "".join([kernel_code, "\n\
242 | #pragma unroll\n\
243 |     for (int k = 0; k < 2 * KK; k++) {\n\
244 |         vec_lut[k] = vld1q_s8(lut + k * 16);\n\
245 |     }\n"])
246 | 
247 |     pre_core_code = "\n\
248 | #pragma unroll\n\
249 |     for (int i = 0; i < BM{}; i += {}) {{\n\
250 |         #pragma unroll\n\
251 |         for (int i=0; i<{}; i++) {{\n\
252 |             vec_c[i] = vandq_s16(vec_c[i], vec_zero);\n\
253 |         }}\n".format(pre, bm, bm // 8)
254 | 
255 |     body_core_pre_code = "\n\
256 | #pragma unroll\n\
257 |         for (int k = 0; k < KK / {}; k++) {{\n\
258 |             ".format(256 // bm // 2)
259 | 
260 |     body_core_post_code = "\n\
261 |     }\n\
262 | \
263 | #endif\n\
264 | }\n"
265 | 
266 |     kernel_code = "".join([kernel_code, pre_core_code, body_core_pre_code, gen_body_core_code(bm, 256 // bm), body_core_post_code])
267 | 
268 |     kernel_code = "".join([kernel_code, "\n\
269 | int32_t qgemm_lut_{0}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
270 |     alignas({1}) uint32_t CBits[BM{0}];\n\
271 |     memset(&(CBits[0]), 0, BM{0} * sizeof(int32_t));\n\
272 | #pragma unroll\n\
273 |     for (int32_t k_outer = 0; k_outer < {2} / BBK{0}; ++k_outer) {{\n\
274 |         tbl_impl_{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{0} / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{0} / 2 / 2 * BM{0})])));\n\
275 |     }}\n\
276 | #pragma unroll\n\
277 |     for (int i = 0; i < BM{0}; i++) {{\n\
278 |         ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];\n\
279 |     }}\n\
280 |   return 0;\n\
281 | }};\n".format(pre, min(32, BK), k)])
282 | 
283 |     return kernel_code
284 | 
285 | def gen_top_api(kernel_shapes):
286 | 
287 |     kernel_code = "void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {{\n\
288 |     if (m == {0} && k == {1}) {{\n\
289 |         preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\
290 |     }}\n\
291 | ".format(kernel_shapes[0][0], kernel_shapes[0][1])
292 |     for i in range(1, len(kernel_shapes)):
293 |         kernel_code = "".join([kernel_code, "    else if (m == {0} && k == {1}) {{\n\
294 |         preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\
295 |     }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])])
296 |     kernel_code = "".join([kernel_code, "}\n"])
297 |     kernel_code = "".join([kernel_code, "void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
298 |     if (m == {0} && k == {1}) {{\n\
299 |         qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\
300 |     }}\n\
301 | ".format(kernel_shapes[0][0], kernel_shapes[0][1])])
302 |     for i in range(1, len(kernel_shapes)):
303 |         kernel_code = "".join([kernel_code, "    else if (m == {0} && k == {1}) {{\n\
304 |         qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\
305 |     }}\n\
306 | ".format(kernel_shapes[i][0], kernel_shapes[i][1])])
307 |     kernel_code = "".join([kernel_code, "}\n"])
308 |     return kernel_code
309 | 
310 | def gen_preprocess_code():
311 |     kernel_code = "\n\
312 | template<int K>\n\
313 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{\n\
314 |   partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));\n\
315 |   per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));\n\
316 |   \n\
317 |   lut_ctor<K>((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));\n\
318 | }}\n"
319 |     return kernel_code
320 | 
321 | def gen_transform_code(kernel_shape):
322 |     kernel_code = "\n\
323 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\
324 |     if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\
325 |         return;\n\
326 |     }\n\
327 | \n\
328 |     int k = tensor->ne[0];\n\
329 |     int m = tensor->ne[1];\n\
330 |     const int lut_scales_size = 1;\n\
331 |     const int scales_size = 1;\n\
332 |     int bk = 0;\n\
333 |     int bm = 0;\n"
334 | 
335 |     kernel_code = "".join([kernel_code, "\n\
336 |     if (m == {0} && k == {1}) {{\n\
337 |         bm = BM{0}_{1};\n\
338 |         bk = BBK{0}_{1};\n\
339 |     }}\n".format(kernel_shapes[0][0], kernel_shapes[0][1])])
340 | 
341 |     for i in range(1, len(kernel_shapes)):
342 |         kernel_code = "".join([kernel_code, "else if (m == {0} && k == {1}) {{\n\
343 |         bm = BM{0}_{1};\n\
344 |         bk = BBK{0}_{1};\n\
345 |     }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])])
346 | 
347 |     kernel_code = "".join([kernel_code, "\n\
348 |     const int n_tile_num = m / bm;\n\
349 |     const int BK = bk;\n\
350 |     uint8_t * qweights;\n\
351 |     bitnet_float_type * scales;\n\
352 | \n\
353 |     scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\
354 |     qweights = (uint8_t *) tensor->data;\n\
355 |     float * i2_scales = (float * )(qweights + k * m / 4);\n\
356 |     scales[0] = (bitnet_float_type) i2_scales[0];\n\
357 | \n\
358 |     tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\
359 |     bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\
360 |         /* .lut_scales_size = */ lut_scales_size,\n\
361 |         /* .BK              = */ BK,\n\
362 |         /* .n_tile_num      = */ n_tile_num,\n\
363 |         /* .qweights        = */ qweights,\n\
364 |         /* .scales          = */ scales\n\
365 |     };\n\
366 | }\n"])
367 | 
368 |     return kernel_code
369 | 
370 | if __name__ == "__main__":
371 |     ModelShapeDict = {
372 |         "bitnet_b1_58-large"                : [[1536, 4096],
373 |                                                [1536, 1536],
374 |                                                [4096, 1536]],
375 |         "bitnet_b1_58-3B"                   : [[3200, 8640],
376 |                                                [3200, 3200],
377 |                                                [8640, 3200]],
378 |         "Llama3-8B-1.58-100B-tokens"        : [[14336, 4096],
379 |                                                [4096, 14336],
380 |                                                [1024, 4096],
381 |                                                [4096, 4096]] 
382 |     }
383 |     
384 |     parser = argparse.ArgumentParser(description='gen impl')
385 |     parser.add_argument('--model',default="input", type=str, dest="model", 
386 |                         help="choose from bitnet_b1_58-large/bitnet_b1_58-3B/Llama3-8B-1.58-100B-tokens.")
387 |     parser.add_argument('--BM',default="input", type=str,
388 |                         help="block length when cutting one weight (M, K) into M / BM weights (BM, K).")
389 |     parser.add_argument('--BK',default="input", type=str,
390 |                         help="block length when cutting one weight (M, K) into K / BK weights (M, BK).")
391 |     parser.add_argument('--bm',default="input", type=str,
392 |                         help="using simd instructions to compute (bm, 256 / bm) in one block")
393 |     args = parser.parse_args()
394 | 
395 |     kernel_shapes = ModelShapeDict[args.model]
396 | 
397 |     BM_list = [int(item) for item in args.BM.split(',')]
398 |     BK_list = [int(item) for item in args.BK.split(',')]
399 |     bm_list = [int(item) for item in args.bm.split(',')]
400 | 
401 |     assert(len(BM_list) == len(BK_list) == len(bm_list) == len(kernel_shapes)), "number of BM / BK / bm shoud be {}".format(len(kernel_shapes))
402 |     
403 |     for i in range(len(kernel_shapes)):
404 |         assert kernel_shapes[i][0] % BM_list[i] == 0, "M %% BM should be 0"
405 |         assert kernel_shapes[i][1] % BK_list[i] == 0, "K %% BK should be 0"
406 |         assert bm_list[i] in [32, 64], "choose bm from [32, 64]"
407 | 
408 |     tbl_impl_code = []
409 | 
410 |     for i in range(len(kernel_shapes)):
411 |         tbl_impl_code.append(
412 |             gen_tbl_impl("{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]), BM_list[i], BK_list[i], bm_list[i], kernel_shapes[i][1])
413 |         )
414 |     api_code = gen_top_api(kernel_shapes)
415 |     pre_code = gen_preprocess_code()
416 |     ctor_code = gen_ctor_code()
417 |     trans_code = gen_transform_code(kernel_shapes)
418 | 
419 |     output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include")
420 | 
421 |     with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f:
422 |         f.write(''.join("#if defined(GGML_BITNET_ARM_TL1)"))
423 |         f.write(''.join(ctor_code))
424 |         for code in tbl_impl_code:
425 |             f.write(''.join(code))
426 |         f.write(''.join(pre_code))
427 |         f.write(''.join(api_code))
428 |         f.write(''.join(trans_code))
429 |         f.write(''.join("#endif"))
430 | 
431 |     config = ConfigParser()
432 | 
433 |     for i in range(len(kernel_shapes)):
434 |         config.add_section('Kernels_{}'.format(i))
435 |         config.set('Kernels_{}'.format(i), 'M'.format(i), str(kernel_shapes[i][0]))
436 |         config.set('Kernels_{}'.format(i), 'K'.format(i), str(kernel_shapes[i][1]))
437 |         config.set('Kernels_{}'.format(i), 'BM'.format(i), str(BM_list[i]))
438 |         config.set('Kernels_{}'.format(i), 'BK'.format(i), str(BK_list[i]))
439 |         config.set('Kernels_{}'.format(i), 'bmm'.format(i), str(bm_list[i]))
440 | 
441 |     with open(''.join([output_dir, "/kernel_config.ini"]), 'w') as configfile:
442 |         config.write(configfile)


--------------------------------------------------------------------------------
/utils/convert-helper-bitnet.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/env python3
  2 | 
  3 | import sys
  4 | import os
  5 | import shutil
  6 | import subprocess
  7 | from pathlib import Path
  8 | 
  9 | def run_command(command_list, cwd=None, check=True):
 10 |     print(f"Executing: {' '.join(map(str, command_list))}")
 11 |     try:
 12 |         process = subprocess.run(command_list, cwd=cwd, check=check, capture_output=False, text=True)
 13 |         return process
 14 |     except subprocess.CalledProcessError as e:
 15 |         print(f"Error executing command: {' '.join(map(str, e.cmd))}")
 16 |         print(f"Return code: {e.returncode}")
 17 |         raise
 18 | 
 19 | def main():
 20 |     if len(sys.argv) < 2:
 21 |         script_name = Path(sys.argv[0]).name
 22 |         print(f"Usage: python {script_name} <model-directory>")
 23 |         sys.exit(1)
 24 | 
 25 |     model_dir_arg = sys.argv[1]
 26 |     model_dir = Path(model_dir_arg).resolve()
 27 | 
 28 |     if not model_dir.is_dir():
 29 |         print(f"Error: Model directory '{model_dir}' not found or is not a directory.")
 30 |         sys.exit(1)
 31 | 
 32 |     utils_dir = Path(__file__).parent.resolve()
 33 |     project_root_dir = utils_dir.parent
 34 | 
 35 |     preprocess_script = utils_dir / "preprocess-huggingface-bitnet.py"
 36 |     convert_script = utils_dir / "convert-ms-to-gguf-bitnet.py"
 37 |     
 38 |     llama_quantize_binary = project_root_dir / "build" / "bin" / "llama-quantize"
 39 | 
 40 |     input_file = model_dir / "model.safetensors"
 41 |     input_backup_file = model_dir / "model.safetensors.backup"
 42 |     preprocessed_output_file = model_dir / "model.safetensors"
 43 | 
 44 |     gguf_f32_output = model_dir / "ggml-model-f32-bitnet.gguf"
 45 |     gguf_i2s_output = model_dir / "ggml-model-i2s-bitnet.gguf"
 46 | 
 47 |     if not preprocess_script.is_file():
 48 |         print(f"Error: Preprocess script not found at '{preprocess_script}'")
 49 |         sys.exit(1)
 50 |     if not convert_script.is_file():
 51 |         print(f"Error: Convert script not found at '{convert_script}'")
 52 |         sys.exit(1)
 53 |     if not llama_quantize_binary.is_file():
 54 |         print(f"Error: llama-quantize binary not found at '{llama_quantize_binary}'")
 55 |         sys.exit(1)
 56 | 
 57 |     if not input_file.is_file():
 58 |         print(f"Error: Input safetensors file not found at '{input_file}'")
 59 |         sys.exit(1)
 60 | 
 61 |     try:
 62 |         print(f"Backing up '{input_file}' to '{input_backup_file}'")
 63 |         if input_backup_file.exists():
 64 |              print(f"Warning: Removing existing backup file '{input_backup_file}'")
 65 |              input_backup_file.unlink()
 66 |         shutil.move(input_file, input_backup_file)
 67 | 
 68 |         print("Preprocessing huggingface checkpoint...")
 69 |         cmd_preprocess = [
 70 |             sys.executable,
 71 |             str(preprocess_script),
 72 |             "--input", str(input_backup_file),
 73 |             "--output", str(preprocessed_output_file)
 74 |         ]
 75 |         run_command(cmd_preprocess)
 76 | 
 77 |         print("Converting to GGUF (f32)...")
 78 |         cmd_convert = [
 79 |             sys.executable,
 80 |             str(convert_script),
 81 |             str(model_dir),
 82 |             "--vocab-type", "bpe",
 83 |             "--outtype", "f32",
 84 |             "--concurrency", "1",
 85 |             "--outfile", str(gguf_f32_output)
 86 |         ]
 87 |         run_command(cmd_convert)
 88 | 
 89 |         print("Quantizing model to I2_S...")
 90 |         cmd_quantize = [
 91 |             str(llama_quantize_binary),
 92 |             str(gguf_f32_output),
 93 |             str(gguf_i2s_output),
 94 |             "I2_S",
 95 |             "1"
 96 |         ]
 97 |         run_command(cmd_quantize)
 98 | 
 99 |         print("Convert successfully.")
100 | 
101 |     except Exception as e:
102 |         print(f"An error occurred: {e}")
103 |     finally:
104 |         print("Cleaning up intermediate files...")
105 |         if preprocessed_output_file.exists() and preprocessed_output_file != input_backup_file:
106 |             print(f"Removing preprocessed file: {preprocessed_output_file}")
107 |             try:
108 |                 preprocessed_output_file.unlink()
109 |             except OSError as e:
110 |                 print(f"Warning: Could not remove {preprocessed_output_file}: {e}")
111 |         
112 |         if gguf_f32_output.exists():
113 |             print(f"Removing f32 GGUF: {gguf_f32_output}")
114 |             try:
115 |                 gguf_f32_output.unlink()
116 |             except OSError as e:
117 |                 print(f"Warning: Could not remove {gguf_f32_output}: {e}")
118 |         
119 |         if input_backup_file.exists():
120 |             if not input_file.exists():
121 |                 print(f"Restoring original '{input_file}' from '{input_backup_file}'")
122 |                 try:
123 |                     shutil.move(input_backup_file, input_file)
124 |                 except Exception as e:
125 |                     print(f"Warning: Could not restore {input_file} from backup: {e}")
126 |             else:
127 |                 print(f"Removing backup '{input_backup_file}' as original '{input_file}' should be present.")
128 |                 try:
129 |                     input_backup_file.unlink()
130 |                 except OSError as e:
131 |                     print(f"Warning: Could not remove backup {input_backup_file}: {e}")
132 | 
133 | if __name__ == "__main__":
134 |     main()


--------------------------------------------------------------------------------
/utils/e2e_benchmark.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import sys
 3 | import logging
 4 | import argparse
 5 | import platform
 6 | import subprocess
 7 | 
 8 | def run_command(command, shell=False, log_step=None):
 9 |     """Run a system command and ensure it succeeds."""
10 |     if log_step:
11 |         log_file = os.path.join(args.log_dir, log_step + ".log")
12 |         with open(log_file, "w") as f:
13 |             try:
14 |                 subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f)
15 |             except subprocess.CalledProcessError as e:
16 |                 logging.error(f"Error occurred while running command: {e}, check details in {log_file}")
17 |                 sys.exit(1)
18 |     else:
19 |         try:
20 |             subprocess.run(command, shell=shell, check=True)
21 |         except subprocess.CalledProcessError as e:
22 |             logging.error(f"Error occurred while running command: {e}")
23 |         sys.exit(1)
24 | 
25 | def run_benchmark():
26 |     build_dir =  os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "build")
27 |     if platform.system() == "Windows":
28 |         bench_path = os.path.join(build_dir, "bin", "Release", "llama-bench.exe")
29 |         if not os.path.exists(bench_path):
30 |             bench_path = os.path.join(build_dir, "bin", "llama-bench")
31 |     else:
32 |         bench_path = os.path.join(build_dir, "bin", "llama-bench")
33 |     if not os.path.exists(bench_path):
34 |         logging.error(f"Benchmark binary not found, please build first.")
35 |         sys.exit(1)
36 |     command = [
37 |         f'{bench_path}',
38 |         '-m', args.model,
39 |         '-n', str(args.n_token),
40 |         '-ngl', '0',
41 |         '-b', '1',
42 |         '-t', str(args.threads),
43 |         '-p', str(args.n_prompt),
44 |         '-r', '5'
45 |     ]
46 |     run_command(command)
47 | 
48 | def parse_args():
49 |     parser = argparse.ArgumentParser(description='Setup the environment for running the inference')
50 |     parser.add_argument("-m", "--model", type=str, help="Path to model file", required=True)
51 |     parser.add_argument("-n", "--n-token", type=int, help="Number of generated tokens", required=False, default=128)
52 |     parser.add_argument("-p", "--n-prompt", type=int, help="Prompt to generate text from", required=False, default=512)
53 |     parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2)
54 |     return parser.parse_args()
55 | 
56 | if __name__ == "__main__":
57 |     logging.basicConfig(level=logging.INFO)
58 |     args = parse_args()
59 |     run_benchmark()


--------------------------------------------------------------------------------
/utils/kernel_tuning.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/404980eecae38affa4871c3e419eae3f44536a95/utils/kernel_tuning.py


--------------------------------------------------------------------------------
/utils/preprocess-huggingface-bitnet.py:
--------------------------------------------------------------------------------
 1 | from safetensors import safe_open
 2 | from safetensors.torch import save_file
 3 | import torch
 4 | 
 5 | def quant_weight_fp16(weight):
 6 |     weight = weight.to(torch.float)
 7 |     s = 1.0 / weight.abs().mean().clamp_(min=1e-5)
 8 |     new_weight = (weight * s).round().clamp(-1, 1) / s
 9 |     return new_weight
10 | 
11 | def quant_model(input, output):
12 |     tensors = {}
13 | 
14 |     with safe_open(input, framework='pt') as f:
15 |         for name in f.keys():
16 |             tensors[name] = f.get_tensor(name)
17 | 
18 |             keyword_list = [
19 |                 'q_proj.weight', 
20 |                 'k_proj.weight', 
21 |                 'v_proj.weight',
22 |                 'o_proj.weight',
23 |                 'gate_proj.weight',
24 |                 'up_proj.weight',
25 |                 'down_proj.weight'
26 |             ]
27 | 
28 |             if any(keyword in name for keyword in keyword_list):
29 |                 print(f'[INFO] Quantizing {name}')
30 |                 tensors[name] = quant_weight_fp16(tensors[name])
31 |     
32 |     print(f'[INFO] Saving to {output}\nThis may take a while.')
33 |     save_file(tensors, output)
34 |                 
35 | 
36 | if __name__ == "__main__":
37 |     import argparse
38 |     parser = argparse.ArgumentParser(description="Convert Safetensors back to Torch .pth checkpoint")
39 |     parser.add_argument(
40 |         "--input", type=str, required=True,
41 |     )
42 |     parser.add_argument(
43 |         "--output", type=str, required=True,
44 |     )
45 |     args = parser.parse_args()
46 | 
47 |     quant_model(
48 |         input=args.input,
49 |         output=args.output,
50 |     )


--------------------------------------------------------------------------------