├── .gitignore
├── .gitmodules
├── CMakeLists.txt
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── assets
├── header_model_release.png
├── intel_performance.jpg
├── m2_performance.jpg
├── tl1.png
└── tl2.png
├── docs
└── codegen.md
├── gpu
├── README.md
├── bitnet_kernels
│ ├── bitnet_kernels.cu
│ ├── bitnet_kernels.h
│ ├── compile.sh
│ └── setup.py
├── convert_checkpoint.py
├── convert_safetensors.py
├── generate.py
├── model.py
├── pack_weight.py
├── requirements.txt
├── sample_utils.py
├── stats.py
├── test.py
├── tokenizer.model
└── tokenizer.py
├── include
└── ggml-bitnet.h
├── media
├── benchmark.png
└── demo.mp4
├── preset_kernels
├── Llama3-8B-1.58-100B-tokens
│ ├── bitnet-lut-kernels-tl1.h
│ ├── bitnet-lut-kernels-tl2.h
│ ├── kernel_config_tl1.ini
│ └── kernel_config_tl2.ini
├── bitnet_b1_58-3B
│ ├── bitnet-lut-kernels-tl1.h
│ ├── bitnet-lut-kernels-tl2.h
│ ├── kernel_config_tl1.ini
│ └── kernel_config_tl2.ini
└── bitnet_b1_58-large
│ ├── bitnet-lut-kernels-tl1.h
│ ├── bitnet-lut-kernels-tl2.h
│ ├── kernel_config_tl1.ini
│ └── kernel_config_tl2.ini
├── requirements.txt
├── run_inference.py
├── run_inference_server.py
├── setup_env.py
├── src
├── CMakeLists.txt
├── ggml-bitnet-lut.cpp
└── ggml-bitnet-mad.cpp
└── utils
├── codegen_tl1.py
├── codegen_tl2.py
├── convert-hf-to-gguf-bitnet.py
├── convert-ms-to-gguf-bitnet.py
├── convert.py
├── e2e_benchmark.py
├── generate-dummy-bitnet-model.py
└── kernel_tuning.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Extensions
2 |
3 | *.a
4 | *.bat
5 | *.bin
6 | *.dll
7 | *.dot
8 | *.etag
9 | *.exe
10 | *.gcda
11 | *.gcno
12 | *.gcov
13 | *.gguf
14 | *.gguf.json
15 | *.lastModified
16 | *.log
17 | *.metallib
18 | *.o
19 | *.so
20 | *.tmp
21 |
22 | # IDE / OS
23 |
24 | .cache/
25 | .ccls-cache/
26 | .direnv/
27 | .DS_Store
28 | .envrc
29 | .idea/
30 | .swiftpm
31 | .vs/
32 | .vscode/
33 | nppBackup
34 |
35 | # Models
36 | models/*
37 | gpu/checkpoints/*
38 |
39 | # Python
40 |
41 | /.venv
42 | __pycache__/
43 | */poetry.lock
44 | poetry.toml
45 |
46 | build/
47 | logs/
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "3rdparty/llama.cpp"]
2 | path = 3rdparty/llama.cpp
3 | url = https://github.com/Eddie-Wang1120/llama.cpp.git
4 | branch = merge-dev
5 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
2 | project("bitnet.cpp" C CXX)
3 | include(CheckIncludeFileCXX)
4 |
5 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
6 |
7 | if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
8 | set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
9 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
10 | endif()
11 |
12 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
13 |
14 | # option list
15 | option(BITNET_ARM_TL1 "bitnet.cpp: use tl1 on arm platform" OFF)
16 | option(BITNET_X86_TL2 "bitnet.cpp: use tl2 on x86 platform" OFF)
17 |
18 |
19 | set(CMAKE_CXX_STANDARD_REQUIRED true)
20 | set(CMAKE_C_STANDARD 11)
21 | set(CMAKE_C_STANDARD_REQUIRED true)
22 | set(THREADS_PREFER_PTHREAD_FLAG ON)
23 |
24 | # override ggml options
25 | set(GGML_BITNET_ARM_TL1 ${BITNET_ARM_TL1})
26 | set(GGML_BITNET_X86_TL2 ${BITNET_X86_TL2})
27 |
28 | if (GGML_BITNET_ARM_TL1)
29 | add_compile_definitions(GGML_BITNET_ARM_TL1)
30 | endif()
31 | if (GGML_BITNET_X86_TL2)
32 | add_compile_definitions(GGML_BITNET_X86_TL2)
33 | endif()
34 |
35 | if (CMAKE_C_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
36 | add_compile_options(-fpermissive)
37 | endif()
38 |
39 | find_package(Threads REQUIRED)
40 |
41 | add_subdirectory(src)
42 | set(LLAMA_BUILD_SERVER ON CACHE BOOL "Build llama.cpp server" FORCE)
43 | add_subdirectory(3rdparty/llama.cpp)
44 |
45 | # install
46 |
47 | include(GNUInstallDirs)
48 | include(CMakePackageConfigHelpers)
49 |
50 | set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR}
51 | CACHE PATH "Location of header files")
52 | set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR}
53 | CACHE PATH "Location of library files")
54 | set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR}
55 | CACHE PATH "Location of binary files")
56 | set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER})
57 | set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT})
58 | set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER})
59 |
60 | get_target_property(GGML_DIRECTORY ggml SOURCE_DIR)
61 | get_directory_property(GGML_DIR_DEFINES DIRECTORY ${GGML_DIRECTORY} COMPILE_DEFINITIONS)
62 | get_target_property(GGML_TARGET_DEFINES ggml COMPILE_DEFINITIONS)
63 | set(GGML_TRANSIENT_DEFINES ${GGML_TARGET_DEFINES} ${GGML_DIR_DEFINES})
64 | get_target_property(GGML_LINK_LIBRARIES ggml LINK_LIBRARIES)
65 |
66 | get_directory_property(LLAMA_TRANSIENT_DEFINES COMPILE_DEFINITIONS)
67 |
68 | write_basic_package_version_file(
69 | ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake
70 | VERSION ${LLAMA_INSTALL_VERSION}
71 | COMPATIBILITY SameMajorVersion)
72 |
73 | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake
74 | ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake
75 | DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama)
76 |
77 | set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/llama.h)
78 | install(TARGETS llama LIBRARY PUBLIC_HEADER)
79 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # bitnet.cpp
2 | [](https://opensource.org/licenses/MIT)
3 | 
4 |
5 | [
](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
6 |
7 | Try it out via this [demo](https://bitnet-demo.azurewebsites.net/), or build and run it on your own [CPU](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) or [GPU](https://github.com/microsoft/BitNet/blob/main/gpu/README.md).
8 |
9 | bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU and GPU (NPU support will coming next).
10 |
11 | The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details.
12 |
13 |
14 |
15 |
16 | >The tested models are dummy setups used in a research context to demonstrate the inference performance of bitnet.cpp.
17 |
18 | ## Demo
19 |
20 | A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2:
21 |
22 | https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1
23 |
24 | ## What's New:
25 | - 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md) 
26 | - 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
27 | - 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880)
28 | - 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965)
29 | - 10/21/2024 [1-bit AI Infra: Part 1.1, Fast and Lossless BitNet b1.58 Inference on CPUs](https://arxiv.org/abs/2410.16144)
30 | - 10/17/2024 bitnet.cpp 1.0 released.
31 | - 03/21/2024 [The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf)
32 | - 02/27/2024 [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764)
33 | - 10/17/2023 [BitNet: Scaling 1-bit Transformers for Large Language Models](https://arxiv.org/abs/2310.11453)
34 |
35 | ## Acknowledgements
36 |
37 | This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp) framework. We would like to thank all the authors for their contributions to the open-source community. Also, bitnet.cpp's kernels are built on top of the Lookup Table methodologies pioneered in [T-MAC](https://github.com/microsoft/T-MAC/). For inference of general low-bit LLMs beyond ternary models, we recommend using T-MAC.
38 | ## Official Models
39 |
40 |
41 |
42 | Model |
43 | Parameters |
44 | CPU |
45 | Kernel |
46 |
47 |
48 | I2_S |
49 | TL1 |
50 | TL2 |
51 |
52 |
53 | BitNet-b1.58-2B-4T |
54 | 2.4B |
55 | x86 |
56 | ✅ |
57 | ❌ |
58 | ✅ |
59 |
60 |
61 | ARM |
62 | ✅ |
63 | ✅ |
64 | ❌ |
65 |
66 |
67 |
68 | ## Supported Models
69 | ❗️**We use existing 1-bit LLMs available on [Hugging Face](https://huggingface.co/) to demonstrate the inference capabilities of bitnet.cpp. We hope the release of bitnet.cpp will inspire the development of 1-bit LLMs in large-scale settings in terms of model size and training tokens.**
70 |
71 |
72 |
73 |
74 | Model |
75 | Parameters |
76 | CPU |
77 | Kernel |
78 |
79 |
80 | I2_S |
81 | TL1 |
82 | TL2 |
83 |
84 |
85 | bitnet_b1_58-large |
86 | 0.7B |
87 | x86 |
88 | ✅ |
89 | ❌ |
90 | ✅ |
91 |
92 |
93 | ARM |
94 | ✅ |
95 | ✅ |
96 | ❌ |
97 |
98 |
99 | bitnet_b1_58-3B |
100 | 3.3B |
101 | x86 |
102 | ❌ |
103 | ❌ |
104 | ✅ |
105 |
106 |
107 | ARM |
108 | ❌ |
109 | ✅ |
110 | ❌ |
111 |
112 |
113 | Llama3-8B-1.58-100B-tokens |
114 | 8.0B |
115 | x86 |
116 | ✅ |
117 | ❌ |
118 | ✅ |
119 |
120 |
121 | ARM |
122 | ✅ |
123 | ✅ |
124 | ❌ |
125 |
126 |
127 | Falcon3 Family |
128 | 1B-10B |
129 | x86 |
130 | ✅ |
131 | ❌ |
132 | ✅ |
133 |
134 |
135 | ARM |
136 | ✅ |
137 | ✅ |
138 | ❌ |
139 |
140 |
141 | Falcon-E Family |
142 | 1B-3B |
143 | x86 |
144 | ✅ |
145 | ❌ |
146 | ✅ |
147 |
148 |
149 | ARM |
150 | ✅ |
151 | ✅ |
152 | ❌ |
153 |
154 |
155 |
156 |
157 |
158 | ## Installation
159 |
160 | ### Requirements
161 | - python>=3.9
162 | - cmake>=3.22
163 | - clang>=18
164 | - For Windows users, install [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/). In the installer, toggle on at least the following options(this also automatically installs the required additional tools like CMake):
165 | - Desktop-development with C++
166 | - C++-CMake Tools for Windows
167 | - Git for Windows
168 | - C++-Clang Compiler for Windows
169 | - MS-Build Support for LLVM-Toolset (clang)
170 | - For Debian/Ubuntu users, you can download with [Automatic installation script](https://apt.llvm.org/)
171 |
172 | `bash -c "$(wget -O - https://apt.llvm.org/llvm.sh)"`
173 | - conda (highly recommend)
174 |
175 | ### Build from source
176 |
177 | > [!IMPORTANT]
178 | > If you are using Windows, please remember to always use a Developer Command Prompt / PowerShell for VS2022 for the following commands. Please refer to the FAQs below if you see any issues.
179 |
180 | 1. Clone the repo
181 | ```bash
182 | git clone --recursive https://github.com/microsoft/BitNet.git
183 | cd BitNet
184 | ```
185 | 2. Install the dependencies
186 | ```bash
187 | # (Recommended) Create a new conda environment
188 | conda create -n bitnet-cpp python=3.9
189 | conda activate bitnet-cpp
190 |
191 | pip install -r requirements.txt
192 | ```
193 | 3. Build the project
194 | ```bash
195 | # Manually download the model and run with local path
196 | huggingface-cli download microsoft/BitNet-b1.58-2B-4T-gguf --local-dir models/BitNet-b1.58-2B-4T
197 | python setup_env.py -md models/BitNet-b1.58-2B-4T -q i2_s
198 |
199 | ```
200 |
201 | usage: setup_env.py [-h] [--hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}] [--model-dir MODEL_DIR] [--log-dir LOG_DIR] [--quant-type {i2_s,tl1}] [--quant-embd]
202 | [--use-pretuned]
203 |
204 | Setup the environment for running inference
205 |
206 | optional arguments:
207 | -h, --help show this help message and exit
208 | --hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}, -hr {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}
209 | Model used for inference
210 | --model-dir MODEL_DIR, -md MODEL_DIR
211 | Directory to save/load the model
212 | --log-dir LOG_DIR, -ld LOG_DIR
213 | Directory to save the logging info
214 | --quant-type {i2_s,tl1}, -q {i2_s,tl1}
215 | Quantization type
216 | --quant-embd Quantize the embeddings to f16
217 | --use-pretuned, -p Use the pretuned kernel parameters
218 |
219 | ## Usage
220 | ### Basic usage
221 | ```bash
222 | # Run inference with the quantized model
223 | python run_inference.py -m models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf -p "You are a helpful assistant" -cnv
224 | ```
225 |
226 | usage: run_inference.py [-h] [-m MODEL] [-n N_PREDICT] -p PROMPT [-t THREADS] [-c CTX_SIZE] [-temp TEMPERATURE] [-cnv]
227 |
228 | Run inference
229 |
230 | optional arguments:
231 | -h, --help show this help message and exit
232 | -m MODEL, --model MODEL
233 | Path to model file
234 | -n N_PREDICT, --n-predict N_PREDICT
235 | Number of tokens to predict when generating text
236 | -p PROMPT, --prompt PROMPT
237 | Prompt to generate text from
238 | -t THREADS, --threads THREADS
239 | Number of threads to use
240 | -c CTX_SIZE, --ctx-size CTX_SIZE
241 | Size of the prompt context
242 | -temp TEMPERATURE, --temperature TEMPERATURE
243 | Temperature, a hyperparameter that controls the randomness of the generated text
244 | -cnv, --conversation Whether to enable chat mode or not (for instruct models.)
245 | (When this option is turned on, the prompt specified by -p will be used as the system prompt.)
246 |
247 |
248 | ### Benchmark
249 | We provide scripts to run the inference benchmark providing a model.
250 |
251 | ```
252 | usage: e2e_benchmark.py -m MODEL [-n N_TOKEN] [-p N_PROMPT] [-t THREADS]
253 |
254 | Setup the environment for running the inference
255 |
256 | required arguments:
257 | -m MODEL, --model MODEL
258 | Path to the model file.
259 |
260 | optional arguments:
261 | -h, --help
262 | Show this help message and exit.
263 | -n N_TOKEN, --n-token N_TOKEN
264 | Number of generated tokens.
265 | -p N_PROMPT, --n-prompt N_PROMPT
266 | Prompt to generate text from.
267 | -t THREADS, --threads THREADS
268 | Number of threads to use.
269 | ```
270 |
271 | Here's a brief explanation of each argument:
272 |
273 | - `-m`, `--model`: The path to the model file. This is a required argument that must be provided when running the script.
274 | - `-n`, `--n-token`: The number of tokens to generate during the inference. It is an optional argument with a default value of 128.
275 | - `-p`, `--n-prompt`: The number of prompt tokens to use for generating text. This is an optional argument with a default value of 512.
276 | - `-t`, `--threads`: The number of threads to use for running the inference. It is an optional argument with a default value of 2.
277 | - `-h`, `--help`: Show the help message and exit. Use this argument to display usage information.
278 |
279 | For example:
280 |
281 | ```sh
282 | python utils/e2e_benchmark.py -m /path/to/model -n 200 -p 256 -t 4
283 | ```
284 |
285 | This command would run the inference benchmark using the model located at `/path/to/model`, generating 200 tokens from a 256 token prompt, utilizing 4 threads.
286 |
287 | For the model layout that do not supported by any public model, we provide scripts to generate a dummy model with the given model layout, and run the benchmark on your machine:
288 |
289 | ```bash
290 | python utils/generate-dummy-bitnet-model.py models/bitnet_b1_58-large --outfile models/dummy-bitnet-125m.tl1.gguf --outtype tl1 --model-size 125M
291 |
292 | # Run benchmark with the generated model, use -m to specify the model path, -p to specify the prompt processed, -n to specify the number of token to generate
293 | python utils/e2e_benchmark.py -m models/dummy-bitnet-125m.tl1.gguf -p 512 -n 128
294 | ```
295 | ### FAQ (Frequently Asked Questions)📌
296 |
297 | #### Q1: The build dies with errors building llama.cpp due to issues with std::chrono in log.cpp?
298 |
299 | **A:**
300 | This is an issue introduced in recent version of llama.cpp. Please refer to this [commit](https://github.com/tinglou/llama.cpp/commit/4e3db1e3d78cc1bcd22bcb3af54bd2a4628dd323) in the [discussion](https://github.com/abetlen/llama-cpp-python/issues/1942) to fix this issue.
301 |
302 | #### Q2: How to build with clang in conda environment on windows?
303 |
304 | **A:**
305 | Before building the project, verify your clang installation and access to Visual Studio tools by running:
306 | ```
307 | clang -v
308 | ```
309 |
310 | This command checks that you are using the correct version of clang and that the Visual Studio tools are available. If you see an error message such as:
311 | ```
312 | 'clang' is not recognized as an internal or external command, operable program or batch file.
313 | ```
314 |
315 | It indicates that your command line window is not properly initialized for Visual Studio tools.
316 |
317 | • If you are using Command Prompt, run:
318 | ```
319 | "C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64
320 | ```
321 |
322 | • If you are using Windows PowerShell, run the following commands:
323 | ```
324 | Import-Module "C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\Tools\Microsoft.VisualStudio.DevShell.dll" Enter-VsDevShell 3f0e31ad -SkipAutomaticLocation -DevCmdArguments "-arch=x64 -host_arch=x64"
325 | ```
326 |
327 | These steps will initialize your environment and allow you to use the correct Visual Studio tools.
328 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/assets/header_model_release.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/assets/header_model_release.png
--------------------------------------------------------------------------------
/assets/intel_performance.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/assets/intel_performance.jpg
--------------------------------------------------------------------------------
/assets/m2_performance.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/assets/m2_performance.jpg
--------------------------------------------------------------------------------
/assets/tl1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/assets/tl1.png
--------------------------------------------------------------------------------
/assets/tl2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/assets/tl2.png
--------------------------------------------------------------------------------
/docs/codegen.md:
--------------------------------------------------------------------------------
1 | Codegen for TL1 and TL2
2 | ------------------------
3 |
4 | codegen_tl1.py and codegen_tl2.py are using params to generate kernel codes in different devices to achieve fastest performance for TL1 and TL2.
5 |
6 | We cutting weight into multiple compute blocks to best utilize hardware capabilities.
7 |
8 | ### Example
9 | bitnet_b1_58-large:
10 |
11 | - Make sure Matmul kernels shapes \
12 | For example, bitnet_b1_58-large Matmul kernel shapes are:\
13 | [1536, 4096]\
14 | [1536, 1536]\
15 | [4096, 1536]
16 |
17 | - Make sure each BM, BK, bm for each kernel to meet the requirements below
18 | - Generate codes\
19 | For example, for bitnet_b1_58-large, we can gencode like:
20 |
21 | ```bash
22 | # For TL1
23 | python utils/codegen_tl1.py --model bitnet_b1_58-large --BM 256,128,256 --BK 128,64,128 --bm 32,64,32
24 |
25 | # For TL2
26 | python utils/codegen_tl2.py --model bitnet_b1_58-large --BM 256,128,256 --BK 96,192,96 --bm 32,32,32
27 | ```
28 |
29 | ### TL1:
30 | 
31 |
32 | For TL1, we cut weight into M / BM weights, each weight shape is (BM, K). Then we cut weight into K / BK weights, each weight shape is (BM, BK). As for (BM, BK) weight, we cut it the same way into (bm, compute_num / bm) compute blocks, and finish computing in it.
33 |
34 | Thus, we need to make sure
35 | - M % BM == 0
36 | - K % BK == 0
37 | - BM % bm == 0
38 | - bm choose in [32, 64]
39 |
40 | ### TL2:
41 | 
42 |
43 | For TL2, things got a little more complicated. Due to TL2 needs BK % 6 == 0, we need to split K into threeK and twoK, in which compute in TL2 for (M, threeK), compute in TL1 for (M, two_K).
44 |
45 | Thus, we needs to make sure
46 | - M % BM == 0
47 | - K % BK % 32 == 0
48 | - BM % bm == 0
49 | - bm choose in \[32\]
--------------------------------------------------------------------------------
/gpu/README.md:
--------------------------------------------------------------------------------
1 | # BitNet Inference Kernel
2 |
3 | This repository provides a highly efficient GEMV kernel implementation for the BitNet model, optimized for W2A8 inference — 2-bit weights and 8-bit activations. It is tailored for use with the [BitNet-b1.58-2B-4T](https://arxiv.org/abs/2504.12285) model.
4 |
5 | ## Features
6 |
7 | - Support for W2A8 (2-bit weight × 8-bit activation) GEMV computation
8 | - Custom CUDA kernels with low-latency execution
9 | - Optimizations for memory access, decoding, and compute throughput
10 |
11 | ## Usage
12 |
13 | Installation and kernel performance tests:
14 |
15 | ```bash
16 | # (Recommended) Create a new conda environment
17 | conda create --name bitnet-gpu "python<3.13"
18 | conda activate bitnet-gpu
19 |
20 | # Install dependencies
21 | pip install -r requirements.txt
22 |
23 | # Build the kernel
24 | cd bitnet_kernels
25 | bash compile.sh
26 | cd ..
27 |
28 | # Run performance tests
29 | python test.py
30 | ```
31 |
32 | End-to-end inference:
33 |
34 | ```bash
35 | # Download and convert the BitNet-b1.58-2B model
36 | mkdir checkpoints
37 | huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./checkpoints/bitnet-b1.58-2B-4T-bf16
38 | python ./convert_safetensors.py --safetensors_file ./checkpoints/bitnet-b1.58-2B-4T-bf16/model.safetensors --output checkpoints/model_state.pt --model_name 2B
39 | python ./convert_checkpoint.py --input ./checkpoints/model_state.pt
40 | rm ./checkpoints/model_state.pt
41 |
42 | # Inference
43 | python3 ./generate.py ./checkpoints/ --interactive --chat_format
44 | ```
45 |
46 | ## Optimizations
47 |
48 | ### Weight Permutation
49 |
50 | The weight matrix is divided into 16×32 blocks to optimize memory access patterns.
51 |
52 | Within each block, values are stored contiguously in memory and permuted to facilitate efficient access and processing.
53 |
54 | See `convert_checkpoint.py` for details.
55 |
56 | ### Fast Decoding
57 |
58 | Every 16 two-bit values are packed into a single 32-bit integer using the following interleaving pattern:
59 | ```
60 | [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
61 | ```
62 |
63 | This layout is designed to accelerate decoding by enabling efficient extraction of 4 values at a time into `int8`.
64 |
65 | ### `dp4a` Instruction
66 |
67 | We use the `dp4a` instruction to accelerate low-precision dot product operations.
68 |
69 | This instruction performs a dot product between two 4-element vectors (each stored in a 32-bit word as 8-bit integers) and accumulates the result into a 32-bit integer.
70 |
71 | It significantly improves GEMV throughput when processing quantized weights and activations.
72 |
73 |
74 | ## Performance
75 |
76 | Kernel performance (tested on NVIDIA A100 40GB GPU):
77 |
78 | | Shape (N×K) | W2A8 Latency (us) | BF16 Latency (us) | Speedup Ratio |
79 | |---------------------|-------------------|-------------------|----------------------|
80 | | 2560 × 2560 | 13.32 | 18.32 | 1.38 |
81 | | 3840 × 2560 | 14.90 | 18.87 | 1.27 |
82 | | 13824 × 2560 | 18.75 | 59.51 | 3.17 |
83 | | 2560 × 6912 | 14.49 | 37.78 | 2.61 |
84 | | 3200 × 3200 | 14.61 | 19.08 | 1.31 |
85 | | 4800 × 3200 | 13.09 | 21.84 | 1.67 |
86 | | 3200 × 10240 | 19.64 | 60.79 | 3.10 |
87 | | 20480 × 3200 | 30.99 | 112.39 | 3.63 |
88 |
89 | Generation throughput:
90 |
91 | | BF16 (tokens/s) | W2A8 (tokens/s) | Speedup Ratio |
92 | |---|---|---|
93 | | 10.9 | 213.3 | 19.6 |
--------------------------------------------------------------------------------
/gpu/bitnet_kernels/bitnet_kernels.cu:
--------------------------------------------------------------------------------
1 | #include "bitnet_kernels.h"
2 |
3 | extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){
4 | if (M == 1 && N == 3840 && K == 2560){
5 | ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<>>(input0, input1, output0, s, ws);
6 | }
7 | else if (M == 1 && N == 2560 && K == 2560){
8 | ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<>>(input0, input1, output0, s, ws);
9 | }
10 | else if (M == 1 && N == 13824 && K == 2560){
11 | ladder_int8xint2_kernel<1, 13824, 2560, 2, 8, 16><<>>(input0, input1, output0, s, ws);
12 | }
13 | else if (M == 1 && N == 2560 && K == 6912){
14 | ladder_int8xint2_kernel<1, 2560, 6912, 1, 8, 16><<>>(input0, input1, output0, s, ws);
15 | }
16 | else if(M == 1 && N == 4800 && K == 3200){
17 | ladder_int8xint2_kernel<1, 4800, 3200, 6, 8, 16><<>>(input0, input1, output0, s, ws);
18 | }
19 | else if(M == 1 && N == 3200 && K == 3200){
20 | ladder_int8xint2_kernel<1, 3200, 3200, 1, 8, 16><<>>(input0, input1, output0, s, ws);
21 | }
22 | else if(M == 1 && N == 20480 && K == 3200){
23 | ladder_int8xint2_kernel<1, 20480, 3200, 2, 8, 16><<>>(input0, input1, output0, s, ws);
24 | }
25 | else if(M == 1 && N == 3200 && K == 10240){
26 | ladder_int8xint2_kernel<1, 3200, 10240, 1, 8, 16><<>>(input0, input1, output0, s, ws);
27 | }
28 | else if(M == 1 && N == 5120 && K == 27648){
29 | ladder_int8xint2_kernel<1, 5120, 27648, 1, 8, 16><<>>(input0, input1, output0, s, ws);
30 | }
31 | else if(M == 1 && N == 55296 && K == 5120){
32 | ladder_int8xint2_kernel<1, 55296, 5120, 1, 8, 16><<>>(input0, input1, output0, s, ws);
33 | }
34 | else{
35 | std::cout << "required ladder gemm kernel: M " << M << ", N " << N << ", K " << K << std::endl;
36 | }
37 | }
--------------------------------------------------------------------------------
/gpu/bitnet_kernels/bitnet_kernels.h:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 |
11 | #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || (__CUDACC_VER_MAJOR__ > 11))
12 | #define TVM_ENABLE_L2_PREFETCH 1
13 | #else
14 | #define TVM_ENABLE_L2_PREFETCH 0
15 | #endif
16 |
17 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
18 | #define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1
19 | #else
20 | #define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0
21 | #endif
22 |
23 | template
24 | __device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16)
25 | {
26 | // convert 8 int2b_t to 8 int8b_t -> 2 int32
27 | uint *i8s = reinterpret_cast(_i8s);
28 |
29 | // i2s = {e0, e4, e8, e12, e1, e5, e9, e13, e2, e6, e10, e14, e3, e7, e11, e15}
30 | uint const i2s = *_i2s;
31 |
32 | static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
33 | static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
34 | static constexpr uint I4s_TO_I8s_MAGIC_NUM = 0x00000000;
35 |
36 | #pragma unroll
37 | for (int i = 0; i < (N / 4); i++)
38 | {
39 | asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
40 | : "=r"(i8s[i])
41 | : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(I4s_TO_I8s_MAGIC_NUM), "n"(immLut));
42 | i8s[i] = __vsubss4(i8s[i], 0x02020202);
43 | }
44 | }
45 |
46 | template
47 | __global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restrict__ A, int8_t* __restrict__ B, __nv_bfloat16* __restrict__ dtype_transform, __nv_bfloat16* __restrict__ s, __nv_bfloat16* __restrict__ ws) {
48 | constexpr int K_per_loop = 16;
49 | constexpr int wmma_K = 32;
50 | constexpr int wmma_N = 16;
51 | int in_thread_C_local[1];
52 | signed char A_local[K_per_loop];
53 | int B_reshape_local[1];
54 | signed char B_decode_local[K_per_loop];
55 | int red_buf0[1];
56 | in_thread_C_local[0] = 0;
57 | #pragma unroll
58 | for (int k_0 = 0; k_0 < K/(K_per_loop * K_block_size); ++k_0) {
59 | *(int4*)(A_local + 0) = *(int4*)(A + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop)));
60 | B_reshape_local[0] = *(int*)(B +
61 | (((int)blockIdx.x) * N_block_size * K / 4) +
62 | (k_0 * K_block_size * K_per_loop * wmma_N / 4) +
63 | ((((int)threadIdx.x) >> 1) * wmma_K * wmma_N / 4) +
64 | ((((int)threadIdx.y) >> 3) * (wmma_K * wmma_N / 2) / 4) +
65 | ((((int)threadIdx.x) & 1) * (wmma_K * wmma_N / 4) / 4) +
66 | ((((int)threadIdx.y) & 7) * (wmma_K / 2) / 4)
67 | );
68 | decode_i2s_to_i8s(B_reshape_local, B_decode_local, 16);
69 | #pragma unroll
70 | for (int k_2_0 = 0; k_2_0 < 4; ++k_2_0) {
71 | in_thread_C_local[0] = __dp4a(*(int *)&A_local[((k_2_0 * 4))],*(int *)&B_decode_local[((k_2_0 * 4))], in_thread_C_local[0]);
72 | }
73 | }
74 | red_buf0[0] = in_thread_C_local[0];
75 | #pragma unroll
76 | for (int offset = K_block_size/2; offset > 0; offset /= 2) {
77 | red_buf0[0] += __shfl_down_sync(__activemask(), red_buf0[0], offset, K_block_size);
78 | }
79 | int out_idx = ((((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y));
80 | int ws_idx = out_idx / (N / ws_num);
81 | if (threadIdx.x == 0)
82 | dtype_transform[out_idx] = (__nv_bfloat16)(((float)red_buf0[0])/(float)s[0]*(float)ws[ws_idx]);
83 | }
--------------------------------------------------------------------------------
/gpu/bitnet_kernels/compile.sh:
--------------------------------------------------------------------------------
1 | nvcc -std=c++17 -Xcudafe --diag_suppress=177 --compiler-options -fPIC -lineinfo --shared bitnet_kernels.cu -lcuda -gencode=arch=compute_80,code=compute_80 -o libbitnet.so
2 |
3 |
4 |
--------------------------------------------------------------------------------
/gpu/bitnet_kernels/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name='bitlinear_cpp',
6 | ext_modules=[
7 | CUDAExtension('bitlinear_cuda', [
8 | 'bitnet_kernels.cu',
9 | ])
10 | ],
11 | cmdclass={
12 | 'build_ext': BuildExtension
13 | })
--------------------------------------------------------------------------------
/gpu/convert_checkpoint.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 | import sys
5 | from pathlib import Path
6 | from typing import Optional
7 | from dataclasses import dataclass
8 | import torch
9 | from einops import rearrange
10 | from safetensors.torch import save_file
11 | import model
12 | from pack_weight import convert_weight_int8_to_int2
13 |
14 | @torch.inference_mode()
15 | def convert_ts_checkpoint(
16 | *,
17 | input_path: str = "",
18 | ) -> None:
19 |
20 | config = model.ModelArgs()
21 | print(f"Model config {config.__dict__}")
22 |
23 | def quant_weight_int8(weight):
24 | s = 1.0 / weight.abs().mean().clamp_(min=1e-5)
25 | new_weight = (weight * s).round().clamp(-1, 1).to(torch.int8)
26 | new_scale = (1.0 / s).to(torch.bfloat16)
27 | return new_weight, new_scale.reshape(1)
28 |
29 | def quant_weight_fp16(weight):
30 | s = 1.0 / weight.abs().mean().clamp_(min=1e-5)
31 | new_weight = (weight * s).round().clamp(-1, 1) / s
32 | return new_weight
33 |
34 | def convert_int8_to_int2(weight):
35 | return convert_weight_int8_to_int2(weight)
36 |
37 | merged_result = torch.load(input_path, map_location="cpu", mmap=True)
38 | int2_result = {}
39 | fp16_result = {}
40 | zero = torch.zeros(1).to(torch.bfloat16)
41 | for key, value in merged_result.items():
42 | if 'wqkv' in key:
43 | wq = value[:config.dim]
44 | wk = value[config.dim:config.dim // config.n_heads * config.n_kv_heads + config.dim]
45 | wv = value[config.dim // config.n_heads * config.n_kv_heads + config.dim:]
46 | wq_weight, wa_scale = quant_weight_int8(wq)
47 | wk_weight, wb_scale = quant_weight_int8(wk)
48 | wv_weight, wc_scale = quant_weight_int8(wv)
49 | wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0)
50 | wqkv_scale = torch.cat([wa_scale, wb_scale, wc_scale, zero], dim=0)
51 | int2_result[key] = convert_int8_to_int2(wqkv_weight)
52 | int2_result[key.replace('weight', 'weight_scale')] = wqkv_scale
53 |
54 | wq_weight = quant_weight_fp16(wq)
55 | wk_weight = quant_weight_fp16(wk)
56 | wv_weight = quant_weight_fp16(wv)
57 | wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0)
58 | fp16_result[key] = wqkv_weight
59 | elif 'w13' in key:
60 | w1 = value[:config.ffn_dim]
61 | w3 = value[config.ffn_dim:]
62 | w1_weight, w1_scale = quant_weight_int8(w1)
63 | w3_weight, w3_scale = quant_weight_int8(w3)
64 | w13_weight = torch.cat([w1_weight, w3_weight], dim=0)
65 | w13_scale = torch.cat([w1_scale, w3_scale, zero, zero], dim=0)
66 | int2_result[key] = convert_int8_to_int2(w13_weight)
67 | int2_result[key.replace('weight', 'weight_scale')] = w13_scale
68 |
69 | w1_weight = quant_weight_fp16(w1)
70 | w3_weight = quant_weight_fp16(w3)
71 | w13_weight = torch.cat([w1_weight, w3_weight], dim=0)
72 | fp16_result[key] = w13_weight
73 | elif 'w2' in key or 'wo' in key:
74 | weight, scale = quant_weight_int8(value)
75 | scale = torch.cat([scale, zero, zero, zero], dim=0)
76 | int2_result[key] = convert_int8_to_int2(weight)
77 | int2_result[key.replace('weight', 'weight_scale')] = scale
78 |
79 | weight = quant_weight_fp16(value)
80 | fp16_result[key] = weight
81 | else:
82 | int2_result[key] = value.clone()
83 | fp16_result[key] = value.clone()
84 |
85 | output_dir = os.path.dirname(input_path)
86 | print(f"Saving checkpoint to {output_dir}/model_state_int2.pt")
87 | torch.save(int2_result, f"{output_dir}/model_state_int2.pt")
88 |
89 | print(f"Saving checkpoint to {output_dir}/model_state_fp16.pt")
90 | torch.save(fp16_result, f"{output_dir}/model_state_fp16.pt")
91 |
92 | if __name__ == '__main__':
93 | import argparse
94 | parser = argparse.ArgumentParser(description='Convert TorchScale checkpoint.')
95 | parser.add_argument('--input', type=str)
96 |
97 | args = parser.parse_args()
98 | convert_ts_checkpoint(
99 | input_path=args.input,
100 | )
101 |
--------------------------------------------------------------------------------
/gpu/convert_safetensors.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | from pathlib import Path
4 | from safetensors.torch import load_file
5 | from einops import rearrange
6 | from dataclasses import dataclass
7 | from typing import Optional
8 |
9 | transformer_configs = {
10 | "2B": dict(n_layer=30, n_head=20, dim=2560, vocab_size=128256, n_local_heads=5, intermediate_size=6912),
11 | }
12 |
13 | @dataclass
14 | class ModelArgs:
15 | block_size: int = 4096
16 | vocab_size: int = 32000
17 | n_layer: int = 32
18 | n_head: int = 32
19 | dim: int = 4096
20 | intermediate_size: int = None
21 | n_local_heads: int = -1
22 | head_dim: int = 64
23 | rope_base: float = 10000
24 | norm_eps: float = 1e-5
25 |
26 | def __post_init__(self):
27 | if self.n_local_heads == -1:
28 | self.n_local_heads = self.n_head
29 | if self.intermediate_size is None:
30 | hidden_dim = 4 * self.dim
31 | n_hidden = int(2 * hidden_dim / 3)
32 | self.intermediate_size = n_hidden + (256 - n_hidden % 256) if n_hidden % 256 else n_hidden
33 | self.head_dim = self.dim // self.n_head
34 |
35 | @classmethod
36 | def from_name(cls, name: str):
37 | if name in transformer_configs:
38 | return cls(**transformer_configs[name])
39 | config = [k for k in transformer_configs if k in name.upper() or k in name]
40 | assert len(config) == 1, f"Unknown model name: {name}"
41 | return cls(**transformer_configs[config[0]])
42 |
43 | def invert_convert_q(w: torch.Tensor, config: ModelArgs) -> torch.Tensor:
44 | return rearrange(w, '(h l d) i -> (h d l) i', h=config.n_head, l=2)
45 |
46 | def invert_convert_k(w: torch.Tensor, config: ModelArgs) -> torch.Tensor:
47 | return rearrange(w, '(h l d) i -> (h d l) i', h=config.n_local_heads, l=2)
48 |
49 | def convert_back(
50 | safetensors_path: str,
51 | output_file: str,
52 | model_name: Optional[str] = None,
53 | ):
54 | st_dict = load_file(safetensors_path)
55 |
56 | cfg = ModelArgs.from_name(model_name)
57 | print(f"Using model configurations: {cfg}")
58 |
59 | recovered: dict = {}
60 |
61 | for layer in range(cfg.n_layer):
62 | base = f"model.layers.{layer}."
63 |
64 | wq = st_dict[f"{base}self_attn.q_proj.weight"]
65 | wk = st_dict[f"{base}self_attn.k_proj.weight"]
66 | wv = st_dict[f"{base}self_attn.v_proj.weight"]
67 |
68 | wq = invert_convert_q(wq, cfg)
69 | wk = invert_convert_k(wk, cfg)
70 |
71 | wqkv = torch.cat([wq, wk, wv], dim=0)
72 | recovered[f"layers.{layer}.attention.wqkv.weight"] = wqkv
73 |
74 | recovered[f"layers.{layer}.attention.wo.weight"] = st_dict[f"{base}self_attn.o_proj.weight"]
75 |
76 | recovered[f"layers.{layer}.attention_norm.weight"] = st_dict[f"{base}input_layernorm.weight"]
77 | recovered[f"layers.{layer}.ffn_norm.weight"] = st_dict[f"{base}post_attention_layernorm.weight"]
78 | recovered[f"layers.{layer}.attention.attn_sub_norm.weight"] = st_dict[f"{base}self_attn.attn_sub_norm.weight"]
79 | recovered[f"layers.{layer}.feed_forward.ffn_sub_norm.weight"] = st_dict[f"{base}mlp.ffn_sub_norm.weight"]
80 |
81 | gate = st_dict[f"{base}mlp.gate_proj.weight"]
82 | up = st_dict[f"{base}mlp.up_proj.weight"]
83 | w13 = torch.cat([gate, up], dim=0)
84 | recovered[f"layers.{layer}.feed_forward.w13.weight"] = w13
85 |
86 | recovered[f"layers.{layer}.feed_forward.w2.weight"] = st_dict[f"{base}mlp.down_proj.weight"]
87 |
88 | recovered["tok_embeddings.weight"] = st_dict["model.embed_tokens.weight"]
89 | recovered["output.weight"] = st_dict["model.embed_tokens.weight"]
90 | recovered["norm.weight"] = st_dict["model.norm.weight"]
91 |
92 | print(f"Saving to {output_file}")
93 | torch.save(recovered, output_file)
94 |
95 | if __name__ == "__main__":
96 | import argparse
97 | parser = argparse.ArgumentParser(description="Convert Safetensors back to Torch .pth checkpoint")
98 | parser.add_argument(
99 | "--safetensors_file", type=str, required=True,
100 | help="Path to input .safetensors file"
101 | )
102 | parser.add_argument(
103 | "--output", type=str, default="./checkpoints/model_state.pt",
104 | help="Path to output .pt file"
105 | )
106 | parser.add_argument(
107 | "--model_name", type=str, default="2B",
108 | help="Model configuration name to use (e.g. 2B)"
109 | )
110 | args = parser.parse_args()
111 |
112 | convert_back(
113 | safetensors_path=args.safetensors_file,
114 | output_file=args.output,
115 | model_name=args.model_name,
116 | )
--------------------------------------------------------------------------------
/gpu/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2 | #
3 | # This source code is licensed under the BSD license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import json
7 | import os
8 | import readline # type: ignore # noqa
9 | import sys
10 | import time
11 | from dataclasses import dataclass
12 | from pathlib import Path
13 | from typing import Iterable, Optional, Tuple, Union
14 |
15 | import fire
16 | import model as fast
17 | import torch
18 | from stats import Stats
19 | from tokenizer import Tokenizer, ChatFormat
20 | import sample_utils
21 | from xformers.ops.fmha.attn_bias import (
22 | BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
23 | )
24 |
25 |
26 | @dataclass
27 | class GenArgs:
28 | gen_length: int = 32
29 | gen_bsz: int = 1
30 | prompt_length: int = 64
31 |
32 | use_sampling: bool = False
33 | temperature: float = 0.8
34 | top_p: float = 0.9
35 |
36 |
37 | class FastGen:
38 | GRAPH_WARMUPS: int = 1
39 | tokenizer: Tokenizer
40 |
41 | @staticmethod
42 | def build(
43 | ckpt_dir: str,
44 | gen_args: GenArgs,
45 | device: Union[torch.device, str],
46 | tokenizer_path: Optional[str] = None,
47 | num_layers: int = 13,
48 | use_full_vocab: bool = False,
49 | ) -> "FastGen":
50 | """
51 | Load a Llama or Code Llama checkpoint and return a new
52 | generator for this model.
53 | """
54 | start_time = time.time()
55 |
56 | model_args_prefill = fast.ModelArgs(use_kernel=False)
57 | model_args_decode = fast.ModelArgs(use_kernel=True)
58 | tokenizer = Tokenizer("./tokenizer.model")
59 |
60 | torch.set_default_device(device)
61 | torch.set_default_dtype(torch.bfloat16)
62 |
63 | prefill_model = fast.Transformer(model_args_prefill)
64 | decode_model = fast.Transformer(model_args_decode)
65 |
66 | fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt")
67 | fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu")
68 | int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt")
69 | int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu")
70 | prefill_model.load_state_dict(fp16_checkpoint, strict=True)
71 | decode_model.load_state_dict(int2_checkpoint, strict=True)
72 |
73 | torch.cuda.synchronize()
74 | print(f"loaded model in {time.time() - start_time:.2f} seconds")
75 | start_time = time.time()
76 |
77 | return FastGen(gen_args, model_args_prefill, prefill_model, decode_model, tokenizer)
78 |
79 | def __init__(
80 | self,
81 | args: GenArgs,
82 | model_args: fast.ModelArgs,
83 | prefill_model: fast.Transformer,
84 | decode_model: fast.Transformer,
85 | tokenizer: Tokenizer,
86 | ):
87 | self.gen_args = args
88 | self.max_seq_length = args.prompt_length + args.gen_length
89 | self.model_args = model_args
90 | # self.model = model
91 | self.prefill_model = prefill_model
92 | self.decode_model = decode_model
93 | self.tokenizer = tokenizer
94 | self._prefill_cuda_graph, self._prefill_compile_model, self._prefill_inputs, self._prefill_logits = None, None, None, None
95 | self._generate_cuda_graph, self._generate_compile_model, self._generate_inputs, self._generate_logits = None, None, None, None
96 | self._cache = None
97 | start_time = time.time()
98 | self._prefill_compile_model = self.compile_prefill()
99 | self._generate_compile_model = self.compile_generate()
100 | print(f"compiled model in {time.time() - start_time:.2f} seconds")
101 |
102 | def compile_prefill(self):
103 |
104 | if self._cache is None:
105 | self._cache = fast.make_cache(
106 | args=self.model_args,
107 | length=self.gen_args.gen_bsz * self.max_seq_length,
108 | )
109 |
110 | seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)]
111 |
112 | bias = AttnBias.from_seqlens(
113 | q_seqlen=seq_lens,
114 | kv_seqlen=seq_lens,
115 | kv_padding=self.max_seq_length,
116 | )
117 | bias.q_seqinfo.to("cuda")
118 | bias.k_seqinfo.to("cuda")
119 |
120 | tokens = torch.IntTensor([1] * self.gen_args.gen_bsz * self.gen_args.prompt_length).cuda()
121 | self._prefill_inputs = (tokens, bias)
122 |
123 | s = torch.cuda.Stream()
124 | s.wait_stream(torch.cuda.current_stream())
125 |
126 | with torch.cuda.stream(s):
127 | _ = self.prefill_model.forward_with_attn_bias(
128 | token_values=self._prefill_inputs[0],
129 | attn_bias=self._prefill_inputs[1],
130 | cache=self._cache,
131 | )
132 | torch.cuda.current_stream().wait_stream(s)
133 |
134 | self._prefill_cuda_graph = torch.cuda.CUDAGraph()
135 | recording_kwargs = {}
136 | if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
137 | # In PyTorch 2.1+ and nightlies from late Aug 2023,
138 | # we can do this to maybe avoid watchdog-related crashes
139 | recording_kwargs["capture_error_mode"] = "thread_local"
140 | with torch.cuda.graph(self._prefill_cuda_graph, **recording_kwargs):
141 | self._prefill_logits = self.prefill_model.forward_with_attn_bias(
142 | token_values=self._prefill_inputs[0],
143 | attn_bias=self._prefill_inputs[1],
144 | cache=self._cache,
145 | )
146 |
147 | def replay(tokens, seq_lens=None):
148 | self._prefill_inputs[0].copy_(tokens)
149 | if seq_lens is not None:
150 | self._prefill_inputs[1].k_seqinfo.seqlen.copy_(seq_lens)
151 |
152 | self._prefill_cuda_graph.replay()
153 | torch.cuda.synchronize()
154 |
155 | return self._prefill_logits
156 |
157 | return replay
158 |
159 | def compile_generate(self):
160 |
161 | if self._cache is None:
162 | self._cache = fast.make_cache(
163 | args=self.model_args,
164 | length=self.gen_args.gen_bsz * self.max_seq_length,
165 | )
166 |
167 | seq_lens = [1 for _ in range(self.gen_args.gen_bsz)]
168 | kv_seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)]
169 |
170 | bias = AttnBias.from_seqlens(
171 | q_seqlen=seq_lens,
172 | kv_seqlen=kv_seq_lens,
173 | kv_padding=self.max_seq_length,
174 | )
175 | bias.q_seqinfo.to("cuda")
176 | bias.k_seqinfo.to("cuda")
177 |
178 | tokens = torch.IntTensor([1] * self.gen_args.gen_bsz).cuda()
179 | self._generate_inputs = (tokens, bias)
180 |
181 | s = torch.cuda.Stream()
182 | s.wait_stream(torch.cuda.current_stream())
183 |
184 | with torch.cuda.stream(s):
185 | _ = self.decode_model.forward_with_attn_bias(
186 | token_values=self._generate_inputs[0],
187 | attn_bias=self._generate_inputs[1],
188 | cache=self._cache,
189 | )
190 | torch.cuda.current_stream().wait_stream(s)
191 |
192 | self._generate_cuda_graph = torch.cuda.CUDAGraph()
193 | recording_kwargs = {}
194 | if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
195 | # In PyTorch 2.1+ and nightlies from late Aug 2023,
196 | # we can do this to maybe avoid watchdog-related crashes
197 | recording_kwargs["capture_error_mode"] = "thread_local"
198 | with torch.cuda.graph(self._generate_cuda_graph, **recording_kwargs):
199 | self._generate_logits = self.decode_model.forward_with_attn_bias(
200 | token_values=self._generate_inputs[0],
201 | attn_bias=self._generate_inputs[1],
202 | cache=self._cache,
203 | )
204 |
205 | def replay(tokens, seq_lens):
206 | self._generate_inputs[0].copy_(tokens)
207 | self._generate_inputs[1].k_seqinfo.seqlen.copy_(seq_lens)
208 |
209 | self._generate_cuda_graph.replay()
210 |
211 | return self._generate_logits
212 |
213 | return replay
214 |
215 |
216 | @torch.inference_mode()
217 | def generate_all(
218 | self, prompts: list[list[int]], use_cuda_graphs: bool, use_sampling: bool
219 | ) -> Tuple[Stats, list[list[int]]]:
220 | bs = len(prompts)
221 | prompt_lens = [len(p) for p in prompts]
222 | padded_prompt_lens = [self.gen_args.prompt_length] * bs
223 | max_prompt_length = max(prompt_lens)
224 | gen_length = self.gen_args.gen_length
225 | max_seq_length = max_prompt_length + gen_length
226 | print(max_prompt_length, gen_length)
227 |
228 | bias = AttnBias.from_seqlens(
229 | q_seqlen=padded_prompt_lens,
230 | kv_seqlen=prompt_lens,
231 | kv_padding=max_seq_length,
232 | )
233 | bias.q_seqinfo.to("cuda")
234 | bias.k_seqinfo.to("cuda")
235 |
236 | # Input tensors to the cuda graph
237 | kv_seqlen = bias.k_seqinfo.seqlen
238 | prompts = [prompt + [1] * (self.gen_args.prompt_length - len(prompt)) for prompt in prompts]
239 | tokens = torch.IntTensor(sum(prompts, [])).cuda()
240 | out_tokens = torch.zeros((max_seq_length, bs), dtype=torch.int)
241 |
242 | stats = Stats()
243 | torch.cuda.synchronize()
244 | stats.phase("prefill" if use_cuda_graphs else "total")
245 | # stats.phase("total")
246 |
247 | output = self._prefill_compile_model(tokens, None)
248 |
249 | logits = output[kv_seqlen - 1, :]
250 | logits = logits.view(bs, self.model_args.vocab_size)
251 |
252 | if use_sampling:
253 | temp = 0.7
254 | top_p = 0.95
255 | probs = torch.softmax(logits / temp, dim=-1)
256 | next_token = sample_utils.top_p(probs, top_p)
257 | else:
258 | next_token = torch.argmax(logits, dim=-1)
259 |
260 | next_token = next_token.reshape(bs)
261 | out_tokens[0, :] = next_token
262 |
263 | torch.cuda.synchronize()
264 | stats.phase("decode" if use_cuda_graphs else "total")
265 |
266 | eos_id = self.tokenizer.eot_id
267 | for niter in range(1, gen_length):
268 | kv_seqlen.add_(kv_seqlen < max_seq_length)
269 | output = self._generate_compile_model(next_token, kv_seqlen)
270 |
271 | logits = output.view(bs, self.model_args.vocab_size)
272 |
273 | if use_sampling:
274 | temp = 0.7
275 | top_p = 0.95
276 | probs = torch.softmax(logits / temp, dim=-1)
277 | next_token = sample_utils.top_p(probs, top_p)
278 | else:
279 | next_token = torch.argmax(logits, dim=-1)
280 |
281 | next_token = next_token.reshape(bs)
282 | out_tokens[niter, :] = next_token
283 |
284 | if next_token.eq(eos_id).any():
285 | break
286 |
287 | torch.cuda.synchronize()
288 | stats.end_phase(tokens=niter * bs)
289 |
290 | def trim_answer(prompt_len, tokens):
291 | # print(prompt, tokens)
292 | """Trim the answer to end it on an eos token."""
293 | tokens = tokens[: max_seq_length - prompt_len]
294 | eos_id = self.tokenizer.eot_id
295 | if eos_id in tokens:
296 | return tokens[: tokens.index(eos_id) + 1]
297 | else:
298 | return tokens
299 |
300 | answers = [
301 | trim_answer(prompt_len, answer)
302 | for prompt_len, answer in zip(prompt_lens, out_tokens.t().tolist())
303 | ]
304 | return stats, answers
305 |
306 |
307 | def get_prompts(interactive: bool) -> Iterable[list[str]]:
308 | if interactive:
309 | while True:
310 | try:
311 | prompts = input("enter prompt: ").split("\n")
312 | except EOFError:
313 | print("exiting")
314 | sys.exit(0)
315 | yield prompts
316 | else:
317 | yield [
318 | "Hello, my name is",
319 | ]
320 |
321 |
322 | def main(ckpt_dir: str, interactive: bool = False, chat_format: bool = False, sampling: bool = False):
323 |
324 | local_rank = 0
325 | device = f"cuda:{local_rank}"
326 | torch.cuda.set_device(local_rank)
327 |
328 | g = FastGen.build(ckpt_dir, GenArgs(), device)
329 |
330 | if chat_format:
331 | g.tokenizer = ChatFormat(g.tokenizer)
332 |
333 | for prompts in get_prompts(interactive):
334 | # prompts = [f"{prompt}\n" for prompt in prompts]
335 | if chat_format:
336 | # prompts = [f'<|begin_of_text|>User: {prompt}<|eot_id|>Assistant: ' for prompt in prompts]
337 | tokens = [g.tokenizer.encode_dialog_prompt(dialog=[{"role": "user", "content": prompt}], completion=True) for prompt in prompts]
338 | else:
339 | tokens = [g.tokenizer.encode(x, bos=False, eos=False) for x in prompts]
340 |
341 | print(tokens)
342 | stats, out_tokens = g.generate_all(
343 | tokens, use_cuda_graphs="NO_CUDA_GRAPHS" not in os.environ, use_sampling=sampling,
344 | )
345 |
346 | for i, prompt in enumerate(prompts):
347 | print(f"> {prompt}")
348 | answer = g.tokenizer.decode(out_tokens[i])
349 | print(answer)
350 | print("---------------")
351 |
352 | for phase_stats in stats.phases:
353 | print(phase_stats.show())
354 |
355 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
356 |
357 |
358 | if __name__ == "__main__":
359 | fire.Fire(main)
--------------------------------------------------------------------------------
/gpu/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2 | #
3 | # This source code is licensed under the BSD license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from dataclasses import dataclass
7 | from typing import Optional, Tuple, Union
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 |
13 | from xformers.ops import RMSNorm, fmha, rope_padded
14 | from xformers.ops.fmha.attn_bias import (
15 | BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
16 | )
17 |
18 | import ctypes
19 | bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so')
20 |
21 | def bitnet_int8xint2_linear(input0, input1, s, ws):
22 | out_shape = list(input0.shape)
23 | out_shape[-1] = input1.shape[0]
24 |
25 | stream = torch.cuda.current_stream()
26 |
27 | M = input0.shape[0]
28 | if len(out_shape) == 3:
29 | M *= input0.shape[1]
30 | N = input1.shape[0]
31 | K = input1.shape[1] * 4
32 |
33 | ret = torch.zeros(*out_shape, dtype=torch.bfloat16, device=input0.device)
34 |
35 | bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)])
36 |
37 | return ret
38 |
39 | @dataclass
40 | class ModelArgs:
41 | dim: int = 2560
42 | n_layers: int = 30
43 | n_heads: int = 20
44 | n_kv_heads: int = 5
45 | vocab_size: int = 128256
46 | ffn_dim: int = 6912
47 | norm_eps: float = 1e-5
48 | rope_theta: float = 500000.0
49 | use_kernel: bool = False
50 |
51 |
52 | LayerCache = Tuple[torch.Tensor, torch.Tensor]
53 |
54 | class BitLinearKernel(nn.Module):
55 | in_features: int
56 | out_features: int
57 | weight: torch.Tensor
58 | weight_scale: torch.Tensor
59 |
60 | def __init__(self, in_features: int, out_features: int, bias: bool = False):
61 | super().__init__()
62 | self.in_features = in_features
63 | self.out_features = out_features
64 |
65 | self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features//4, dtype=torch.int8), requires_grad=False)
66 | self.weight_scale = torch.nn.Parameter(torch.zeros(4, dtype=torch.bfloat16), requires_grad=False)
67 |
68 | @torch.compile
69 | def quant_input(self, input):
70 | s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
71 | return (input * s).round().clamp(-128, 127).to(torch.int8), s
72 |
73 | def forward(self, input):
74 | input, s = self.quant_input(input)
75 | return bitnet_int8xint2_linear(input, self.weight, s, self.weight_scale)
76 |
77 | class BitLinear(nn.Linear):
78 | @torch.compile
79 | def quant_input(self, input):
80 | s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
81 | return (input * s).round().clamp(-128, 127) / s
82 |
83 | def forward(self, input):
84 | input = self.quant_input(input)
85 | return F.linear(input, self.weight)
86 |
87 | class Attention(nn.Module):
88 | def __init__(
89 | self,
90 | dim: int,
91 | head_dim: int,
92 | n_heads: int,
93 | n_kv_heads: int,
94 | rope_theta: float,
95 | norm_eps: float,
96 | use_kernel: bool,
97 | ):
98 | super().__init__()
99 |
100 | self.head_dim = head_dim
101 | self.rope_theta = rope_theta
102 |
103 | self.n_local_heads = n_heads
104 | self.n_local_kv_heads = n_kv_heads
105 |
106 | Linear = BitLinearKernel if use_kernel else BitLinear
107 |
108 | self.wqkv = Linear(
109 | dim,
110 | (self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim,
111 | bias=False,
112 | )
113 | self.wo = Linear(
114 | self.n_local_heads * head_dim,
115 | dim,
116 | bias=False,
117 | )
118 |
119 | self.attn_sub_norm = RMSNorm(dim, norm_eps)
120 |
121 | def forward(
122 | self,
123 | x: torch.Tensor,
124 | cache: LayerCache,
125 | attn_bias: AttnBias,
126 | ) -> torch.Tensor:
127 |
128 | xqkv = self.wqkv(x)
129 | xq = xqkv[:, : (self.n_local_heads * self.head_dim)]
130 | xkv = xqkv[:, (self.n_local_heads * self.head_dim) :]
131 | xk, xv = xkv.chunk(2, 1)
132 |
133 | output_shape = xq.shape
134 | heads_per_group = self.n_local_heads // self.n_local_kv_heads
135 | xq = xq.view(
136 | 1, xq.shape[0], self.n_local_kv_heads, heads_per_group, self.head_dim
137 | )
138 | xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, 1, self.head_dim)
139 | # xq = rearrange(xq, 'b (g h l d) -> 1 b h g (d l)', g=heads_per_group, h=self.n_local_kv_heads, d=self.head_dim // 2, l=2)
140 | # xk = rearrange(xk, 'b (g l d) -> 1 b g 1 (d l)', g=self.n_local_kv_heads, d=self.head_dim // 2)
141 | xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, 1, self.head_dim)
142 | cache_k, cache_v = cache
143 |
144 | xq = rope_padded(
145 | xq=xq,
146 | xk=xk,
147 | xv=xv,
148 | cache_k=cache_k,
149 | cache_v=cache_v,
150 | attn_bias=attn_bias,
151 | theta=self.rope_theta,
152 | )
153 |
154 | output = fmha.memory_efficient_attention_forward(
155 | xq, cache_k, cache_v, attn_bias, op = fmha.flash.FwOp
156 | )
157 |
158 | output = output.reshape(output_shape)
159 | output = self.attn_sub_norm(output)
160 | output = self.wo(output)
161 |
162 | return output
163 |
164 | @torch.compile
165 | def squared_relu(x: torch.Tensor) -> torch.Tensor:
166 | return F.relu(x) ** 2
167 |
168 | class FeedForward(nn.Module):
169 | def __init__(
170 | self,
171 | dim: int,
172 | hidden_dim: int,
173 | norm_eps: float,
174 | use_kernel: bool,
175 | ):
176 | super().__init__()
177 |
178 | Linear = BitLinearKernel if use_kernel else BitLinear
179 |
180 | self.w13 = Linear(
181 | dim,
182 | 2 * hidden_dim,
183 | bias=False,
184 | )
185 | self.w2 = Linear(
186 | hidden_dim,
187 | dim,
188 | bias=False,
189 | )
190 | self.ffn_sub_norm = RMSNorm(hidden_dim, norm_eps)
191 |
192 | def forward(self, x: torch.Tensor) -> torch.Tensor:
193 | x13 = self.w13(x)
194 | x1, x3 = x13.chunk(2, -1)
195 | inner = self.ffn_sub_norm(squared_relu(x1) * x3)
196 | output = self.w2(inner)
197 | return output
198 |
199 |
200 | class TransformerBlock(nn.Module):
201 | def __init__(self, args: ModelArgs):
202 | super().__init__()
203 |
204 | assert args.dim % args.n_heads == 0
205 | head_dim = args.dim // args.n_heads
206 | if args.n_kv_heads is not None:
207 | n_kv_heads = args.n_kv_heads
208 | else:
209 | n_kv_heads = args.n_heads
210 |
211 | assert args.n_heads % n_kv_heads == 0
212 |
213 | self.attention = Attention(
214 | dim=args.dim,
215 | head_dim=head_dim,
216 | n_heads=args.n_heads,
217 | n_kv_heads=n_kv_heads,
218 | rope_theta=args.rope_theta,
219 | norm_eps=args.norm_eps,
220 | use_kernel=args.use_kernel,
221 | )
222 | self.feed_forward = FeedForward(
223 | dim=args.dim,
224 | hidden_dim=args.ffn_dim,
225 | norm_eps=args.norm_eps,
226 | use_kernel=args.use_kernel,
227 | )
228 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
229 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
230 |
231 | def forward(
232 | self,
233 | x: torch.Tensor,
234 | cache: LayerCache,
235 | attn_bias: AttnBias,
236 | ) -> torch.Tensor:
237 | h = x + self.attention.forward(
238 | self.attention_norm(x),
239 | cache,
240 | attn_bias,
241 | )
242 | out = h + self.feed_forward(self.ffn_norm(h))
243 | return out
244 |
245 |
246 | class Transformer(nn.Module):
247 | def __init__(self, args: ModelArgs):
248 | super().__init__()
249 | assert args.vocab_size > 0
250 |
251 | self.tok_embeddings = nn.Embedding(
252 | num_embeddings=args.vocab_size,
253 | embedding_dim=args.dim,
254 | )
255 |
256 | self.layers = nn.ModuleList()
257 | for _ in range(args.n_layers):
258 | self.layers.append(TransformerBlock(args))
259 |
260 | self.norm = RMSNorm(args.dim, eps=args.norm_eps)
261 |
262 | self.output = nn.Linear(
263 | args.dim,
264 | args.vocab_size,
265 | bias=False,
266 | )
267 |
268 | @torch.no_grad()
269 | def forward_with_attn_bias(
270 | self,
271 | token_values: torch.Tensor,
272 | attn_bias: AttnBias,
273 | cache: list[LayerCache],
274 | ) -> torch.Tensor:
275 | h = self.tok_embeddings(token_values)
276 |
277 | for i, layer in enumerate(self.layers):
278 | h = layer(h, cache[i], attn_bias)
279 |
280 | logits = self.output(self.norm(h))
281 | return logits.float()
282 |
283 | def forward(
284 | self,
285 | token_values: torch.Tensor,
286 | token_lengths: torch.Tensor,
287 | start_pos: torch.Tensor,
288 | cache: list[LayerCache],
289 | kv_padding: int,
290 | ) -> torch.Tensor:
291 | attn_bias = AttnBias.from_seqlens(
292 | q_seqlen=token_lengths.tolist(),
293 | kv_seqlen=(start_pos + token_lengths).tolist(),
294 | kv_padding=kv_padding,
295 | )
296 | return self.forward_with_attn_bias(token_values, attn_bias, cache)
297 |
298 |
299 | def make_cache(
300 | args: ModelArgs,
301 | length: int,
302 | device: Optional[Union[str, torch.device]] = None,
303 | n_layers: Optional[int] = None,
304 | dtype: Optional[torch.dtype] = None,
305 | ) -> list[LayerCache]:
306 | """
307 | Allocate a cache to be used with the Transformer module.
308 |
309 | Args:
310 | args (ModelArgs): the model configuration.
311 | length (int): per layer cache size.
312 | It is usually budgeted as ``max_batch * max_seq``
313 | device (torch.device, optional): the device on which
314 | the cache should be allocated.
315 | n_layers (int, optional): the number of layers to
316 | allocate a cache for (defaults to the model
317 | settings).
318 | dtype (torch.dtype, optional): the dtype to use for
319 | cache entries (defaults to the default dtype).
320 |
321 | Returns:
322 | The cache object to pass to ``Tranformer.forward``.
323 | """
324 |
325 | head_dim = args.dim // args.n_heads
326 | n_kv_heads = args.n_kv_heads
327 | if n_kv_heads is None:
328 | n_kv_heads = args.n_heads
329 | n_local_kv_heads = n_kv_heads
330 |
331 | if n_layers is None:
332 | n_layers = args.n_layers
333 |
334 | shape = (1, length, n_local_kv_heads, 1, head_dim)
335 | heads_per_group = args.n_heads // n_kv_heads
336 | expansion = (-1, -1, -1, heads_per_group, -1)
337 | return [
338 | (
339 | torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
340 | torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
341 | )
342 | for _ in range(n_layers)
343 | ]
344 |
345 |
346 | def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]:
347 | """
348 | Take a prefix view of a larger cache.
349 |
350 | The original cache object remains of identical size and valid
351 | after the shrinked alias has been used. This function is useful
352 | when a cache was allocated for a larger batch size than what is
353 | necessary.
354 |
355 | Args:
356 | cache: the cache to take a view in.
357 | length (int): the desired length
358 |
359 | Returns:
360 | A view in the input cache object.
361 | """
362 |
363 | if len(cache) > 0:
364 | assert cache[0][0].shape[1] >= length
365 |
366 | return [(ck[:, :length], cv[:, :length]) for ck, cv in cache]
--------------------------------------------------------------------------------
/gpu/pack_weight.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def B_global_16x32_to_shared_load_16x32_layout(i, j):
6 | """
7 | stride * 8 * (tx // HALF_WARP_expr)
8 | + (tx % 8) * stride
9 | + 16 * ((tx % HALF_WARP_expr) // 8)
10 | """
11 | thread_id = i * 2 + j // 16
12 | row = (thread_id // 16) * 8 + (thread_id % 8)
13 | col = (j % 16) + 16 * ((thread_id % 16) // 8)
14 | return row, col
15 |
16 |
17 | def permutate_weight_fastest(weight):
18 | wmma_n = 16
19 | wmma_k = 32
20 | N = weight.shape[0]
21 | K = weight.shape[1]
22 |
23 | # Create a lookup table for the permutation
24 | mapping = np.zeros((wmma_n, wmma_k, 2), dtype=int)
25 | for ii in range(wmma_n):
26 | for jj in range(wmma_k):
27 | mapping[ii, jj] = B_global_16x32_to_shared_load_16x32_layout(ii, jj)
28 |
29 | # Reshape weight for the final format
30 | permutated_weight = np.zeros((N // wmma_n, K // wmma_k, wmma_n, wmma_k), dtype="int8")
31 |
32 | # Use advanced indexing for the entire operation
33 | i_indices = np.arange(N // wmma_n)[:, np.newaxis, np.newaxis, np.newaxis]
34 | j_indices = np.arange(K // wmma_k)[np.newaxis, :, np.newaxis, np.newaxis]
35 |
36 | # Create the source indices
37 | src_i = i_indices * wmma_n + mapping[:, :, 0]
38 | src_j = j_indices * wmma_k + mapping[:, :, 1]
39 |
40 | # Extract and reshape in one go
41 | permutated_weight = weight[src_i, src_j]
42 |
43 | return permutated_weight
44 |
45 |
46 | def compress_int2_to_int8(int2_weight):
47 | int8_weight = np.zeros(
48 | (*int2_weight.shape[:-1], int2_weight.shape[-1] // 4), dtype=np.int8
49 | )
50 | for j in range(int2_weight.shape[-1] // 4):
51 | for k in range(4):
52 | int8_weight[:, :, :, j] |= int2_weight[:, :, :, j * 4 + k] << (k * 2)
53 | return int8_weight
54 |
55 |
56 | def interleave_weight_int8(qweight, nbits=2):\
57 | # reinterpret the data type of qweight to int32
58 | # shift = [ 0, 8, 16, 24, 2, 10, 18, 26, 4, 12, 20, 28, 6, 14, 22, 30]
59 | # index: [ 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
60 | qweight = qweight.view(np.int32)
61 | new_qweight = np.zeros_like(qweight)
62 | bits_stride = 8
63 | mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f
64 | num_groups = 32 // bits_stride # 4
65 | elems_per_group = bits_stride // nbits # 4
66 | for i in range(num_groups):
67 | for j in range(elems_per_group):
68 | offset = i * elems_per_group + j
69 | shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
70 |
71 | new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
72 | return new_qweight.view(np.int8)
73 |
74 |
75 |
76 | def convert_weight_int8_to_int2(weight):
77 | N = weight.shape[0]
78 | K = weight.shape[1]
79 |
80 | weight = weight+2
81 |
82 | weight = weight.cpu().numpy()
83 |
84 | # print(weight)
85 | # print(torch.max(weight), torch.min(weight))
86 |
87 | # permutated_weight_slow = permutate_weight(weight)
88 | permutated_weight = permutate_weight_fastest(weight)
89 | # assert np.all(permutated_weight_slow == permutated_weight)
90 | # print("Permutation is correct")
91 | compressed_weight = compress_int2_to_int8(permutated_weight)
92 | interleaved_weight = interleave_weight_int8(compressed_weight, 2)
93 |
94 | ret = torch.from_numpy(interleaved_weight)
95 |
96 | ret = torch.reshape(ret, (N, K // 4))
97 |
98 | return ret
99 |
--------------------------------------------------------------------------------
/gpu/requirements.txt:
--------------------------------------------------------------------------------
1 | fire
2 | sentencepiece
3 | torch>=2.2.0
4 | xformers>=0.0.22
5 | tiktoken
6 | blobfile
7 | flask
8 | einops
9 | transformers
--------------------------------------------------------------------------------
/gpu/sample_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2 | #
3 | # This source code is licensed under the BSD license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 |
8 | @torch.compile
9 | def top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
10 | """
11 | Perform top-p (nucleus) sampling on a probability distribution.
12 |
13 | Args:
14 | probs (torch.Tensor): probability distribution tensor.
15 | p (float): probability threshold for top-p sampling.
16 |
17 | Returns:
18 | torch.Tensor: sampled token indices.
19 |
20 | Note:
21 | Top-p sampling selects the smallest set of tokens whose cumulative
22 | probability mass exceeds the threshold p. The distribution is
23 | renormalized based on the selected tokens.
24 | """
25 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
26 | probs_sum = torch.cumsum(probs_sort, dim=-1)
27 | mask = probs_sum - probs_sort > p
28 | probs_sort[mask] = 0.0
29 | next_token = torch.multinomial(probs_sort, num_samples=1)
30 | next_token = torch.gather(probs_idx, -1, next_token)
31 | return next_token
--------------------------------------------------------------------------------
/gpu/stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2 | #
3 | # This source code is licensed under the BSD license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import time
7 | from dataclasses import dataclass
8 | from typing import Optional
9 |
10 |
11 | @dataclass
12 | class PhaseStats:
13 | name: str
14 | tokens: int
15 | time: float
16 |
17 | def show(self) -> str:
18 | tps = self.tokens / self.time
19 | return (
20 | f"[{self.name}] "
21 | f"generated tokens: {self.tokens}"
22 | f" - total time: {self.time:.3f}s"
23 | f" - {tps:.1f} tokens per second"
24 | )
25 |
26 |
27 | class Stats:
28 | """
29 | Generation stats, split by phases.
30 | """
31 |
32 | def __init__(self):
33 | self.phases = []
34 | self.current = None
35 |
36 | def end_phase(self, tokens: int, now: Optional[float] = None):
37 | """Terminate the current phase."""
38 | if self.current is None:
39 | return
40 | if now is None:
41 | now = time.time()
42 | cname, ctokens, ctime = self.current
43 | stats = PhaseStats(
44 | name=cname,
45 | tokens=tokens - ctokens,
46 | time=now - ctime,
47 | )
48 | self.phases.append(stats)
49 |
50 | def phase(self, name: str, tokens: int = 0):
51 | """
52 | Start a new phase, and terminate the current one,
53 | if one is ongoing.
54 | """
55 | now = time.time()
56 | self.end_phase(tokens, now)
57 | self.current = (name, tokens, now)
--------------------------------------------------------------------------------
/gpu/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils import benchmark
3 | from torch import nn
4 |
5 | from pack_weight import convert_weight_int8_to_int2
6 | from torch.profiler import profile, record_function, ProfilerActivity
7 | import ctypes
8 | import numpy as np
9 | # set all seed
10 | torch.manual_seed(42)
11 | np.random.seed(42)
12 |
13 | bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so')
14 |
15 | def bitnet_int8xint2_linear(input0, input1, s, ws, ret):
16 | out_shape = list(input0.shape)
17 | out_shape[-1] = input1.shape[0]
18 |
19 | stream = torch.cuda.current_stream()
20 |
21 | M = input0.shape[0]
22 | if len(out_shape) == 3:
23 | M *= input0.shape[1]
24 | N = input1.shape[0]
25 | K = input1.shape[1] * 4
26 |
27 | bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)])
28 |
29 | return ret
30 |
31 | if __name__ == '__main__':
32 | test_list = [
33 | (2560, 2560),
34 | (3840, 2560),
35 | (13824, 2560),
36 | (2560, 6912) ,
37 | (3200, 3200),
38 | (4800, 3200),
39 | (3200, 10240),
40 | (20480, 3200),
41 | ]
42 | for N,K in test_list:
43 | weight = torch.randint(-1, 2, (N, K), dtype=torch.int8, device='cuda')
44 | weight_scale = torch.ones(1, dtype=torch.bfloat16, device='cuda')
45 | weight_compressed = convert_weight_int8_to_int2(weight).to('cuda')
46 |
47 | for i in range(1):
48 | input0 = torch.randint(-128,127,(1, K),dtype=torch.int8, device='cuda')
49 | input0_bf16 = input0.to(torch.bfloat16)
50 | input_np = input0.cpu().to(torch.int32).numpy()
51 | weight_np = weight.cpu().to(torch.int32).T.numpy()
52 | out_np = np.matmul(input_np,weight_np)
53 | out_np = torch.tensor(out_np).cuda().to(torch.bfloat16)
54 |
55 | s = torch.ones(1, dtype=torch.bfloat16, device='cuda')
56 | ws = torch.ones(6, dtype=torch.bfloat16, device='cuda')
57 |
58 | ret = torch.empty((1,N), dtype=torch.bfloat16, device=input0.device)
59 | out = bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)
60 |
61 | print(f'custom == np {torch.all(out==out_np)}')
62 |
63 | input0 = torch.randint(-128,127,(1, K),dtype=torch.int8, device='cuda')
64 | input0_fp16 = input0.to(torch.float16)
65 | input0_bf16 = input0.to(torch.bfloat16)
66 | weight_fp16 = weight.to(torch.float16).T
67 | weight_bf16 = weight.to(torch.bfloat16).T
68 | ret = torch.empty((1,N), dtype=torch.bfloat16, device=input0.device)
69 | s = torch.ones(1, dtype=torch.bfloat16, device='cuda')
70 | ws = torch.ones(6, dtype=torch.bfloat16, device='cuda')
71 | t0 = benchmark.Timer(
72 | stmt="bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)",
73 | setup="from __main__ import input0, weight_compressed, s, ws, ret, bitnet_int8xint2_linear",
74 | num_threads=1,
75 | )
76 |
77 | t1 = benchmark.Timer(
78 | stmt="torch.matmul(input0_bf16,weight_bf16)",
79 | setup="from __main__ import input0_bf16, weight_bf16",
80 | num_threads=1,
81 | )
82 |
83 | time0 = t0.timeit(50)
84 | time1 = t1.timeit(50)
85 |
86 | print(f'Shape{N,K}, W2A8: {time0.mean * 1e6:.2f}us, torch BF16: {time1.mean * 1e6:.2f}us')
87 | # activities = [ ProfilerActivity.CUDA,
88 | # # ProfilerActivity.CPU
89 | # ]
90 | # sort_by_keyword = 'cuda' + "_time_total"
91 | # with profile(activities=activities, record_shapes=True) as prof:
92 | # with record_function("model_inference1"):
93 | # for _ in range(10):
94 | # bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)
95 | # torch.matmul(input0_fp16,weight_fp16)
96 | # torch.matmul(input0_bf16,weight_bf16)
97 |
98 | # print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=15))
99 |
100 |
--------------------------------------------------------------------------------
/gpu/tokenizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from logging import getLogger
3 | from pathlib import Path
4 | from typing import (
5 | AbstractSet,
6 | cast,
7 | Collection,
8 | Dict,
9 | Iterator,
10 | List,
11 | Literal,
12 | Sequence,
13 | TypedDict,
14 | Union,
15 | )
16 |
17 | import tiktoken
18 | from tiktoken.load import load_tiktoken_bpe
19 |
20 |
21 | logger = getLogger(__name__)
22 |
23 | Role = Literal["system", "user", "assistant"]
24 |
25 |
26 | class Message(TypedDict):
27 | role: Role
28 | content: str
29 |
30 |
31 | Dialog = Sequence[Message]
32 |
33 |
34 | class Tokenizer:
35 | """
36 | Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
37 | """
38 |
39 | special_tokens: Dict[str, int]
40 |
41 | num_reserved_special_tokens = 256
42 |
43 | pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
44 |
45 | def __init__(self, model_path: str):
46 | """
47 | Initializes the Tokenizer with a Tiktoken model.
48 |
49 | Args:
50 | model_path (str): The path to the Tiktoken model file.
51 | """
52 | assert os.path.isfile(model_path), model_path
53 |
54 | mergeable_ranks = load_tiktoken_bpe(model_path)
55 | num_base_tokens = len(mergeable_ranks)
56 | special_tokens = [
57 | "<|begin_of_text|>",
58 | "<|end_of_text|>",
59 | "<|reserved_special_token_0|>",
60 | "<|reserved_special_token_1|>",
61 | "<|reserved_special_token_2|>",
62 | "<|reserved_special_token_3|>",
63 | "<|start_header_id|>",
64 | "<|end_header_id|>",
65 | "<|reserved_special_token_4|>",
66 | "<|eot_id|>", # end of turn
67 | ] + [
68 | f"<|reserved_special_token_{i}|>"
69 | for i in range(5, self.num_reserved_special_tokens - 5)
70 | ]
71 | self.special_tokens = {
72 | token: num_base_tokens + i for i, token in enumerate(special_tokens)
73 | }
74 | self.model = tiktoken.Encoding(
75 | name=Path(model_path).name,
76 | pat_str=self.pat_str,
77 | mergeable_ranks=mergeable_ranks,
78 | special_tokens=self.special_tokens,
79 | )
80 | logger.info(f"Reloaded tiktoken model from {model_path}")
81 |
82 | self.n_words: int = self.model.n_vocab
83 | # BOS / EOS token IDs
84 | self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
85 | self.eos_id: int = self.special_tokens["<|end_of_text|>"]
86 | self.pad_id: int = self.n_words - 1
87 | self.stop_tokens = {
88 | self.special_tokens["<|end_of_text|>"],
89 | self.special_tokens["<|eot_id|>"],
90 | }
91 | logger.info(
92 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
93 | )
94 |
95 | def encode(
96 | self,
97 | s: str,
98 | *,
99 | bos: bool,
100 | eos: bool,
101 | allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
102 | disallowed_special: Union[Literal["all"], Collection[str]] = (),
103 | ) -> List[int]:
104 | """
105 | Encodes a string into a list of token IDs.
106 |
107 | Args:
108 | s (str): The input string to be encoded.
109 | bos (bool): Whether to prepend the beginning-of-sequence token.
110 | eos (bool): Whether to append the end-of-sequence token.
111 | allowed_tokens ("all"|set[str]): allowed special tokens in string
112 | disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
113 |
114 | Returns:
115 | list[int]: A list of token IDs.
116 |
117 | By default, setting disallowed_special=() encodes a string by ignoring
118 | special tokens. Specifically:
119 | - Setting `disallowed_special` to () will cause all text corresponding
120 | to special tokens to be encoded as natural text (insteading of raising
121 | an error).
122 | - Setting `allowed_special` to "all" will treat all text corresponding
123 | to special tokens to be encoded as special tokens.
124 | """
125 | assert type(s) is str
126 |
127 | # The tiktoken tokenizer can handle <=400k chars without
128 | # pyo3_runtime.PanicException.
129 | TIKTOKEN_MAX_ENCODE_CHARS = 400_000
130 |
131 | # https://github.com/openai/tiktoken/issues/195
132 | # Here we iterate over subsequences and split if we exceed the limit
133 | # of max consecutive non-whitespace or whitespace characters.
134 | MAX_NO_WHITESPACES_CHARS = 25_000
135 |
136 | substrs = (
137 | substr
138 | for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
139 | for substr in self._split_whitespaces_or_nonwhitespaces(
140 | s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
141 | )
142 | )
143 | t: List[int] = []
144 | for substr in substrs:
145 | t.extend(
146 | self.model.encode(
147 | substr,
148 | allowed_special=allowed_special,
149 | disallowed_special=disallowed_special,
150 | )
151 | )
152 | if bos:
153 | t.insert(0, self.bos_id)
154 | if eos:
155 | t.append(self.eos_id)
156 | return t
157 |
158 | def decode(self, t: Sequence[int]) -> str:
159 | """
160 | Decodes a list of token IDs into a string.
161 |
162 | Args:
163 | t (List[int]): The list of token IDs to be decoded.
164 |
165 | Returns:
166 | str: The decoded string.
167 | """
168 | # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
169 | return self.model.decode(cast(List[int], t))
170 |
171 | @staticmethod
172 | def _split_whitespaces_or_nonwhitespaces(
173 | s: str, max_consecutive_slice_len: int
174 | ) -> Iterator[str]:
175 | """
176 | Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
177 | consecutive whitespaces or consecutive non-whitespaces.
178 | """
179 | current_slice_len = 0
180 | current_slice_is_space = s[0].isspace() if len(s) > 0 else False
181 | slice_start = 0
182 |
183 | for i in range(len(s)):
184 | is_now_space = s[i].isspace()
185 |
186 | if current_slice_is_space ^ is_now_space:
187 | current_slice_len = 1
188 | current_slice_is_space = is_now_space
189 | else:
190 | current_slice_len += 1
191 | if current_slice_len > max_consecutive_slice_len:
192 | yield s[slice_start:i]
193 | slice_start = i
194 | current_slice_len = 1
195 | yield s[slice_start:]
196 |
197 | class ChatFormat:
198 | def __init__(self, tokenizer: Tokenizer):
199 | self.tokenizer = tokenizer
200 | self.eot_id = tokenizer.special_tokens["<|eot_id|>"]
201 |
202 | def decode(self, tokens: List[int]) -> str:
203 | # Decode the tokens to a string.
204 | decoded_str = self.tokenizer.decode(tokens)
205 | # Remove the special tokens from the decoded string.
206 | decoded_str = decoded_str.replace("<|eot_id|>", "")
207 | return decoded_str
208 |
209 | def encode_header(self, message: Message) -> List[int]:
210 | tokens = []
211 | if message["role"] == "system":
212 | tokens.extend(self.tokenizer.encode("System: ", bos=False, eos=False))
213 | elif message["role"] == "user":
214 | tokens.extend(self.tokenizer.encode("User: ", bos=False, eos=False))
215 | elif message["role"] == "assistant":
216 | tokens.extend(self.tokenizer.encode("Assistant: ", bos=False, eos=False))
217 | else:
218 | raise NotImplementedError(f"Role {message['role']} not implemented.")
219 | # tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
220 | # tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
221 | # tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
222 | # tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
223 | return tokens
224 |
225 | def encode_message(self, message: Message, return_target=False) -> List[int]:
226 | tokens, targets = [], []
227 | headers = self.encode_header(message)
228 | contents = self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
229 | contents.append(self.tokenizer.special_tokens["<|eot_id|>"])
230 | tokens = headers + contents
231 |
232 | if message["role"] == "assistant":
233 | targets = [-1] * len(headers) + contents
234 | else:
235 | targets = [-1] * len(tokens)
236 |
237 | if return_target:
238 | return tokens, targets
239 |
240 | return tokens, None
241 |
242 | def encode_dialog_prompt(self, dialog: Dialog, completion=False, return_target=False) -> List[int]:
243 | tokens = [self.tokenizer.special_tokens["<|begin_of_text|>"]]
244 | targets = [-1]
245 | for message in dialog:
246 | _tokens, _targets = self.encode_message(message, return_target=return_target)
247 | tokens.extend(_tokens)
248 | if _targets is not None:
249 | targets.extend(_targets)
250 | # Add the start of an assistant message for the model to complete.
251 | if completion:
252 | tokens.extend(self.encode_header({"role": "assistant", "content": ""}))
253 |
254 | if return_target:
255 | return tokens, targets
256 |
257 | return tokens
--------------------------------------------------------------------------------
/include/ggml-bitnet.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ggml.h"
4 | #include "ggml-backend.h"
5 |
6 | #ifdef __ARM_NEON
7 | #include
8 | typedef float32_t bitnet_float_type;
9 | #else
10 | typedef float bitnet_float_type;
11 | #endif
12 |
13 | #ifdef __cplusplus
14 | extern "C" {
15 | #endif
16 |
17 | struct bitnet_tensor_extra {
18 | int lut_scales_size;
19 | int BK;
20 | int n_tile_num;
21 | uint8_t * qweights;
22 | bitnet_float_type * scales;
23 | };
24 |
25 | GGML_API void ggml_bitnet_init(void);
26 | GGML_API void ggml_bitnet_free(void);
27 | // src0->type == Q4_0/IQ2_XXS/IQ3_XXS
28 | // bitnet.cpp currently only supports BitNet quantization or GPTQ-like quantization (only scales, without zeros)
29 | // If use i-quantization gguf models, the results will be wrong
30 | // TODO: add customized block types Q2_0/Q3_0
31 | GGML_API bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
32 | GGML_API size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
33 | GGML_API void ggml_bitnet_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits);
34 | GGML_API void ggml_bitnet_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits);
35 | GGML_API void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor);
36 | GGML_API int ggml_bitnet_get_type_bits(enum ggml_type type);
37 | GGML_API void ggml_bitnet_set_n_threads(int n_threads);
38 | #if defined(GGML_BITNET_ARM_TL1)
39 | GGML_API void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C);
40 | GGML_API void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT);
41 | #endif
42 | #if defined(GGML_BITNET_X86_TL2)
43 | GGML_API void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C);
44 | GGML_API void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT);
45 | #endif
46 |
47 | #ifdef __cplusplus
48 | }
49 | #endif
50 |
--------------------------------------------------------------------------------
/media/benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/media/benchmark.png
--------------------------------------------------------------------------------
/media/demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/media/demo.mp4
--------------------------------------------------------------------------------
/preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl1.ini:
--------------------------------------------------------------------------------
1 | [Kernels_0]
2 | m = 14336
3 | k = 4096
4 | bm = 256
5 | bk = 128
6 | bmm = 64
7 |
8 | [Kernels_1]
9 | m = 4096
10 | k = 14336
11 | bm = 256
12 | bk = 128
13 | bmm = 32
14 |
15 | [Kernels_2]
16 | m = 1024
17 | k = 4096
18 | bm = 128
19 | bk = 64
20 | bmm = 64
21 |
22 | [Kernels_3]
23 | m = 4096
24 | k = 4096
25 | bm = 128
26 | bk = 64
27 | bmm = 32
28 |
29 |
--------------------------------------------------------------------------------
/preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl2.ini:
--------------------------------------------------------------------------------
1 | [Kernels_0]
2 | m = 14336
3 | k = 4096
4 | bm = 256
5 | bk = 96
6 | bmm = 32
7 |
8 | [Kernels_1]
9 | m = 4096
10 | k = 14336
11 | bm = 128
12 | bk = 96
13 | bmm = 32
14 |
15 | [Kernels_2]
16 | m = 1024
17 | k = 4096
18 | bm = 256
19 | bk = 96
20 | bmm = 32
21 |
22 | [Kernels_3]
23 | m = 4096
24 | k = 4096
25 | bm = 128
26 | bk = 96
27 | bmm = 32
28 |
29 |
--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-3B/bitnet-lut-kernels-tl1.h:
--------------------------------------------------------------------------------
1 | #if defined(GGML_BITNET_ARM_TL1)
2 | #include "ggml-bitnet.h"
3 | #define GGML_BITNET_MAX_NODES 8192
4 | static bool initialized = false;
5 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;
6 | static size_t bitnet_tensor_extras_index = 0;
7 | static void * aligned_malloc(size_t size) {{
8 | #if defined(_WIN32)
9 | return _aligned_malloc(size, 64);
10 | #else
11 | void * ptr = nullptr;
12 | posix_memalign(&ptr, 64, size);
13 | return ptr;
14 | #endif
15 | }}
16 | static void aligned_free(void * ptr) {{
17 | #if defined(_WIN32)
18 | _aligned_free(ptr);
19 | #else
20 | free(ptr);
21 | #endif
22 | }}
23 |
24 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{
25 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
26 | bitnet_float_type* b = (bitnet_float_type*)b_;
27 | #ifdef __ARM_NEON
28 | float32x4_t temp_max = vdupq_n_f32(0);
29 | for (int i=0; i < k / 4; i++) {{
30 | float32x4_t vec_bs = vld1q_f32(b + 4 * i);
31 | float32x4_t abssum = vabsq_f32(vec_bs);
32 | temp_max = vmaxq_f32(abssum, temp_max);
33 | }}
34 | float32_t scales = 127 / vmaxvq_f32(temp_max);
35 | *lut_scales = scales;
36 | #elif defined __AVX2__
37 | __m256 max_vec = _mm256_set1_ps(0.f);
38 | const __m256 vec_sign = _mm256_set1_ps(-0.0f);
39 | // #pragma unroll
40 | for (int i = 0; i < k / 8; i++) {{
41 | __m256 vec_b = _mm256_loadu_ps(b + i * 8);
42 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);
43 | max_vec = _mm256_max_ps(vec_babs, max_vec);
44 | }}
45 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));
46 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));
47 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));
48 | float scales = 127 / _mm_cvtss_f32(max1);
49 | *lut_scales = scales;
50 | #endif
51 | }}
52 |
53 | void partial_max_reset(void* lut_scales_) {{
54 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
55 | *lut_scales = 0.0;
56 | }}
57 |
58 | #ifdef __ARM_NEON
59 | inline void Transpose_8_8(
60 | int16x8_t *v0,
61 | int16x8_t *v1,
62 | int16x8_t *v2,
63 | int16x8_t *v3,
64 | int16x8_t *v4,
65 | int16x8_t *v5,
66 | int16x8_t *v6,
67 | int16x8_t *v7)
68 | {{
69 | int16x8x2_t q04 = vzipq_s16(*v0, *v4);
70 | int16x8x2_t q15 = vzipq_s16(*v1, *v5);
71 | int16x8x2_t q26 = vzipq_s16(*v2, *v6);
72 | int16x8x2_t q37 = vzipq_s16(*v3, *v7);
73 |
74 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);
75 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);
76 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);
77 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);
78 |
79 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);
80 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);
81 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);
82 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);
83 |
84 | *v0 = q_fin_0.val[0];
85 | *v1 = q_fin_0.val[1];
86 | *v2 = q_fin_1.val[0];
87 | *v3 = q_fin_1.val[1];
88 | *v4 = q_fin_2.val[0];
89 | *v5 = q_fin_2.val[1];
90 | *v6 = q_fin_3.val[0];
91 | *v7 = q_fin_3.val[1];
92 | }}
93 | #endif
94 |
95 | template
96 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{
97 | #ifdef __ARM_NEON
98 | int16x8_t vec_lut[16];
99 | float32_t scales = *lut_scales;
100 | uint8_t tbl_mask[16];
101 | tbl_mask[0] = 0;
102 | tbl_mask[1] = 2;
103 | tbl_mask[2] = 4;
104 | tbl_mask[3] = 6;
105 | tbl_mask[4] = 8;
106 | tbl_mask[5] = 10;
107 | tbl_mask[6] = 12;
108 | tbl_mask[7] = 14;
109 | tbl_mask[8] = 1;
110 | tbl_mask[9] = 3;
111 | tbl_mask[10] = 5;
112 | tbl_mask[11] = 7;
113 | tbl_mask[12] = 9;
114 | tbl_mask[13] = 11;
115 | tbl_mask[14] = 13;
116 | tbl_mask[15] = 15;
117 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);
118 | #pragma unroll
119 | for (int k = 0; k < act_k / 16; ++k) {{
120 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);
121 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);
122 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);
123 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);
124 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);
125 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);
126 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);
127 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);
128 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);
129 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);
130 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);
131 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);
132 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);
133 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);
134 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);
135 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);
136 | vec_lut[0] = vdupq_n_s16(0);
137 | vec_lut[0] = vec_lut[0] - vec_bs_0;
138 | vec_lut[0] = vec_lut[0] - vec_bs_1;
139 | vec_lut[1] = vdupq_n_s16(0);
140 | vec_lut[1] = vec_lut[1] - vec_bs_0;
141 | vec_lut[2] = vdupq_n_s16(0);
142 | vec_lut[2] = vec_lut[2] - vec_bs_0;
143 | vec_lut[2] = vec_lut[2] + vec_bs_1;
144 | vec_lut[3] = vdupq_n_s16(0);
145 | vec_lut[3] = vec_lut[3] - vec_bs_1;
146 | vec_lut[4] = vdupq_n_s16(0);
147 | vec_lut[5] = vec_bs_1;
148 | vec_lut[6] = vec_bs_0;
149 | vec_lut[6] = vec_lut[6] - vec_bs_1;
150 | vec_lut[7] = vec_bs_0;
151 | vec_lut[8] = vec_bs_0;
152 | vec_lut[8] = vec_lut[8] + vec_bs_1;
153 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),
154 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));
155 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),
156 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));
157 | #pragma unroll
158 | for (int idx = 0; idx < 8; idx++) {{
159 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);
160 | int8x8_t q0_low = vget_low_s8(q0_s);
161 | int8x8_t q0_high = vget_high_s8(q0_s);
162 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);
163 | int8x8_t q1_low = vget_low_s8(q1_s);
164 | int8x8_t q1_high = vget_high_s8(q1_s);
165 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);
166 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);
167 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);
168 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);
169 | }}
170 | }}
171 | #endif
172 | }}
173 |
174 | static bool is_type_supported(enum ggml_type type) {{
175 | if (type == GGML_TYPE_Q4_0 ||
176 | type == GGML_TYPE_TL1) {{
177 | return true;
178 | }} else {{
179 | return false;
180 | }}
181 | }}
182 | #include
183 |
184 | #define BM3200_8640 160
185 | #define BBK3200_8640 64
186 | inline void tbl_impl_3200_8640(int32_t* c, int8_t* lut, uint8_t* a) {
187 | #ifdef __ARM_NEON
188 | const int KK = BBK3200_8640 / 2;
189 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
190 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
191 | int8x16_t vec_lut[2 * KK];
192 | int16x8_t vec_c[4];
193 | #pragma unroll
194 | for (int k = 0; k < 2 * KK; k++) {
195 | vec_lut[k] = vld1q_s8(lut + k * 16);
196 | }
197 |
198 | #pragma unroll
199 | for (int i = 0; i < BM3200_8640; i += 32) {
200 | #pragma unroll
201 | for (int i=0; i<4; i++) {
202 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
203 | }
204 |
205 | #pragma unroll
206 | for (int k = 0; k < KK / 4; k++) {
207 |
208 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
209 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
210 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
211 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
212 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
213 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
214 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
215 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
216 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
217 | vec_c[0] += vec_v_left_0.val[0];
218 | vec_c[0] += vec_v_right_0.val[0];
219 | vec_c[1] += vec_v_left_0.val[1];
220 | vec_c[1] += vec_v_right_0.val[1];
221 |
222 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
223 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
224 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
225 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
226 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
227 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
228 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
229 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
230 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
231 | vec_c[0] += vec_v_left_1.val[0];
232 | vec_c[0] += vec_v_right_1.val[0];
233 | vec_c[1] += vec_v_left_1.val[1];
234 | vec_c[1] += vec_v_right_1.val[1];
235 |
236 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
237 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
238 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
239 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
240 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
241 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
242 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
243 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
244 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
245 | vec_c[2] += vec_v_left_2.val[0];
246 | vec_c[2] += vec_v_right_2.val[0];
247 | vec_c[3] += vec_v_left_2.val[1];
248 | vec_c[3] += vec_v_right_2.val[1];
249 |
250 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
251 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
252 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
253 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
254 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
255 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
256 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
257 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
258 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
259 | vec_c[2] += vec_v_left_3.val[0];
260 | vec_c[2] += vec_v_right_3.val[0];
261 | vec_c[3] += vec_v_left_3.val[1];
262 | vec_c[3] += vec_v_right_3.val[1];
263 |
264 | }
265 |
266 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
267 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
268 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
269 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
270 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
271 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
272 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
273 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
274 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
275 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
276 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
277 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
278 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
279 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
280 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
281 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
282 |
283 | }
284 | #endif
285 | }
286 |
287 | int32_t qgemm_lut_3200_8640(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
288 | alignas(32) uint32_t CBits[BM3200_8640];
289 | memset(&(CBits[0]), 0, BM3200_8640 * sizeof(int32_t));
290 | #pragma unroll
291 | for (int32_t k_outer = 0; k_outer < 8640 / BBK3200_8640; ++k_outer) {
292 | tbl_impl_3200_8640((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_8640 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_8640 / 2 / 2 * BM3200_8640)])));
293 | }
294 | #pragma unroll
295 | for (int i = 0; i < BM3200_8640; i++) {
296 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
297 | }
298 | return 0;
299 | };
300 | #include
301 |
302 | #define BM3200_3200 320
303 | #define BBK3200_3200 128
304 | inline void tbl_impl_3200_3200(int32_t* c, int8_t* lut, uint8_t* a) {
305 | #ifdef __ARM_NEON
306 | const int KK = BBK3200_3200 / 2;
307 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
308 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
309 | int8x16_t vec_lut[2 * KK];
310 | int16x8_t vec_c[8];
311 | #pragma unroll
312 | for (int k = 0; k < 2 * KK; k++) {
313 | vec_lut[k] = vld1q_s8(lut + k * 16);
314 | }
315 |
316 | #pragma unroll
317 | for (int i = 0; i < BM3200_3200; i += 64) {
318 | #pragma unroll
319 | for (int i=0; i<8; i++) {
320 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
321 | }
322 |
323 | #pragma unroll
324 | for (int k = 0; k < KK / 2; k++) {
325 |
326 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
327 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
328 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
329 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top);
330 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top);
331 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot);
332 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot);
333 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
334 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
335 | vec_c[0] += vec_v_left_0.val[0];
336 | vec_c[0] += vec_v_right_0.val[0];
337 | vec_c[1] += vec_v_left_0.val[1];
338 | vec_c[1] += vec_v_right_0.val[1];
339 |
340 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
341 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
342 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
343 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top);
344 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top);
345 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot);
346 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot);
347 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
348 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
349 | vec_c[2] += vec_v_left_1.val[0];
350 | vec_c[2] += vec_v_right_1.val[0];
351 | vec_c[3] += vec_v_left_1.val[1];
352 | vec_c[3] += vec_v_right_1.val[1];
353 |
354 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
355 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
356 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
357 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top);
358 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top);
359 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot);
360 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot);
361 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
362 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
363 | vec_c[4] += vec_v_left_2.val[0];
364 | vec_c[4] += vec_v_right_2.val[0];
365 | vec_c[5] += vec_v_left_2.val[1];
366 | vec_c[5] += vec_v_right_2.val[1];
367 |
368 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
369 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
370 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
371 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top);
372 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top);
373 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot);
374 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot);
375 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
376 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
377 | vec_c[6] += vec_v_left_3.val[0];
378 | vec_c[6] += vec_v_right_3.val[0];
379 | vec_c[7] += vec_v_left_3.val[1];
380 | vec_c[7] += vec_v_right_3.val[1];
381 |
382 | }
383 |
384 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
385 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
386 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
387 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
388 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
389 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
390 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
391 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
392 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
393 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
394 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
395 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
396 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
397 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
398 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
399 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
400 | int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4]));
401 | int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]);
402 | vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4);
403 | vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4);
404 | int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5]));
405 | int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]);
406 | vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5);
407 | vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5);
408 | int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6]));
409 | int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]);
410 | vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6);
411 | vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6);
412 | int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7]));
413 | int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]);
414 | vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7);
415 | vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7);
416 |
417 | }
418 | #endif
419 | }
420 |
421 | int32_t qgemm_lut_3200_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
422 | alignas(32) uint32_t CBits[BM3200_3200];
423 | memset(&(CBits[0]), 0, BM3200_3200 * sizeof(int32_t));
424 | #pragma unroll
425 | for (int32_t k_outer = 0; k_outer < 3200 / BBK3200_3200; ++k_outer) {
426 | tbl_impl_3200_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_3200 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_3200 / 2 / 2 * BM3200_3200)])));
427 | }
428 | #pragma unroll
429 | for (int i = 0; i < BM3200_3200; i++) {
430 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
431 | }
432 | return 0;
433 | };
434 | #include
435 |
436 | #define BM8640_3200 320
437 | #define BBK8640_3200 64
438 | inline void tbl_impl_8640_3200(int32_t* c, int8_t* lut, uint8_t* a) {
439 | #ifdef __ARM_NEON
440 | const int KK = BBK8640_3200 / 2;
441 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
442 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
443 | int8x16_t vec_lut[2 * KK];
444 | int16x8_t vec_c[4];
445 | #pragma unroll
446 | for (int k = 0; k < 2 * KK; k++) {
447 | vec_lut[k] = vld1q_s8(lut + k * 16);
448 | }
449 |
450 | #pragma unroll
451 | for (int i = 0; i < BM8640_3200; i += 32) {
452 | #pragma unroll
453 | for (int i=0; i<4; i++) {
454 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
455 | }
456 |
457 | #pragma unroll
458 | for (int k = 0; k < KK / 4; k++) {
459 |
460 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
461 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
462 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
463 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
464 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
465 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
466 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
467 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
468 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
469 | vec_c[0] += vec_v_left_0.val[0];
470 | vec_c[0] += vec_v_right_0.val[0];
471 | vec_c[1] += vec_v_left_0.val[1];
472 | vec_c[1] += vec_v_right_0.val[1];
473 |
474 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
475 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
476 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
477 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
478 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
479 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
480 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
481 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
482 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
483 | vec_c[0] += vec_v_left_1.val[0];
484 | vec_c[0] += vec_v_right_1.val[0];
485 | vec_c[1] += vec_v_left_1.val[1];
486 | vec_c[1] += vec_v_right_1.val[1];
487 |
488 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
489 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
490 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
491 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
492 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
493 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
494 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
495 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
496 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
497 | vec_c[2] += vec_v_left_2.val[0];
498 | vec_c[2] += vec_v_right_2.val[0];
499 | vec_c[3] += vec_v_left_2.val[1];
500 | vec_c[3] += vec_v_right_2.val[1];
501 |
502 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
503 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
504 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
505 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
506 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
507 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
508 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
509 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
510 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
511 | vec_c[2] += vec_v_left_3.val[0];
512 | vec_c[2] += vec_v_right_3.val[0];
513 | vec_c[3] += vec_v_left_3.val[1];
514 | vec_c[3] += vec_v_right_3.val[1];
515 |
516 | }
517 |
518 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
519 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
520 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
521 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
522 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
523 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
524 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
525 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
526 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
527 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
528 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
529 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
530 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
531 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
532 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
533 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
534 |
535 | }
536 | #endif
537 | }
538 |
539 | int32_t qgemm_lut_8640_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
540 | alignas(32) uint32_t CBits[BM8640_3200];
541 | memset(&(CBits[0]), 0, BM8640_3200 * sizeof(int32_t));
542 | #pragma unroll
543 | for (int32_t k_outer = 0; k_outer < 3200 / BBK8640_3200; ++k_outer) {
544 | tbl_impl_8640_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK8640_3200 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK8640_3200 / 2 / 2 * BM8640_3200)])));
545 | }
546 | #pragma unroll
547 | for (int i = 0; i < BM8640_3200; i++) {
548 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
549 | }
550 | return 0;
551 | };
552 |
553 | template
554 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{
555 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));
556 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));
557 |
558 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));
559 | }}
560 | void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {
561 | if (m == 3200 && k == 8640) {
562 | preprocessor_k<8640>(B, LUT_Scales, QLUT);
563 | }
564 | else if (m == 3200 && k == 3200) {
565 | preprocessor_k<3200>(B, LUT_Scales, QLUT);
566 | }
567 | else if (m == 8640 && k == 3200) {
568 | preprocessor_k<3200>(B, LUT_Scales, QLUT);
569 | }
570 | }
571 | void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
572 | if (m == 3200 && k == 8640) {
573 | qgemm_lut_3200_8640(A, LUT, Scales, LUT_Scales, C);
574 | }
575 | else if (m == 3200 && k == 3200) {
576 | qgemm_lut_3200_3200(A, LUT, Scales, LUT_Scales, C);
577 | }
578 | else if (m == 8640 && k == 3200) {
579 | qgemm_lut_8640_3200(A, LUT, Scales, LUT_Scales, C);
580 | }
581 | }
582 |
583 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {
584 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
585 | return;
586 | }
587 |
588 | int k = tensor->ne[0];
589 | int m = tensor->ne[1];
590 | const int lut_scales_size = 1;
591 | const int scales_size = 1;
592 | int bk = 0;
593 | int bm = 0;
594 |
595 | if (m == 3200 && k == 8640) {
596 | bm = BM3200_8640;
597 | bk = BBK3200_8640;
598 | }
599 | else if (m == 3200 && k == 3200) {
600 | bm = BM3200_3200;
601 | bk = BBK3200_3200;
602 | }
603 | else if (m == 8640 && k == 3200) {
604 | bm = BM8640_3200;
605 | bk = BBK8640_3200;
606 | }
607 |
608 | const int n_tile_num = m / bm;
609 | const int BK = bk;
610 | uint8_t * qweights;
611 | bitnet_float_type * scales;
612 |
613 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));
614 | qweights = (uint8_t *) tensor->data;
615 | float * i2_scales = (float * )(qweights + k * m / 4);
616 | scales[0] = (bitnet_float_type) i2_scales[0];
617 |
618 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;
619 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = {
620 | /* .lut_scales_size = */ lut_scales_size,
621 | /* .scales_size = */ scales_size,
622 | /* .n_tile_num = */ n_tile_num,
623 | /* .qweights = */ qweights,
624 | /* .scales = */ scales
625 | };
626 | }
627 | #endif
--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-3B/kernel_config_tl1.ini:
--------------------------------------------------------------------------------
1 | [Kernels_0]
2 | m = 3200
3 | k = 8640
4 | bm = 160
5 | bk = 64
6 | bmm = 32
7 |
8 | [Kernels_1]
9 | m = 3200
10 | k = 3200
11 | bm = 320
12 | bk = 128
13 | bmm = 64
14 |
15 | [Kernels_2]
16 | m = 8640
17 | k = 3200
18 | bm = 320
19 | bk = 64
20 | bmm = 32
21 |
22 |
--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-3B/kernel_config_tl2.ini:
--------------------------------------------------------------------------------
1 | [Kernels_0]
2 | m = 3200
3 | k = 8640
4 | bm = 160
5 | bk = 96
6 | bmm = 32
7 |
8 | [Kernels_1]
9 | m = 3200
10 | k = 3200
11 | bm = 320
12 | bk = 96
13 | bmm = 32
14 |
15 | [Kernels_2]
16 | m = 8640
17 | k = 3200
18 | bm = 320
19 | bk = 96
20 | bmm = 32
21 |
22 |
--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-large/bitnet-lut-kernels-tl1.h:
--------------------------------------------------------------------------------
1 | #if defined(GGML_BITNET_ARM_TL1)
2 | #include "ggml-bitnet.h"
3 | #define GGML_BITNET_MAX_NODES 8192
4 | static bool initialized = false;
5 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;
6 | static size_t bitnet_tensor_extras_index = 0;
7 | static void * aligned_malloc(size_t size) {{
8 | #if defined(_WIN32)
9 | return _aligned_malloc(size, 64);
10 | #else
11 | void * ptr = nullptr;
12 | posix_memalign(&ptr, 64, size);
13 | return ptr;
14 | #endif
15 | }}
16 | static void aligned_free(void * ptr) {{
17 | #if defined(_WIN32)
18 | _aligned_free(ptr);
19 | #else
20 | free(ptr);
21 | #endif
22 | }}
23 |
24 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{
25 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
26 | bitnet_float_type* b = (bitnet_float_type*)b_;
27 | #ifdef __ARM_NEON
28 | float32x4_t temp_max = vdupq_n_f32(0);
29 | for (int i=0; i < k / 4; i++) {{
30 | float32x4_t vec_bs = vld1q_f32(b + 4 * i);
31 | float32x4_t abssum = vabsq_f32(vec_bs);
32 | temp_max = vmaxq_f32(abssum, temp_max);
33 | }}
34 | float32_t scales = 127 / vmaxvq_f32(temp_max);
35 | *lut_scales = scales;
36 | #elif defined __AVX2__
37 | __m256 max_vec = _mm256_set1_ps(0.f);
38 | const __m256 vec_sign = _mm256_set1_ps(-0.0f);
39 | // #pragma unroll
40 | for (int i = 0; i < k / 8; i++) {{
41 | __m256 vec_b = _mm256_loadu_ps(b + i * 8);
42 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);
43 | max_vec = _mm256_max_ps(vec_babs, max_vec);
44 | }}
45 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));
46 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));
47 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));
48 | float scales = 127 / _mm_cvtss_f32(max1);
49 | *lut_scales = scales;
50 | #endif
51 | }}
52 |
53 | void partial_max_reset(void* lut_scales_) {{
54 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
55 | *lut_scales = 0.0;
56 | }}
57 |
58 | #ifdef __ARM_NEON
59 | inline void Transpose_8_8(
60 | int16x8_t *v0,
61 | int16x8_t *v1,
62 | int16x8_t *v2,
63 | int16x8_t *v3,
64 | int16x8_t *v4,
65 | int16x8_t *v5,
66 | int16x8_t *v6,
67 | int16x8_t *v7)
68 | {{
69 | int16x8x2_t q04 = vzipq_s16(*v0, *v4);
70 | int16x8x2_t q15 = vzipq_s16(*v1, *v5);
71 | int16x8x2_t q26 = vzipq_s16(*v2, *v6);
72 | int16x8x2_t q37 = vzipq_s16(*v3, *v7);
73 |
74 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);
75 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);
76 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);
77 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);
78 |
79 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);
80 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);
81 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);
82 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);
83 |
84 | *v0 = q_fin_0.val[0];
85 | *v1 = q_fin_0.val[1];
86 | *v2 = q_fin_1.val[0];
87 | *v3 = q_fin_1.val[1];
88 | *v4 = q_fin_2.val[0];
89 | *v5 = q_fin_2.val[1];
90 | *v6 = q_fin_3.val[0];
91 | *v7 = q_fin_3.val[1];
92 | }}
93 | #endif
94 |
95 | template
96 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{
97 | #ifdef __ARM_NEON
98 | int16x8_t vec_lut[16];
99 | float32_t scales = *lut_scales;
100 | uint8_t tbl_mask[16];
101 | tbl_mask[0] = 0;
102 | tbl_mask[1] = 2;
103 | tbl_mask[2] = 4;
104 | tbl_mask[3] = 6;
105 | tbl_mask[4] = 8;
106 | tbl_mask[5] = 10;
107 | tbl_mask[6] = 12;
108 | tbl_mask[7] = 14;
109 | tbl_mask[8] = 1;
110 | tbl_mask[9] = 3;
111 | tbl_mask[10] = 5;
112 | tbl_mask[11] = 7;
113 | tbl_mask[12] = 9;
114 | tbl_mask[13] = 11;
115 | tbl_mask[14] = 13;
116 | tbl_mask[15] = 15;
117 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);
118 | #pragma unroll
119 | for (int k = 0; k < act_k / 16; ++k) {{
120 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);
121 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);
122 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);
123 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);
124 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);
125 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);
126 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);
127 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);
128 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);
129 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);
130 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);
131 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);
132 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);
133 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);
134 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);
135 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);
136 | vec_lut[0] = vdupq_n_s16(0);
137 | vec_lut[0] = vec_lut[0] - vec_bs_0;
138 | vec_lut[0] = vec_lut[0] - vec_bs_1;
139 | vec_lut[1] = vdupq_n_s16(0);
140 | vec_lut[1] = vec_lut[1] - vec_bs_0;
141 | vec_lut[2] = vdupq_n_s16(0);
142 | vec_lut[2] = vec_lut[2] - vec_bs_0;
143 | vec_lut[2] = vec_lut[2] + vec_bs_1;
144 | vec_lut[3] = vdupq_n_s16(0);
145 | vec_lut[3] = vec_lut[3] - vec_bs_1;
146 | vec_lut[4] = vdupq_n_s16(0);
147 | vec_lut[5] = vec_bs_1;
148 | vec_lut[6] = vec_bs_0;
149 | vec_lut[6] = vec_lut[6] - vec_bs_1;
150 | vec_lut[7] = vec_bs_0;
151 | vec_lut[8] = vec_bs_0;
152 | vec_lut[8] = vec_lut[8] + vec_bs_1;
153 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),
154 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));
155 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),
156 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));
157 | #pragma unroll
158 | for (int idx = 0; idx < 8; idx++) {{
159 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);
160 | int8x8_t q0_low = vget_low_s8(q0_s);
161 | int8x8_t q0_high = vget_high_s8(q0_s);
162 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);
163 | int8x8_t q1_low = vget_low_s8(q1_s);
164 | int8x8_t q1_high = vget_high_s8(q1_s);
165 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);
166 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);
167 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);
168 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);
169 | }}
170 | }}
171 | #endif
172 | }}
173 |
174 | static bool is_type_supported(enum ggml_type type) {{
175 | if (type == GGML_TYPE_Q4_0 ||
176 | type == GGML_TYPE_TL1) {{
177 | return true;
178 | }} else {{
179 | return false;
180 | }}
181 | }}
182 | #include
183 |
184 | #define BM1536_4096 256
185 | #define BBK1536_4096 128
186 | inline void tbl_impl_1536_4096(int32_t* c, int8_t* lut, uint8_t* a) {
187 | #ifdef __ARM_NEON
188 | const int KK = BBK1536_4096 / 2;
189 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
190 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
191 | int8x16_t vec_lut[2 * KK];
192 | int16x8_t vec_c[4];
193 | #pragma unroll
194 | for (int k = 0; k < 2 * KK; k++) {
195 | vec_lut[k] = vld1q_s8(lut + k * 16);
196 | }
197 |
198 | #pragma unroll
199 | for (int i = 0; i < BM1536_4096; i += 32) {
200 | #pragma unroll
201 | for (int i=0; i<4; i++) {
202 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
203 | }
204 |
205 | #pragma unroll
206 | for (int k = 0; k < KK / 4; k++) {
207 |
208 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
209 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
210 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
211 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
212 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
213 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
214 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
215 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
216 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
217 | vec_c[0] += vec_v_left_0.val[0];
218 | vec_c[0] += vec_v_right_0.val[0];
219 | vec_c[1] += vec_v_left_0.val[1];
220 | vec_c[1] += vec_v_right_0.val[1];
221 |
222 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
223 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
224 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
225 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
226 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
227 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
228 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
229 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
230 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
231 | vec_c[0] += vec_v_left_1.val[0];
232 | vec_c[0] += vec_v_right_1.val[0];
233 | vec_c[1] += vec_v_left_1.val[1];
234 | vec_c[1] += vec_v_right_1.val[1];
235 |
236 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
237 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
238 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
239 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
240 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
241 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
242 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
243 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
244 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
245 | vec_c[2] += vec_v_left_2.val[0];
246 | vec_c[2] += vec_v_right_2.val[0];
247 | vec_c[3] += vec_v_left_2.val[1];
248 | vec_c[3] += vec_v_right_2.val[1];
249 |
250 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
251 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
252 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
253 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
254 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
255 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
256 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
257 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
258 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
259 | vec_c[2] += vec_v_left_3.val[0];
260 | vec_c[2] += vec_v_right_3.val[0];
261 | vec_c[3] += vec_v_left_3.val[1];
262 | vec_c[3] += vec_v_right_3.val[1];
263 |
264 | }
265 |
266 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
267 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
268 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
269 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
270 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
271 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
272 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
273 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
274 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
275 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
276 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
277 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
278 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
279 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
280 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
281 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
282 |
283 | }
284 | #endif
285 | }
286 |
287 | int32_t qgemm_lut_1536_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
288 | alignas(32) uint32_t CBits[BM1536_4096];
289 | memset(&(CBits[0]), 0, BM1536_4096 * sizeof(int32_t));
290 | #pragma unroll
291 | for (int32_t k_outer = 0; k_outer < 4096 / BBK1536_4096; ++k_outer) {
292 | tbl_impl_1536_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_4096 / 2 / 2 * BM1536_4096)])));
293 | }
294 | #pragma unroll
295 | for (int i = 0; i < BM1536_4096; i++) {
296 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
297 | }
298 | return 0;
299 | };
300 | #include
301 |
302 | #define BM1536_1536 128
303 | #define BBK1536_1536 64
304 | inline void tbl_impl_1536_1536(int32_t* c, int8_t* lut, uint8_t* a) {
305 | #ifdef __ARM_NEON
306 | const int KK = BBK1536_1536 / 2;
307 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
308 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
309 | int8x16_t vec_lut[2 * KK];
310 | int16x8_t vec_c[8];
311 | #pragma unroll
312 | for (int k = 0; k < 2 * KK; k++) {
313 | vec_lut[k] = vld1q_s8(lut + k * 16);
314 | }
315 |
316 | #pragma unroll
317 | for (int i = 0; i < BM1536_1536; i += 64) {
318 | #pragma unroll
319 | for (int i=0; i<8; i++) {
320 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
321 | }
322 |
323 | #pragma unroll
324 | for (int k = 0; k < KK / 2; k++) {
325 |
326 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
327 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
328 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
329 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top);
330 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top);
331 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot);
332 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot);
333 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
334 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
335 | vec_c[0] += vec_v_left_0.val[0];
336 | vec_c[0] += vec_v_right_0.val[0];
337 | vec_c[1] += vec_v_left_0.val[1];
338 | vec_c[1] += vec_v_right_0.val[1];
339 |
340 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
341 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
342 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
343 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top);
344 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top);
345 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot);
346 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot);
347 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
348 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
349 | vec_c[2] += vec_v_left_1.val[0];
350 | vec_c[2] += vec_v_right_1.val[0];
351 | vec_c[3] += vec_v_left_1.val[1];
352 | vec_c[3] += vec_v_right_1.val[1];
353 |
354 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
355 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
356 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
357 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top);
358 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top);
359 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot);
360 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot);
361 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
362 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
363 | vec_c[4] += vec_v_left_2.val[0];
364 | vec_c[4] += vec_v_right_2.val[0];
365 | vec_c[5] += vec_v_left_2.val[1];
366 | vec_c[5] += vec_v_right_2.val[1];
367 |
368 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
369 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
370 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
371 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top);
372 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top);
373 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot);
374 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot);
375 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
376 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
377 | vec_c[6] += vec_v_left_3.val[0];
378 | vec_c[6] += vec_v_right_3.val[0];
379 | vec_c[7] += vec_v_left_3.val[1];
380 | vec_c[7] += vec_v_right_3.val[1];
381 |
382 | }
383 |
384 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
385 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
386 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
387 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
388 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
389 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
390 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
391 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
392 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
393 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
394 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
395 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
396 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
397 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
398 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
399 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
400 | int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4]));
401 | int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]);
402 | vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4);
403 | vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4);
404 | int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5]));
405 | int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]);
406 | vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5);
407 | vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5);
408 | int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6]));
409 | int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]);
410 | vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6);
411 | vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6);
412 | int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7]));
413 | int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]);
414 | vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7);
415 | vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7);
416 |
417 | }
418 | #endif
419 | }
420 |
421 | int32_t qgemm_lut_1536_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
422 | alignas(32) uint32_t CBits[BM1536_1536];
423 | memset(&(CBits[0]), 0, BM1536_1536 * sizeof(int32_t));
424 | #pragma unroll
425 | for (int32_t k_outer = 0; k_outer < 1536 / BBK1536_1536; ++k_outer) {
426 | tbl_impl_1536_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_1536 / 2 / 2 * BM1536_1536)])));
427 | }
428 | #pragma unroll
429 | for (int i = 0; i < BM1536_1536; i++) {
430 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
431 | }
432 | return 0;
433 | };
434 | #include
435 |
436 | #define BM4096_1536 256
437 | #define BBK4096_1536 128
438 | inline void tbl_impl_4096_1536(int32_t* c, int8_t* lut, uint8_t* a) {
439 | #ifdef __ARM_NEON
440 | const int KK = BBK4096_1536 / 2;
441 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
442 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
443 | int8x16_t vec_lut[2 * KK];
444 | int16x8_t vec_c[4];
445 | #pragma unroll
446 | for (int k = 0; k < 2 * KK; k++) {
447 | vec_lut[k] = vld1q_s8(lut + k * 16);
448 | }
449 |
450 | #pragma unroll
451 | for (int i = 0; i < BM4096_1536; i += 32) {
452 | #pragma unroll
453 | for (int i=0; i<4; i++) {
454 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
455 | }
456 |
457 | #pragma unroll
458 | for (int k = 0; k < KK / 4; k++) {
459 |
460 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
461 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
462 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
463 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
464 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
465 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
466 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
467 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
468 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
469 | vec_c[0] += vec_v_left_0.val[0];
470 | vec_c[0] += vec_v_right_0.val[0];
471 | vec_c[1] += vec_v_left_0.val[1];
472 | vec_c[1] += vec_v_right_0.val[1];
473 |
474 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
475 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
476 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
477 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
478 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
479 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
480 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
481 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
482 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
483 | vec_c[0] += vec_v_left_1.val[0];
484 | vec_c[0] += vec_v_right_1.val[0];
485 | vec_c[1] += vec_v_left_1.val[1];
486 | vec_c[1] += vec_v_right_1.val[1];
487 |
488 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
489 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
490 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
491 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
492 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
493 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
494 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
495 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
496 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
497 | vec_c[2] += vec_v_left_2.val[0];
498 | vec_c[2] += vec_v_right_2.val[0];
499 | vec_c[3] += vec_v_left_2.val[1];
500 | vec_c[3] += vec_v_right_2.val[1];
501 |
502 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
503 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
504 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
505 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
506 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
507 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
508 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
509 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
510 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
511 | vec_c[2] += vec_v_left_3.val[0];
512 | vec_c[2] += vec_v_right_3.val[0];
513 | vec_c[3] += vec_v_left_3.val[1];
514 | vec_c[3] += vec_v_right_3.val[1];
515 |
516 | }
517 |
518 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
519 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
520 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
521 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
522 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
523 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
524 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
525 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
526 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
527 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
528 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
529 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
530 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
531 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
532 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
533 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
534 |
535 | }
536 | #endif
537 | }
538 |
539 | int32_t qgemm_lut_4096_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
540 | alignas(32) uint32_t CBits[BM4096_1536];
541 | memset(&(CBits[0]), 0, BM4096_1536 * sizeof(int32_t));
542 | #pragma unroll
543 | for (int32_t k_outer = 0; k_outer < 1536 / BBK4096_1536; ++k_outer) {
544 | tbl_impl_4096_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_1536 / 2 / 2 * BM4096_1536)])));
545 | }
546 | #pragma unroll
547 | for (int i = 0; i < BM4096_1536; i++) {
548 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
549 | }
550 | return 0;
551 | };
552 |
553 | template
554 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{
555 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));
556 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));
557 |
558 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));
559 | }}
560 | void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {
561 | if (m == 1536 && k == 4096) {
562 | preprocessor_k<4096>(B, LUT_Scales, QLUT);
563 | }
564 | else if (m == 1536 && k == 1536) {
565 | preprocessor_k<1536>(B, LUT_Scales, QLUT);
566 | }
567 | else if (m == 4096 && k == 1536) {
568 | preprocessor_k<1536>(B, LUT_Scales, QLUT);
569 | }
570 | }
571 | void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
572 | if (m == 1536 && k == 4096) {
573 | qgemm_lut_1536_4096(A, LUT, Scales, LUT_Scales, C);
574 | }
575 | else if (m == 1536 && k == 1536) {
576 | qgemm_lut_1536_1536(A, LUT, Scales, LUT_Scales, C);
577 | }
578 | else if (m == 4096 && k == 1536) {
579 | qgemm_lut_4096_1536(A, LUT, Scales, LUT_Scales, C);
580 | }
581 | }
582 |
583 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {
584 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
585 | return;
586 | }
587 |
588 | int k = tensor->ne[0];
589 | int m = tensor->ne[1];
590 | const int lut_scales_size = 1;
591 | const int scales_size = 1;
592 | int bk = 0;
593 | int bm = 0;
594 |
595 | if (m == 1536 && k == 4096) {
596 | bm = BM1536_4096;
597 | bk = BBK1536_4096;
598 | }
599 | else if (m == 1536 && k == 1536) {
600 | bm = BM1536_1536;
601 | bk = BBK1536_1536;
602 | }
603 | else if (m == 4096 && k == 1536) {
604 | bm = BM4096_1536;
605 | bk = BBK4096_1536;
606 | }
607 |
608 | const int n_tile_num = m / bm;
609 | const int BK = bk;
610 | uint8_t * qweights;
611 | bitnet_float_type * scales;
612 |
613 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));
614 | qweights = (uint8_t *) tensor->data;
615 | float * i2_scales = (float * )(qweights + k * m / 4);
616 | scales[0] = (bitnet_float_type) i2_scales[0];
617 |
618 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;
619 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = {
620 | /* .lut_scales_size = */ lut_scales_size,
621 | /* .scales_size = */ scales_size,
622 | /* .n_tile_num = */ n_tile_num,
623 | /* .qweights = */ qweights,
624 | /* .scales = */ scales
625 | };
626 | }
627 | #endif
--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-large/kernel_config_tl1.ini:
--------------------------------------------------------------------------------
1 | [Kernels_0]
2 | m = 1536
3 | k = 4096
4 | bm = 256
5 | bk = 128
6 | bmm = 32
7 |
8 | [Kernels_1]
9 | m = 1536
10 | k = 1536
11 | bm = 128
12 | bk = 64
13 | bmm = 64
14 |
15 | [Kernels_2]
16 | m = 4096
17 | k = 1536
18 | bm = 256
19 | bk = 128
20 | bmm = 32
21 |
22 |
--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-large/kernel_config_tl2.ini:
--------------------------------------------------------------------------------
1 | [Kernels_0]
2 | m = 1536
3 | k = 4096
4 | bm = 256
5 | bk = 96
6 | bmm = 32
7 |
8 | [Kernels_1]
9 | m = 1536
10 | k = 1536
11 | bm = 128
12 | bk = 192
13 | bmm = 32
14 |
15 | [Kernels_2]
16 | m = 4096
17 | k = 1536
18 | bm = 256
19 | bk = 96
20 | bmm = 64
21 |
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # These requirements include all dependencies for all top-level python scripts
2 | # for llama.cpp. Avoid adding packages here directly.
3 | #
4 | # Package versions must stay compatible across all top-level python scripts.
5 | #
6 |
7 | -r 3rdparty/llama.cpp/requirements/requirements-convert_legacy_llama.txt
8 | -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt
9 | -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt
10 | -r 3rdparty/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt
11 | -r 3rdparty/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt
--------------------------------------------------------------------------------
/run_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import signal
4 | import platform
5 | import argparse
6 | import subprocess
7 |
8 | def run_command(command, shell=False):
9 | """Run a system command and ensure it succeeds."""
10 | try:
11 | subprocess.run(command, shell=shell, check=True)
12 | except subprocess.CalledProcessError as e:
13 | print(f"Error occurred while running command: {e}")
14 | sys.exit(1)
15 |
16 | def run_inference():
17 | build_dir = "build"
18 | if platform.system() == "Windows":
19 | main_path = os.path.join(build_dir, "bin", "Release", "llama-cli.exe")
20 | if not os.path.exists(main_path):
21 | main_path = os.path.join(build_dir, "bin", "llama-cli")
22 | else:
23 | main_path = os.path.join(build_dir, "bin", "llama-cli")
24 | command = [
25 | f'{main_path}',
26 | '-m', args.model,
27 | '-n', str(args.n_predict),
28 | '-t', str(args.threads),
29 | '-p', args.prompt,
30 | '-ngl', '0',
31 | '-c', str(args.ctx_size),
32 | '--temp', str(args.temperature),
33 | "-b", "1",
34 | ]
35 | if args.conversation:
36 | command.append("-cnv")
37 | run_command(command)
38 |
39 | def signal_handler(sig, frame):
40 | print("Ctrl+C pressed, exiting...")
41 | sys.exit(0)
42 |
43 | if __name__ == "__main__":
44 | signal.signal(signal.SIGINT, signal_handler)
45 | # Usage: python run_inference.py -p "Microsoft Corporation is an American multinational corporation and technology company headquartered in Redmond, Washington."
46 | parser = argparse.ArgumentParser(description='Run inference')
47 | parser.add_argument("-m", "--model", type=str, help="Path to model file", required=False, default="models/bitnet_b1_58-3B/ggml-model-i2_s.gguf")
48 | parser.add_argument("-n", "--n-predict", type=int, help="Number of tokens to predict when generating text", required=False, default=128)
49 | parser.add_argument("-p", "--prompt", type=str, help="Prompt to generate text from", required=True)
50 | parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2)
51 | parser.add_argument("-c", "--ctx-size", type=int, help="Size of the prompt context", required=False, default=2048)
52 | parser.add_argument("-temp", "--temperature", type=float, help="Temperature, a hyperparameter that controls the randomness of the generated text", required=False, default=0.8)
53 | parser.add_argument("-cnv", "--conversation", action='store_true', help="Whether to enable chat mode or not (for instruct models.)")
54 |
55 | args = parser.parse_args()
56 | run_inference()
--------------------------------------------------------------------------------
/run_inference_server.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import signal
4 | import platform
5 | import argparse
6 | import subprocess
7 |
8 | def run_command(command, shell=False):
9 | """Run a system command and ensure it succeeds."""
10 | try:
11 | subprocess.run(command, shell=shell, check=True)
12 | except subprocess.CalledProcessError as e:
13 | print(f"Error occurred while running command: {e}")
14 | sys.exit(1)
15 |
16 | def run_server():
17 | build_dir = "build"
18 | if platform.system() == "Windows":
19 | server_path = os.path.join(build_dir, "bin", "Release", "llama-server.exe")
20 | if not os.path.exists(server_path):
21 | server_path = os.path.join(build_dir, "bin", "llama-server")
22 | else:
23 | server_path = os.path.join(build_dir, "bin", "llama-server")
24 |
25 | command = [
26 | f'{server_path}',
27 | '-m', args.model,
28 | '-c', str(args.ctx_size),
29 | '-t', str(args.threads),
30 | '-n', str(args.n_predict),
31 | '-ngl', '0',
32 | '--temp', str(args.temperature),
33 | '--host', args.host,
34 | '--port', str(args.port),
35 | '-cb' # Enable continuous batching
36 | ]
37 |
38 | if args.prompt:
39 | command.extend(['-p', args.prompt])
40 |
41 | # Note: -cnv flag is removed as it's not supported by the server
42 |
43 | print(f"Starting server on {args.host}:{args.port}")
44 | run_command(command)
45 |
46 | def signal_handler(sig, frame):
47 | print("Ctrl+C pressed, shutting down server...")
48 | sys.exit(0)
49 |
50 | if __name__ == "__main__":
51 | signal.signal(signal.SIGINT, signal_handler)
52 |
53 | parser = argparse.ArgumentParser(description='Run llama.cpp server')
54 | parser.add_argument("-m", "--model", type=str, help="Path to model file", required=False, default="models/bitnet_b1_58-3B/ggml-model-i2_s.gguf")
55 | parser.add_argument("-p", "--prompt", type=str, help="System prompt for the model", required=False)
56 | parser.add_argument("-n", "--n-predict", type=int, help="Number of tokens to predict", required=False, default=4096)
57 | parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2)
58 | parser.add_argument("-c", "--ctx-size", type=int, help="Size of the context window", required=False, default=2048)
59 | parser.add_argument("--temperature", type=float, help="Temperature for sampling", required=False, default=0.8)
60 | parser.add_argument("--host", type=str, help="IP address to listen on", required=False, default="127.0.0.1")
61 | parser.add_argument("--port", type=int, help="Port to listen on", required=False, default=8080)
62 |
63 | args = parser.parse_args()
64 | run_server()
65 |
--------------------------------------------------------------------------------
/setup_env.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import signal
3 | import sys
4 | import os
5 | import platform
6 | import argparse
7 | import logging
8 | import shutil
9 | from pathlib import Path
10 |
11 | logger = logging.getLogger("setup_env")
12 |
13 | SUPPORTED_HF_MODELS = {
14 | "1bitLLM/bitnet_b1_58-large": {
15 | "model_name": "bitnet_b1_58-large",
16 | },
17 | "1bitLLM/bitnet_b1_58-3B": {
18 | "model_name": "bitnet_b1_58-3B",
19 | },
20 | "HF1BitLLM/Llama3-8B-1.58-100B-tokens": {
21 | "model_name": "Llama3-8B-1.58-100B-tokens",
22 | },
23 | "tiiuae/Falcon3-7B-Instruct-1.58bit": {
24 | "model_name": "Falcon3-7B-Instruct-1.58bit",
25 | },
26 | "tiiuae/Falcon3-7B-1.58bit": {
27 | "model_name": "Falcon3-7B-1.58bit",
28 | },
29 | "tiiuae/Falcon3-10B-Instruct-1.58bit": {
30 | "model_name": "Falcon3-10B-Instruct-1.58bit",
31 | },
32 | "tiiuae/Falcon3-10B-1.58bit": {
33 | "model_name": "Falcon3-10B-1.58bit",
34 | },
35 | "tiiuae/Falcon3-3B-Instruct-1.58bit": {
36 | "model_name": "Falcon3-3B-Instruct-1.58bit",
37 | },
38 | "tiiuae/Falcon3-3B-1.58bit": {
39 | "model_name": "Falcon3-3B-1.58bit",
40 | },
41 | "tiiuae/Falcon3-1B-Instruct-1.58bit": {
42 | "model_name": "Falcon3-1B-Instruct-1.58bit",
43 | },
44 | "microsoft/BitNet-b1.58-2B-4T": {
45 | "model_name": "BitNet-b1.58-2B-4T",
46 | },
47 | "tiiuae/Falcon-E-3B-Instruct": {
48 | "model_name": "Falcon-E-3B-Instruct",
49 | },
50 | "tiiuae/Falcon-E-1B-Instruct": {
51 | "model_name": "Falcon-E-1B-Instruct",
52 | },
53 | "tiiuae/Falcon-E-3B-Base": {
54 | "model_name": "Falcon-E-3B-Base",
55 | },
56 | "tiiuae/Falcon-E-1B-Base": {
57 | "model_name": "Falcon-E-1B-Base",
58 | },
59 | }
60 |
61 | SUPPORTED_QUANT_TYPES = {
62 | "arm64": ["i2_s", "tl1"],
63 | "x86_64": ["i2_s", "tl2"]
64 | }
65 |
66 | COMPILER_EXTRA_ARGS = {
67 | "arm64": ["-DBITNET_ARM_TL1=ON"],
68 | "x86_64": ["-DBITNET_X86_TL2=ON"]
69 | }
70 |
71 | OS_EXTRA_ARGS = {
72 | "Windows":["-T", "ClangCL"],
73 | }
74 |
75 | ARCH_ALIAS = {
76 | "AMD64": "x86_64",
77 | "x86": "x86_64",
78 | "x86_64": "x86_64",
79 | "aarch64": "arm64",
80 | "arm64": "arm64",
81 | "ARM64": "arm64",
82 | }
83 |
84 | def system_info():
85 | return platform.system(), ARCH_ALIAS[platform.machine()]
86 |
87 | def get_model_name():
88 | if args.hf_repo:
89 | return SUPPORTED_HF_MODELS[args.hf_repo]["model_name"]
90 | return os.path.basename(os.path.normpath(args.model_dir))
91 |
92 | def run_command(command, shell=False, log_step=None):
93 | """Run a system command and ensure it succeeds."""
94 | if log_step:
95 | log_file = os.path.join(args.log_dir, log_step + ".log")
96 | with open(log_file, "w") as f:
97 | try:
98 | subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f)
99 | except subprocess.CalledProcessError as e:
100 | logging.error(f"Error occurred while running command: {e}, check details in {log_file}")
101 | sys.exit(1)
102 | else:
103 | try:
104 | subprocess.run(command, shell=shell, check=True)
105 | except subprocess.CalledProcessError as e:
106 | logging.error(f"Error occurred while running command: {e}")
107 | sys.exit(1)
108 |
109 | def prepare_model():
110 | _, arch = system_info()
111 | hf_url = args.hf_repo
112 | model_dir = args.model_dir
113 | quant_type = args.quant_type
114 | quant_embd = args.quant_embd
115 | if hf_url is not None:
116 | # download the model
117 | model_dir = os.path.join(model_dir, SUPPORTED_HF_MODELS[hf_url]["model_name"])
118 | Path(model_dir).mkdir(parents=True, exist_ok=True)
119 | logging.info(f"Downloading model {hf_url} from HuggingFace to {model_dir}...")
120 | run_command(["huggingface-cli", "download", hf_url, "--local-dir", model_dir], log_step="download_model")
121 | elif not os.path.exists(model_dir):
122 | logging.error(f"Model directory {model_dir} does not exist.")
123 | sys.exit(1)
124 | else:
125 | logging.info(f"Loading model from directory {model_dir}.")
126 | gguf_path = os.path.join(model_dir, "ggml-model-" + quant_type + ".gguf")
127 | if not os.path.exists(gguf_path) or os.path.getsize(gguf_path) == 0:
128 | logging.info(f"Converting HF model to GGUF format...")
129 | if quant_type.startswith("tl"):
130 | run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", quant_type, "--quant-embd"], log_step="convert_to_tl")
131 | else: # i2s
132 | # convert to f32
133 | run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "f32"], log_step="convert_to_f32_gguf")
134 | f32_model = os.path.join(model_dir, "ggml-model-f32.gguf")
135 | i2s_model = os.path.join(model_dir, "ggml-model-i2_s.gguf")
136 | # quantize to i2s
137 | if platform.system() != "Windows":
138 | if quant_embd:
139 | run_command(["./build/bin/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s")
140 | else:
141 | run_command(["./build/bin/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s")
142 | else:
143 | if quant_embd:
144 | run_command(["./build/bin/Release/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s")
145 | else:
146 | run_command(["./build/bin/Release/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s")
147 |
148 | logging.info(f"GGUF model saved at {gguf_path}")
149 | else:
150 | logging.info(f"GGUF model already exists at {gguf_path}")
151 |
152 | def setup_gguf():
153 | # Install the pip package
154 | run_command([sys.executable, "-m", "pip", "install", "3rdparty/llama.cpp/gguf-py"], log_step="install_gguf")
155 |
156 | def gen_code():
157 | _, arch = system_info()
158 |
159 | llama3_f3_models = set([model['model_name'] for model in SUPPORTED_HF_MODELS.values() if model['model_name'].startswith("Falcon") or model['model_name'].startswith("Llama")])
160 |
161 | if arch == "arm64":
162 | if args.use_pretuned:
163 | pretuned_kernels = os.path.join("preset_kernels", get_model_name())
164 | if not os.path.exists(pretuned_kernels):
165 | logging.error(f"Pretuned kernels not found for model {args.hf_repo}")
166 | sys.exit(1)
167 | if args.quant_type == "tl1":
168 | shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl1.h"), "include/bitnet-lut-kernels.h")
169 | shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl1.ini"), "include/kernel_config.ini")
170 | elif args.quant_type == "tl2":
171 | shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
172 | shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl2.ini"), "include/kernel_config.ini")
173 | if get_model_name() == "bitnet_b1_58-large":
174 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen")
175 | elif get_model_name() in llama3_f3_models:
176 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen")
177 | elif get_model_name() == "bitnet_b1_58-3B":
178 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
179 | elif get_model_name() == "BitNet-b1.58-2B-4T":
180 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
181 | else:
182 | raise NotImplementedError()
183 | else:
184 | if args.use_pretuned:
185 | # cp preset_kernels/model_name/bitnet-lut-kernels_tl1.h to include/bitnet-lut-kernels.h
186 | pretuned_kernels = os.path.join("preset_kernels", get_model_name())
187 | if not os.path.exists(pretuned_kernels):
188 | logging.error(f"Pretuned kernels not found for model {args.hf_repo}")
189 | sys.exit(1)
190 | shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
191 | if get_model_name() == "bitnet_b1_58-large":
192 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen")
193 | elif get_model_name() in llama3_f3_models:
194 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen")
195 | elif get_model_name() == "bitnet_b1_58-3B":
196 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
197 | elif get_model_name() == "BitNet-b1.58-2B-4T":
198 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
199 | else:
200 | raise NotImplementedError()
201 |
202 |
203 | def compile():
204 | # Check if cmake is installed
205 | cmake_exists = subprocess.run(["cmake", "--version"], capture_output=True)
206 | if cmake_exists.returncode != 0:
207 | logging.error("Cmake is not available. Please install CMake and try again.")
208 | sys.exit(1)
209 | _, arch = system_info()
210 | if arch not in COMPILER_EXTRA_ARGS.keys():
211 | logging.error(f"Arch {arch} is not supported yet")
212 | exit(0)
213 | logging.info("Compiling the code using CMake.")
214 | run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), []), "-DCMAKE_C_COMPILER=clang", "-DCMAKE_CXX_COMPILER=clang++"], log_step="generate_build_files")
215 | # run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"])
216 | run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile")
217 |
218 | def main():
219 | setup_gguf()
220 | gen_code()
221 | compile()
222 | prepare_model()
223 |
224 | def parse_args():
225 | _, arch = system_info()
226 | parser = argparse.ArgumentParser(description='Setup the environment for running the inference')
227 | parser.add_argument("--hf-repo", "-hr", type=str, help="Model used for inference", choices=SUPPORTED_HF_MODELS.keys())
228 | parser.add_argument("--model-dir", "-md", type=str, help="Directory to save/load the model", default="models")
229 | parser.add_argument("--log-dir", "-ld", type=str, help="Directory to save the logging info", default="logs")
230 | parser.add_argument("--quant-type", "-q", type=str, help="Quantization type", choices=SUPPORTED_QUANT_TYPES[arch], default="i2_s")
231 | parser.add_argument("--quant-embd", action="store_true", help="Quantize the embeddings to f16")
232 | parser.add_argument("--use-pretuned", "-p", action="store_true", help="Use the pretuned kernel parameters")
233 | return parser.parse_args()
234 |
235 | def signal_handler(sig, frame):
236 | logging.info("Ctrl+C pressed, exiting...")
237 | sys.exit(0)
238 |
239 | if __name__ == "__main__":
240 | signal.signal(signal.SIGINT, signal_handler)
241 | args = parse_args()
242 | Path(args.log_dir).mkdir(parents=True, exist_ok=True)
243 | logging.basicConfig(level=logging.INFO)
244 | main()
245 |
--------------------------------------------------------------------------------
/src/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | set(GGML_HEADERS_BITNET ../include/ggml-bitnet.h)
2 | set(GGML_SOURCES_BITNET ggml-bitnet-mad.cpp)
3 | set(GGML_SOURCES_BITNET ggml-bitnet-lut.cpp)
4 |
5 | include_directories(3rdparty/llama.cpp/ggml/include)
6 |
7 | if (NOT (CMAKE_C_COMPILER_ID MATCHES "Clang" OR CMAKE_C_COMPILER_ID STREQUAL "GNU") OR
8 | NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU"))
9 | message(FATAL_ERROR "Clang or GCC is required for Bitnet.cpp compilation")
10 | endif()
11 |
--------------------------------------------------------------------------------
/src/ggml-bitnet-lut.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include "ggml-bitnet.h"
9 | #include "ggml-quants.h"
10 | #include "bitnet-lut-kernels.h"
11 |
12 | #if defined(GGML_BITNET_ARM_TL1)
13 |
14 | void ggml_bitnet_init(void) {
15 | // LOG(INFO) << "ggml_bitnet_init";
16 |
17 | if (initialized) {
18 | return;
19 | }
20 | initialized = true;
21 |
22 | // if (wrapper == nullptr) {
23 | // wrapper = new BITNET::BITNETGeMMWrapper();
24 | // }
25 | if (bitnet_tensor_extras == nullptr) {
26 | bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES];
27 | }
28 | bitnet_tensor_extras_index = 0;
29 | }
30 |
31 | void ggml_bitnet_free(void) {
32 | // LOG(INFO) << "ggml_bitnet_free";
33 |
34 | if (!initialized) {
35 | return;
36 | }
37 | initialized = false;
38 |
39 | // delete wrapper;
40 | // wrapper = nullptr;
41 | for (size_t i = 0; i < bitnet_tensor_extras_index; i++) {
42 | // aligned_free(bitnet_tensor_extras[i].qweights);
43 | // aligned_free(bitnet_tensor_extras[i].scales);
44 | }
45 | delete[] bitnet_tensor_extras;
46 | bitnet_tensor_extras = nullptr;
47 | }
48 |
49 | static bool do_permutate(enum ggml_type type) {
50 | if (type == GGML_TYPE_TL1) {
51 | // Add additional args to decide if permuted I2 or naive I2
52 | return false;
53 | } else {
54 | return true;
55 | }
56 | }
57 |
58 | bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
59 | if ((is_type_supported(src0->type)) &&
60 | src1->type == GGML_TYPE_F32 &&
61 | dst->type == GGML_TYPE_F32 &&
62 | src0->backend == GGML_BACKEND_TYPE_CPU) {
63 | if (src1->ne[1] <= 1) {
64 | return true;
65 | }
66 | }
67 | return false;
68 | }
69 |
70 | size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
71 | const size_t ne01 = src0->ne[1];
72 | const size_t ne10 = src1->ne[0];
73 | const size_t ne11 = src1->ne[1];
74 | const int bits = ggml_bitnet_get_type_bits(src0->type);
75 |
76 | size_t wsize = ne10 * ne11 * 15 * sizeof(int8_t) + 1 * ne11 * 2 * sizeof(bitnet_float_type);
77 | if (sizeof(bitnet_float_type) == 2) {
78 | // Need fp32 to fp16 conversion
79 | wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type);
80 | }
81 | wsize = ((wsize - 1) / 64 + 1) * 64;
82 | return wsize;
83 | }
84 |
85 | int ggml_bitnet_get_type_bits(enum ggml_type type) {
86 | switch (type) {
87 | case GGML_TYPE_TL1:
88 | return 2;
89 | case GGML_TYPE_Q4_0:
90 | return 4;
91 | default:
92 | return 0;
93 | }
94 | }
95 |
96 | #endif
97 | #if defined(GGML_BITNET_X86_TL2)
98 | void ggml_bitnet_init(void) {
99 | // LOG(INFO) << "ggml_bitnet_init";
100 |
101 | if (initialized) {
102 | return;
103 | }
104 | initialized = true;
105 |
106 | // if (wrapper == nullptr) {
107 | // wrapper = new BITNET::BITNETGeMMWrapper();
108 | // }
109 | if (bitnet_tensor_extras == nullptr) {
110 | bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES];
111 | }
112 | bitnet_tensor_extras_index = 0;
113 | }
114 |
115 | void ggml_bitnet_free(void) {
116 | // LOG(INFO) << "ggml_bitnet_free";
117 |
118 | if (!initialized) {
119 | return;
120 | }
121 | initialized = false;
122 |
123 | // delete wrapper;
124 | // wrapper = nullptr;
125 | for (size_t i = 0; i < bitnet_tensor_extras_index; i++) {
126 | // aligned_free(bitnet_tensor_extras[i].qweights);
127 | // aligned_free(bitnet_tensor_extras[i].scales);
128 | }
129 | delete[] bitnet_tensor_extras;
130 | bitnet_tensor_extras = nullptr;
131 | }
132 |
133 | bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
134 | if ((is_type_supported(src0->type)) &&
135 | src1->type == GGML_TYPE_F32 &&
136 | dst->type == GGML_TYPE_F32 &&
137 | src0->backend == GGML_BACKEND_TYPE_CPU) {
138 | return true;
139 | }
140 | return false;
141 | }
142 |
143 | size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
144 | const size_t ne01 = src0->ne[1];
145 | const size_t ne10 = src1->ne[0];
146 | const size_t ne11 = src1->ne[1];
147 |
148 | size_t wsize = ne10 * ne11 * 11 * sizeof(int8_t) + 2 * ne11 * 2 * sizeof(bitnet_float_type);
149 | if (sizeof(bitnet_float_type) == 2) {
150 | // Need fp32 to fp16 conversion
151 | wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type);
152 | }
153 | wsize = ((wsize - 1) / 64 + 1) * 64;
154 | return wsize;
155 | }
156 |
157 | int ggml_bitnet_get_type_bits(enum ggml_type type) {
158 | switch (type) {
159 | case GGML_TYPE_TL2:
160 | return 2;
161 | case GGML_TYPE_Q4_0:
162 | return 4;
163 | default:
164 | return 0;
165 | }
166 | }
167 | #endif
--------------------------------------------------------------------------------
/src/ggml-bitnet-mad.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "ggml-bitnet.h"
5 | #include "ggml-quants.h"
6 | #include
7 | #include
8 |
9 | #define QK_I2_S 128
10 | #define QK_I2 128
11 |
12 | #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
13 | #include
14 | // horizontally add 8 int32_t
15 | static inline int hsum_i32_8(const __m256i a) {
16 | const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
17 | const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
18 | const __m128i sum64 = _mm_add_epi32(hi64, sum128);
19 | const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
20 | return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
21 | }
22 | #elif defined(__loongarch_asx)
23 | // horizontally add 8 int32_t
24 | static inline int hsum_i32_8(const __m256i a) {
25 |
26 | __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
27 | __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
28 |
29 | __m128i tmp1_128 = lasx_extracti128_lo(tmp1);
30 | __m128i tmp2_128 = lasx_extracti128_lo(tmp2);
31 |
32 | __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
33 |
34 | __m128i ev = __lsx_vpickev_w(sum128, sum128);
35 | __m128i od = __lsx_vpickod_w(sum128, sum128);
36 | __m128i sum64 = __lsx_vadd_w(ev, od);
37 |
38 | int sum64_1, sum64_2;
39 | sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
40 | sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
41 |
42 | return sum64_1 + sum64_2;
43 | }
44 | #endif
45 |
46 | size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
47 | // 2 bits per weight
48 |
49 | size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row);
50 |
51 | int n = nrow * n_per_row;
52 |
53 | // f32 -> q8
54 | double max = 0;
55 | for (int i = 0; i < n; ++i) {
56 | max = fmax(max, (double)fabs((double)src[i]));
57 | }
58 | double i2_scale = max;
59 |
60 | uint8_t* q8 = (uint8_t*)malloc(n * sizeof(uint8_t));
61 | for (int i=0; i 0 ? 2 : 0;
67 | }
68 |
69 | memset(dst, 0, n * sizeof(uint8_t) / 4);
70 |
71 | // q8 -> 0, 1, 2
72 | // | | |
73 | // -1, 0, 1
74 |
75 | uint8_t* i2_weight = (uint8_t*)dst;
76 | for (int i = 0; i < n / QK_I2; i++) {
77 | for (int j = 0; j < QK_I2; j++) {
78 | int group_idx = j / 32;
79 | int group_pos = j % 32;
80 | uint8_t temp = (q8[i * QK_I2 + j] << (6 - 2 * group_idx));
81 | i2_weight[i * 32 + group_pos] |= temp;
82 | }
83 | }
84 |
85 | float* scale_ptr = (float*)((char*)i2_weight + n / 4);
86 | scale_ptr[0] = i2_scale;
87 |
88 | free(q8);
89 |
90 | // 32B for alignment
91 | return nrow * row_size / 4 + 32;
92 | }
93 |
94 | void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
95 | const uint8_t * x = (uint8_t *)vx;
96 | const int8_t * y = (int8_t *)vy;
97 |
98 | const int nb = n / QK_I2_S;
99 | const int group32_num = nb / 32;
100 | const int la_num = nb % 32;
101 | const int groupla_num = nb % 32 != 0 ? 1 : 0;
102 |
103 | #if defined(__AVX2__)
104 |
105 | __m256i mask = _mm256_set1_epi8(0x03);
106 | __m256i accu = _mm256_setzero_si256();
107 |
108 | for (int i=0; i < group32_num; i++){
109 | __m256i accu32 = _mm256_setzero_si256();
110 | for (int j=0; j < 32; j++) {
111 | // 128 index
112 | __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + i * 32 * 32 + j * 32));
113 | __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
114 | __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
115 | __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
116 |
117 | // each 32 index
118 | xq8_3 = _mm256_and_si256(xq8_3, mask);
119 | xq8_2 = _mm256_and_si256(xq8_2, mask);
120 | xq8_1 = _mm256_and_si256(xq8_1, mask);
121 | xq8_0 = _mm256_and_si256(xq8_0, mask);
122 |
123 | // each 32 index
124 | __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 0));
125 | __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 32));
126 | __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 64));
127 | __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 96));
128 |
129 | // 128 index accumulation add
130 | // split into 32 accumulation block
131 | // each block each 128 index accumulated 4index
132 | // each index maximum 256
133 | // each block maximum 4 * 256
134 | // each block accumulation maximum 127 * 256
135 | // each 32 group index (128 index in one group) needs cast to int32
136 | xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
137 | xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
138 | xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
139 | xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
140 |
141 | accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1));
142 | accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3));
143 | }
144 | accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, _mm256_set1_epi16(1)), accu);
145 | }
146 |
147 | for (int i = 0; i < groupla_num; i++){
148 | __m256i accula = _mm256_setzero_si256();
149 | for (int j = 0; j < la_num; j++) {
150 | // 128 index
151 | __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + group32_num * 32 * 32 + j * 32));
152 | __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
153 | __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
154 | __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
155 |
156 | // each 32 index
157 | xq8_3 = _mm256_and_si256(xq8_3, mask);
158 | xq8_2 = _mm256_and_si256(xq8_2, mask);
159 | xq8_1 = _mm256_and_si256(xq8_1, mask);
160 | xq8_0 = _mm256_and_si256(xq8_0, mask);
161 |
162 | // each 32 index
163 | __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 0));
164 | __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 32));
165 | __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 64));
166 | __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 96));
167 |
168 | // 128 index accumulation add
169 | // split into 32 accumulation block
170 | // each block each 128 index accumulated 4index
171 | // each index maximum 256
172 | // each block maximum 4 * 256
173 | // each block accumulation maximum 127 * 256
174 | // each 32 group index (128 index in one group) needs cast to int32
175 | xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
176 | xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
177 | xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
178 | xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
179 |
180 | accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1));
181 | accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3));
182 | }
183 | accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, _mm256_set1_epi16(1)));
184 | }
185 | int sumi = hsum_i32_8(accu);
186 | *s = (float)sumi;
187 |
188 | #elif defined(__ARM_NEON)
189 |
190 | int32x4_t accu_0 = vdupq_n_s32(0);
191 | int32x4_t accu_1 = vdupq_n_s32(0);
192 | int32x4_t accu_2 = vdupq_n_s32(0);
193 | int32x4_t accu_3 = vdupq_n_s32(0);
194 | const uint8x16_t mask = vdupq_n_u8(3);
195 |
196 | for (int i=0; i < group32_num; i++) {
197 |
198 | #if defined(__ARM_FEATURE_DOTPROD)
199 |
200 | #else
201 | int16x8_t accu32_0 = vdupq_n_s16(0);
202 | int16x8_t accu32_1 = vdupq_n_s16(0);
203 | int16x8_t accu32_2 = vdupq_n_s16(0);
204 | int16x8_t accu32_3 = vdupq_n_s16(0);
205 | #endif
206 |
207 | for (int j=0; j < 32; j++) {
208 | uint8x16_t xq8_6 = vld1q_u8(x + i * 32 * 32 + j * 32);
209 | uint8x16_t xq8_7 = vld1q_u8(x + i * 32 * 32 + j * 32 + 16);
210 | uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2);
211 | uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2);
212 | uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4);
213 | uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4);
214 | uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6);
215 | uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6);
216 |
217 | int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
218 | int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
219 | int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
220 | int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
221 | int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask));
222 | int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask));
223 | int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask));
224 | int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask));
225 |
226 | const int8x16_t yq8_0 = vld1q_s8(y + i * 128 * 32 + j * 128 + 0);
227 | const int8x16_t yq8_1 = vld1q_s8(y + i * 128 * 32 + j * 128 + 16);
228 | const int8x16_t yq8_2 = vld1q_s8(y + i * 128 * 32 + j * 128 + 32);
229 | const int8x16_t yq8_3 = vld1q_s8(y + i * 128 * 32 + j * 128 + 48);
230 | const int8x16_t yq8_4 = vld1q_s8(y + i * 128 * 32 + j * 128 + 64);
231 | const int8x16_t yq8_5 = vld1q_s8(y + i * 128 * 32 + j * 128 + 80);
232 | const int8x16_t yq8_6 = vld1q_s8(y + i * 128 * 32 + j * 128 + 96);
233 | const int8x16_t yq8_7 = vld1q_s8(y + i * 128 * 32 + j * 128 + 112);
234 |
235 | #if defined(__ARM_FEATURE_DOTPROD)
236 | accu_0 = vdotq_s32(accu_0, q8_0, yq8_0);
237 | accu_1 = vdotq_s32(accu_1, q8_1, yq8_1);
238 | accu_2 = vdotq_s32(accu_2, q8_2, yq8_2);
239 | accu_3 = vdotq_s32(accu_3, q8_3, yq8_3);
240 | accu_0 = vdotq_s32(accu_0, q8_4, yq8_4);
241 | accu_1 = vdotq_s32(accu_1, q8_5, yq8_5);
242 | accu_2 = vdotq_s32(accu_2, q8_6, yq8_6);
243 | accu_3 = vdotq_s32(accu_3, q8_7, yq8_7);
244 | #else
245 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_0), vget_low_s8(yq8_0));
246 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_0), vget_high_s8(yq8_0));
247 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_1), vget_low_s8(yq8_1));
248 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_1), vget_high_s8(yq8_1));
249 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_2), vget_low_s8(yq8_2));
250 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_2), vget_high_s8(yq8_2));
251 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_3), vget_low_s8(yq8_3));
252 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_3), vget_high_s8(yq8_3));
253 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_4), vget_low_s8(yq8_4));
254 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_4), vget_high_s8(yq8_4));
255 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_5), vget_low_s8(yq8_5));
256 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_5), vget_high_s8(yq8_5));
257 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_6), vget_low_s8(yq8_6));
258 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_6), vget_high_s8(yq8_6));
259 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_7), vget_low_s8(yq8_7));
260 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_7), vget_high_s8(yq8_7));
261 | #endif
262 | }
263 |
264 | #if defined(__ARM_FEATURE_DOTPROD)
265 |
266 | #else
267 | accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accu32_0)));
268 | accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accu32_0));
269 | accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accu32_1)));
270 | accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accu32_1));
271 | accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accu32_2)));
272 | accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accu32_2));
273 | accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accu32_3)));
274 | accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accu32_3));
275 | #endif
276 | }
277 |
278 | for (int i = 0; i < groupla_num; i++){
279 | #if defined(__ARM_FEATURE_DOTPROD)
280 |
281 | #else
282 | int16x8_t accula_0 = vdupq_n_s16(0);
283 | int16x8_t accula_1 = vdupq_n_s16(0);
284 | int16x8_t accula_2 = vdupq_n_s16(0);
285 | int16x8_t accula_3 = vdupq_n_s16(0);
286 | #endif
287 | for (int j = 0; j < la_num; j++) {
288 | uint8x16_t xq8_6 = vld1q_u8(x + group32_num * 32 * 32 + j * 32);
289 | uint8x16_t xq8_7 = vld1q_u8(x + group32_num * 32 * 32 + j * 32 + 16);
290 | uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2);
291 | uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2);
292 | uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4);
293 | uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4);
294 | uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6);
295 | uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6);
296 |
297 | int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
298 | int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
299 | int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
300 | int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
301 | int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask));
302 | int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask));
303 | int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask));
304 | int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask));
305 |
306 | const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 0);
307 | const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 16);
308 | const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 32);
309 | const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 48);
310 | const int8x16_t yq8_4 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 64);
311 | const int8x16_t yq8_5 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 80);
312 | const int8x16_t yq8_6 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 96);
313 | const int8x16_t yq8_7 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 112);
314 |
315 | #if defined(__ARM_FEATURE_DOTPROD)
316 | accu_0 = vdotq_s32(accu_0, q8_0, yq8_0);
317 | accu_1 = vdotq_s32(accu_1, q8_1, yq8_1);
318 | accu_2 = vdotq_s32(accu_2, q8_2, yq8_2);
319 | accu_3 = vdotq_s32(accu_3, q8_3, yq8_3);
320 | accu_0 = vdotq_s32(accu_0, q8_4, yq8_4);
321 | accu_1 = vdotq_s32(accu_1, q8_5, yq8_5);
322 | accu_2 = vdotq_s32(accu_2, q8_6, yq8_6);
323 | accu_3 = vdotq_s32(accu_3, q8_7, yq8_7);
324 | #else
325 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_0), vget_low_s8(yq8_0));
326 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_0), vget_high_s8(yq8_0));
327 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_1), vget_low_s8(yq8_1));
328 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_1), vget_high_s8(yq8_1));
329 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_2), vget_low_s8(yq8_2));
330 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_2), vget_high_s8(yq8_2));
331 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_3), vget_low_s8(yq8_3));
332 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_3), vget_high_s8(yq8_3));
333 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_4), vget_low_s8(yq8_4));
334 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_4), vget_high_s8(yq8_4));
335 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_5), vget_low_s8(yq8_5));
336 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_5), vget_high_s8(yq8_5));
337 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_6), vget_low_s8(yq8_6));
338 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_6), vget_high_s8(yq8_6));
339 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_7), vget_low_s8(yq8_7));
340 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_7), vget_high_s8(yq8_7));
341 | #endif
342 | }
343 | #if defined(__ARM_FEATURE_DOTPROD)
344 |
345 | #else
346 | accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accula_0)));
347 | accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accula_0));
348 | accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accula_1)));
349 | accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accula_1));
350 | accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accula_2)));
351 | accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accula_2));
352 | accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accula_3)));
353 | accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accula_3));
354 | #endif
355 | }
356 | accu_0 = vaddq_s32(accu_0, accu_1);
357 | accu_2 = vaddq_s32(accu_2, accu_3);
358 | accu_0 = vaddq_s32(accu_0, accu_2);
359 | int sumi = vaddlvq_s32(accu_0);
360 | *s = (float)sumi;
361 |
362 | #endif
363 | }
--------------------------------------------------------------------------------
/utils/codegen_tl1.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from configparser import ConfigParser
4 |
5 | def gen_ctor_code():
6 | kernel_code = "\n\
7 | #include \"ggml-bitnet.h\"\n\
8 | #define GGML_BITNET_MAX_NODES 8192\n\
9 | static bool initialized = false;\n\
10 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\
11 | static size_t bitnet_tensor_extras_index = 0;\n\
12 | static void * aligned_malloc(size_t size) {{\n\
13 | #if defined(_WIN32)\n\
14 | return _aligned_malloc(size, 64);\n\
15 | #else\n\
16 | void * ptr = nullptr;\n\
17 | posix_memalign(&ptr, 64, size);\n\
18 | return ptr;\n\
19 | #endif\n\
20 | }}\n\
21 | static void aligned_free(void * ptr) {{\n\
22 | #if defined(_WIN32)\n\
23 | _aligned_free(ptr);\n\
24 | #else\n\
25 | free(ptr);\n\
26 | #endif\n\
27 | }}\n\
28 | \n\
29 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{\n\
30 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
31 | bitnet_float_type* b = (bitnet_float_type*)b_;\n\
32 | #ifdef __ARM_NEON\n\
33 | float32x4_t temp_max = vdupq_n_f32(0);\n\
34 | for (int i=0; i < k / 4; i++) {{\n\
35 | float32x4_t vec_bs = vld1q_f32(b + 4 * i);\n\
36 | float32x4_t abssum = vabsq_f32(vec_bs);\n\
37 | temp_max = vmaxq_f32(abssum, temp_max);\n\
38 | }}\n\
39 | float32_t scales = 127 / vmaxvq_f32(temp_max);\n\
40 | *lut_scales = scales;\n\
41 | #elif defined __AVX2__\n\
42 | __m256 max_vec = _mm256_set1_ps(0.f);\n\
43 | const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\
44 | // #pragma unroll\n\
45 | for (int i = 0; i < k / 8; i++) {{\n\
46 | __m256 vec_b = _mm256_loadu_ps(b + i * 8);\n\
47 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);\n\
48 | max_vec = _mm256_max_ps(vec_babs, max_vec);\n\
49 | }}\n\
50 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));\n\
51 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));\n\
52 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));\n\
53 | float scales = 127 / _mm_cvtss_f32(max1);\n\
54 | *lut_scales = scales;\n\
55 | #endif\n\
56 | }}\n\
57 | \n\
58 | void partial_max_reset(void* lut_scales_) {{\n\
59 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
60 | *lut_scales = 0.0;\n\
61 | }}\n\
62 | \n\
63 | #ifdef __ARM_NEON\n\
64 | inline void Transpose_8_8(\n\
65 | int16x8_t *v0,\n\
66 | int16x8_t *v1,\n\
67 | int16x8_t *v2,\n\
68 | int16x8_t *v3,\n\
69 | int16x8_t *v4,\n\
70 | int16x8_t *v5,\n\
71 | int16x8_t *v6,\n\
72 | int16x8_t *v7)\n\
73 | {{\n\
74 | int16x8x2_t q04 = vzipq_s16(*v0, *v4);\n\
75 | int16x8x2_t q15 = vzipq_s16(*v1, *v5);\n\
76 | int16x8x2_t q26 = vzipq_s16(*v2, *v6);\n\
77 | int16x8x2_t q37 = vzipq_s16(*v3, *v7);\n\
78 | \n\
79 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);\n\
80 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);\n\
81 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);\n\
82 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);\n\
83 | \n\
84 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);\n\
85 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);\n\
86 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);\n\
87 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);\n\
88 | \n\
89 | *v0 = q_fin_0.val[0];\n\
90 | *v1 = q_fin_0.val[1];\n\
91 | *v2 = q_fin_1.val[0];\n\
92 | *v3 = q_fin_1.val[1];\n\
93 | *v4 = q_fin_2.val[0];\n\
94 | *v5 = q_fin_2.val[1];\n\
95 | *v6 = q_fin_3.val[0];\n\
96 | *v7 = q_fin_3.val[1];\n\
97 | }}\n\
98 | #endif\n\
99 | \n\
100 | template\n\
101 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{\n\
102 | #ifdef __ARM_NEON\n\
103 | int16x8_t vec_lut[16];\n\
104 | float32_t scales = *lut_scales;\n\
105 | uint8_t tbl_mask[16];\n\
106 | tbl_mask[0] = 0;\n\
107 | tbl_mask[1] = 2;\n\
108 | tbl_mask[2] = 4;\n\
109 | tbl_mask[3] = 6;\n\
110 | tbl_mask[4] = 8;\n\
111 | tbl_mask[5] = 10;\n\
112 | tbl_mask[6] = 12;\n\
113 | tbl_mask[7] = 14;\n\
114 | tbl_mask[8] = 1;\n\
115 | tbl_mask[9] = 3;\n\
116 | tbl_mask[10] = 5;\n\
117 | tbl_mask[11] = 7;\n\
118 | tbl_mask[12] = 9;\n\
119 | tbl_mask[13] = 11;\n\
120 | tbl_mask[14] = 13;\n\
121 | tbl_mask[15] = 15;\n\
122 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);\n\
123 | #pragma unroll\n\
124 | for (int k = 0; k < act_k / 16; ++k) {{\n\
125 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);\n\
126 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);\n\
127 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);\n\
128 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);\n\
129 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);\n\
130 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);\n\
131 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);\n\
132 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);\n\
133 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);\n\
134 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);\n\
135 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);\n\
136 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);\n\
137 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);\n\
138 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);\n\
139 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);\n\
140 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);\n\
141 | vec_lut[0] = vdupq_n_s16(0);\n\
142 | vec_lut[0] = vec_lut[0] - vec_bs_0;\n\
143 | vec_lut[0] = vec_lut[0] - vec_bs_1;\n\
144 | vec_lut[1] = vdupq_n_s16(0);\n\
145 | vec_lut[1] = vec_lut[1] - vec_bs_0;\n\
146 | vec_lut[2] = vdupq_n_s16(0);\n\
147 | vec_lut[2] = vec_lut[2] - vec_bs_0;\n\
148 | vec_lut[2] = vec_lut[2] + vec_bs_1;\n\
149 | vec_lut[3] = vdupq_n_s16(0);\n\
150 | vec_lut[3] = vec_lut[3] - vec_bs_1;\n\
151 | vec_lut[4] = vdupq_n_s16(0);\n\
152 | vec_lut[5] = vec_bs_1;\n\
153 | vec_lut[6] = vec_bs_0;\n\
154 | vec_lut[6] = vec_lut[6] - vec_bs_1;\n\
155 | vec_lut[7] = vec_bs_0;\n\
156 | vec_lut[8] = vec_bs_0;\n\
157 | vec_lut[8] = vec_lut[8] + vec_bs_1;\n\
158 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),\n\
159 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));\n\
160 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),\n\
161 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));\n\
162 | #pragma unroll\n\
163 | for (int idx = 0; idx < 8; idx++) {{\n\
164 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);\n\
165 | int8x8_t q0_low = vget_low_s8(q0_s);\n\
166 | int8x8_t q0_high = vget_high_s8(q0_s);\n\
167 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);\n\
168 | int8x8_t q1_low = vget_low_s8(q1_s);\n\
169 | int8x8_t q1_high = vget_high_s8(q1_s);\n\
170 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);\n\
171 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);\n\
172 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);\n\
173 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);\n\
174 | }}\n\
175 | }}\n\
176 | #endif\n\
177 | }}\n\
178 | \n\
179 | static bool is_type_supported(enum ggml_type type) {{\n\
180 | if (type == GGML_TYPE_Q4_0 ||\n\
181 | type == GGML_TYPE_TL1) {{\n\
182 | return true;\n\
183 | }} else {{\n\
184 | return false;\n\
185 | }}\n\
186 | }}\n\
187 | "
188 | return kernel_code
189 |
190 | def gen_body_core_code(bm, by):
191 | length = 4
192 | all_code = ""
193 | for i in range(length):
194 | core_code = "\n\
195 | uint8x16_t vec_a_{0} = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + {0} * 16);\n\
196 | uint8x16_t vec_a{0}_top = vshrq_n_u8(vec_a_{0}, 4);\n\
197 | uint8x16_t vec_a{0}_bot = vandq_u8(vec_a_{0}, vec_mask);\n\
198 | int8x16_t vec_v_{0}_left_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {2}], vec_a{0}_top);\n\
199 | int8x16_t vec_v_{0}_left_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {3}], vec_a{0}_top);\n\
200 | int8x16_t vec_v_{0}_right_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {4}], vec_a{0}_bot);\n\
201 | int8x16_t vec_v_{0}_right_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {5}], vec_a{0}_bot);\n\
202 | int8x16x2_t vec_v_left_{0} = vzipq_s8(vec_v_{0}_left_tmp1, vec_v_{0}_left_tmp0);\n\
203 | int8x16x2_t vec_v_right_{0} = vzipq_s8(vec_v_{0}_right_tmp1, vec_v_{0}_right_tmp0);\n\
204 | vec_c[{6}] += vec_v_left_{0}.val[0];\n\
205 | vec_c[{6}] += vec_v_right_{0}.val[0];\n\
206 | vec_c[{7}] += vec_v_left_{0}.val[1];\n\
207 | vec_c[{7}] += vec_v_right_{0}.val[1];\n\
208 | ".format(i, 2 * by // 2, (4 * i) % (2 * by // 2), (4 * i + 1) % (2 * by // 2), (4 * i + 2) % (2 * by // 2), (4 * i + 3) % (2 * by // 2), (i * 2) // (by // 2) * 2 + 0, (i * 2) // (by // 2) * 2 + 1)
209 |
210 | all_code = "".join([all_code, core_code])
211 |
212 | all_code = "".join([all_code, "\n }\n\n"])
213 |
214 | for i in range(bm // 8):
215 | core_code = "\
216 | int32x4_t vec_v_bot_low_low_{0} = vmovl_s16(vget_low_s16(vec_c[{0}]));\n\
217 | int32x4_t vec_v_bot_low_high_{0} = vmovl_high_s16(vec_c[{0}]);\n\
218 | vst1q_s32(c + i + {1}, vld1q_s32(c + i + {1}) + vec_v_bot_low_low_{0});\n\
219 | vst1q_s32(c + i + {2}, vld1q_s32(c + i + {2}) + vec_v_bot_low_high_{0});\n".format(i, i * 8, i * 8 + 4)
220 | all_code = "".join([all_code, core_code])
221 |
222 | return all_code
223 |
224 | def gen_tbl_impl(pre, BM, BK, bm, k):
225 |
226 | kernel_code = "\
227 | #include \n\
228 | \n\
229 | #define BM{0} {1}\n\
230 | #define BBK{0} {2}\n\
231 | inline void tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\
232 | #ifdef __ARM_NEON\n\
233 | const int KK = BBK{0} / 2;\n\
234 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\
235 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);\n\
236 | int8x16_t vec_lut[2 * KK];\n\
237 | ".format(pre, BM, BK)
238 |
239 | kernel_code = "".join([kernel_code, " int16x8_t vec_c[{}];".format(bm // 8)])
240 |
241 | kernel_code = "".join([kernel_code, "\n\
242 | #pragma unroll\n\
243 | for (int k = 0; k < 2 * KK; k++) {\n\
244 | vec_lut[k] = vld1q_s8(lut + k * 16);\n\
245 | }\n"])
246 |
247 | pre_core_code = "\n\
248 | #pragma unroll\n\
249 | for (int i = 0; i < BM{}; i += {}) {{\n\
250 | #pragma unroll\n\
251 | for (int i=0; i<{}; i++) {{\n\
252 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);\n\
253 | }}\n".format(pre, bm, bm // 8)
254 |
255 | body_core_pre_code = "\n\
256 | #pragma unroll\n\
257 | for (int k = 0; k < KK / {}; k++) {{\n\
258 | ".format(256 // bm // 2)
259 |
260 | body_core_post_code = "\n\
261 | }\n\
262 | \
263 | #endif\n\
264 | }\n"
265 |
266 | kernel_code = "".join([kernel_code, pre_core_code, body_core_pre_code, gen_body_core_code(bm, 256 // bm), body_core_post_code])
267 |
268 | kernel_code = "".join([kernel_code, "\n\
269 | int32_t qgemm_lut_{0}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
270 | alignas({1}) uint32_t CBits[BM{0}];\n\
271 | memset(&(CBits[0]), 0, BM{0} * sizeof(int32_t));\n\
272 | #pragma unroll\n\
273 | for (int32_t k_outer = 0; k_outer < {2} / BBK{0}; ++k_outer) {{\n\
274 | tbl_impl_{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{0} / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{0} / 2 / 2 * BM{0})])));\n\
275 | }}\n\
276 | #pragma unroll\n\
277 | for (int i = 0; i < BM{0}; i++) {{\n\
278 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];\n\
279 | }}\n\
280 | return 0;\n\
281 | }};\n".format(pre, min(32, BK), k)])
282 |
283 | return kernel_code
284 |
285 | def gen_top_api(kernel_shapes):
286 |
287 | kernel_code = "void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {{\n\
288 | if (m == {0} && k == {1}) {{\n\
289 | preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\
290 | }}\n\
291 | ".format(kernel_shapes[0][0], kernel_shapes[0][1])
292 | for i in range(1, len(kernel_shapes)):
293 | kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\
294 | preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\
295 | }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])])
296 | kernel_code = "".join([kernel_code, "}\n"])
297 | kernel_code = "".join([kernel_code, "void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
298 | if (m == {0} && k == {1}) {{\n\
299 | qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\
300 | }}\n\
301 | ".format(kernel_shapes[0][0], kernel_shapes[0][1])])
302 | for i in range(1, len(kernel_shapes)):
303 | kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\
304 | qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\
305 | }}\n\
306 | ".format(kernel_shapes[i][0], kernel_shapes[i][1])])
307 | kernel_code = "".join([kernel_code, "}\n"])
308 | return kernel_code
309 |
310 | def gen_preprocess_code():
311 | kernel_code = "\n\
312 | template\n\
313 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{\n\
314 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));\n\
315 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));\n\
316 | \n\
317 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));\n\
318 | }}\n"
319 | return kernel_code
320 |
321 | def gen_transform_code(kernel_shape):
322 | kernel_code = "\n\
323 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\
324 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\
325 | return;\n\
326 | }\n\
327 | \n\
328 | int k = tensor->ne[0];\n\
329 | int m = tensor->ne[1];\n\
330 | const int lut_scales_size = 1;\n\
331 | const int scales_size = 1;\n\
332 | int bk = 0;\n\
333 | int bm = 0;\n"
334 |
335 | kernel_code = "".join([kernel_code, "\n\
336 | if (m == {0} && k == {1}) {{\n\
337 | bm = BM{0}_{1};\n\
338 | bk = BBK{0}_{1};\n\
339 | }}\n".format(kernel_shapes[0][0], kernel_shapes[0][1])])
340 |
341 | for i in range(1, len(kernel_shapes)):
342 | kernel_code = "".join([kernel_code, "else if (m == {0} && k == {1}) {{\n\
343 | bm = BM{0}_{1};\n\
344 | bk = BBK{0}_{1};\n\
345 | }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])])
346 |
347 | kernel_code = "".join([kernel_code, "\n\
348 | const int n_tile_num = m / bm;\n\
349 | const int BK = bk;\n\
350 | uint8_t * qweights;\n\
351 | bitnet_float_type * scales;\n\
352 | \n\
353 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\
354 | qweights = (uint8_t *) tensor->data;\n\
355 | float * i2_scales = (float * )(qweights + k * m / 4);\n\
356 | scales[0] = (bitnet_float_type) i2_scales[0];\n\
357 | \n\
358 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\
359 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\
360 | /* .lut_scales_size = */ lut_scales_size,\n\
361 | /* .BK = */ BK,\n\
362 | /* .n_tile_num = */ n_tile_num,\n\
363 | /* .qweights = */ qweights,\n\
364 | /* .scales = */ scales\n\
365 | };\n\
366 | }\n"])
367 |
368 | return kernel_code
369 |
370 | if __name__ == "__main__":
371 | ModelShapeDict = {
372 | "bitnet_b1_58-large" : [[1536, 4096],
373 | [1536, 1536],
374 | [4096, 1536]],
375 | "bitnet_b1_58-3B" : [[3200, 8640],
376 | [3200, 3200],
377 | [8640, 3200]],
378 | "Llama3-8B-1.58-100B-tokens" : [[14336, 4096],
379 | [4096, 14336],
380 | [1024, 4096],
381 | [4096, 4096]]
382 | }
383 |
384 | parser = argparse.ArgumentParser(description='gen impl')
385 | parser.add_argument('--model',default="input", type=str, dest="model",
386 | help="choose from bitnet_b1_58-large/bitnet_b1_58-3B/Llama3-8B-1.58-100B-tokens.")
387 | parser.add_argument('--BM',default="input", type=str,
388 | help="block length when cutting one weight (M, K) into M / BM weights (BM, K).")
389 | parser.add_argument('--BK',default="input", type=str,
390 | help="block length when cutting one weight (M, K) into K / BK weights (M, BK).")
391 | parser.add_argument('--bm',default="input", type=str,
392 | help="using simd instructions to compute (bm, 256 / bm) in one block")
393 | args = parser.parse_args()
394 |
395 | kernel_shapes = ModelShapeDict[args.model]
396 |
397 | BM_list = [int(item) for item in args.BM.split(',')]
398 | BK_list = [int(item) for item in args.BK.split(',')]
399 | bm_list = [int(item) for item in args.bm.split(',')]
400 |
401 | assert(len(BM_list) == len(BK_list) == len(bm_list) == len(kernel_shapes)), "number of BM / BK / bm shoud be {}".format(len(kernel_shapes))
402 |
403 | for i in range(len(kernel_shapes)):
404 | assert kernel_shapes[i][0] % BM_list[i] == 0, "M %% BM should be 0"
405 | assert kernel_shapes[i][1] % BK_list[i] == 0, "K %% BK should be 0"
406 | assert bm_list[i] in [32, 64], "choose bm from [32, 64]"
407 |
408 | tbl_impl_code = []
409 |
410 | for i in range(len(kernel_shapes)):
411 | tbl_impl_code.append(
412 | gen_tbl_impl("{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]), BM_list[i], BK_list[i], bm_list[i], kernel_shapes[i][1])
413 | )
414 | api_code = gen_top_api(kernel_shapes)
415 | pre_code = gen_preprocess_code()
416 | ctor_code = gen_ctor_code()
417 | trans_code = gen_transform_code(kernel_shapes)
418 |
419 | output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include")
420 |
421 | with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f:
422 | f.write(''.join("#if defined(GGML_BITNET_ARM_TL1)"))
423 | f.write(''.join(ctor_code))
424 | for code in tbl_impl_code:
425 | f.write(''.join(code))
426 | f.write(''.join(pre_code))
427 | f.write(''.join(api_code))
428 | f.write(''.join(trans_code))
429 | f.write(''.join("#endif"))
430 |
431 | config = ConfigParser()
432 |
433 | for i in range(len(kernel_shapes)):
434 | config.add_section('Kernels_{}'.format(i))
435 | config.set('Kernels_{}'.format(i), 'M'.format(i), str(kernel_shapes[i][0]))
436 | config.set('Kernels_{}'.format(i), 'K'.format(i), str(kernel_shapes[i][1]))
437 | config.set('Kernels_{}'.format(i), 'BM'.format(i), str(BM_list[i]))
438 | config.set('Kernels_{}'.format(i), 'BK'.format(i), str(BK_list[i]))
439 | config.set('Kernels_{}'.format(i), 'bmm'.format(i), str(bm_list[i]))
440 |
441 | with open(''.join([output_dir, "/kernel_config.ini"]), 'w') as configfile:
442 | config.write(configfile)
--------------------------------------------------------------------------------
/utils/e2e_benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import logging
4 | import argparse
5 | import platform
6 | import subprocess
7 |
8 | def run_command(command, shell=False, log_step=None):
9 | """Run a system command and ensure it succeeds."""
10 | if log_step:
11 | log_file = os.path.join(args.log_dir, log_step + ".log")
12 | with open(log_file, "w") as f:
13 | try:
14 | subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f)
15 | except subprocess.CalledProcessError as e:
16 | logging.error(f"Error occurred while running command: {e}, check details in {log_file}")
17 | sys.exit(1)
18 | else:
19 | try:
20 | subprocess.run(command, shell=shell, check=True)
21 | except subprocess.CalledProcessError as e:
22 | logging.error(f"Error occurred while running command: {e}")
23 | sys.exit(1)
24 |
25 | def run_benchmark():
26 | build_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "build")
27 | if platform.system() == "Windows":
28 | bench_path = os.path.join(build_dir, "bin", "Release", "llama-bench.exe")
29 | if not os.path.exists(bench_path):
30 | bench_path = os.path.join(build_dir, "bin", "llama-bench")
31 | else:
32 | bench_path = os.path.join(build_dir, "bin", "llama-bench")
33 | if not os.path.exists(bench_path):
34 | logging.error(f"Benchmark binary not found, please build first.")
35 | sys.exit(1)
36 | command = [
37 | f'{bench_path}',
38 | '-m', args.model,
39 | '-n', str(args.n_token),
40 | '-ngl', '0',
41 | '-b', '1',
42 | '-t', str(args.threads),
43 | '-p', str(args.n_prompt),
44 | '-r', '5'
45 | ]
46 | run_command(command)
47 |
48 | def parse_args():
49 | parser = argparse.ArgumentParser(description='Setup the environment for running the inference')
50 | parser.add_argument("-m", "--model", type=str, help="Path to model file", required=True)
51 | parser.add_argument("-n", "--n-token", type=int, help="Number of generated tokens", required=False, default=128)
52 | parser.add_argument("-p", "--n-prompt", type=int, help="Prompt to generate text from", required=False, default=512)
53 | parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2)
54 | return parser.parse_args()
55 |
56 | if __name__ == "__main__":
57 | logging.basicConfig(level=logging.INFO)
58 | args = parse_args()
59 | run_benchmark()
--------------------------------------------------------------------------------
/utils/kernel_tuning.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/utils/kernel_tuning.py
--------------------------------------------------------------------------------