├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── assets ├── header_model_release.png ├── intel_performance.jpg ├── m2_performance.jpg ├── tl1.png └── tl2.png ├── docs └── codegen.md ├── gpu ├── README.md ├── bitnet_kernels │ ├── bitnet_kernels.cu │ ├── bitnet_kernels.h │ ├── compile.sh │ └── setup.py ├── convert_checkpoint.py ├── convert_safetensors.py ├── generate.py ├── model.py ├── pack_weight.py ├── requirements.txt ├── sample_utils.py ├── stats.py ├── test.py ├── tokenizer.model └── tokenizer.py ├── include └── ggml-bitnet.h ├── media ├── benchmark.png └── demo.mp4 ├── preset_kernels ├── Llama3-8B-1.58-100B-tokens │ ├── bitnet-lut-kernels-tl1.h │ ├── bitnet-lut-kernels-tl2.h │ ├── kernel_config_tl1.ini │ └── kernel_config_tl2.ini ├── bitnet_b1_58-3B │ ├── bitnet-lut-kernels-tl1.h │ ├── bitnet-lut-kernels-tl2.h │ ├── kernel_config_tl1.ini │ └── kernel_config_tl2.ini └── bitnet_b1_58-large │ ├── bitnet-lut-kernels-tl1.h │ ├── bitnet-lut-kernels-tl2.h │ ├── kernel_config_tl1.ini │ └── kernel_config_tl2.ini ├── requirements.txt ├── run_inference.py ├── run_inference_server.py ├── setup_env.py ├── src ├── CMakeLists.txt ├── ggml-bitnet-lut.cpp └── ggml-bitnet-mad.cpp └── utils ├── codegen_tl1.py ├── codegen_tl2.py ├── convert-hf-to-gguf-bitnet.py ├── convert-ms-to-gguf-bitnet.py ├── convert.py ├── e2e_benchmark.py ├── generate-dummy-bitnet-model.py └── kernel_tuning.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Extensions 2 | 3 | *.a 4 | *.bat 5 | *.bin 6 | *.dll 7 | *.dot 8 | *.etag 9 | *.exe 10 | *.gcda 11 | *.gcno 12 | *.gcov 13 | *.gguf 14 | *.gguf.json 15 | *.lastModified 16 | *.log 17 | *.metallib 18 | *.o 19 | *.so 20 | *.tmp 21 | 22 | # IDE / OS 23 | 24 | .cache/ 25 | .ccls-cache/ 26 | .direnv/ 27 | .DS_Store 28 | .envrc 29 | .idea/ 30 | .swiftpm 31 | .vs/ 32 | .vscode/ 33 | nppBackup 34 | 35 | # Models 36 | models/* 37 | gpu/checkpoints/* 38 | 39 | # Python 40 | 41 | /.venv 42 | __pycache__/ 43 | */poetry.lock 44 | poetry.toml 45 | 46 | build/ 47 | logs/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/llama.cpp"] 2 | path = 3rdparty/llama.cpp 3 | url = https://github.com/Eddie-Wang1120/llama.cpp.git 4 | branch = merge-dev 5 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. 2 | project("bitnet.cpp" C CXX) 3 | include(CheckIncludeFileCXX) 4 | 5 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 6 | 7 | if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) 8 | set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) 9 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") 10 | endif() 11 | 12 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 13 | 14 | # option list 15 | option(BITNET_ARM_TL1 "bitnet.cpp: use tl1 on arm platform" OFF) 16 | option(BITNET_X86_TL2 "bitnet.cpp: use tl2 on x86 platform" OFF) 17 | 18 | 19 | set(CMAKE_CXX_STANDARD_REQUIRED true) 20 | set(CMAKE_C_STANDARD 11) 21 | set(CMAKE_C_STANDARD_REQUIRED true) 22 | set(THREADS_PREFER_PTHREAD_FLAG ON) 23 | 24 | # override ggml options 25 | set(GGML_BITNET_ARM_TL1 ${BITNET_ARM_TL1}) 26 | set(GGML_BITNET_X86_TL2 ${BITNET_X86_TL2}) 27 | 28 | if (GGML_BITNET_ARM_TL1) 29 | add_compile_definitions(GGML_BITNET_ARM_TL1) 30 | endif() 31 | if (GGML_BITNET_X86_TL2) 32 | add_compile_definitions(GGML_BITNET_X86_TL2) 33 | endif() 34 | 35 | if (CMAKE_C_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") 36 | add_compile_options(-fpermissive) 37 | endif() 38 | 39 | find_package(Threads REQUIRED) 40 | 41 | add_subdirectory(src) 42 | set(LLAMA_BUILD_SERVER ON CACHE BOOL "Build llama.cpp server" FORCE) 43 | add_subdirectory(3rdparty/llama.cpp) 44 | 45 | # install 46 | 47 | include(GNUInstallDirs) 48 | include(CMakePackageConfigHelpers) 49 | 50 | set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} 51 | CACHE PATH "Location of header files") 52 | set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} 53 | CACHE PATH "Location of library files") 54 | set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} 55 | CACHE PATH "Location of binary files") 56 | set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER}) 57 | set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT}) 58 | set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER}) 59 | 60 | get_target_property(GGML_DIRECTORY ggml SOURCE_DIR) 61 | get_directory_property(GGML_DIR_DEFINES DIRECTORY ${GGML_DIRECTORY} COMPILE_DEFINITIONS) 62 | get_target_property(GGML_TARGET_DEFINES ggml COMPILE_DEFINITIONS) 63 | set(GGML_TRANSIENT_DEFINES ${GGML_TARGET_DEFINES} ${GGML_DIR_DEFINES}) 64 | get_target_property(GGML_LINK_LIBRARIES ggml LINK_LIBRARIES) 65 | 66 | get_directory_property(LLAMA_TRANSIENT_DEFINES COMPILE_DEFINITIONS) 67 | 68 | write_basic_package_version_file( 69 | ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake 70 | VERSION ${LLAMA_INSTALL_VERSION} 71 | COMPATIBILITY SameMajorVersion) 72 | 73 | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake 74 | ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake 75 | DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama) 76 | 77 | set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/llama.h) 78 | install(TARGETS llama LIBRARY PUBLIC_HEADER) 79 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bitnet.cpp 2 | [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) 3 | ![version](https://img.shields.io/badge/version-1.0-blue) 4 | 5 | [BitNet Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) 6 | 7 | Try it out via this [demo](https://bitnet-demo.azurewebsites.net/), or build and run it on your own [CPU](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) or [GPU](https://github.com/microsoft/BitNet/blob/main/gpu/README.md). 8 | 9 | bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU and GPU (NPU support will coming next). 10 | 11 | The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details. 12 | 13 | m2_performance 14 | m2_performance 15 | 16 | >The tested models are dummy setups used in a research context to demonstrate the inference performance of bitnet.cpp. 17 | 18 | ## Demo 19 | 20 | A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2: 21 | 22 | https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1 23 | 24 | ## What's New: 25 | - 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md) ![NEW](https://img.shields.io/badge/NEW-red) 26 | - 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) 27 | - 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880) 28 | - 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965) 29 | - 10/21/2024 [1-bit AI Infra: Part 1.1, Fast and Lossless BitNet b1.58 Inference on CPUs](https://arxiv.org/abs/2410.16144) 30 | - 10/17/2024 bitnet.cpp 1.0 released. 31 | - 03/21/2024 [The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf) 32 | - 02/27/2024 [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764) 33 | - 10/17/2023 [BitNet: Scaling 1-bit Transformers for Large Language Models](https://arxiv.org/abs/2310.11453) 34 | 35 | ## Acknowledgements 36 | 37 | This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp) framework. We would like to thank all the authors for their contributions to the open-source community. Also, bitnet.cpp's kernels are built on top of the Lookup Table methodologies pioneered in [T-MAC](https://github.com/microsoft/T-MAC/). For inference of general low-bit LLMs beyond ternary models, we recommend using T-MAC. 38 | ## Official Models 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 |
ModelParametersCPUKernel
I2_STL1TL2
BitNet-b1.58-2B-4T2.4Bx86
ARM
67 | 68 | ## Supported Models 69 | ❗️**We use existing 1-bit LLMs available on [Hugging Face](https://huggingface.co/) to demonstrate the inference capabilities of bitnet.cpp. We hope the release of bitnet.cpp will inspire the development of 1-bit LLMs in large-scale settings in terms of model size and training tokens.** 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 |
ModelParametersCPUKernel
I2_STL1TL2
bitnet_b1_58-large0.7Bx86
ARM
bitnet_b1_58-3B3.3Bx86
ARM
Llama3-8B-1.58-100B-tokens8.0Bx86
ARM
Falcon3 Family1B-10Bx86
ARM
Falcon-E Family1B-3Bx86
ARM
155 | 156 | 157 | 158 | ## Installation 159 | 160 | ### Requirements 161 | - python>=3.9 162 | - cmake>=3.22 163 | - clang>=18 164 | - For Windows users, install [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/). In the installer, toggle on at least the following options(this also automatically installs the required additional tools like CMake): 165 | - Desktop-development with C++ 166 | - C++-CMake Tools for Windows 167 | - Git for Windows 168 | - C++-Clang Compiler for Windows 169 | - MS-Build Support for LLVM-Toolset (clang) 170 | - For Debian/Ubuntu users, you can download with [Automatic installation script](https://apt.llvm.org/) 171 | 172 | `bash -c "$(wget -O - https://apt.llvm.org/llvm.sh)"` 173 | - conda (highly recommend) 174 | 175 | ### Build from source 176 | 177 | > [!IMPORTANT] 178 | > If you are using Windows, please remember to always use a Developer Command Prompt / PowerShell for VS2022 for the following commands. Please refer to the FAQs below if you see any issues. 179 | 180 | 1. Clone the repo 181 | ```bash 182 | git clone --recursive https://github.com/microsoft/BitNet.git 183 | cd BitNet 184 | ``` 185 | 2. Install the dependencies 186 | ```bash 187 | # (Recommended) Create a new conda environment 188 | conda create -n bitnet-cpp python=3.9 189 | conda activate bitnet-cpp 190 | 191 | pip install -r requirements.txt 192 | ``` 193 | 3. Build the project 194 | ```bash 195 | # Manually download the model and run with local path 196 | huggingface-cli download microsoft/BitNet-b1.58-2B-4T-gguf --local-dir models/BitNet-b1.58-2B-4T 197 | python setup_env.py -md models/BitNet-b1.58-2B-4T -q i2_s 198 | 199 | ``` 200 |
201 | usage: setup_env.py [-h] [--hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}] [--model-dir MODEL_DIR] [--log-dir LOG_DIR] [--quant-type {i2_s,tl1}] [--quant-embd]
202 |                     [--use-pretuned]
203 | 
204 | Setup the environment for running inference
205 | 
206 | optional arguments:
207 |   -h, --help            show this help message and exit
208 |   --hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}, -hr {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}
209 |                         Model used for inference
210 |   --model-dir MODEL_DIR, -md MODEL_DIR
211 |                         Directory to save/load the model
212 |   --log-dir LOG_DIR, -ld LOG_DIR
213 |                         Directory to save the logging info
214 |   --quant-type {i2_s,tl1}, -q {i2_s,tl1}
215 |                         Quantization type
216 |   --quant-embd          Quantize the embeddings to f16
217 |   --use-pretuned, -p    Use the pretuned kernel parameters
218 | 
219 | ## Usage 220 | ### Basic usage 221 | ```bash 222 | # Run inference with the quantized model 223 | python run_inference.py -m models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf -p "You are a helpful assistant" -cnv 224 | ``` 225 |
226 | usage: run_inference.py [-h] [-m MODEL] [-n N_PREDICT] -p PROMPT [-t THREADS] [-c CTX_SIZE] [-temp TEMPERATURE] [-cnv]
227 | 
228 | Run inference
229 | 
230 | optional arguments:
231 |   -h, --help            show this help message and exit
232 |   -m MODEL, --model MODEL
233 |                         Path to model file
234 |   -n N_PREDICT, --n-predict N_PREDICT
235 |                         Number of tokens to predict when generating text
236 |   -p PROMPT, --prompt PROMPT
237 |                         Prompt to generate text from
238 |   -t THREADS, --threads THREADS
239 |                         Number of threads to use
240 |   -c CTX_SIZE, --ctx-size CTX_SIZE
241 |                         Size of the prompt context
242 |   -temp TEMPERATURE, --temperature TEMPERATURE
243 |                         Temperature, a hyperparameter that controls the randomness of the generated text
244 |   -cnv, --conversation  Whether to enable chat mode or not (for instruct models.)
245 |                         (When this option is turned on, the prompt specified by -p will be used as the system prompt.)
246 | 
247 | 248 | ### Benchmark 249 | We provide scripts to run the inference benchmark providing a model. 250 | 251 | ``` 252 | usage: e2e_benchmark.py -m MODEL [-n N_TOKEN] [-p N_PROMPT] [-t THREADS] 253 | 254 | Setup the environment for running the inference 255 | 256 | required arguments: 257 | -m MODEL, --model MODEL 258 | Path to the model file. 259 | 260 | optional arguments: 261 | -h, --help 262 | Show this help message and exit. 263 | -n N_TOKEN, --n-token N_TOKEN 264 | Number of generated tokens. 265 | -p N_PROMPT, --n-prompt N_PROMPT 266 | Prompt to generate text from. 267 | -t THREADS, --threads THREADS 268 | Number of threads to use. 269 | ``` 270 | 271 | Here's a brief explanation of each argument: 272 | 273 | - `-m`, `--model`: The path to the model file. This is a required argument that must be provided when running the script. 274 | - `-n`, `--n-token`: The number of tokens to generate during the inference. It is an optional argument with a default value of 128. 275 | - `-p`, `--n-prompt`: The number of prompt tokens to use for generating text. This is an optional argument with a default value of 512. 276 | - `-t`, `--threads`: The number of threads to use for running the inference. It is an optional argument with a default value of 2. 277 | - `-h`, `--help`: Show the help message and exit. Use this argument to display usage information. 278 | 279 | For example: 280 | 281 | ```sh 282 | python utils/e2e_benchmark.py -m /path/to/model -n 200 -p 256 -t 4 283 | ``` 284 | 285 | This command would run the inference benchmark using the model located at `/path/to/model`, generating 200 tokens from a 256 token prompt, utilizing 4 threads. 286 | 287 | For the model layout that do not supported by any public model, we provide scripts to generate a dummy model with the given model layout, and run the benchmark on your machine: 288 | 289 | ```bash 290 | python utils/generate-dummy-bitnet-model.py models/bitnet_b1_58-large --outfile models/dummy-bitnet-125m.tl1.gguf --outtype tl1 --model-size 125M 291 | 292 | # Run benchmark with the generated model, use -m to specify the model path, -p to specify the prompt processed, -n to specify the number of token to generate 293 | python utils/e2e_benchmark.py -m models/dummy-bitnet-125m.tl1.gguf -p 512 -n 128 294 | ``` 295 | ### FAQ (Frequently Asked Questions)📌 296 | 297 | #### Q1: The build dies with errors building llama.cpp due to issues with std::chrono in log.cpp? 298 | 299 | **A:** 300 | This is an issue introduced in recent version of llama.cpp. Please refer to this [commit](https://github.com/tinglou/llama.cpp/commit/4e3db1e3d78cc1bcd22bcb3af54bd2a4628dd323) in the [discussion](https://github.com/abetlen/llama-cpp-python/issues/1942) to fix this issue. 301 | 302 | #### Q2: How to build with clang in conda environment on windows? 303 | 304 | **A:** 305 | Before building the project, verify your clang installation and access to Visual Studio tools by running: 306 | ``` 307 | clang -v 308 | ``` 309 | 310 | This command checks that you are using the correct version of clang and that the Visual Studio tools are available. If you see an error message such as: 311 | ``` 312 | 'clang' is not recognized as an internal or external command, operable program or batch file. 313 | ``` 314 | 315 | It indicates that your command line window is not properly initialized for Visual Studio tools. 316 | 317 | • If you are using Command Prompt, run: 318 | ``` 319 | "C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64 320 | ``` 321 | 322 | • If you are using Windows PowerShell, run the following commands: 323 | ``` 324 | Import-Module "C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\Tools\Microsoft.VisualStudio.DevShell.dll" Enter-VsDevShell 3f0e31ad -SkipAutomaticLocation -DevCmdArguments "-arch=x64 -host_arch=x64" 325 | ``` 326 | 327 | These steps will initialize your environment and allow you to use the correct Visual Studio tools. 328 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /assets/header_model_release.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/assets/header_model_release.png -------------------------------------------------------------------------------- /assets/intel_performance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/assets/intel_performance.jpg -------------------------------------------------------------------------------- /assets/m2_performance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/assets/m2_performance.jpg -------------------------------------------------------------------------------- /assets/tl1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/assets/tl1.png -------------------------------------------------------------------------------- /assets/tl2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/assets/tl2.png -------------------------------------------------------------------------------- /docs/codegen.md: -------------------------------------------------------------------------------- 1 | Codegen for TL1 and TL2 2 | ------------------------ 3 | 4 | codegen_tl1.py and codegen_tl2.py are using params to generate kernel codes in different devices to achieve fastest performance for TL1 and TL2. 5 | 6 | We cutting weight into multiple compute blocks to best utilize hardware capabilities. 7 | 8 | ### Example 9 | bitnet_b1_58-large: 10 | 11 | - Make sure Matmul kernels shapes \ 12 | For example, bitnet_b1_58-large Matmul kernel shapes are:\ 13 | [1536, 4096]\ 14 | [1536, 1536]\ 15 | [4096, 1536] 16 | 17 | - Make sure each BM, BK, bm for each kernel to meet the requirements below 18 | - Generate codes\ 19 | For example, for bitnet_b1_58-large, we can gencode like: 20 | 21 | ```bash 22 | # For TL1 23 | python utils/codegen_tl1.py --model bitnet_b1_58-large --BM 256,128,256 --BK 128,64,128 --bm 32,64,32 24 | 25 | # For TL2 26 | python utils/codegen_tl2.py --model bitnet_b1_58-large --BM 256,128,256 --BK 96,192,96 --bm 32,32,32 27 | ``` 28 | 29 | ### TL1: 30 | ![TL1](../assets/tl1.png) 31 | 32 | For TL1, we cut weight into M / BM weights, each weight shape is (BM, K). Then we cut weight into K / BK weights, each weight shape is (BM, BK). As for (BM, BK) weight, we cut it the same way into (bm, compute_num / bm) compute blocks, and finish computing in it. 33 | 34 | Thus, we need to make sure 35 | - M % BM == 0 36 | - K % BK == 0 37 | - BM % bm == 0 38 | - bm choose in [32, 64] 39 | 40 | ### TL2: 41 | ![TL2](../assets/tl2.png) 42 | 43 | For TL2, things got a little more complicated. Due to TL2 needs BK % 6 == 0, we need to split K into threeK and twoK, in which compute in TL2 for (M, threeK), compute in TL1 for (M, two_K). 44 | 45 | Thus, we needs to make sure 46 | - M % BM == 0 47 | - K % BK % 32 == 0 48 | - BM % bm == 0 49 | - bm choose in \[32\] -------------------------------------------------------------------------------- /gpu/README.md: -------------------------------------------------------------------------------- 1 | # BitNet Inference Kernel 2 | 3 | This repository provides a highly efficient GEMV kernel implementation for the BitNet model, optimized for W2A8 inference — 2-bit weights and 8-bit activations. It is tailored for use with the [BitNet-b1.58-2B-4T](https://arxiv.org/abs/2504.12285) model. 4 | 5 | ## Features 6 | 7 | - Support for W2A8 (2-bit weight × 8-bit activation) GEMV computation 8 | - Custom CUDA kernels with low-latency execution 9 | - Optimizations for memory access, decoding, and compute throughput 10 | 11 | ## Usage 12 | 13 | Installation and kernel performance tests: 14 | 15 | ```bash 16 | # (Recommended) Create a new conda environment 17 | conda create --name bitnet-gpu "python<3.13" 18 | conda activate bitnet-gpu 19 | 20 | # Install dependencies 21 | pip install -r requirements.txt 22 | 23 | # Build the kernel 24 | cd bitnet_kernels 25 | bash compile.sh 26 | cd .. 27 | 28 | # Run performance tests 29 | python test.py 30 | ``` 31 | 32 | End-to-end inference: 33 | 34 | ```bash 35 | # Download and convert the BitNet-b1.58-2B model 36 | mkdir checkpoints 37 | huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./checkpoints/bitnet-b1.58-2B-4T-bf16 38 | python ./convert_safetensors.py --safetensors_file ./checkpoints/bitnet-b1.58-2B-4T-bf16/model.safetensors --output checkpoints/model_state.pt --model_name 2B 39 | python ./convert_checkpoint.py --input ./checkpoints/model_state.pt 40 | rm ./checkpoints/model_state.pt 41 | 42 | # Inference 43 | python3 ./generate.py ./checkpoints/ --interactive --chat_format 44 | ``` 45 | 46 | ## Optimizations 47 | 48 | ### Weight Permutation 49 | 50 | The weight matrix is divided into 16×32 blocks to optimize memory access patterns. 51 | 52 | Within each block, values are stored contiguously in memory and permuted to facilitate efficient access and processing. 53 | 54 | See `convert_checkpoint.py` for details. 55 | 56 | ### Fast Decoding 57 | 58 | Every 16 two-bit values are packed into a single 32-bit integer using the following interleaving pattern: 59 | ``` 60 | [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15] 61 | ``` 62 | 63 | This layout is designed to accelerate decoding by enabling efficient extraction of 4 values at a time into `int8`. 64 | 65 | ### `dp4a` Instruction 66 | 67 | We use the `dp4a` instruction to accelerate low-precision dot product operations. 68 | 69 | This instruction performs a dot product between two 4-element vectors (each stored in a 32-bit word as 8-bit integers) and accumulates the result into a 32-bit integer. 70 | 71 | It significantly improves GEMV throughput when processing quantized weights and activations. 72 | 73 | 74 | ## Performance 75 | 76 | Kernel performance (tested on NVIDIA A100 40GB GPU): 77 | 78 | | Shape (N×K) | W2A8 Latency (us) | BF16 Latency (us) | Speedup Ratio | 79 | |---------------------|-------------------|-------------------|----------------------| 80 | | 2560 × 2560 | 13.32 | 18.32 | 1.38 | 81 | | 3840 × 2560 | 14.90 | 18.87 | 1.27 | 82 | | 13824 × 2560 | 18.75 | 59.51 | 3.17 | 83 | | 2560 × 6912 | 14.49 | 37.78 | 2.61 | 84 | | 3200 × 3200 | 14.61 | 19.08 | 1.31 | 85 | | 4800 × 3200 | 13.09 | 21.84 | 1.67 | 86 | | 3200 × 10240 | 19.64 | 60.79 | 3.10 | 87 | | 20480 × 3200 | 30.99 | 112.39 | 3.63 | 88 | 89 | Generation throughput: 90 | 91 | | BF16 (tokens/s) | W2A8 (tokens/s) | Speedup Ratio | 92 | |---|---|---| 93 | | 10.9 | 213.3 | 19.6 | -------------------------------------------------------------------------------- /gpu/bitnet_kernels/bitnet_kernels.cu: -------------------------------------------------------------------------------- 1 | #include "bitnet_kernels.h" 2 | 3 | extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){ 4 | if (M == 1 && N == 3840 && K == 2560){ 5 | ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<>>(input0, input1, output0, s, ws); 6 | } 7 | else if (M == 1 && N == 2560 && K == 2560){ 8 | ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<>>(input0, input1, output0, s, ws); 9 | } 10 | else if (M == 1 && N == 13824 && K == 2560){ 11 | ladder_int8xint2_kernel<1, 13824, 2560, 2, 8, 16><<>>(input0, input1, output0, s, ws); 12 | } 13 | else if (M == 1 && N == 2560 && K == 6912){ 14 | ladder_int8xint2_kernel<1, 2560, 6912, 1, 8, 16><<>>(input0, input1, output0, s, ws); 15 | } 16 | else if(M == 1 && N == 4800 && K == 3200){ 17 | ladder_int8xint2_kernel<1, 4800, 3200, 6, 8, 16><<>>(input0, input1, output0, s, ws); 18 | } 19 | else if(M == 1 && N == 3200 && K == 3200){ 20 | ladder_int8xint2_kernel<1, 3200, 3200, 1, 8, 16><<>>(input0, input1, output0, s, ws); 21 | } 22 | else if(M == 1 && N == 20480 && K == 3200){ 23 | ladder_int8xint2_kernel<1, 20480, 3200, 2, 8, 16><<>>(input0, input1, output0, s, ws); 24 | } 25 | else if(M == 1 && N == 3200 && K == 10240){ 26 | ladder_int8xint2_kernel<1, 3200, 10240, 1, 8, 16><<>>(input0, input1, output0, s, ws); 27 | } 28 | else if(M == 1 && N == 5120 && K == 27648){ 29 | ladder_int8xint2_kernel<1, 5120, 27648, 1, 8, 16><<>>(input0, input1, output0, s, ws); 30 | } 31 | else if(M == 1 && N == 55296 && K == 5120){ 32 | ladder_int8xint2_kernel<1, 55296, 5120, 1, 8, 16><<>>(input0, input1, output0, s, ws); 33 | } 34 | else{ 35 | std::cout << "required ladder gemm kernel: M " << M << ", N " << N << ", K " << K << std::endl; 36 | } 37 | } -------------------------------------------------------------------------------- /gpu/bitnet_kernels/bitnet_kernels.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | 11 | #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || (__CUDACC_VER_MAJOR__ > 11)) 12 | #define TVM_ENABLE_L2_PREFETCH 1 13 | #else 14 | #define TVM_ENABLE_L2_PREFETCH 0 15 | #endif 16 | 17 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800 18 | #define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1 19 | #else 20 | #define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0 21 | #endif 22 | 23 | template 24 | __device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16) 25 | { 26 | // convert 8 int2b_t to 8 int8b_t -> 2 int32 27 | uint *i8s = reinterpret_cast(_i8s); 28 | 29 | // i2s = {e0, e4, e8, e12, e1, e5, e9, e13, e2, e6, e10, e14, e3, e7, e11, e15} 30 | uint const i2s = *_i2s; 31 | 32 | static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 33 | static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 34 | static constexpr uint I4s_TO_I8s_MAGIC_NUM = 0x00000000; 35 | 36 | #pragma unroll 37 | for (int i = 0; i < (N / 4); i++) 38 | { 39 | asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" 40 | : "=r"(i8s[i]) 41 | : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(I4s_TO_I8s_MAGIC_NUM), "n"(immLut)); 42 | i8s[i] = __vsubss4(i8s[i], 0x02020202); 43 | } 44 | } 45 | 46 | template 47 | __global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restrict__ A, int8_t* __restrict__ B, __nv_bfloat16* __restrict__ dtype_transform, __nv_bfloat16* __restrict__ s, __nv_bfloat16* __restrict__ ws) { 48 | constexpr int K_per_loop = 16; 49 | constexpr int wmma_K = 32; 50 | constexpr int wmma_N = 16; 51 | int in_thread_C_local[1]; 52 | signed char A_local[K_per_loop]; 53 | int B_reshape_local[1]; 54 | signed char B_decode_local[K_per_loop]; 55 | int red_buf0[1]; 56 | in_thread_C_local[0] = 0; 57 | #pragma unroll 58 | for (int k_0 = 0; k_0 < K/(K_per_loop * K_block_size); ++k_0) { 59 | *(int4*)(A_local + 0) = *(int4*)(A + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop))); 60 | B_reshape_local[0] = *(int*)(B + 61 | (((int)blockIdx.x) * N_block_size * K / 4) + 62 | (k_0 * K_block_size * K_per_loop * wmma_N / 4) + 63 | ((((int)threadIdx.x) >> 1) * wmma_K * wmma_N / 4) + 64 | ((((int)threadIdx.y) >> 3) * (wmma_K * wmma_N / 2) / 4) + 65 | ((((int)threadIdx.x) & 1) * (wmma_K * wmma_N / 4) / 4) + 66 | ((((int)threadIdx.y) & 7) * (wmma_K / 2) / 4) 67 | ); 68 | decode_i2s_to_i8s(B_reshape_local, B_decode_local, 16); 69 | #pragma unroll 70 | for (int k_2_0 = 0; k_2_0 < 4; ++k_2_0) { 71 | in_thread_C_local[0] = __dp4a(*(int *)&A_local[((k_2_0 * 4))],*(int *)&B_decode_local[((k_2_0 * 4))], in_thread_C_local[0]); 72 | } 73 | } 74 | red_buf0[0] = in_thread_C_local[0]; 75 | #pragma unroll 76 | for (int offset = K_block_size/2; offset > 0; offset /= 2) { 77 | red_buf0[0] += __shfl_down_sync(__activemask(), red_buf0[0], offset, K_block_size); 78 | } 79 | int out_idx = ((((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y)); 80 | int ws_idx = out_idx / (N / ws_num); 81 | if (threadIdx.x == 0) 82 | dtype_transform[out_idx] = (__nv_bfloat16)(((float)red_buf0[0])/(float)s[0]*(float)ws[ws_idx]); 83 | } -------------------------------------------------------------------------------- /gpu/bitnet_kernels/compile.sh: -------------------------------------------------------------------------------- 1 | nvcc -std=c++17 -Xcudafe --diag_suppress=177 --compiler-options -fPIC -lineinfo --shared bitnet_kernels.cu -lcuda -gencode=arch=compute_80,code=compute_80 -o libbitnet.so 2 | 3 | 4 | -------------------------------------------------------------------------------- /gpu/bitnet_kernels/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='bitlinear_cpp', 6 | ext_modules=[ 7 | CUDAExtension('bitlinear_cuda', [ 8 | 'bitnet_kernels.cu', 9 | ]) 10 | ], 11 | cmdclass={ 12 | 'build_ext': BuildExtension 13 | }) -------------------------------------------------------------------------------- /gpu/convert_checkpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | from pathlib import Path 6 | from typing import Optional 7 | from dataclasses import dataclass 8 | import torch 9 | from einops import rearrange 10 | from safetensors.torch import save_file 11 | import model 12 | from pack_weight import convert_weight_int8_to_int2 13 | 14 | @torch.inference_mode() 15 | def convert_ts_checkpoint( 16 | *, 17 | input_path: str = "", 18 | ) -> None: 19 | 20 | config = model.ModelArgs() 21 | print(f"Model config {config.__dict__}") 22 | 23 | def quant_weight_int8(weight): 24 | s = 1.0 / weight.abs().mean().clamp_(min=1e-5) 25 | new_weight = (weight * s).round().clamp(-1, 1).to(torch.int8) 26 | new_scale = (1.0 / s).to(torch.bfloat16) 27 | return new_weight, new_scale.reshape(1) 28 | 29 | def quant_weight_fp16(weight): 30 | s = 1.0 / weight.abs().mean().clamp_(min=1e-5) 31 | new_weight = (weight * s).round().clamp(-1, 1) / s 32 | return new_weight 33 | 34 | def convert_int8_to_int2(weight): 35 | return convert_weight_int8_to_int2(weight) 36 | 37 | merged_result = torch.load(input_path, map_location="cpu", mmap=True) 38 | int2_result = {} 39 | fp16_result = {} 40 | zero = torch.zeros(1).to(torch.bfloat16) 41 | for key, value in merged_result.items(): 42 | if 'wqkv' in key: 43 | wq = value[:config.dim] 44 | wk = value[config.dim:config.dim // config.n_heads * config.n_kv_heads + config.dim] 45 | wv = value[config.dim // config.n_heads * config.n_kv_heads + config.dim:] 46 | wq_weight, wa_scale = quant_weight_int8(wq) 47 | wk_weight, wb_scale = quant_weight_int8(wk) 48 | wv_weight, wc_scale = quant_weight_int8(wv) 49 | wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0) 50 | wqkv_scale = torch.cat([wa_scale, wb_scale, wc_scale, zero], dim=0) 51 | int2_result[key] = convert_int8_to_int2(wqkv_weight) 52 | int2_result[key.replace('weight', 'weight_scale')] = wqkv_scale 53 | 54 | wq_weight = quant_weight_fp16(wq) 55 | wk_weight = quant_weight_fp16(wk) 56 | wv_weight = quant_weight_fp16(wv) 57 | wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0) 58 | fp16_result[key] = wqkv_weight 59 | elif 'w13' in key: 60 | w1 = value[:config.ffn_dim] 61 | w3 = value[config.ffn_dim:] 62 | w1_weight, w1_scale = quant_weight_int8(w1) 63 | w3_weight, w3_scale = quant_weight_int8(w3) 64 | w13_weight = torch.cat([w1_weight, w3_weight], dim=0) 65 | w13_scale = torch.cat([w1_scale, w3_scale, zero, zero], dim=0) 66 | int2_result[key] = convert_int8_to_int2(w13_weight) 67 | int2_result[key.replace('weight', 'weight_scale')] = w13_scale 68 | 69 | w1_weight = quant_weight_fp16(w1) 70 | w3_weight = quant_weight_fp16(w3) 71 | w13_weight = torch.cat([w1_weight, w3_weight], dim=0) 72 | fp16_result[key] = w13_weight 73 | elif 'w2' in key or 'wo' in key: 74 | weight, scale = quant_weight_int8(value) 75 | scale = torch.cat([scale, zero, zero, zero], dim=0) 76 | int2_result[key] = convert_int8_to_int2(weight) 77 | int2_result[key.replace('weight', 'weight_scale')] = scale 78 | 79 | weight = quant_weight_fp16(value) 80 | fp16_result[key] = weight 81 | else: 82 | int2_result[key] = value.clone() 83 | fp16_result[key] = value.clone() 84 | 85 | output_dir = os.path.dirname(input_path) 86 | print(f"Saving checkpoint to {output_dir}/model_state_int2.pt") 87 | torch.save(int2_result, f"{output_dir}/model_state_int2.pt") 88 | 89 | print(f"Saving checkpoint to {output_dir}/model_state_fp16.pt") 90 | torch.save(fp16_result, f"{output_dir}/model_state_fp16.pt") 91 | 92 | if __name__ == '__main__': 93 | import argparse 94 | parser = argparse.ArgumentParser(description='Convert TorchScale checkpoint.') 95 | parser.add_argument('--input', type=str) 96 | 97 | args = parser.parse_args() 98 | convert_ts_checkpoint( 99 | input_path=args.input, 100 | ) 101 | -------------------------------------------------------------------------------- /gpu/convert_safetensors.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from pathlib import Path 4 | from safetensors.torch import load_file 5 | from einops import rearrange 6 | from dataclasses import dataclass 7 | from typing import Optional 8 | 9 | transformer_configs = { 10 | "2B": dict(n_layer=30, n_head=20, dim=2560, vocab_size=128256, n_local_heads=5, intermediate_size=6912), 11 | } 12 | 13 | @dataclass 14 | class ModelArgs: 15 | block_size: int = 4096 16 | vocab_size: int = 32000 17 | n_layer: int = 32 18 | n_head: int = 32 19 | dim: int = 4096 20 | intermediate_size: int = None 21 | n_local_heads: int = -1 22 | head_dim: int = 64 23 | rope_base: float = 10000 24 | norm_eps: float = 1e-5 25 | 26 | def __post_init__(self): 27 | if self.n_local_heads == -1: 28 | self.n_local_heads = self.n_head 29 | if self.intermediate_size is None: 30 | hidden_dim = 4 * self.dim 31 | n_hidden = int(2 * hidden_dim / 3) 32 | self.intermediate_size = n_hidden + (256 - n_hidden % 256) if n_hidden % 256 else n_hidden 33 | self.head_dim = self.dim // self.n_head 34 | 35 | @classmethod 36 | def from_name(cls, name: str): 37 | if name in transformer_configs: 38 | return cls(**transformer_configs[name]) 39 | config = [k for k in transformer_configs if k in name.upper() or k in name] 40 | assert len(config) == 1, f"Unknown model name: {name}" 41 | return cls(**transformer_configs[config[0]]) 42 | 43 | def invert_convert_q(w: torch.Tensor, config: ModelArgs) -> torch.Tensor: 44 | return rearrange(w, '(h l d) i -> (h d l) i', h=config.n_head, l=2) 45 | 46 | def invert_convert_k(w: torch.Tensor, config: ModelArgs) -> torch.Tensor: 47 | return rearrange(w, '(h l d) i -> (h d l) i', h=config.n_local_heads, l=2) 48 | 49 | def convert_back( 50 | safetensors_path: str, 51 | output_file: str, 52 | model_name: Optional[str] = None, 53 | ): 54 | st_dict = load_file(safetensors_path) 55 | 56 | cfg = ModelArgs.from_name(model_name) 57 | print(f"Using model configurations: {cfg}") 58 | 59 | recovered: dict = {} 60 | 61 | for layer in range(cfg.n_layer): 62 | base = f"model.layers.{layer}." 63 | 64 | wq = st_dict[f"{base}self_attn.q_proj.weight"] 65 | wk = st_dict[f"{base}self_attn.k_proj.weight"] 66 | wv = st_dict[f"{base}self_attn.v_proj.weight"] 67 | 68 | wq = invert_convert_q(wq, cfg) 69 | wk = invert_convert_k(wk, cfg) 70 | 71 | wqkv = torch.cat([wq, wk, wv], dim=0) 72 | recovered[f"layers.{layer}.attention.wqkv.weight"] = wqkv 73 | 74 | recovered[f"layers.{layer}.attention.wo.weight"] = st_dict[f"{base}self_attn.o_proj.weight"] 75 | 76 | recovered[f"layers.{layer}.attention_norm.weight"] = st_dict[f"{base}input_layernorm.weight"] 77 | recovered[f"layers.{layer}.ffn_norm.weight"] = st_dict[f"{base}post_attention_layernorm.weight"] 78 | recovered[f"layers.{layer}.attention.attn_sub_norm.weight"] = st_dict[f"{base}self_attn.attn_sub_norm.weight"] 79 | recovered[f"layers.{layer}.feed_forward.ffn_sub_norm.weight"] = st_dict[f"{base}mlp.ffn_sub_norm.weight"] 80 | 81 | gate = st_dict[f"{base}mlp.gate_proj.weight"] 82 | up = st_dict[f"{base}mlp.up_proj.weight"] 83 | w13 = torch.cat([gate, up], dim=0) 84 | recovered[f"layers.{layer}.feed_forward.w13.weight"] = w13 85 | 86 | recovered[f"layers.{layer}.feed_forward.w2.weight"] = st_dict[f"{base}mlp.down_proj.weight"] 87 | 88 | recovered["tok_embeddings.weight"] = st_dict["model.embed_tokens.weight"] 89 | recovered["output.weight"] = st_dict["model.embed_tokens.weight"] 90 | recovered["norm.weight"] = st_dict["model.norm.weight"] 91 | 92 | print(f"Saving to {output_file}") 93 | torch.save(recovered, output_file) 94 | 95 | if __name__ == "__main__": 96 | import argparse 97 | parser = argparse.ArgumentParser(description="Convert Safetensors back to Torch .pth checkpoint") 98 | parser.add_argument( 99 | "--safetensors_file", type=str, required=True, 100 | help="Path to input .safetensors file" 101 | ) 102 | parser.add_argument( 103 | "--output", type=str, default="./checkpoints/model_state.pt", 104 | help="Path to output .pt file" 105 | ) 106 | parser.add_argument( 107 | "--model_name", type=str, default="2B", 108 | help="Model configuration name to use (e.g. 2B)" 109 | ) 110 | args = parser.parse_args() 111 | 112 | convert_back( 113 | safetensors_path=args.safetensors_file, 114 | output_file=args.output, 115 | model_name=args.model_name, 116 | ) -------------------------------------------------------------------------------- /gpu/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import json 7 | import os 8 | import readline # type: ignore # noqa 9 | import sys 10 | import time 11 | from dataclasses import dataclass 12 | from pathlib import Path 13 | from typing import Iterable, Optional, Tuple, Union 14 | 15 | import fire 16 | import model as fast 17 | import torch 18 | from stats import Stats 19 | from tokenizer import Tokenizer, ChatFormat 20 | import sample_utils 21 | from xformers.ops.fmha.attn_bias import ( 22 | BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, 23 | ) 24 | 25 | 26 | @dataclass 27 | class GenArgs: 28 | gen_length: int = 32 29 | gen_bsz: int = 1 30 | prompt_length: int = 64 31 | 32 | use_sampling: bool = False 33 | temperature: float = 0.8 34 | top_p: float = 0.9 35 | 36 | 37 | class FastGen: 38 | GRAPH_WARMUPS: int = 1 39 | tokenizer: Tokenizer 40 | 41 | @staticmethod 42 | def build( 43 | ckpt_dir: str, 44 | gen_args: GenArgs, 45 | device: Union[torch.device, str], 46 | tokenizer_path: Optional[str] = None, 47 | num_layers: int = 13, 48 | use_full_vocab: bool = False, 49 | ) -> "FastGen": 50 | """ 51 | Load a Llama or Code Llama checkpoint and return a new 52 | generator for this model. 53 | """ 54 | start_time = time.time() 55 | 56 | model_args_prefill = fast.ModelArgs(use_kernel=False) 57 | model_args_decode = fast.ModelArgs(use_kernel=True) 58 | tokenizer = Tokenizer("./tokenizer.model") 59 | 60 | torch.set_default_device(device) 61 | torch.set_default_dtype(torch.bfloat16) 62 | 63 | prefill_model = fast.Transformer(model_args_prefill) 64 | decode_model = fast.Transformer(model_args_decode) 65 | 66 | fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt") 67 | fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu") 68 | int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt") 69 | int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu") 70 | prefill_model.load_state_dict(fp16_checkpoint, strict=True) 71 | decode_model.load_state_dict(int2_checkpoint, strict=True) 72 | 73 | torch.cuda.synchronize() 74 | print(f"loaded model in {time.time() - start_time:.2f} seconds") 75 | start_time = time.time() 76 | 77 | return FastGen(gen_args, model_args_prefill, prefill_model, decode_model, tokenizer) 78 | 79 | def __init__( 80 | self, 81 | args: GenArgs, 82 | model_args: fast.ModelArgs, 83 | prefill_model: fast.Transformer, 84 | decode_model: fast.Transformer, 85 | tokenizer: Tokenizer, 86 | ): 87 | self.gen_args = args 88 | self.max_seq_length = args.prompt_length + args.gen_length 89 | self.model_args = model_args 90 | # self.model = model 91 | self.prefill_model = prefill_model 92 | self.decode_model = decode_model 93 | self.tokenizer = tokenizer 94 | self._prefill_cuda_graph, self._prefill_compile_model, self._prefill_inputs, self._prefill_logits = None, None, None, None 95 | self._generate_cuda_graph, self._generate_compile_model, self._generate_inputs, self._generate_logits = None, None, None, None 96 | self._cache = None 97 | start_time = time.time() 98 | self._prefill_compile_model = self.compile_prefill() 99 | self._generate_compile_model = self.compile_generate() 100 | print(f"compiled model in {time.time() - start_time:.2f} seconds") 101 | 102 | def compile_prefill(self): 103 | 104 | if self._cache is None: 105 | self._cache = fast.make_cache( 106 | args=self.model_args, 107 | length=self.gen_args.gen_bsz * self.max_seq_length, 108 | ) 109 | 110 | seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)] 111 | 112 | bias = AttnBias.from_seqlens( 113 | q_seqlen=seq_lens, 114 | kv_seqlen=seq_lens, 115 | kv_padding=self.max_seq_length, 116 | ) 117 | bias.q_seqinfo.to("cuda") 118 | bias.k_seqinfo.to("cuda") 119 | 120 | tokens = torch.IntTensor([1] * self.gen_args.gen_bsz * self.gen_args.prompt_length).cuda() 121 | self._prefill_inputs = (tokens, bias) 122 | 123 | s = torch.cuda.Stream() 124 | s.wait_stream(torch.cuda.current_stream()) 125 | 126 | with torch.cuda.stream(s): 127 | _ = self.prefill_model.forward_with_attn_bias( 128 | token_values=self._prefill_inputs[0], 129 | attn_bias=self._prefill_inputs[1], 130 | cache=self._cache, 131 | ) 132 | torch.cuda.current_stream().wait_stream(s) 133 | 134 | self._prefill_cuda_graph = torch.cuda.CUDAGraph() 135 | recording_kwargs = {} 136 | if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__: 137 | # In PyTorch 2.1+ and nightlies from late Aug 2023, 138 | # we can do this to maybe avoid watchdog-related crashes 139 | recording_kwargs["capture_error_mode"] = "thread_local" 140 | with torch.cuda.graph(self._prefill_cuda_graph, **recording_kwargs): 141 | self._prefill_logits = self.prefill_model.forward_with_attn_bias( 142 | token_values=self._prefill_inputs[0], 143 | attn_bias=self._prefill_inputs[1], 144 | cache=self._cache, 145 | ) 146 | 147 | def replay(tokens, seq_lens=None): 148 | self._prefill_inputs[0].copy_(tokens) 149 | if seq_lens is not None: 150 | self._prefill_inputs[1].k_seqinfo.seqlen.copy_(seq_lens) 151 | 152 | self._prefill_cuda_graph.replay() 153 | torch.cuda.synchronize() 154 | 155 | return self._prefill_logits 156 | 157 | return replay 158 | 159 | def compile_generate(self): 160 | 161 | if self._cache is None: 162 | self._cache = fast.make_cache( 163 | args=self.model_args, 164 | length=self.gen_args.gen_bsz * self.max_seq_length, 165 | ) 166 | 167 | seq_lens = [1 for _ in range(self.gen_args.gen_bsz)] 168 | kv_seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)] 169 | 170 | bias = AttnBias.from_seqlens( 171 | q_seqlen=seq_lens, 172 | kv_seqlen=kv_seq_lens, 173 | kv_padding=self.max_seq_length, 174 | ) 175 | bias.q_seqinfo.to("cuda") 176 | bias.k_seqinfo.to("cuda") 177 | 178 | tokens = torch.IntTensor([1] * self.gen_args.gen_bsz).cuda() 179 | self._generate_inputs = (tokens, bias) 180 | 181 | s = torch.cuda.Stream() 182 | s.wait_stream(torch.cuda.current_stream()) 183 | 184 | with torch.cuda.stream(s): 185 | _ = self.decode_model.forward_with_attn_bias( 186 | token_values=self._generate_inputs[0], 187 | attn_bias=self._generate_inputs[1], 188 | cache=self._cache, 189 | ) 190 | torch.cuda.current_stream().wait_stream(s) 191 | 192 | self._generate_cuda_graph = torch.cuda.CUDAGraph() 193 | recording_kwargs = {} 194 | if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__: 195 | # In PyTorch 2.1+ and nightlies from late Aug 2023, 196 | # we can do this to maybe avoid watchdog-related crashes 197 | recording_kwargs["capture_error_mode"] = "thread_local" 198 | with torch.cuda.graph(self._generate_cuda_graph, **recording_kwargs): 199 | self._generate_logits = self.decode_model.forward_with_attn_bias( 200 | token_values=self._generate_inputs[0], 201 | attn_bias=self._generate_inputs[1], 202 | cache=self._cache, 203 | ) 204 | 205 | def replay(tokens, seq_lens): 206 | self._generate_inputs[0].copy_(tokens) 207 | self._generate_inputs[1].k_seqinfo.seqlen.copy_(seq_lens) 208 | 209 | self._generate_cuda_graph.replay() 210 | 211 | return self._generate_logits 212 | 213 | return replay 214 | 215 | 216 | @torch.inference_mode() 217 | def generate_all( 218 | self, prompts: list[list[int]], use_cuda_graphs: bool, use_sampling: bool 219 | ) -> Tuple[Stats, list[list[int]]]: 220 | bs = len(prompts) 221 | prompt_lens = [len(p) for p in prompts] 222 | padded_prompt_lens = [self.gen_args.prompt_length] * bs 223 | max_prompt_length = max(prompt_lens) 224 | gen_length = self.gen_args.gen_length 225 | max_seq_length = max_prompt_length + gen_length 226 | print(max_prompt_length, gen_length) 227 | 228 | bias = AttnBias.from_seqlens( 229 | q_seqlen=padded_prompt_lens, 230 | kv_seqlen=prompt_lens, 231 | kv_padding=max_seq_length, 232 | ) 233 | bias.q_seqinfo.to("cuda") 234 | bias.k_seqinfo.to("cuda") 235 | 236 | # Input tensors to the cuda graph 237 | kv_seqlen = bias.k_seqinfo.seqlen 238 | prompts = [prompt + [1] * (self.gen_args.prompt_length - len(prompt)) for prompt in prompts] 239 | tokens = torch.IntTensor(sum(prompts, [])).cuda() 240 | out_tokens = torch.zeros((max_seq_length, bs), dtype=torch.int) 241 | 242 | stats = Stats() 243 | torch.cuda.synchronize() 244 | stats.phase("prefill" if use_cuda_graphs else "total") 245 | # stats.phase("total") 246 | 247 | output = self._prefill_compile_model(tokens, None) 248 | 249 | logits = output[kv_seqlen - 1, :] 250 | logits = logits.view(bs, self.model_args.vocab_size) 251 | 252 | if use_sampling: 253 | temp = 0.7 254 | top_p = 0.95 255 | probs = torch.softmax(logits / temp, dim=-1) 256 | next_token = sample_utils.top_p(probs, top_p) 257 | else: 258 | next_token = torch.argmax(logits, dim=-1) 259 | 260 | next_token = next_token.reshape(bs) 261 | out_tokens[0, :] = next_token 262 | 263 | torch.cuda.synchronize() 264 | stats.phase("decode" if use_cuda_graphs else "total") 265 | 266 | eos_id = self.tokenizer.eot_id 267 | for niter in range(1, gen_length): 268 | kv_seqlen.add_(kv_seqlen < max_seq_length) 269 | output = self._generate_compile_model(next_token, kv_seqlen) 270 | 271 | logits = output.view(bs, self.model_args.vocab_size) 272 | 273 | if use_sampling: 274 | temp = 0.7 275 | top_p = 0.95 276 | probs = torch.softmax(logits / temp, dim=-1) 277 | next_token = sample_utils.top_p(probs, top_p) 278 | else: 279 | next_token = torch.argmax(logits, dim=-1) 280 | 281 | next_token = next_token.reshape(bs) 282 | out_tokens[niter, :] = next_token 283 | 284 | if next_token.eq(eos_id).any(): 285 | break 286 | 287 | torch.cuda.synchronize() 288 | stats.end_phase(tokens=niter * bs) 289 | 290 | def trim_answer(prompt_len, tokens): 291 | # print(prompt, tokens) 292 | """Trim the answer to end it on an eos token.""" 293 | tokens = tokens[: max_seq_length - prompt_len] 294 | eos_id = self.tokenizer.eot_id 295 | if eos_id in tokens: 296 | return tokens[: tokens.index(eos_id) + 1] 297 | else: 298 | return tokens 299 | 300 | answers = [ 301 | trim_answer(prompt_len, answer) 302 | for prompt_len, answer in zip(prompt_lens, out_tokens.t().tolist()) 303 | ] 304 | return stats, answers 305 | 306 | 307 | def get_prompts(interactive: bool) -> Iterable[list[str]]: 308 | if interactive: 309 | while True: 310 | try: 311 | prompts = input("enter prompt: ").split("\n") 312 | except EOFError: 313 | print("exiting") 314 | sys.exit(0) 315 | yield prompts 316 | else: 317 | yield [ 318 | "Hello, my name is", 319 | ] 320 | 321 | 322 | def main(ckpt_dir: str, interactive: bool = False, chat_format: bool = False, sampling: bool = False): 323 | 324 | local_rank = 0 325 | device = f"cuda:{local_rank}" 326 | torch.cuda.set_device(local_rank) 327 | 328 | g = FastGen.build(ckpt_dir, GenArgs(), device) 329 | 330 | if chat_format: 331 | g.tokenizer = ChatFormat(g.tokenizer) 332 | 333 | for prompts in get_prompts(interactive): 334 | # prompts = [f"{prompt}\n" for prompt in prompts] 335 | if chat_format: 336 | # prompts = [f'<|begin_of_text|>User: {prompt}<|eot_id|>Assistant: ' for prompt in prompts] 337 | tokens = [g.tokenizer.encode_dialog_prompt(dialog=[{"role": "user", "content": prompt}], completion=True) for prompt in prompts] 338 | else: 339 | tokens = [g.tokenizer.encode(x, bos=False, eos=False) for x in prompts] 340 | 341 | print(tokens) 342 | stats, out_tokens = g.generate_all( 343 | tokens, use_cuda_graphs="NO_CUDA_GRAPHS" not in os.environ, use_sampling=sampling, 344 | ) 345 | 346 | for i, prompt in enumerate(prompts): 347 | print(f"> {prompt}") 348 | answer = g.tokenizer.decode(out_tokens[i]) 349 | print(answer) 350 | print("---------------") 351 | 352 | for phase_stats in stats.phases: 353 | print(phase_stats.show()) 354 | 355 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 356 | 357 | 358 | if __name__ == "__main__": 359 | fire.Fire(main) -------------------------------------------------------------------------------- /gpu/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass 7 | from typing import Optional, Tuple, Union 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from xformers.ops import RMSNorm, fmha, rope_padded 14 | from xformers.ops.fmha.attn_bias import ( 15 | BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, 16 | ) 17 | 18 | import ctypes 19 | bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so') 20 | 21 | def bitnet_int8xint2_linear(input0, input1, s, ws): 22 | out_shape = list(input0.shape) 23 | out_shape[-1] = input1.shape[0] 24 | 25 | stream = torch.cuda.current_stream() 26 | 27 | M = input0.shape[0] 28 | if len(out_shape) == 3: 29 | M *= input0.shape[1] 30 | N = input1.shape[0] 31 | K = input1.shape[1] * 4 32 | 33 | ret = torch.zeros(*out_shape, dtype=torch.bfloat16, device=input0.device) 34 | 35 | bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)]) 36 | 37 | return ret 38 | 39 | @dataclass 40 | class ModelArgs: 41 | dim: int = 2560 42 | n_layers: int = 30 43 | n_heads: int = 20 44 | n_kv_heads: int = 5 45 | vocab_size: int = 128256 46 | ffn_dim: int = 6912 47 | norm_eps: float = 1e-5 48 | rope_theta: float = 500000.0 49 | use_kernel: bool = False 50 | 51 | 52 | LayerCache = Tuple[torch.Tensor, torch.Tensor] 53 | 54 | class BitLinearKernel(nn.Module): 55 | in_features: int 56 | out_features: int 57 | weight: torch.Tensor 58 | weight_scale: torch.Tensor 59 | 60 | def __init__(self, in_features: int, out_features: int, bias: bool = False): 61 | super().__init__() 62 | self.in_features = in_features 63 | self.out_features = out_features 64 | 65 | self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features//4, dtype=torch.int8), requires_grad=False) 66 | self.weight_scale = torch.nn.Parameter(torch.zeros(4, dtype=torch.bfloat16), requires_grad=False) 67 | 68 | @torch.compile 69 | def quant_input(self, input): 70 | s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) 71 | return (input * s).round().clamp(-128, 127).to(torch.int8), s 72 | 73 | def forward(self, input): 74 | input, s = self.quant_input(input) 75 | return bitnet_int8xint2_linear(input, self.weight, s, self.weight_scale) 76 | 77 | class BitLinear(nn.Linear): 78 | @torch.compile 79 | def quant_input(self, input): 80 | s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) 81 | return (input * s).round().clamp(-128, 127) / s 82 | 83 | def forward(self, input): 84 | input = self.quant_input(input) 85 | return F.linear(input, self.weight) 86 | 87 | class Attention(nn.Module): 88 | def __init__( 89 | self, 90 | dim: int, 91 | head_dim: int, 92 | n_heads: int, 93 | n_kv_heads: int, 94 | rope_theta: float, 95 | norm_eps: float, 96 | use_kernel: bool, 97 | ): 98 | super().__init__() 99 | 100 | self.head_dim = head_dim 101 | self.rope_theta = rope_theta 102 | 103 | self.n_local_heads = n_heads 104 | self.n_local_kv_heads = n_kv_heads 105 | 106 | Linear = BitLinearKernel if use_kernel else BitLinear 107 | 108 | self.wqkv = Linear( 109 | dim, 110 | (self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim, 111 | bias=False, 112 | ) 113 | self.wo = Linear( 114 | self.n_local_heads * head_dim, 115 | dim, 116 | bias=False, 117 | ) 118 | 119 | self.attn_sub_norm = RMSNorm(dim, norm_eps) 120 | 121 | def forward( 122 | self, 123 | x: torch.Tensor, 124 | cache: LayerCache, 125 | attn_bias: AttnBias, 126 | ) -> torch.Tensor: 127 | 128 | xqkv = self.wqkv(x) 129 | xq = xqkv[:, : (self.n_local_heads * self.head_dim)] 130 | xkv = xqkv[:, (self.n_local_heads * self.head_dim) :] 131 | xk, xv = xkv.chunk(2, 1) 132 | 133 | output_shape = xq.shape 134 | heads_per_group = self.n_local_heads // self.n_local_kv_heads 135 | xq = xq.view( 136 | 1, xq.shape[0], self.n_local_kv_heads, heads_per_group, self.head_dim 137 | ) 138 | xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, 1, self.head_dim) 139 | # xq = rearrange(xq, 'b (g h l d) -> 1 b h g (d l)', g=heads_per_group, h=self.n_local_kv_heads, d=self.head_dim // 2, l=2) 140 | # xk = rearrange(xk, 'b (g l d) -> 1 b g 1 (d l)', g=self.n_local_kv_heads, d=self.head_dim // 2) 141 | xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, 1, self.head_dim) 142 | cache_k, cache_v = cache 143 | 144 | xq = rope_padded( 145 | xq=xq, 146 | xk=xk, 147 | xv=xv, 148 | cache_k=cache_k, 149 | cache_v=cache_v, 150 | attn_bias=attn_bias, 151 | theta=self.rope_theta, 152 | ) 153 | 154 | output = fmha.memory_efficient_attention_forward( 155 | xq, cache_k, cache_v, attn_bias, op = fmha.flash.FwOp 156 | ) 157 | 158 | output = output.reshape(output_shape) 159 | output = self.attn_sub_norm(output) 160 | output = self.wo(output) 161 | 162 | return output 163 | 164 | @torch.compile 165 | def squared_relu(x: torch.Tensor) -> torch.Tensor: 166 | return F.relu(x) ** 2 167 | 168 | class FeedForward(nn.Module): 169 | def __init__( 170 | self, 171 | dim: int, 172 | hidden_dim: int, 173 | norm_eps: float, 174 | use_kernel: bool, 175 | ): 176 | super().__init__() 177 | 178 | Linear = BitLinearKernel if use_kernel else BitLinear 179 | 180 | self.w13 = Linear( 181 | dim, 182 | 2 * hidden_dim, 183 | bias=False, 184 | ) 185 | self.w2 = Linear( 186 | hidden_dim, 187 | dim, 188 | bias=False, 189 | ) 190 | self.ffn_sub_norm = RMSNorm(hidden_dim, norm_eps) 191 | 192 | def forward(self, x: torch.Tensor) -> torch.Tensor: 193 | x13 = self.w13(x) 194 | x1, x3 = x13.chunk(2, -1) 195 | inner = self.ffn_sub_norm(squared_relu(x1) * x3) 196 | output = self.w2(inner) 197 | return output 198 | 199 | 200 | class TransformerBlock(nn.Module): 201 | def __init__(self, args: ModelArgs): 202 | super().__init__() 203 | 204 | assert args.dim % args.n_heads == 0 205 | head_dim = args.dim // args.n_heads 206 | if args.n_kv_heads is not None: 207 | n_kv_heads = args.n_kv_heads 208 | else: 209 | n_kv_heads = args.n_heads 210 | 211 | assert args.n_heads % n_kv_heads == 0 212 | 213 | self.attention = Attention( 214 | dim=args.dim, 215 | head_dim=head_dim, 216 | n_heads=args.n_heads, 217 | n_kv_heads=n_kv_heads, 218 | rope_theta=args.rope_theta, 219 | norm_eps=args.norm_eps, 220 | use_kernel=args.use_kernel, 221 | ) 222 | self.feed_forward = FeedForward( 223 | dim=args.dim, 224 | hidden_dim=args.ffn_dim, 225 | norm_eps=args.norm_eps, 226 | use_kernel=args.use_kernel, 227 | ) 228 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 229 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 230 | 231 | def forward( 232 | self, 233 | x: torch.Tensor, 234 | cache: LayerCache, 235 | attn_bias: AttnBias, 236 | ) -> torch.Tensor: 237 | h = x + self.attention.forward( 238 | self.attention_norm(x), 239 | cache, 240 | attn_bias, 241 | ) 242 | out = h + self.feed_forward(self.ffn_norm(h)) 243 | return out 244 | 245 | 246 | class Transformer(nn.Module): 247 | def __init__(self, args: ModelArgs): 248 | super().__init__() 249 | assert args.vocab_size > 0 250 | 251 | self.tok_embeddings = nn.Embedding( 252 | num_embeddings=args.vocab_size, 253 | embedding_dim=args.dim, 254 | ) 255 | 256 | self.layers = nn.ModuleList() 257 | for _ in range(args.n_layers): 258 | self.layers.append(TransformerBlock(args)) 259 | 260 | self.norm = RMSNorm(args.dim, eps=args.norm_eps) 261 | 262 | self.output = nn.Linear( 263 | args.dim, 264 | args.vocab_size, 265 | bias=False, 266 | ) 267 | 268 | @torch.no_grad() 269 | def forward_with_attn_bias( 270 | self, 271 | token_values: torch.Tensor, 272 | attn_bias: AttnBias, 273 | cache: list[LayerCache], 274 | ) -> torch.Tensor: 275 | h = self.tok_embeddings(token_values) 276 | 277 | for i, layer in enumerate(self.layers): 278 | h = layer(h, cache[i], attn_bias) 279 | 280 | logits = self.output(self.norm(h)) 281 | return logits.float() 282 | 283 | def forward( 284 | self, 285 | token_values: torch.Tensor, 286 | token_lengths: torch.Tensor, 287 | start_pos: torch.Tensor, 288 | cache: list[LayerCache], 289 | kv_padding: int, 290 | ) -> torch.Tensor: 291 | attn_bias = AttnBias.from_seqlens( 292 | q_seqlen=token_lengths.tolist(), 293 | kv_seqlen=(start_pos + token_lengths).tolist(), 294 | kv_padding=kv_padding, 295 | ) 296 | return self.forward_with_attn_bias(token_values, attn_bias, cache) 297 | 298 | 299 | def make_cache( 300 | args: ModelArgs, 301 | length: int, 302 | device: Optional[Union[str, torch.device]] = None, 303 | n_layers: Optional[int] = None, 304 | dtype: Optional[torch.dtype] = None, 305 | ) -> list[LayerCache]: 306 | """ 307 | Allocate a cache to be used with the Transformer module. 308 | 309 | Args: 310 | args (ModelArgs): the model configuration. 311 | length (int): per layer cache size. 312 | It is usually budgeted as ``max_batch * max_seq`` 313 | device (torch.device, optional): the device on which 314 | the cache should be allocated. 315 | n_layers (int, optional): the number of layers to 316 | allocate a cache for (defaults to the model 317 | settings). 318 | dtype (torch.dtype, optional): the dtype to use for 319 | cache entries (defaults to the default dtype). 320 | 321 | Returns: 322 | The cache object to pass to ``Tranformer.forward``. 323 | """ 324 | 325 | head_dim = args.dim // args.n_heads 326 | n_kv_heads = args.n_kv_heads 327 | if n_kv_heads is None: 328 | n_kv_heads = args.n_heads 329 | n_local_kv_heads = n_kv_heads 330 | 331 | if n_layers is None: 332 | n_layers = args.n_layers 333 | 334 | shape = (1, length, n_local_kv_heads, 1, head_dim) 335 | heads_per_group = args.n_heads // n_kv_heads 336 | expansion = (-1, -1, -1, heads_per_group, -1) 337 | return [ 338 | ( 339 | torch.zeros(shape, device=device, dtype=dtype).expand(expansion), 340 | torch.zeros(shape, device=device, dtype=dtype).expand(expansion), 341 | ) 342 | for _ in range(n_layers) 343 | ] 344 | 345 | 346 | def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]: 347 | """ 348 | Take a prefix view of a larger cache. 349 | 350 | The original cache object remains of identical size and valid 351 | after the shrinked alias has been used. This function is useful 352 | when a cache was allocated for a larger batch size than what is 353 | necessary. 354 | 355 | Args: 356 | cache: the cache to take a view in. 357 | length (int): the desired length 358 | 359 | Returns: 360 | A view in the input cache object. 361 | """ 362 | 363 | if len(cache) > 0: 364 | assert cache[0][0].shape[1] >= length 365 | 366 | return [(ck[:, :length], cv[:, :length]) for ck, cv in cache] -------------------------------------------------------------------------------- /gpu/pack_weight.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def B_global_16x32_to_shared_load_16x32_layout(i, j): 6 | """ 7 | stride * 8 * (tx // HALF_WARP_expr) 8 | + (tx % 8) * stride 9 | + 16 * ((tx % HALF_WARP_expr) // 8) 10 | """ 11 | thread_id = i * 2 + j // 16 12 | row = (thread_id // 16) * 8 + (thread_id % 8) 13 | col = (j % 16) + 16 * ((thread_id % 16) // 8) 14 | return row, col 15 | 16 | 17 | def permutate_weight_fastest(weight): 18 | wmma_n = 16 19 | wmma_k = 32 20 | N = weight.shape[0] 21 | K = weight.shape[1] 22 | 23 | # Create a lookup table for the permutation 24 | mapping = np.zeros((wmma_n, wmma_k, 2), dtype=int) 25 | for ii in range(wmma_n): 26 | for jj in range(wmma_k): 27 | mapping[ii, jj] = B_global_16x32_to_shared_load_16x32_layout(ii, jj) 28 | 29 | # Reshape weight for the final format 30 | permutated_weight = np.zeros((N // wmma_n, K // wmma_k, wmma_n, wmma_k), dtype="int8") 31 | 32 | # Use advanced indexing for the entire operation 33 | i_indices = np.arange(N // wmma_n)[:, np.newaxis, np.newaxis, np.newaxis] 34 | j_indices = np.arange(K // wmma_k)[np.newaxis, :, np.newaxis, np.newaxis] 35 | 36 | # Create the source indices 37 | src_i = i_indices * wmma_n + mapping[:, :, 0] 38 | src_j = j_indices * wmma_k + mapping[:, :, 1] 39 | 40 | # Extract and reshape in one go 41 | permutated_weight = weight[src_i, src_j] 42 | 43 | return permutated_weight 44 | 45 | 46 | def compress_int2_to_int8(int2_weight): 47 | int8_weight = np.zeros( 48 | (*int2_weight.shape[:-1], int2_weight.shape[-1] // 4), dtype=np.int8 49 | ) 50 | for j in range(int2_weight.shape[-1] // 4): 51 | for k in range(4): 52 | int8_weight[:, :, :, j] |= int2_weight[:, :, :, j * 4 + k] << (k * 2) 53 | return int8_weight 54 | 55 | 56 | def interleave_weight_int8(qweight, nbits=2):\ 57 | # reinterpret the data type of qweight to int32 58 | # shift = [ 0, 8, 16, 24, 2, 10, 18, 26, 4, 12, 20, 28, 6, 14, 22, 30] 59 | # index: [ 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15] 60 | qweight = qweight.view(np.int32) 61 | new_qweight = np.zeros_like(qweight) 62 | bits_stride = 8 63 | mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f 64 | num_groups = 32 // bits_stride # 4 65 | elems_per_group = bits_stride // nbits # 4 66 | for i in range(num_groups): 67 | for j in range(elems_per_group): 68 | offset = i * elems_per_group + j 69 | shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits 70 | 71 | new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift 72 | return new_qweight.view(np.int8) 73 | 74 | 75 | 76 | def convert_weight_int8_to_int2(weight): 77 | N = weight.shape[0] 78 | K = weight.shape[1] 79 | 80 | weight = weight+2 81 | 82 | weight = weight.cpu().numpy() 83 | 84 | # print(weight) 85 | # print(torch.max(weight), torch.min(weight)) 86 | 87 | # permutated_weight_slow = permutate_weight(weight) 88 | permutated_weight = permutate_weight_fastest(weight) 89 | # assert np.all(permutated_weight_slow == permutated_weight) 90 | # print("Permutation is correct") 91 | compressed_weight = compress_int2_to_int8(permutated_weight) 92 | interleaved_weight = interleave_weight_int8(compressed_weight, 2) 93 | 94 | ret = torch.from_numpy(interleaved_weight) 95 | 96 | ret = torch.reshape(ret, (N, K // 4)) 97 | 98 | return ret 99 | -------------------------------------------------------------------------------- /gpu/requirements.txt: -------------------------------------------------------------------------------- 1 | fire 2 | sentencepiece 3 | torch>=2.2.0 4 | xformers>=0.0.22 5 | tiktoken 6 | blobfile 7 | flask 8 | einops 9 | transformers -------------------------------------------------------------------------------- /gpu/sample_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | @torch.compile 9 | def top_p(probs: torch.Tensor, p: float) -> torch.Tensor: 10 | """ 11 | Perform top-p (nucleus) sampling on a probability distribution. 12 | 13 | Args: 14 | probs (torch.Tensor): probability distribution tensor. 15 | p (float): probability threshold for top-p sampling. 16 | 17 | Returns: 18 | torch.Tensor: sampled token indices. 19 | 20 | Note: 21 | Top-p sampling selects the smallest set of tokens whose cumulative 22 | probability mass exceeds the threshold p. The distribution is 23 | renormalized based on the selected tokens. 24 | """ 25 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 26 | probs_sum = torch.cumsum(probs_sort, dim=-1) 27 | mask = probs_sum - probs_sort > p 28 | probs_sort[mask] = 0.0 29 | next_token = torch.multinomial(probs_sort, num_samples=1) 30 | next_token = torch.gather(probs_idx, -1, next_token) 31 | return next_token -------------------------------------------------------------------------------- /gpu/stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import time 7 | from dataclasses import dataclass 8 | from typing import Optional 9 | 10 | 11 | @dataclass 12 | class PhaseStats: 13 | name: str 14 | tokens: int 15 | time: float 16 | 17 | def show(self) -> str: 18 | tps = self.tokens / self.time 19 | return ( 20 | f"[{self.name}] " 21 | f"generated tokens: {self.tokens}" 22 | f" - total time: {self.time:.3f}s" 23 | f" - {tps:.1f} tokens per second" 24 | ) 25 | 26 | 27 | class Stats: 28 | """ 29 | Generation stats, split by phases. 30 | """ 31 | 32 | def __init__(self): 33 | self.phases = [] 34 | self.current = None 35 | 36 | def end_phase(self, tokens: int, now: Optional[float] = None): 37 | """Terminate the current phase.""" 38 | if self.current is None: 39 | return 40 | if now is None: 41 | now = time.time() 42 | cname, ctokens, ctime = self.current 43 | stats = PhaseStats( 44 | name=cname, 45 | tokens=tokens - ctokens, 46 | time=now - ctime, 47 | ) 48 | self.phases.append(stats) 49 | 50 | def phase(self, name: str, tokens: int = 0): 51 | """ 52 | Start a new phase, and terminate the current one, 53 | if one is ongoing. 54 | """ 55 | now = time.time() 56 | self.end_phase(tokens, now) 57 | self.current = (name, tokens, now) -------------------------------------------------------------------------------- /gpu/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import benchmark 3 | from torch import nn 4 | 5 | from pack_weight import convert_weight_int8_to_int2 6 | from torch.profiler import profile, record_function, ProfilerActivity 7 | import ctypes 8 | import numpy as np 9 | # set all seed 10 | torch.manual_seed(42) 11 | np.random.seed(42) 12 | 13 | bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so') 14 | 15 | def bitnet_int8xint2_linear(input0, input1, s, ws, ret): 16 | out_shape = list(input0.shape) 17 | out_shape[-1] = input1.shape[0] 18 | 19 | stream = torch.cuda.current_stream() 20 | 21 | M = input0.shape[0] 22 | if len(out_shape) == 3: 23 | M *= input0.shape[1] 24 | N = input1.shape[0] 25 | K = input1.shape[1] * 4 26 | 27 | bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)]) 28 | 29 | return ret 30 | 31 | if __name__ == '__main__': 32 | test_list = [ 33 | (2560, 2560), 34 | (3840, 2560), 35 | (13824, 2560), 36 | (2560, 6912) , 37 | (3200, 3200), 38 | (4800, 3200), 39 | (3200, 10240), 40 | (20480, 3200), 41 | ] 42 | for N,K in test_list: 43 | weight = torch.randint(-1, 2, (N, K), dtype=torch.int8, device='cuda') 44 | weight_scale = torch.ones(1, dtype=torch.bfloat16, device='cuda') 45 | weight_compressed = convert_weight_int8_to_int2(weight).to('cuda') 46 | 47 | for i in range(1): 48 | input0 = torch.randint(-128,127,(1, K),dtype=torch.int8, device='cuda') 49 | input0_bf16 = input0.to(torch.bfloat16) 50 | input_np = input0.cpu().to(torch.int32).numpy() 51 | weight_np = weight.cpu().to(torch.int32).T.numpy() 52 | out_np = np.matmul(input_np,weight_np) 53 | out_np = torch.tensor(out_np).cuda().to(torch.bfloat16) 54 | 55 | s = torch.ones(1, dtype=torch.bfloat16, device='cuda') 56 | ws = torch.ones(6, dtype=torch.bfloat16, device='cuda') 57 | 58 | ret = torch.empty((1,N), dtype=torch.bfloat16, device=input0.device) 59 | out = bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret) 60 | 61 | print(f'custom == np {torch.all(out==out_np)}') 62 | 63 | input0 = torch.randint(-128,127,(1, K),dtype=torch.int8, device='cuda') 64 | input0_fp16 = input0.to(torch.float16) 65 | input0_bf16 = input0.to(torch.bfloat16) 66 | weight_fp16 = weight.to(torch.float16).T 67 | weight_bf16 = weight.to(torch.bfloat16).T 68 | ret = torch.empty((1,N), dtype=torch.bfloat16, device=input0.device) 69 | s = torch.ones(1, dtype=torch.bfloat16, device='cuda') 70 | ws = torch.ones(6, dtype=torch.bfloat16, device='cuda') 71 | t0 = benchmark.Timer( 72 | stmt="bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)", 73 | setup="from __main__ import input0, weight_compressed, s, ws, ret, bitnet_int8xint2_linear", 74 | num_threads=1, 75 | ) 76 | 77 | t1 = benchmark.Timer( 78 | stmt="torch.matmul(input0_bf16,weight_bf16)", 79 | setup="from __main__ import input0_bf16, weight_bf16", 80 | num_threads=1, 81 | ) 82 | 83 | time0 = t0.timeit(50) 84 | time1 = t1.timeit(50) 85 | 86 | print(f'Shape{N,K}, W2A8: {time0.mean * 1e6:.2f}us, torch BF16: {time1.mean * 1e6:.2f}us') 87 | # activities = [ ProfilerActivity.CUDA, 88 | # # ProfilerActivity.CPU 89 | # ] 90 | # sort_by_keyword = 'cuda' + "_time_total" 91 | # with profile(activities=activities, record_shapes=True) as prof: 92 | # with record_function("model_inference1"): 93 | # for _ in range(10): 94 | # bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret) 95 | # torch.matmul(input0_fp16,weight_fp16) 96 | # torch.matmul(input0_bf16,weight_bf16) 97 | 98 | # print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=15)) 99 | 100 | -------------------------------------------------------------------------------- /gpu/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging import getLogger 3 | from pathlib import Path 4 | from typing import ( 5 | AbstractSet, 6 | cast, 7 | Collection, 8 | Dict, 9 | Iterator, 10 | List, 11 | Literal, 12 | Sequence, 13 | TypedDict, 14 | Union, 15 | ) 16 | 17 | import tiktoken 18 | from tiktoken.load import load_tiktoken_bpe 19 | 20 | 21 | logger = getLogger(__name__) 22 | 23 | Role = Literal["system", "user", "assistant"] 24 | 25 | 26 | class Message(TypedDict): 27 | role: Role 28 | content: str 29 | 30 | 31 | Dialog = Sequence[Message] 32 | 33 | 34 | class Tokenizer: 35 | """ 36 | Tokenizing and encoding/decoding text using the Tiktoken tokenizer. 37 | """ 38 | 39 | special_tokens: Dict[str, int] 40 | 41 | num_reserved_special_tokens = 256 42 | 43 | pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 44 | 45 | def __init__(self, model_path: str): 46 | """ 47 | Initializes the Tokenizer with a Tiktoken model. 48 | 49 | Args: 50 | model_path (str): The path to the Tiktoken model file. 51 | """ 52 | assert os.path.isfile(model_path), model_path 53 | 54 | mergeable_ranks = load_tiktoken_bpe(model_path) 55 | num_base_tokens = len(mergeable_ranks) 56 | special_tokens = [ 57 | "<|begin_of_text|>", 58 | "<|end_of_text|>", 59 | "<|reserved_special_token_0|>", 60 | "<|reserved_special_token_1|>", 61 | "<|reserved_special_token_2|>", 62 | "<|reserved_special_token_3|>", 63 | "<|start_header_id|>", 64 | "<|end_header_id|>", 65 | "<|reserved_special_token_4|>", 66 | "<|eot_id|>", # end of turn 67 | ] + [ 68 | f"<|reserved_special_token_{i}|>" 69 | for i in range(5, self.num_reserved_special_tokens - 5) 70 | ] 71 | self.special_tokens = { 72 | token: num_base_tokens + i for i, token in enumerate(special_tokens) 73 | } 74 | self.model = tiktoken.Encoding( 75 | name=Path(model_path).name, 76 | pat_str=self.pat_str, 77 | mergeable_ranks=mergeable_ranks, 78 | special_tokens=self.special_tokens, 79 | ) 80 | logger.info(f"Reloaded tiktoken model from {model_path}") 81 | 82 | self.n_words: int = self.model.n_vocab 83 | # BOS / EOS token IDs 84 | self.bos_id: int = self.special_tokens["<|begin_of_text|>"] 85 | self.eos_id: int = self.special_tokens["<|end_of_text|>"] 86 | self.pad_id: int = self.n_words - 1 87 | self.stop_tokens = { 88 | self.special_tokens["<|end_of_text|>"], 89 | self.special_tokens["<|eot_id|>"], 90 | } 91 | logger.info( 92 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 93 | ) 94 | 95 | def encode( 96 | self, 97 | s: str, 98 | *, 99 | bos: bool, 100 | eos: bool, 101 | allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), 102 | disallowed_special: Union[Literal["all"], Collection[str]] = (), 103 | ) -> List[int]: 104 | """ 105 | Encodes a string into a list of token IDs. 106 | 107 | Args: 108 | s (str): The input string to be encoded. 109 | bos (bool): Whether to prepend the beginning-of-sequence token. 110 | eos (bool): Whether to append the end-of-sequence token. 111 | allowed_tokens ("all"|set[str]): allowed special tokens in string 112 | disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string 113 | 114 | Returns: 115 | list[int]: A list of token IDs. 116 | 117 | By default, setting disallowed_special=() encodes a string by ignoring 118 | special tokens. Specifically: 119 | - Setting `disallowed_special` to () will cause all text corresponding 120 | to special tokens to be encoded as natural text (insteading of raising 121 | an error). 122 | - Setting `allowed_special` to "all" will treat all text corresponding 123 | to special tokens to be encoded as special tokens. 124 | """ 125 | assert type(s) is str 126 | 127 | # The tiktoken tokenizer can handle <=400k chars without 128 | # pyo3_runtime.PanicException. 129 | TIKTOKEN_MAX_ENCODE_CHARS = 400_000 130 | 131 | # https://github.com/openai/tiktoken/issues/195 132 | # Here we iterate over subsequences and split if we exceed the limit 133 | # of max consecutive non-whitespace or whitespace characters. 134 | MAX_NO_WHITESPACES_CHARS = 25_000 135 | 136 | substrs = ( 137 | substr 138 | for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) 139 | for substr in self._split_whitespaces_or_nonwhitespaces( 140 | s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS 141 | ) 142 | ) 143 | t: List[int] = [] 144 | for substr in substrs: 145 | t.extend( 146 | self.model.encode( 147 | substr, 148 | allowed_special=allowed_special, 149 | disallowed_special=disallowed_special, 150 | ) 151 | ) 152 | if bos: 153 | t.insert(0, self.bos_id) 154 | if eos: 155 | t.append(self.eos_id) 156 | return t 157 | 158 | def decode(self, t: Sequence[int]) -> str: 159 | """ 160 | Decodes a list of token IDs into a string. 161 | 162 | Args: 163 | t (List[int]): The list of token IDs to be decoded. 164 | 165 | Returns: 166 | str: The decoded string. 167 | """ 168 | # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. 169 | return self.model.decode(cast(List[int], t)) 170 | 171 | @staticmethod 172 | def _split_whitespaces_or_nonwhitespaces( 173 | s: str, max_consecutive_slice_len: int 174 | ) -> Iterator[str]: 175 | """ 176 | Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` 177 | consecutive whitespaces or consecutive non-whitespaces. 178 | """ 179 | current_slice_len = 0 180 | current_slice_is_space = s[0].isspace() if len(s) > 0 else False 181 | slice_start = 0 182 | 183 | for i in range(len(s)): 184 | is_now_space = s[i].isspace() 185 | 186 | if current_slice_is_space ^ is_now_space: 187 | current_slice_len = 1 188 | current_slice_is_space = is_now_space 189 | else: 190 | current_slice_len += 1 191 | if current_slice_len > max_consecutive_slice_len: 192 | yield s[slice_start:i] 193 | slice_start = i 194 | current_slice_len = 1 195 | yield s[slice_start:] 196 | 197 | class ChatFormat: 198 | def __init__(self, tokenizer: Tokenizer): 199 | self.tokenizer = tokenizer 200 | self.eot_id = tokenizer.special_tokens["<|eot_id|>"] 201 | 202 | def decode(self, tokens: List[int]) -> str: 203 | # Decode the tokens to a string. 204 | decoded_str = self.tokenizer.decode(tokens) 205 | # Remove the special tokens from the decoded string. 206 | decoded_str = decoded_str.replace("<|eot_id|>", "") 207 | return decoded_str 208 | 209 | def encode_header(self, message: Message) -> List[int]: 210 | tokens = [] 211 | if message["role"] == "system": 212 | tokens.extend(self.tokenizer.encode("System: ", bos=False, eos=False)) 213 | elif message["role"] == "user": 214 | tokens.extend(self.tokenizer.encode("User: ", bos=False, eos=False)) 215 | elif message["role"] == "assistant": 216 | tokens.extend(self.tokenizer.encode("Assistant: ", bos=False, eos=False)) 217 | else: 218 | raise NotImplementedError(f"Role {message['role']} not implemented.") 219 | # tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) 220 | # tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) 221 | # tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) 222 | # tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) 223 | return tokens 224 | 225 | def encode_message(self, message: Message, return_target=False) -> List[int]: 226 | tokens, targets = [], [] 227 | headers = self.encode_header(message) 228 | contents = self.tokenizer.encode(message["content"].strip(), bos=False, eos=False) 229 | contents.append(self.tokenizer.special_tokens["<|eot_id|>"]) 230 | tokens = headers + contents 231 | 232 | if message["role"] == "assistant": 233 | targets = [-1] * len(headers) + contents 234 | else: 235 | targets = [-1] * len(tokens) 236 | 237 | if return_target: 238 | return tokens, targets 239 | 240 | return tokens, None 241 | 242 | def encode_dialog_prompt(self, dialog: Dialog, completion=False, return_target=False) -> List[int]: 243 | tokens = [self.tokenizer.special_tokens["<|begin_of_text|>"]] 244 | targets = [-1] 245 | for message in dialog: 246 | _tokens, _targets = self.encode_message(message, return_target=return_target) 247 | tokens.extend(_tokens) 248 | if _targets is not None: 249 | targets.extend(_targets) 250 | # Add the start of an assistant message for the model to complete. 251 | if completion: 252 | tokens.extend(self.encode_header({"role": "assistant", "content": ""})) 253 | 254 | if return_target: 255 | return tokens, targets 256 | 257 | return tokens -------------------------------------------------------------------------------- /include/ggml-bitnet.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | #include "ggml-backend.h" 5 | 6 | #ifdef __ARM_NEON 7 | #include 8 | typedef float32_t bitnet_float_type; 9 | #else 10 | typedef float bitnet_float_type; 11 | #endif 12 | 13 | #ifdef __cplusplus 14 | extern "C" { 15 | #endif 16 | 17 | struct bitnet_tensor_extra { 18 | int lut_scales_size; 19 | int BK; 20 | int n_tile_num; 21 | uint8_t * qweights; 22 | bitnet_float_type * scales; 23 | }; 24 | 25 | GGML_API void ggml_bitnet_init(void); 26 | GGML_API void ggml_bitnet_free(void); 27 | // src0->type == Q4_0/IQ2_XXS/IQ3_XXS 28 | // bitnet.cpp currently only supports BitNet quantization or GPTQ-like quantization (only scales, without zeros) 29 | // If use i-quantization gguf models, the results will be wrong 30 | // TODO: add customized block types Q2_0/Q3_0 31 | GGML_API bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst); 32 | GGML_API size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst); 33 | GGML_API void ggml_bitnet_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits); 34 | GGML_API void ggml_bitnet_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits); 35 | GGML_API void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor); 36 | GGML_API int ggml_bitnet_get_type_bits(enum ggml_type type); 37 | GGML_API void ggml_bitnet_set_n_threads(int n_threads); 38 | #if defined(GGML_BITNET_ARM_TL1) 39 | GGML_API void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C); 40 | GGML_API void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT); 41 | #endif 42 | #if defined(GGML_BITNET_X86_TL2) 43 | GGML_API void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C); 44 | GGML_API void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT); 45 | #endif 46 | 47 | #ifdef __cplusplus 48 | } 49 | #endif 50 | -------------------------------------------------------------------------------- /media/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/media/benchmark.png -------------------------------------------------------------------------------- /media/demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/media/demo.mp4 -------------------------------------------------------------------------------- /preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl1.ini: -------------------------------------------------------------------------------- 1 | [Kernels_0] 2 | m = 14336 3 | k = 4096 4 | bm = 256 5 | bk = 128 6 | bmm = 64 7 | 8 | [Kernels_1] 9 | m = 4096 10 | k = 14336 11 | bm = 256 12 | bk = 128 13 | bmm = 32 14 | 15 | [Kernels_2] 16 | m = 1024 17 | k = 4096 18 | bm = 128 19 | bk = 64 20 | bmm = 64 21 | 22 | [Kernels_3] 23 | m = 4096 24 | k = 4096 25 | bm = 128 26 | bk = 64 27 | bmm = 32 28 | 29 | -------------------------------------------------------------------------------- /preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl2.ini: -------------------------------------------------------------------------------- 1 | [Kernels_0] 2 | m = 14336 3 | k = 4096 4 | bm = 256 5 | bk = 96 6 | bmm = 32 7 | 8 | [Kernels_1] 9 | m = 4096 10 | k = 14336 11 | bm = 128 12 | bk = 96 13 | bmm = 32 14 | 15 | [Kernels_2] 16 | m = 1024 17 | k = 4096 18 | bm = 256 19 | bk = 96 20 | bmm = 32 21 | 22 | [Kernels_3] 23 | m = 4096 24 | k = 4096 25 | bm = 128 26 | bk = 96 27 | bmm = 32 28 | 29 | -------------------------------------------------------------------------------- /preset_kernels/bitnet_b1_58-3B/bitnet-lut-kernels-tl1.h: -------------------------------------------------------------------------------- 1 | #if defined(GGML_BITNET_ARM_TL1) 2 | #include "ggml-bitnet.h" 3 | #define GGML_BITNET_MAX_NODES 8192 4 | static bool initialized = false; 5 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; 6 | static size_t bitnet_tensor_extras_index = 0; 7 | static void * aligned_malloc(size_t size) {{ 8 | #if defined(_WIN32) 9 | return _aligned_malloc(size, 64); 10 | #else 11 | void * ptr = nullptr; 12 | posix_memalign(&ptr, 64, size); 13 | return ptr; 14 | #endif 15 | }} 16 | static void aligned_free(void * ptr) {{ 17 | #if defined(_WIN32) 18 | _aligned_free(ptr); 19 | #else 20 | free(ptr); 21 | #endif 22 | }} 23 | 24 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{ 25 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; 26 | bitnet_float_type* b = (bitnet_float_type*)b_; 27 | #ifdef __ARM_NEON 28 | float32x4_t temp_max = vdupq_n_f32(0); 29 | for (int i=0; i < k / 4; i++) {{ 30 | float32x4_t vec_bs = vld1q_f32(b + 4 * i); 31 | float32x4_t abssum = vabsq_f32(vec_bs); 32 | temp_max = vmaxq_f32(abssum, temp_max); 33 | }} 34 | float32_t scales = 127 / vmaxvq_f32(temp_max); 35 | *lut_scales = scales; 36 | #elif defined __AVX2__ 37 | __m256 max_vec = _mm256_set1_ps(0.f); 38 | const __m256 vec_sign = _mm256_set1_ps(-0.0f); 39 | // #pragma unroll 40 | for (int i = 0; i < k / 8; i++) {{ 41 | __m256 vec_b = _mm256_loadu_ps(b + i * 8); 42 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); 43 | max_vec = _mm256_max_ps(vec_babs, max_vec); 44 | }} 45 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); 46 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); 47 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); 48 | float scales = 127 / _mm_cvtss_f32(max1); 49 | *lut_scales = scales; 50 | #endif 51 | }} 52 | 53 | void partial_max_reset(void* lut_scales_) {{ 54 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; 55 | *lut_scales = 0.0; 56 | }} 57 | 58 | #ifdef __ARM_NEON 59 | inline void Transpose_8_8( 60 | int16x8_t *v0, 61 | int16x8_t *v1, 62 | int16x8_t *v2, 63 | int16x8_t *v3, 64 | int16x8_t *v4, 65 | int16x8_t *v5, 66 | int16x8_t *v6, 67 | int16x8_t *v7) 68 | {{ 69 | int16x8x2_t q04 = vzipq_s16(*v0, *v4); 70 | int16x8x2_t q15 = vzipq_s16(*v1, *v5); 71 | int16x8x2_t q26 = vzipq_s16(*v2, *v6); 72 | int16x8x2_t q37 = vzipq_s16(*v3, *v7); 73 | 74 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]); 75 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]); 76 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]); 77 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]); 78 | 79 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]); 80 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]); 81 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]); 82 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]); 83 | 84 | *v0 = q_fin_0.val[0]; 85 | *v1 = q_fin_0.val[1]; 86 | *v2 = q_fin_1.val[0]; 87 | *v3 = q_fin_1.val[1]; 88 | *v4 = q_fin_2.val[0]; 89 | *v5 = q_fin_2.val[1]; 90 | *v6 = q_fin_3.val[0]; 91 | *v7 = q_fin_3.val[1]; 92 | }} 93 | #endif 94 | 95 | template 96 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{ 97 | #ifdef __ARM_NEON 98 | int16x8_t vec_lut[16]; 99 | float32_t scales = *lut_scales; 100 | uint8_t tbl_mask[16]; 101 | tbl_mask[0] = 0; 102 | tbl_mask[1] = 2; 103 | tbl_mask[2] = 4; 104 | tbl_mask[3] = 6; 105 | tbl_mask[4] = 8; 106 | tbl_mask[5] = 10; 107 | tbl_mask[6] = 12; 108 | tbl_mask[7] = 14; 109 | tbl_mask[8] = 1; 110 | tbl_mask[9] = 3; 111 | tbl_mask[10] = 5; 112 | tbl_mask[11] = 7; 113 | tbl_mask[12] = 9; 114 | tbl_mask[13] = 11; 115 | tbl_mask[14] = 13; 116 | tbl_mask[15] = 15; 117 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask); 118 | #pragma unroll 119 | for (int k = 0; k < act_k / 16; ++k) {{ 120 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16); 121 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8); 122 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales); 123 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales); 124 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales); 125 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales); 126 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0); 127 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1); 128 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2); 129 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3); 130 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0); 131 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1); 132 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2); 133 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3); 134 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2); 135 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3); 136 | vec_lut[0] = vdupq_n_s16(0); 137 | vec_lut[0] = vec_lut[0] - vec_bs_0; 138 | vec_lut[0] = vec_lut[0] - vec_bs_1; 139 | vec_lut[1] = vdupq_n_s16(0); 140 | vec_lut[1] = vec_lut[1] - vec_bs_0; 141 | vec_lut[2] = vdupq_n_s16(0); 142 | vec_lut[2] = vec_lut[2] - vec_bs_0; 143 | vec_lut[2] = vec_lut[2] + vec_bs_1; 144 | vec_lut[3] = vdupq_n_s16(0); 145 | vec_lut[3] = vec_lut[3] - vec_bs_1; 146 | vec_lut[4] = vdupq_n_s16(0); 147 | vec_lut[5] = vec_bs_1; 148 | vec_lut[6] = vec_bs_0; 149 | vec_lut[6] = vec_lut[6] - vec_bs_1; 150 | vec_lut[7] = vec_bs_0; 151 | vec_lut[8] = vec_bs_0; 152 | vec_lut[8] = vec_lut[8] + vec_bs_1; 153 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]), 154 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7])); 155 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]), 156 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15])); 157 | #pragma unroll 158 | for (int idx = 0; idx < 8; idx++) {{ 159 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q); 160 | int8x8_t q0_low = vget_low_s8(q0_s); 161 | int8x8_t q0_high = vget_high_s8(q0_s); 162 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q); 163 | int8x8_t q1_low = vget_low_s8(q1_s); 164 | int8x8_t q1_high = vget_high_s8(q1_s); 165 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high); 166 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high); 167 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low); 168 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low); 169 | }} 170 | }} 171 | #endif 172 | }} 173 | 174 | static bool is_type_supported(enum ggml_type type) {{ 175 | if (type == GGML_TYPE_Q4_0 || 176 | type == GGML_TYPE_TL1) {{ 177 | return true; 178 | }} else {{ 179 | return false; 180 | }} 181 | }} 182 | #include 183 | 184 | #define BM3200_8640 160 185 | #define BBK3200_8640 64 186 | inline void tbl_impl_3200_8640(int32_t* c, int8_t* lut, uint8_t* a) { 187 | #ifdef __ARM_NEON 188 | const int KK = BBK3200_8640 / 2; 189 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 190 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 191 | int8x16_t vec_lut[2 * KK]; 192 | int16x8_t vec_c[4]; 193 | #pragma unroll 194 | for (int k = 0; k < 2 * KK; k++) { 195 | vec_lut[k] = vld1q_s8(lut + k * 16); 196 | } 197 | 198 | #pragma unroll 199 | for (int i = 0; i < BM3200_8640; i += 32) { 200 | #pragma unroll 201 | for (int i=0; i<4; i++) { 202 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 203 | } 204 | 205 | #pragma unroll 206 | for (int k = 0; k < KK / 4; k++) { 207 | 208 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 209 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 210 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 211 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); 212 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); 213 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); 214 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); 215 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 216 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 217 | vec_c[0] += vec_v_left_0.val[0]; 218 | vec_c[0] += vec_v_right_0.val[0]; 219 | vec_c[1] += vec_v_left_0.val[1]; 220 | vec_c[1] += vec_v_right_0.val[1]; 221 | 222 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 223 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 224 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 225 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); 226 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); 227 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); 228 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); 229 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 230 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 231 | vec_c[0] += vec_v_left_1.val[0]; 232 | vec_c[0] += vec_v_right_1.val[0]; 233 | vec_c[1] += vec_v_left_1.val[1]; 234 | vec_c[1] += vec_v_right_1.val[1]; 235 | 236 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 237 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 238 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 239 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); 240 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); 241 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); 242 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); 243 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 244 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 245 | vec_c[2] += vec_v_left_2.val[0]; 246 | vec_c[2] += vec_v_right_2.val[0]; 247 | vec_c[3] += vec_v_left_2.val[1]; 248 | vec_c[3] += vec_v_right_2.val[1]; 249 | 250 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 251 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 252 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 253 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); 254 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); 255 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); 256 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); 257 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 258 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 259 | vec_c[2] += vec_v_left_3.val[0]; 260 | vec_c[2] += vec_v_right_3.val[0]; 261 | vec_c[3] += vec_v_left_3.val[1]; 262 | vec_c[3] += vec_v_right_3.val[1]; 263 | 264 | } 265 | 266 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 267 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 268 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 269 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 270 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 271 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 272 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 273 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 274 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 275 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 276 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 277 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 278 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 279 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 280 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 281 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 282 | 283 | } 284 | #endif 285 | } 286 | 287 | int32_t qgemm_lut_3200_8640(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 288 | alignas(32) uint32_t CBits[BM3200_8640]; 289 | memset(&(CBits[0]), 0, BM3200_8640 * sizeof(int32_t)); 290 | #pragma unroll 291 | for (int32_t k_outer = 0; k_outer < 8640 / BBK3200_8640; ++k_outer) { 292 | tbl_impl_3200_8640((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_8640 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_8640 / 2 / 2 * BM3200_8640)]))); 293 | } 294 | #pragma unroll 295 | for (int i = 0; i < BM3200_8640; i++) { 296 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 297 | } 298 | return 0; 299 | }; 300 | #include 301 | 302 | #define BM3200_3200 320 303 | #define BBK3200_3200 128 304 | inline void tbl_impl_3200_3200(int32_t* c, int8_t* lut, uint8_t* a) { 305 | #ifdef __ARM_NEON 306 | const int KK = BBK3200_3200 / 2; 307 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 308 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 309 | int8x16_t vec_lut[2 * KK]; 310 | int16x8_t vec_c[8]; 311 | #pragma unroll 312 | for (int k = 0; k < 2 * KK; k++) { 313 | vec_lut[k] = vld1q_s8(lut + k * 16); 314 | } 315 | 316 | #pragma unroll 317 | for (int i = 0; i < BM3200_3200; i += 64) { 318 | #pragma unroll 319 | for (int i=0; i<8; i++) { 320 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 321 | } 322 | 323 | #pragma unroll 324 | for (int k = 0; k < KK / 2; k++) { 325 | 326 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 327 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 328 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 329 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top); 330 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top); 331 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot); 332 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot); 333 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 334 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 335 | vec_c[0] += vec_v_left_0.val[0]; 336 | vec_c[0] += vec_v_right_0.val[0]; 337 | vec_c[1] += vec_v_left_0.val[1]; 338 | vec_c[1] += vec_v_right_0.val[1]; 339 | 340 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 341 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 342 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 343 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top); 344 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top); 345 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot); 346 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot); 347 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 348 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 349 | vec_c[2] += vec_v_left_1.val[0]; 350 | vec_c[2] += vec_v_right_1.val[0]; 351 | vec_c[3] += vec_v_left_1.val[1]; 352 | vec_c[3] += vec_v_right_1.val[1]; 353 | 354 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 355 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 356 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 357 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top); 358 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top); 359 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot); 360 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot); 361 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 362 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 363 | vec_c[4] += vec_v_left_2.val[0]; 364 | vec_c[4] += vec_v_right_2.val[0]; 365 | vec_c[5] += vec_v_left_2.val[1]; 366 | vec_c[5] += vec_v_right_2.val[1]; 367 | 368 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 369 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 370 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 371 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top); 372 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top); 373 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot); 374 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot); 375 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 376 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 377 | vec_c[6] += vec_v_left_3.val[0]; 378 | vec_c[6] += vec_v_right_3.val[0]; 379 | vec_c[7] += vec_v_left_3.val[1]; 380 | vec_c[7] += vec_v_right_3.val[1]; 381 | 382 | } 383 | 384 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 385 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 386 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 387 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 388 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 389 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 390 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 391 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 392 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 393 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 394 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 395 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 396 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 397 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 398 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 399 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 400 | int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4])); 401 | int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]); 402 | vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4); 403 | vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4); 404 | int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5])); 405 | int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]); 406 | vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5); 407 | vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5); 408 | int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6])); 409 | int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]); 410 | vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6); 411 | vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6); 412 | int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7])); 413 | int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]); 414 | vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7); 415 | vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7); 416 | 417 | } 418 | #endif 419 | } 420 | 421 | int32_t qgemm_lut_3200_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 422 | alignas(32) uint32_t CBits[BM3200_3200]; 423 | memset(&(CBits[0]), 0, BM3200_3200 * sizeof(int32_t)); 424 | #pragma unroll 425 | for (int32_t k_outer = 0; k_outer < 3200 / BBK3200_3200; ++k_outer) { 426 | tbl_impl_3200_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_3200 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_3200 / 2 / 2 * BM3200_3200)]))); 427 | } 428 | #pragma unroll 429 | for (int i = 0; i < BM3200_3200; i++) { 430 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 431 | } 432 | return 0; 433 | }; 434 | #include 435 | 436 | #define BM8640_3200 320 437 | #define BBK8640_3200 64 438 | inline void tbl_impl_8640_3200(int32_t* c, int8_t* lut, uint8_t* a) { 439 | #ifdef __ARM_NEON 440 | const int KK = BBK8640_3200 / 2; 441 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 442 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 443 | int8x16_t vec_lut[2 * KK]; 444 | int16x8_t vec_c[4]; 445 | #pragma unroll 446 | for (int k = 0; k < 2 * KK; k++) { 447 | vec_lut[k] = vld1q_s8(lut + k * 16); 448 | } 449 | 450 | #pragma unroll 451 | for (int i = 0; i < BM8640_3200; i += 32) { 452 | #pragma unroll 453 | for (int i=0; i<4; i++) { 454 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 455 | } 456 | 457 | #pragma unroll 458 | for (int k = 0; k < KK / 4; k++) { 459 | 460 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 461 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 462 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 463 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); 464 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); 465 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); 466 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); 467 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 468 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 469 | vec_c[0] += vec_v_left_0.val[0]; 470 | vec_c[0] += vec_v_right_0.val[0]; 471 | vec_c[1] += vec_v_left_0.val[1]; 472 | vec_c[1] += vec_v_right_0.val[1]; 473 | 474 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 475 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 476 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 477 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); 478 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); 479 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); 480 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); 481 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 482 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 483 | vec_c[0] += vec_v_left_1.val[0]; 484 | vec_c[0] += vec_v_right_1.val[0]; 485 | vec_c[1] += vec_v_left_1.val[1]; 486 | vec_c[1] += vec_v_right_1.val[1]; 487 | 488 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 489 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 490 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 491 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); 492 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); 493 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); 494 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); 495 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 496 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 497 | vec_c[2] += vec_v_left_2.val[0]; 498 | vec_c[2] += vec_v_right_2.val[0]; 499 | vec_c[3] += vec_v_left_2.val[1]; 500 | vec_c[3] += vec_v_right_2.val[1]; 501 | 502 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 503 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 504 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 505 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); 506 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); 507 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); 508 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); 509 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 510 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 511 | vec_c[2] += vec_v_left_3.val[0]; 512 | vec_c[2] += vec_v_right_3.val[0]; 513 | vec_c[3] += vec_v_left_3.val[1]; 514 | vec_c[3] += vec_v_right_3.val[1]; 515 | 516 | } 517 | 518 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 519 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 520 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 521 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 522 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 523 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 524 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 525 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 526 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 527 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 528 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 529 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 530 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 531 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 532 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 533 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 534 | 535 | } 536 | #endif 537 | } 538 | 539 | int32_t qgemm_lut_8640_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 540 | alignas(32) uint32_t CBits[BM8640_3200]; 541 | memset(&(CBits[0]), 0, BM8640_3200 * sizeof(int32_t)); 542 | #pragma unroll 543 | for (int32_t k_outer = 0; k_outer < 3200 / BBK8640_3200; ++k_outer) { 544 | tbl_impl_8640_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK8640_3200 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK8640_3200 / 2 / 2 * BM8640_3200)]))); 545 | } 546 | #pragma unroll 547 | for (int i = 0; i < BM8640_3200; i++) { 548 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 549 | } 550 | return 0; 551 | }; 552 | 553 | template 554 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{ 555 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0]))); 556 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0]))); 557 | 558 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0]))); 559 | }} 560 | void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) { 561 | if (m == 3200 && k == 8640) { 562 | preprocessor_k<8640>(B, LUT_Scales, QLUT); 563 | } 564 | else if (m == 3200 && k == 3200) { 565 | preprocessor_k<3200>(B, LUT_Scales, QLUT); 566 | } 567 | else if (m == 8640 && k == 3200) { 568 | preprocessor_k<3200>(B, LUT_Scales, QLUT); 569 | } 570 | } 571 | void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 572 | if (m == 3200 && k == 8640) { 573 | qgemm_lut_3200_8640(A, LUT, Scales, LUT_Scales, C); 574 | } 575 | else if (m == 3200 && k == 3200) { 576 | qgemm_lut_3200_3200(A, LUT, Scales, LUT_Scales, C); 577 | } 578 | else if (m == 8640 && k == 3200) { 579 | qgemm_lut_8640_3200(A, LUT, Scales, LUT_Scales, C); 580 | } 581 | } 582 | 583 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { 584 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { 585 | return; 586 | } 587 | 588 | int k = tensor->ne[0]; 589 | int m = tensor->ne[1]; 590 | const int lut_scales_size = 1; 591 | const int scales_size = 1; 592 | int bk = 0; 593 | int bm = 0; 594 | 595 | if (m == 3200 && k == 8640) { 596 | bm = BM3200_8640; 597 | bk = BBK3200_8640; 598 | } 599 | else if (m == 3200 && k == 3200) { 600 | bm = BM3200_3200; 601 | bk = BBK3200_3200; 602 | } 603 | else if (m == 8640 && k == 3200) { 604 | bm = BM8640_3200; 605 | bk = BBK8640_3200; 606 | } 607 | 608 | const int n_tile_num = m / bm; 609 | const int BK = bk; 610 | uint8_t * qweights; 611 | bitnet_float_type * scales; 612 | 613 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); 614 | qweights = (uint8_t *) tensor->data; 615 | float * i2_scales = (float * )(qweights + k * m / 4); 616 | scales[0] = (bitnet_float_type) i2_scales[0]; 617 | 618 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; 619 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = { 620 | /* .lut_scales_size = */ lut_scales_size, 621 | /* .scales_size = */ scales_size, 622 | /* .n_tile_num = */ n_tile_num, 623 | /* .qweights = */ qweights, 624 | /* .scales = */ scales 625 | }; 626 | } 627 | #endif -------------------------------------------------------------------------------- /preset_kernels/bitnet_b1_58-3B/kernel_config_tl1.ini: -------------------------------------------------------------------------------- 1 | [Kernels_0] 2 | m = 3200 3 | k = 8640 4 | bm = 160 5 | bk = 64 6 | bmm = 32 7 | 8 | [Kernels_1] 9 | m = 3200 10 | k = 3200 11 | bm = 320 12 | bk = 128 13 | bmm = 64 14 | 15 | [Kernels_2] 16 | m = 8640 17 | k = 3200 18 | bm = 320 19 | bk = 64 20 | bmm = 32 21 | 22 | -------------------------------------------------------------------------------- /preset_kernels/bitnet_b1_58-3B/kernel_config_tl2.ini: -------------------------------------------------------------------------------- 1 | [Kernels_0] 2 | m = 3200 3 | k = 8640 4 | bm = 160 5 | bk = 96 6 | bmm = 32 7 | 8 | [Kernels_1] 9 | m = 3200 10 | k = 3200 11 | bm = 320 12 | bk = 96 13 | bmm = 32 14 | 15 | [Kernels_2] 16 | m = 8640 17 | k = 3200 18 | bm = 320 19 | bk = 96 20 | bmm = 32 21 | 22 | -------------------------------------------------------------------------------- /preset_kernels/bitnet_b1_58-large/bitnet-lut-kernels-tl1.h: -------------------------------------------------------------------------------- 1 | #if defined(GGML_BITNET_ARM_TL1) 2 | #include "ggml-bitnet.h" 3 | #define GGML_BITNET_MAX_NODES 8192 4 | static bool initialized = false; 5 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; 6 | static size_t bitnet_tensor_extras_index = 0; 7 | static void * aligned_malloc(size_t size) {{ 8 | #if defined(_WIN32) 9 | return _aligned_malloc(size, 64); 10 | #else 11 | void * ptr = nullptr; 12 | posix_memalign(&ptr, 64, size); 13 | return ptr; 14 | #endif 15 | }} 16 | static void aligned_free(void * ptr) {{ 17 | #if defined(_WIN32) 18 | _aligned_free(ptr); 19 | #else 20 | free(ptr); 21 | #endif 22 | }} 23 | 24 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{ 25 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; 26 | bitnet_float_type* b = (bitnet_float_type*)b_; 27 | #ifdef __ARM_NEON 28 | float32x4_t temp_max = vdupq_n_f32(0); 29 | for (int i=0; i < k / 4; i++) {{ 30 | float32x4_t vec_bs = vld1q_f32(b + 4 * i); 31 | float32x4_t abssum = vabsq_f32(vec_bs); 32 | temp_max = vmaxq_f32(abssum, temp_max); 33 | }} 34 | float32_t scales = 127 / vmaxvq_f32(temp_max); 35 | *lut_scales = scales; 36 | #elif defined __AVX2__ 37 | __m256 max_vec = _mm256_set1_ps(0.f); 38 | const __m256 vec_sign = _mm256_set1_ps(-0.0f); 39 | // #pragma unroll 40 | for (int i = 0; i < k / 8; i++) {{ 41 | __m256 vec_b = _mm256_loadu_ps(b + i * 8); 42 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); 43 | max_vec = _mm256_max_ps(vec_babs, max_vec); 44 | }} 45 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); 46 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); 47 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); 48 | float scales = 127 / _mm_cvtss_f32(max1); 49 | *lut_scales = scales; 50 | #endif 51 | }} 52 | 53 | void partial_max_reset(void* lut_scales_) {{ 54 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; 55 | *lut_scales = 0.0; 56 | }} 57 | 58 | #ifdef __ARM_NEON 59 | inline void Transpose_8_8( 60 | int16x8_t *v0, 61 | int16x8_t *v1, 62 | int16x8_t *v2, 63 | int16x8_t *v3, 64 | int16x8_t *v4, 65 | int16x8_t *v5, 66 | int16x8_t *v6, 67 | int16x8_t *v7) 68 | {{ 69 | int16x8x2_t q04 = vzipq_s16(*v0, *v4); 70 | int16x8x2_t q15 = vzipq_s16(*v1, *v5); 71 | int16x8x2_t q26 = vzipq_s16(*v2, *v6); 72 | int16x8x2_t q37 = vzipq_s16(*v3, *v7); 73 | 74 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]); 75 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]); 76 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]); 77 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]); 78 | 79 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]); 80 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]); 81 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]); 82 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]); 83 | 84 | *v0 = q_fin_0.val[0]; 85 | *v1 = q_fin_0.val[1]; 86 | *v2 = q_fin_1.val[0]; 87 | *v3 = q_fin_1.val[1]; 88 | *v4 = q_fin_2.val[0]; 89 | *v5 = q_fin_2.val[1]; 90 | *v6 = q_fin_3.val[0]; 91 | *v7 = q_fin_3.val[1]; 92 | }} 93 | #endif 94 | 95 | template 96 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{ 97 | #ifdef __ARM_NEON 98 | int16x8_t vec_lut[16]; 99 | float32_t scales = *lut_scales; 100 | uint8_t tbl_mask[16]; 101 | tbl_mask[0] = 0; 102 | tbl_mask[1] = 2; 103 | tbl_mask[2] = 4; 104 | tbl_mask[3] = 6; 105 | tbl_mask[4] = 8; 106 | tbl_mask[5] = 10; 107 | tbl_mask[6] = 12; 108 | tbl_mask[7] = 14; 109 | tbl_mask[8] = 1; 110 | tbl_mask[9] = 3; 111 | tbl_mask[10] = 5; 112 | tbl_mask[11] = 7; 113 | tbl_mask[12] = 9; 114 | tbl_mask[13] = 11; 115 | tbl_mask[14] = 13; 116 | tbl_mask[15] = 15; 117 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask); 118 | #pragma unroll 119 | for (int k = 0; k < act_k / 16; ++k) {{ 120 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16); 121 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8); 122 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales); 123 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales); 124 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales); 125 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales); 126 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0); 127 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1); 128 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2); 129 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3); 130 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0); 131 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1); 132 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2); 133 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3); 134 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2); 135 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3); 136 | vec_lut[0] = vdupq_n_s16(0); 137 | vec_lut[0] = vec_lut[0] - vec_bs_0; 138 | vec_lut[0] = vec_lut[0] - vec_bs_1; 139 | vec_lut[1] = vdupq_n_s16(0); 140 | vec_lut[1] = vec_lut[1] - vec_bs_0; 141 | vec_lut[2] = vdupq_n_s16(0); 142 | vec_lut[2] = vec_lut[2] - vec_bs_0; 143 | vec_lut[2] = vec_lut[2] + vec_bs_1; 144 | vec_lut[3] = vdupq_n_s16(0); 145 | vec_lut[3] = vec_lut[3] - vec_bs_1; 146 | vec_lut[4] = vdupq_n_s16(0); 147 | vec_lut[5] = vec_bs_1; 148 | vec_lut[6] = vec_bs_0; 149 | vec_lut[6] = vec_lut[6] - vec_bs_1; 150 | vec_lut[7] = vec_bs_0; 151 | vec_lut[8] = vec_bs_0; 152 | vec_lut[8] = vec_lut[8] + vec_bs_1; 153 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]), 154 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7])); 155 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]), 156 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15])); 157 | #pragma unroll 158 | for (int idx = 0; idx < 8; idx++) {{ 159 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q); 160 | int8x8_t q0_low = vget_low_s8(q0_s); 161 | int8x8_t q0_high = vget_high_s8(q0_s); 162 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q); 163 | int8x8_t q1_low = vget_low_s8(q1_s); 164 | int8x8_t q1_high = vget_high_s8(q1_s); 165 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high); 166 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high); 167 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low); 168 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low); 169 | }} 170 | }} 171 | #endif 172 | }} 173 | 174 | static bool is_type_supported(enum ggml_type type) {{ 175 | if (type == GGML_TYPE_Q4_0 || 176 | type == GGML_TYPE_TL1) {{ 177 | return true; 178 | }} else {{ 179 | return false; 180 | }} 181 | }} 182 | #include 183 | 184 | #define BM1536_4096 256 185 | #define BBK1536_4096 128 186 | inline void tbl_impl_1536_4096(int32_t* c, int8_t* lut, uint8_t* a) { 187 | #ifdef __ARM_NEON 188 | const int KK = BBK1536_4096 / 2; 189 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 190 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 191 | int8x16_t vec_lut[2 * KK]; 192 | int16x8_t vec_c[4]; 193 | #pragma unroll 194 | for (int k = 0; k < 2 * KK; k++) { 195 | vec_lut[k] = vld1q_s8(lut + k * 16); 196 | } 197 | 198 | #pragma unroll 199 | for (int i = 0; i < BM1536_4096; i += 32) { 200 | #pragma unroll 201 | for (int i=0; i<4; i++) { 202 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 203 | } 204 | 205 | #pragma unroll 206 | for (int k = 0; k < KK / 4; k++) { 207 | 208 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 209 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 210 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 211 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); 212 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); 213 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); 214 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); 215 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 216 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 217 | vec_c[0] += vec_v_left_0.val[0]; 218 | vec_c[0] += vec_v_right_0.val[0]; 219 | vec_c[1] += vec_v_left_0.val[1]; 220 | vec_c[1] += vec_v_right_0.val[1]; 221 | 222 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 223 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 224 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 225 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); 226 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); 227 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); 228 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); 229 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 230 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 231 | vec_c[0] += vec_v_left_1.val[0]; 232 | vec_c[0] += vec_v_right_1.val[0]; 233 | vec_c[1] += vec_v_left_1.val[1]; 234 | vec_c[1] += vec_v_right_1.val[1]; 235 | 236 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 237 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 238 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 239 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); 240 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); 241 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); 242 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); 243 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 244 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 245 | vec_c[2] += vec_v_left_2.val[0]; 246 | vec_c[2] += vec_v_right_2.val[0]; 247 | vec_c[3] += vec_v_left_2.val[1]; 248 | vec_c[3] += vec_v_right_2.val[1]; 249 | 250 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 251 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 252 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 253 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); 254 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); 255 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); 256 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); 257 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 258 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 259 | vec_c[2] += vec_v_left_3.val[0]; 260 | vec_c[2] += vec_v_right_3.val[0]; 261 | vec_c[3] += vec_v_left_3.val[1]; 262 | vec_c[3] += vec_v_right_3.val[1]; 263 | 264 | } 265 | 266 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 267 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 268 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 269 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 270 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 271 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 272 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 273 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 274 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 275 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 276 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 277 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 278 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 279 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 280 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 281 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 282 | 283 | } 284 | #endif 285 | } 286 | 287 | int32_t qgemm_lut_1536_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 288 | alignas(32) uint32_t CBits[BM1536_4096]; 289 | memset(&(CBits[0]), 0, BM1536_4096 * sizeof(int32_t)); 290 | #pragma unroll 291 | for (int32_t k_outer = 0; k_outer < 4096 / BBK1536_4096; ++k_outer) { 292 | tbl_impl_1536_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_4096 / 2 / 2 * BM1536_4096)]))); 293 | } 294 | #pragma unroll 295 | for (int i = 0; i < BM1536_4096; i++) { 296 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 297 | } 298 | return 0; 299 | }; 300 | #include 301 | 302 | #define BM1536_1536 128 303 | #define BBK1536_1536 64 304 | inline void tbl_impl_1536_1536(int32_t* c, int8_t* lut, uint8_t* a) { 305 | #ifdef __ARM_NEON 306 | const int KK = BBK1536_1536 / 2; 307 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 308 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 309 | int8x16_t vec_lut[2 * KK]; 310 | int16x8_t vec_c[8]; 311 | #pragma unroll 312 | for (int k = 0; k < 2 * KK; k++) { 313 | vec_lut[k] = vld1q_s8(lut + k * 16); 314 | } 315 | 316 | #pragma unroll 317 | for (int i = 0; i < BM1536_1536; i += 64) { 318 | #pragma unroll 319 | for (int i=0; i<8; i++) { 320 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 321 | } 322 | 323 | #pragma unroll 324 | for (int k = 0; k < KK / 2; k++) { 325 | 326 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 327 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 328 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 329 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top); 330 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top); 331 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot); 332 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot); 333 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 334 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 335 | vec_c[0] += vec_v_left_0.val[0]; 336 | vec_c[0] += vec_v_right_0.val[0]; 337 | vec_c[1] += vec_v_left_0.val[1]; 338 | vec_c[1] += vec_v_right_0.val[1]; 339 | 340 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 341 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 342 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 343 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top); 344 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top); 345 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot); 346 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot); 347 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 348 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 349 | vec_c[2] += vec_v_left_1.val[0]; 350 | vec_c[2] += vec_v_right_1.val[0]; 351 | vec_c[3] += vec_v_left_1.val[1]; 352 | vec_c[3] += vec_v_right_1.val[1]; 353 | 354 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 355 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 356 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 357 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top); 358 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top); 359 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot); 360 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot); 361 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 362 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 363 | vec_c[4] += vec_v_left_2.val[0]; 364 | vec_c[4] += vec_v_right_2.val[0]; 365 | vec_c[5] += vec_v_left_2.val[1]; 366 | vec_c[5] += vec_v_right_2.val[1]; 367 | 368 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 369 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 370 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 371 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top); 372 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top); 373 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot); 374 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot); 375 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 376 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 377 | vec_c[6] += vec_v_left_3.val[0]; 378 | vec_c[6] += vec_v_right_3.val[0]; 379 | vec_c[7] += vec_v_left_3.val[1]; 380 | vec_c[7] += vec_v_right_3.val[1]; 381 | 382 | } 383 | 384 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 385 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 386 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 387 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 388 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 389 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 390 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 391 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 392 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 393 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 394 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 395 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 396 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 397 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 398 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 399 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 400 | int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4])); 401 | int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]); 402 | vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4); 403 | vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4); 404 | int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5])); 405 | int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]); 406 | vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5); 407 | vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5); 408 | int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6])); 409 | int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]); 410 | vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6); 411 | vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6); 412 | int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7])); 413 | int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]); 414 | vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7); 415 | vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7); 416 | 417 | } 418 | #endif 419 | } 420 | 421 | int32_t qgemm_lut_1536_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 422 | alignas(32) uint32_t CBits[BM1536_1536]; 423 | memset(&(CBits[0]), 0, BM1536_1536 * sizeof(int32_t)); 424 | #pragma unroll 425 | for (int32_t k_outer = 0; k_outer < 1536 / BBK1536_1536; ++k_outer) { 426 | tbl_impl_1536_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_1536 / 2 / 2 * BM1536_1536)]))); 427 | } 428 | #pragma unroll 429 | for (int i = 0; i < BM1536_1536; i++) { 430 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 431 | } 432 | return 0; 433 | }; 434 | #include 435 | 436 | #define BM4096_1536 256 437 | #define BBK4096_1536 128 438 | inline void tbl_impl_4096_1536(int32_t* c, int8_t* lut, uint8_t* a) { 439 | #ifdef __ARM_NEON 440 | const int KK = BBK4096_1536 / 2; 441 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 442 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 443 | int8x16_t vec_lut[2 * KK]; 444 | int16x8_t vec_c[4]; 445 | #pragma unroll 446 | for (int k = 0; k < 2 * KK; k++) { 447 | vec_lut[k] = vld1q_s8(lut + k * 16); 448 | } 449 | 450 | #pragma unroll 451 | for (int i = 0; i < BM4096_1536; i += 32) { 452 | #pragma unroll 453 | for (int i=0; i<4; i++) { 454 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 455 | } 456 | 457 | #pragma unroll 458 | for (int k = 0; k < KK / 4; k++) { 459 | 460 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 461 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 462 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 463 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); 464 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); 465 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); 466 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); 467 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 468 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 469 | vec_c[0] += vec_v_left_0.val[0]; 470 | vec_c[0] += vec_v_right_0.val[0]; 471 | vec_c[1] += vec_v_left_0.val[1]; 472 | vec_c[1] += vec_v_right_0.val[1]; 473 | 474 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 475 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 476 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 477 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); 478 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); 479 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); 480 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); 481 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 482 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 483 | vec_c[0] += vec_v_left_1.val[0]; 484 | vec_c[0] += vec_v_right_1.val[0]; 485 | vec_c[1] += vec_v_left_1.val[1]; 486 | vec_c[1] += vec_v_right_1.val[1]; 487 | 488 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 489 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 490 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 491 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); 492 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); 493 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); 494 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); 495 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 496 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 497 | vec_c[2] += vec_v_left_2.val[0]; 498 | vec_c[2] += vec_v_right_2.val[0]; 499 | vec_c[3] += vec_v_left_2.val[1]; 500 | vec_c[3] += vec_v_right_2.val[1]; 501 | 502 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 503 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 504 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 505 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); 506 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); 507 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); 508 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); 509 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 510 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 511 | vec_c[2] += vec_v_left_3.val[0]; 512 | vec_c[2] += vec_v_right_3.val[0]; 513 | vec_c[3] += vec_v_left_3.val[1]; 514 | vec_c[3] += vec_v_right_3.val[1]; 515 | 516 | } 517 | 518 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 519 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 520 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 521 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 522 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 523 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 524 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 525 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 526 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 527 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 528 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 529 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 530 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 531 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 532 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 533 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 534 | 535 | } 536 | #endif 537 | } 538 | 539 | int32_t qgemm_lut_4096_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 540 | alignas(32) uint32_t CBits[BM4096_1536]; 541 | memset(&(CBits[0]), 0, BM4096_1536 * sizeof(int32_t)); 542 | #pragma unroll 543 | for (int32_t k_outer = 0; k_outer < 1536 / BBK4096_1536; ++k_outer) { 544 | tbl_impl_4096_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_1536 / 2 / 2 * BM4096_1536)]))); 545 | } 546 | #pragma unroll 547 | for (int i = 0; i < BM4096_1536; i++) { 548 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 549 | } 550 | return 0; 551 | }; 552 | 553 | template 554 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{ 555 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0]))); 556 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0]))); 557 | 558 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0]))); 559 | }} 560 | void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) { 561 | if (m == 1536 && k == 4096) { 562 | preprocessor_k<4096>(B, LUT_Scales, QLUT); 563 | } 564 | else if (m == 1536 && k == 1536) { 565 | preprocessor_k<1536>(B, LUT_Scales, QLUT); 566 | } 567 | else if (m == 4096 && k == 1536) { 568 | preprocessor_k<1536>(B, LUT_Scales, QLUT); 569 | } 570 | } 571 | void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 572 | if (m == 1536 && k == 4096) { 573 | qgemm_lut_1536_4096(A, LUT, Scales, LUT_Scales, C); 574 | } 575 | else if (m == 1536 && k == 1536) { 576 | qgemm_lut_1536_1536(A, LUT, Scales, LUT_Scales, C); 577 | } 578 | else if (m == 4096 && k == 1536) { 579 | qgemm_lut_4096_1536(A, LUT, Scales, LUT_Scales, C); 580 | } 581 | } 582 | 583 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { 584 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { 585 | return; 586 | } 587 | 588 | int k = tensor->ne[0]; 589 | int m = tensor->ne[1]; 590 | const int lut_scales_size = 1; 591 | const int scales_size = 1; 592 | int bk = 0; 593 | int bm = 0; 594 | 595 | if (m == 1536 && k == 4096) { 596 | bm = BM1536_4096; 597 | bk = BBK1536_4096; 598 | } 599 | else if (m == 1536 && k == 1536) { 600 | bm = BM1536_1536; 601 | bk = BBK1536_1536; 602 | } 603 | else if (m == 4096 && k == 1536) { 604 | bm = BM4096_1536; 605 | bk = BBK4096_1536; 606 | } 607 | 608 | const int n_tile_num = m / bm; 609 | const int BK = bk; 610 | uint8_t * qweights; 611 | bitnet_float_type * scales; 612 | 613 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); 614 | qweights = (uint8_t *) tensor->data; 615 | float * i2_scales = (float * )(qweights + k * m / 4); 616 | scales[0] = (bitnet_float_type) i2_scales[0]; 617 | 618 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; 619 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = { 620 | /* .lut_scales_size = */ lut_scales_size, 621 | /* .scales_size = */ scales_size, 622 | /* .n_tile_num = */ n_tile_num, 623 | /* .qweights = */ qweights, 624 | /* .scales = */ scales 625 | }; 626 | } 627 | #endif -------------------------------------------------------------------------------- /preset_kernels/bitnet_b1_58-large/kernel_config_tl1.ini: -------------------------------------------------------------------------------- 1 | [Kernels_0] 2 | m = 1536 3 | k = 4096 4 | bm = 256 5 | bk = 128 6 | bmm = 32 7 | 8 | [Kernels_1] 9 | m = 1536 10 | k = 1536 11 | bm = 128 12 | bk = 64 13 | bmm = 64 14 | 15 | [Kernels_2] 16 | m = 4096 17 | k = 1536 18 | bm = 256 19 | bk = 128 20 | bmm = 32 21 | 22 | -------------------------------------------------------------------------------- /preset_kernels/bitnet_b1_58-large/kernel_config_tl2.ini: -------------------------------------------------------------------------------- 1 | [Kernels_0] 2 | m = 1536 3 | k = 4096 4 | bm = 256 5 | bk = 96 6 | bmm = 32 7 | 8 | [Kernels_1] 9 | m = 1536 10 | k = 1536 11 | bm = 128 12 | bk = 192 13 | bmm = 32 14 | 15 | [Kernels_2] 16 | m = 4096 17 | k = 1536 18 | bm = 256 19 | bk = 96 20 | bmm = 64 21 | 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # These requirements include all dependencies for all top-level python scripts 2 | # for llama.cpp. Avoid adding packages here directly. 3 | # 4 | # Package versions must stay compatible across all top-level python scripts. 5 | # 6 | 7 | -r 3rdparty/llama.cpp/requirements/requirements-convert_legacy_llama.txt 8 | -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt 9 | -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt 10 | -r 3rdparty/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt 11 | -r 3rdparty/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import signal 4 | import platform 5 | import argparse 6 | import subprocess 7 | 8 | def run_command(command, shell=False): 9 | """Run a system command and ensure it succeeds.""" 10 | try: 11 | subprocess.run(command, shell=shell, check=True) 12 | except subprocess.CalledProcessError as e: 13 | print(f"Error occurred while running command: {e}") 14 | sys.exit(1) 15 | 16 | def run_inference(): 17 | build_dir = "build" 18 | if platform.system() == "Windows": 19 | main_path = os.path.join(build_dir, "bin", "Release", "llama-cli.exe") 20 | if not os.path.exists(main_path): 21 | main_path = os.path.join(build_dir, "bin", "llama-cli") 22 | else: 23 | main_path = os.path.join(build_dir, "bin", "llama-cli") 24 | command = [ 25 | f'{main_path}', 26 | '-m', args.model, 27 | '-n', str(args.n_predict), 28 | '-t', str(args.threads), 29 | '-p', args.prompt, 30 | '-ngl', '0', 31 | '-c', str(args.ctx_size), 32 | '--temp', str(args.temperature), 33 | "-b", "1", 34 | ] 35 | if args.conversation: 36 | command.append("-cnv") 37 | run_command(command) 38 | 39 | def signal_handler(sig, frame): 40 | print("Ctrl+C pressed, exiting...") 41 | sys.exit(0) 42 | 43 | if __name__ == "__main__": 44 | signal.signal(signal.SIGINT, signal_handler) 45 | # Usage: python run_inference.py -p "Microsoft Corporation is an American multinational corporation and technology company headquartered in Redmond, Washington." 46 | parser = argparse.ArgumentParser(description='Run inference') 47 | parser.add_argument("-m", "--model", type=str, help="Path to model file", required=False, default="models/bitnet_b1_58-3B/ggml-model-i2_s.gguf") 48 | parser.add_argument("-n", "--n-predict", type=int, help="Number of tokens to predict when generating text", required=False, default=128) 49 | parser.add_argument("-p", "--prompt", type=str, help="Prompt to generate text from", required=True) 50 | parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2) 51 | parser.add_argument("-c", "--ctx-size", type=int, help="Size of the prompt context", required=False, default=2048) 52 | parser.add_argument("-temp", "--temperature", type=float, help="Temperature, a hyperparameter that controls the randomness of the generated text", required=False, default=0.8) 53 | parser.add_argument("-cnv", "--conversation", action='store_true', help="Whether to enable chat mode or not (for instruct models.)") 54 | 55 | args = parser.parse_args() 56 | run_inference() -------------------------------------------------------------------------------- /run_inference_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import signal 4 | import platform 5 | import argparse 6 | import subprocess 7 | 8 | def run_command(command, shell=False): 9 | """Run a system command and ensure it succeeds.""" 10 | try: 11 | subprocess.run(command, shell=shell, check=True) 12 | except subprocess.CalledProcessError as e: 13 | print(f"Error occurred while running command: {e}") 14 | sys.exit(1) 15 | 16 | def run_server(): 17 | build_dir = "build" 18 | if platform.system() == "Windows": 19 | server_path = os.path.join(build_dir, "bin", "Release", "llama-server.exe") 20 | if not os.path.exists(server_path): 21 | server_path = os.path.join(build_dir, "bin", "llama-server") 22 | else: 23 | server_path = os.path.join(build_dir, "bin", "llama-server") 24 | 25 | command = [ 26 | f'{server_path}', 27 | '-m', args.model, 28 | '-c', str(args.ctx_size), 29 | '-t', str(args.threads), 30 | '-n', str(args.n_predict), 31 | '-ngl', '0', 32 | '--temp', str(args.temperature), 33 | '--host', args.host, 34 | '--port', str(args.port), 35 | '-cb' # Enable continuous batching 36 | ] 37 | 38 | if args.prompt: 39 | command.extend(['-p', args.prompt]) 40 | 41 | # Note: -cnv flag is removed as it's not supported by the server 42 | 43 | print(f"Starting server on {args.host}:{args.port}") 44 | run_command(command) 45 | 46 | def signal_handler(sig, frame): 47 | print("Ctrl+C pressed, shutting down server...") 48 | sys.exit(0) 49 | 50 | if __name__ == "__main__": 51 | signal.signal(signal.SIGINT, signal_handler) 52 | 53 | parser = argparse.ArgumentParser(description='Run llama.cpp server') 54 | parser.add_argument("-m", "--model", type=str, help="Path to model file", required=False, default="models/bitnet_b1_58-3B/ggml-model-i2_s.gguf") 55 | parser.add_argument("-p", "--prompt", type=str, help="System prompt for the model", required=False) 56 | parser.add_argument("-n", "--n-predict", type=int, help="Number of tokens to predict", required=False, default=4096) 57 | parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2) 58 | parser.add_argument("-c", "--ctx-size", type=int, help="Size of the context window", required=False, default=2048) 59 | parser.add_argument("--temperature", type=float, help="Temperature for sampling", required=False, default=0.8) 60 | parser.add_argument("--host", type=str, help="IP address to listen on", required=False, default="127.0.0.1") 61 | parser.add_argument("--port", type=int, help="Port to listen on", required=False, default=8080) 62 | 63 | args = parser.parse_args() 64 | run_server() 65 | -------------------------------------------------------------------------------- /setup_env.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import signal 3 | import sys 4 | import os 5 | import platform 6 | import argparse 7 | import logging 8 | import shutil 9 | from pathlib import Path 10 | 11 | logger = logging.getLogger("setup_env") 12 | 13 | SUPPORTED_HF_MODELS = { 14 | "1bitLLM/bitnet_b1_58-large": { 15 | "model_name": "bitnet_b1_58-large", 16 | }, 17 | "1bitLLM/bitnet_b1_58-3B": { 18 | "model_name": "bitnet_b1_58-3B", 19 | }, 20 | "HF1BitLLM/Llama3-8B-1.58-100B-tokens": { 21 | "model_name": "Llama3-8B-1.58-100B-tokens", 22 | }, 23 | "tiiuae/Falcon3-7B-Instruct-1.58bit": { 24 | "model_name": "Falcon3-7B-Instruct-1.58bit", 25 | }, 26 | "tiiuae/Falcon3-7B-1.58bit": { 27 | "model_name": "Falcon3-7B-1.58bit", 28 | }, 29 | "tiiuae/Falcon3-10B-Instruct-1.58bit": { 30 | "model_name": "Falcon3-10B-Instruct-1.58bit", 31 | }, 32 | "tiiuae/Falcon3-10B-1.58bit": { 33 | "model_name": "Falcon3-10B-1.58bit", 34 | }, 35 | "tiiuae/Falcon3-3B-Instruct-1.58bit": { 36 | "model_name": "Falcon3-3B-Instruct-1.58bit", 37 | }, 38 | "tiiuae/Falcon3-3B-1.58bit": { 39 | "model_name": "Falcon3-3B-1.58bit", 40 | }, 41 | "tiiuae/Falcon3-1B-Instruct-1.58bit": { 42 | "model_name": "Falcon3-1B-Instruct-1.58bit", 43 | }, 44 | "microsoft/BitNet-b1.58-2B-4T": { 45 | "model_name": "BitNet-b1.58-2B-4T", 46 | }, 47 | "tiiuae/Falcon-E-3B-Instruct": { 48 | "model_name": "Falcon-E-3B-Instruct", 49 | }, 50 | "tiiuae/Falcon-E-1B-Instruct": { 51 | "model_name": "Falcon-E-1B-Instruct", 52 | }, 53 | "tiiuae/Falcon-E-3B-Base": { 54 | "model_name": "Falcon-E-3B-Base", 55 | }, 56 | "tiiuae/Falcon-E-1B-Base": { 57 | "model_name": "Falcon-E-1B-Base", 58 | }, 59 | } 60 | 61 | SUPPORTED_QUANT_TYPES = { 62 | "arm64": ["i2_s", "tl1"], 63 | "x86_64": ["i2_s", "tl2"] 64 | } 65 | 66 | COMPILER_EXTRA_ARGS = { 67 | "arm64": ["-DBITNET_ARM_TL1=ON"], 68 | "x86_64": ["-DBITNET_X86_TL2=ON"] 69 | } 70 | 71 | OS_EXTRA_ARGS = { 72 | "Windows":["-T", "ClangCL"], 73 | } 74 | 75 | ARCH_ALIAS = { 76 | "AMD64": "x86_64", 77 | "x86": "x86_64", 78 | "x86_64": "x86_64", 79 | "aarch64": "arm64", 80 | "arm64": "arm64", 81 | "ARM64": "arm64", 82 | } 83 | 84 | def system_info(): 85 | return platform.system(), ARCH_ALIAS[platform.machine()] 86 | 87 | def get_model_name(): 88 | if args.hf_repo: 89 | return SUPPORTED_HF_MODELS[args.hf_repo]["model_name"] 90 | return os.path.basename(os.path.normpath(args.model_dir)) 91 | 92 | def run_command(command, shell=False, log_step=None): 93 | """Run a system command and ensure it succeeds.""" 94 | if log_step: 95 | log_file = os.path.join(args.log_dir, log_step + ".log") 96 | with open(log_file, "w") as f: 97 | try: 98 | subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f) 99 | except subprocess.CalledProcessError as e: 100 | logging.error(f"Error occurred while running command: {e}, check details in {log_file}") 101 | sys.exit(1) 102 | else: 103 | try: 104 | subprocess.run(command, shell=shell, check=True) 105 | except subprocess.CalledProcessError as e: 106 | logging.error(f"Error occurred while running command: {e}") 107 | sys.exit(1) 108 | 109 | def prepare_model(): 110 | _, arch = system_info() 111 | hf_url = args.hf_repo 112 | model_dir = args.model_dir 113 | quant_type = args.quant_type 114 | quant_embd = args.quant_embd 115 | if hf_url is not None: 116 | # download the model 117 | model_dir = os.path.join(model_dir, SUPPORTED_HF_MODELS[hf_url]["model_name"]) 118 | Path(model_dir).mkdir(parents=True, exist_ok=True) 119 | logging.info(f"Downloading model {hf_url} from HuggingFace to {model_dir}...") 120 | run_command(["huggingface-cli", "download", hf_url, "--local-dir", model_dir], log_step="download_model") 121 | elif not os.path.exists(model_dir): 122 | logging.error(f"Model directory {model_dir} does not exist.") 123 | sys.exit(1) 124 | else: 125 | logging.info(f"Loading model from directory {model_dir}.") 126 | gguf_path = os.path.join(model_dir, "ggml-model-" + quant_type + ".gguf") 127 | if not os.path.exists(gguf_path) or os.path.getsize(gguf_path) == 0: 128 | logging.info(f"Converting HF model to GGUF format...") 129 | if quant_type.startswith("tl"): 130 | run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", quant_type, "--quant-embd"], log_step="convert_to_tl") 131 | else: # i2s 132 | # convert to f32 133 | run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "f32"], log_step="convert_to_f32_gguf") 134 | f32_model = os.path.join(model_dir, "ggml-model-f32.gguf") 135 | i2s_model = os.path.join(model_dir, "ggml-model-i2_s.gguf") 136 | # quantize to i2s 137 | if platform.system() != "Windows": 138 | if quant_embd: 139 | run_command(["./build/bin/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s") 140 | else: 141 | run_command(["./build/bin/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s") 142 | else: 143 | if quant_embd: 144 | run_command(["./build/bin/Release/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s") 145 | else: 146 | run_command(["./build/bin/Release/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s") 147 | 148 | logging.info(f"GGUF model saved at {gguf_path}") 149 | else: 150 | logging.info(f"GGUF model already exists at {gguf_path}") 151 | 152 | def setup_gguf(): 153 | # Install the pip package 154 | run_command([sys.executable, "-m", "pip", "install", "3rdparty/llama.cpp/gguf-py"], log_step="install_gguf") 155 | 156 | def gen_code(): 157 | _, arch = system_info() 158 | 159 | llama3_f3_models = set([model['model_name'] for model in SUPPORTED_HF_MODELS.values() if model['model_name'].startswith("Falcon") or model['model_name'].startswith("Llama")]) 160 | 161 | if arch == "arm64": 162 | if args.use_pretuned: 163 | pretuned_kernels = os.path.join("preset_kernels", get_model_name()) 164 | if not os.path.exists(pretuned_kernels): 165 | logging.error(f"Pretuned kernels not found for model {args.hf_repo}") 166 | sys.exit(1) 167 | if args.quant_type == "tl1": 168 | shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl1.h"), "include/bitnet-lut-kernels.h") 169 | shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl1.ini"), "include/kernel_config.ini") 170 | elif args.quant_type == "tl2": 171 | shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h") 172 | shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl2.ini"), "include/kernel_config.ini") 173 | if get_model_name() == "bitnet_b1_58-large": 174 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen") 175 | elif get_model_name() in llama3_f3_models: 176 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen") 177 | elif get_model_name() == "bitnet_b1_58-3B": 178 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen") 179 | elif get_model_name() == "BitNet-b1.58-2B-4T": 180 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen") 181 | else: 182 | raise NotImplementedError() 183 | else: 184 | if args.use_pretuned: 185 | # cp preset_kernels/model_name/bitnet-lut-kernels_tl1.h to include/bitnet-lut-kernels.h 186 | pretuned_kernels = os.path.join("preset_kernels", get_model_name()) 187 | if not os.path.exists(pretuned_kernels): 188 | logging.error(f"Pretuned kernels not found for model {args.hf_repo}") 189 | sys.exit(1) 190 | shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h") 191 | if get_model_name() == "bitnet_b1_58-large": 192 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen") 193 | elif get_model_name() in llama3_f3_models: 194 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen") 195 | elif get_model_name() == "bitnet_b1_58-3B": 196 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen") 197 | elif get_model_name() == "BitNet-b1.58-2B-4T": 198 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen") 199 | else: 200 | raise NotImplementedError() 201 | 202 | 203 | def compile(): 204 | # Check if cmake is installed 205 | cmake_exists = subprocess.run(["cmake", "--version"], capture_output=True) 206 | if cmake_exists.returncode != 0: 207 | logging.error("Cmake is not available. Please install CMake and try again.") 208 | sys.exit(1) 209 | _, arch = system_info() 210 | if arch not in COMPILER_EXTRA_ARGS.keys(): 211 | logging.error(f"Arch {arch} is not supported yet") 212 | exit(0) 213 | logging.info("Compiling the code using CMake.") 214 | run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), []), "-DCMAKE_C_COMPILER=clang", "-DCMAKE_CXX_COMPILER=clang++"], log_step="generate_build_files") 215 | # run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"]) 216 | run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile") 217 | 218 | def main(): 219 | setup_gguf() 220 | gen_code() 221 | compile() 222 | prepare_model() 223 | 224 | def parse_args(): 225 | _, arch = system_info() 226 | parser = argparse.ArgumentParser(description='Setup the environment for running the inference') 227 | parser.add_argument("--hf-repo", "-hr", type=str, help="Model used for inference", choices=SUPPORTED_HF_MODELS.keys()) 228 | parser.add_argument("--model-dir", "-md", type=str, help="Directory to save/load the model", default="models") 229 | parser.add_argument("--log-dir", "-ld", type=str, help="Directory to save the logging info", default="logs") 230 | parser.add_argument("--quant-type", "-q", type=str, help="Quantization type", choices=SUPPORTED_QUANT_TYPES[arch], default="i2_s") 231 | parser.add_argument("--quant-embd", action="store_true", help="Quantize the embeddings to f16") 232 | parser.add_argument("--use-pretuned", "-p", action="store_true", help="Use the pretuned kernel parameters") 233 | return parser.parse_args() 234 | 235 | def signal_handler(sig, frame): 236 | logging.info("Ctrl+C pressed, exiting...") 237 | sys.exit(0) 238 | 239 | if __name__ == "__main__": 240 | signal.signal(signal.SIGINT, signal_handler) 241 | args = parse_args() 242 | Path(args.log_dir).mkdir(parents=True, exist_ok=True) 243 | logging.basicConfig(level=logging.INFO) 244 | main() 245 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(GGML_HEADERS_BITNET ../include/ggml-bitnet.h) 2 | set(GGML_SOURCES_BITNET ggml-bitnet-mad.cpp) 3 | set(GGML_SOURCES_BITNET ggml-bitnet-lut.cpp) 4 | 5 | include_directories(3rdparty/llama.cpp/ggml/include) 6 | 7 | if (NOT (CMAKE_C_COMPILER_ID MATCHES "Clang" OR CMAKE_C_COMPILER_ID STREQUAL "GNU") OR 8 | NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")) 9 | message(FATAL_ERROR "Clang or GCC is required for Bitnet.cpp compilation") 10 | endif() 11 | -------------------------------------------------------------------------------- /src/ggml-bitnet-lut.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "ggml-bitnet.h" 9 | #include "ggml-quants.h" 10 | #include "bitnet-lut-kernels.h" 11 | 12 | #if defined(GGML_BITNET_ARM_TL1) 13 | 14 | void ggml_bitnet_init(void) { 15 | // LOG(INFO) << "ggml_bitnet_init"; 16 | 17 | if (initialized) { 18 | return; 19 | } 20 | initialized = true; 21 | 22 | // if (wrapper == nullptr) { 23 | // wrapper = new BITNET::BITNETGeMMWrapper(); 24 | // } 25 | if (bitnet_tensor_extras == nullptr) { 26 | bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES]; 27 | } 28 | bitnet_tensor_extras_index = 0; 29 | } 30 | 31 | void ggml_bitnet_free(void) { 32 | // LOG(INFO) << "ggml_bitnet_free"; 33 | 34 | if (!initialized) { 35 | return; 36 | } 37 | initialized = false; 38 | 39 | // delete wrapper; 40 | // wrapper = nullptr; 41 | for (size_t i = 0; i < bitnet_tensor_extras_index; i++) { 42 | // aligned_free(bitnet_tensor_extras[i].qweights); 43 | // aligned_free(bitnet_tensor_extras[i].scales); 44 | } 45 | delete[] bitnet_tensor_extras; 46 | bitnet_tensor_extras = nullptr; 47 | } 48 | 49 | static bool do_permutate(enum ggml_type type) { 50 | if (type == GGML_TYPE_TL1) { 51 | // Add additional args to decide if permuted I2 or naive I2 52 | return false; 53 | } else { 54 | return true; 55 | } 56 | } 57 | 58 | bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { 59 | if ((is_type_supported(src0->type)) && 60 | src1->type == GGML_TYPE_F32 && 61 | dst->type == GGML_TYPE_F32 && 62 | src0->backend == GGML_BACKEND_TYPE_CPU) { 63 | if (src1->ne[1] <= 1) { 64 | return true; 65 | } 66 | } 67 | return false; 68 | } 69 | 70 | size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { 71 | const size_t ne01 = src0->ne[1]; 72 | const size_t ne10 = src1->ne[0]; 73 | const size_t ne11 = src1->ne[1]; 74 | const int bits = ggml_bitnet_get_type_bits(src0->type); 75 | 76 | size_t wsize = ne10 * ne11 * 15 * sizeof(int8_t) + 1 * ne11 * 2 * sizeof(bitnet_float_type); 77 | if (sizeof(bitnet_float_type) == 2) { 78 | // Need fp32 to fp16 conversion 79 | wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type); 80 | } 81 | wsize = ((wsize - 1) / 64 + 1) * 64; 82 | return wsize; 83 | } 84 | 85 | int ggml_bitnet_get_type_bits(enum ggml_type type) { 86 | switch (type) { 87 | case GGML_TYPE_TL1: 88 | return 2; 89 | case GGML_TYPE_Q4_0: 90 | return 4; 91 | default: 92 | return 0; 93 | } 94 | } 95 | 96 | #endif 97 | #if defined(GGML_BITNET_X86_TL2) 98 | void ggml_bitnet_init(void) { 99 | // LOG(INFO) << "ggml_bitnet_init"; 100 | 101 | if (initialized) { 102 | return; 103 | } 104 | initialized = true; 105 | 106 | // if (wrapper == nullptr) { 107 | // wrapper = new BITNET::BITNETGeMMWrapper(); 108 | // } 109 | if (bitnet_tensor_extras == nullptr) { 110 | bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES]; 111 | } 112 | bitnet_tensor_extras_index = 0; 113 | } 114 | 115 | void ggml_bitnet_free(void) { 116 | // LOG(INFO) << "ggml_bitnet_free"; 117 | 118 | if (!initialized) { 119 | return; 120 | } 121 | initialized = false; 122 | 123 | // delete wrapper; 124 | // wrapper = nullptr; 125 | for (size_t i = 0; i < bitnet_tensor_extras_index; i++) { 126 | // aligned_free(bitnet_tensor_extras[i].qweights); 127 | // aligned_free(bitnet_tensor_extras[i].scales); 128 | } 129 | delete[] bitnet_tensor_extras; 130 | bitnet_tensor_extras = nullptr; 131 | } 132 | 133 | bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { 134 | if ((is_type_supported(src0->type)) && 135 | src1->type == GGML_TYPE_F32 && 136 | dst->type == GGML_TYPE_F32 && 137 | src0->backend == GGML_BACKEND_TYPE_CPU) { 138 | return true; 139 | } 140 | return false; 141 | } 142 | 143 | size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { 144 | const size_t ne01 = src0->ne[1]; 145 | const size_t ne10 = src1->ne[0]; 146 | const size_t ne11 = src1->ne[1]; 147 | 148 | size_t wsize = ne10 * ne11 * 11 * sizeof(int8_t) + 2 * ne11 * 2 * sizeof(bitnet_float_type); 149 | if (sizeof(bitnet_float_type) == 2) { 150 | // Need fp32 to fp16 conversion 151 | wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type); 152 | } 153 | wsize = ((wsize - 1) / 64 + 1) * 64; 154 | return wsize; 155 | } 156 | 157 | int ggml_bitnet_get_type_bits(enum ggml_type type) { 158 | switch (type) { 159 | case GGML_TYPE_TL2: 160 | return 2; 161 | case GGML_TYPE_Q4_0: 162 | return 4; 163 | default: 164 | return 0; 165 | } 166 | } 167 | #endif -------------------------------------------------------------------------------- /src/ggml-bitnet-mad.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ggml-bitnet.h" 5 | #include "ggml-quants.h" 6 | #include 7 | #include 8 | 9 | #define QK_I2_S 128 10 | #define QK_I2 128 11 | 12 | #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) 13 | #include 14 | // horizontally add 8 int32_t 15 | static inline int hsum_i32_8(const __m256i a) { 16 | const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); 17 | const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); 18 | const __m128i sum64 = _mm_add_epi32(hi64, sum128); 19 | const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); 20 | return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); 21 | } 22 | #elif defined(__loongarch_asx) 23 | // horizontally add 8 int32_t 24 | static inline int hsum_i32_8(const __m256i a) { 25 | 26 | __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11); 27 | __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00); 28 | 29 | __m128i tmp1_128 = lasx_extracti128_lo(tmp1); 30 | __m128i tmp2_128 = lasx_extracti128_lo(tmp2); 31 | 32 | __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128); 33 | 34 | __m128i ev = __lsx_vpickev_w(sum128, sum128); 35 | __m128i od = __lsx_vpickod_w(sum128, sum128); 36 | __m128i sum64 = __lsx_vadd_w(ev, od); 37 | 38 | int sum64_1, sum64_2; 39 | sum64_1 = __lsx_vpickve2gr_w(sum64, 0); 40 | sum64_2 = __lsx_vpickve2gr_w(sum64, 1); 41 | 42 | return sum64_1 + sum64_2; 43 | } 44 | #endif 45 | 46 | size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { 47 | // 2 bits per weight 48 | 49 | size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); 50 | 51 | int n = nrow * n_per_row; 52 | 53 | // f32 -> q8 54 | double max = 0; 55 | for (int i = 0; i < n; ++i) { 56 | max = fmax(max, (double)fabs((double)src[i])); 57 | } 58 | double i2_scale = max; 59 | 60 | uint8_t* q8 = (uint8_t*)malloc(n * sizeof(uint8_t)); 61 | for (int i=0; i 0 ? 2 : 0; 67 | } 68 | 69 | memset(dst, 0, n * sizeof(uint8_t) / 4); 70 | 71 | // q8 -> 0, 1, 2 72 | // | | | 73 | // -1, 0, 1 74 | 75 | uint8_t* i2_weight = (uint8_t*)dst; 76 | for (int i = 0; i < n / QK_I2; i++) { 77 | for (int j = 0; j < QK_I2; j++) { 78 | int group_idx = j / 32; 79 | int group_pos = j % 32; 80 | uint8_t temp = (q8[i * QK_I2 + j] << (6 - 2 * group_idx)); 81 | i2_weight[i * 32 + group_pos] |= temp; 82 | } 83 | } 84 | 85 | float* scale_ptr = (float*)((char*)i2_weight + n / 4); 86 | scale_ptr[0] = i2_scale; 87 | 88 | free(q8); 89 | 90 | // 32B for alignment 91 | return nrow * row_size / 4 + 32; 92 | } 93 | 94 | void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { 95 | const uint8_t * x = (uint8_t *)vx; 96 | const int8_t * y = (int8_t *)vy; 97 | 98 | const int nb = n / QK_I2_S; 99 | const int group32_num = nb / 32; 100 | const int la_num = nb % 32; 101 | const int groupla_num = nb % 32 != 0 ? 1 : 0; 102 | 103 | #if defined(__AVX2__) 104 | 105 | __m256i mask = _mm256_set1_epi8(0x03); 106 | __m256i accu = _mm256_setzero_si256(); 107 | 108 | for (int i=0; i < group32_num; i++){ 109 | __m256i accu32 = _mm256_setzero_si256(); 110 | for (int j=0; j < 32; j++) { 111 | // 128 index 112 | __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + i * 32 * 32 + j * 32)); 113 | __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); 114 | __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); 115 | __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); 116 | 117 | // each 32 index 118 | xq8_3 = _mm256_and_si256(xq8_3, mask); 119 | xq8_2 = _mm256_and_si256(xq8_2, mask); 120 | xq8_1 = _mm256_and_si256(xq8_1, mask); 121 | xq8_0 = _mm256_and_si256(xq8_0, mask); 122 | 123 | // each 32 index 124 | __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 0)); 125 | __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 32)); 126 | __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 64)); 127 | __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 96)); 128 | 129 | // 128 index accumulation add 130 | // split into 32 accumulation block 131 | // each block each 128 index accumulated 4index 132 | // each index maximum 256 133 | // each block maximum 4 * 256 134 | // each block accumulation maximum 127 * 256 135 | // each 32 group index (128 index in one group) needs cast to int32 136 | xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); 137 | xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); 138 | xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); 139 | xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); 140 | 141 | accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1)); 142 | accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3)); 143 | } 144 | accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, _mm256_set1_epi16(1)), accu); 145 | } 146 | 147 | for (int i = 0; i < groupla_num; i++){ 148 | __m256i accula = _mm256_setzero_si256(); 149 | for (int j = 0; j < la_num; j++) { 150 | // 128 index 151 | __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + group32_num * 32 * 32 + j * 32)); 152 | __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); 153 | __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); 154 | __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); 155 | 156 | // each 32 index 157 | xq8_3 = _mm256_and_si256(xq8_3, mask); 158 | xq8_2 = _mm256_and_si256(xq8_2, mask); 159 | xq8_1 = _mm256_and_si256(xq8_1, mask); 160 | xq8_0 = _mm256_and_si256(xq8_0, mask); 161 | 162 | // each 32 index 163 | __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 0)); 164 | __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 32)); 165 | __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 64)); 166 | __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 96)); 167 | 168 | // 128 index accumulation add 169 | // split into 32 accumulation block 170 | // each block each 128 index accumulated 4index 171 | // each index maximum 256 172 | // each block maximum 4 * 256 173 | // each block accumulation maximum 127 * 256 174 | // each 32 group index (128 index in one group) needs cast to int32 175 | xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); 176 | xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); 177 | xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); 178 | xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); 179 | 180 | accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1)); 181 | accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3)); 182 | } 183 | accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, _mm256_set1_epi16(1))); 184 | } 185 | int sumi = hsum_i32_8(accu); 186 | *s = (float)sumi; 187 | 188 | #elif defined(__ARM_NEON) 189 | 190 | int32x4_t accu_0 = vdupq_n_s32(0); 191 | int32x4_t accu_1 = vdupq_n_s32(0); 192 | int32x4_t accu_2 = vdupq_n_s32(0); 193 | int32x4_t accu_3 = vdupq_n_s32(0); 194 | const uint8x16_t mask = vdupq_n_u8(3); 195 | 196 | for (int i=0; i < group32_num; i++) { 197 | 198 | #if defined(__ARM_FEATURE_DOTPROD) 199 | 200 | #else 201 | int16x8_t accu32_0 = vdupq_n_s16(0); 202 | int16x8_t accu32_1 = vdupq_n_s16(0); 203 | int16x8_t accu32_2 = vdupq_n_s16(0); 204 | int16x8_t accu32_3 = vdupq_n_s16(0); 205 | #endif 206 | 207 | for (int j=0; j < 32; j++) { 208 | uint8x16_t xq8_6 = vld1q_u8(x + i * 32 * 32 + j * 32); 209 | uint8x16_t xq8_7 = vld1q_u8(x + i * 32 * 32 + j * 32 + 16); 210 | uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2); 211 | uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2); 212 | uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4); 213 | uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4); 214 | uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6); 215 | uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6); 216 | 217 | int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); 218 | int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); 219 | int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); 220 | int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); 221 | int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask)); 222 | int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask)); 223 | int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask)); 224 | int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask)); 225 | 226 | const int8x16_t yq8_0 = vld1q_s8(y + i * 128 * 32 + j * 128 + 0); 227 | const int8x16_t yq8_1 = vld1q_s8(y + i * 128 * 32 + j * 128 + 16); 228 | const int8x16_t yq8_2 = vld1q_s8(y + i * 128 * 32 + j * 128 + 32); 229 | const int8x16_t yq8_3 = vld1q_s8(y + i * 128 * 32 + j * 128 + 48); 230 | const int8x16_t yq8_4 = vld1q_s8(y + i * 128 * 32 + j * 128 + 64); 231 | const int8x16_t yq8_5 = vld1q_s8(y + i * 128 * 32 + j * 128 + 80); 232 | const int8x16_t yq8_6 = vld1q_s8(y + i * 128 * 32 + j * 128 + 96); 233 | const int8x16_t yq8_7 = vld1q_s8(y + i * 128 * 32 + j * 128 + 112); 234 | 235 | #if defined(__ARM_FEATURE_DOTPROD) 236 | accu_0 = vdotq_s32(accu_0, q8_0, yq8_0); 237 | accu_1 = vdotq_s32(accu_1, q8_1, yq8_1); 238 | accu_2 = vdotq_s32(accu_2, q8_2, yq8_2); 239 | accu_3 = vdotq_s32(accu_3, q8_3, yq8_3); 240 | accu_0 = vdotq_s32(accu_0, q8_4, yq8_4); 241 | accu_1 = vdotq_s32(accu_1, q8_5, yq8_5); 242 | accu_2 = vdotq_s32(accu_2, q8_6, yq8_6); 243 | accu_3 = vdotq_s32(accu_3, q8_7, yq8_7); 244 | #else 245 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_0), vget_low_s8(yq8_0)); 246 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_0), vget_high_s8(yq8_0)); 247 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_1), vget_low_s8(yq8_1)); 248 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_1), vget_high_s8(yq8_1)); 249 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_2), vget_low_s8(yq8_2)); 250 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_2), vget_high_s8(yq8_2)); 251 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_3), vget_low_s8(yq8_3)); 252 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_3), vget_high_s8(yq8_3)); 253 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_4), vget_low_s8(yq8_4)); 254 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_4), vget_high_s8(yq8_4)); 255 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_5), vget_low_s8(yq8_5)); 256 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_5), vget_high_s8(yq8_5)); 257 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_6), vget_low_s8(yq8_6)); 258 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_6), vget_high_s8(yq8_6)); 259 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_7), vget_low_s8(yq8_7)); 260 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_7), vget_high_s8(yq8_7)); 261 | #endif 262 | } 263 | 264 | #if defined(__ARM_FEATURE_DOTPROD) 265 | 266 | #else 267 | accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accu32_0))); 268 | accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accu32_0)); 269 | accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accu32_1))); 270 | accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accu32_1)); 271 | accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accu32_2))); 272 | accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accu32_2)); 273 | accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accu32_3))); 274 | accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accu32_3)); 275 | #endif 276 | } 277 | 278 | for (int i = 0; i < groupla_num; i++){ 279 | #if defined(__ARM_FEATURE_DOTPROD) 280 | 281 | #else 282 | int16x8_t accula_0 = vdupq_n_s16(0); 283 | int16x8_t accula_1 = vdupq_n_s16(0); 284 | int16x8_t accula_2 = vdupq_n_s16(0); 285 | int16x8_t accula_3 = vdupq_n_s16(0); 286 | #endif 287 | for (int j = 0; j < la_num; j++) { 288 | uint8x16_t xq8_6 = vld1q_u8(x + group32_num * 32 * 32 + j * 32); 289 | uint8x16_t xq8_7 = vld1q_u8(x + group32_num * 32 * 32 + j * 32 + 16); 290 | uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2); 291 | uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2); 292 | uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4); 293 | uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4); 294 | uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6); 295 | uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6); 296 | 297 | int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); 298 | int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); 299 | int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); 300 | int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); 301 | int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask)); 302 | int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask)); 303 | int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask)); 304 | int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask)); 305 | 306 | const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 0); 307 | const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 16); 308 | const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 32); 309 | const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 48); 310 | const int8x16_t yq8_4 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 64); 311 | const int8x16_t yq8_5 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 80); 312 | const int8x16_t yq8_6 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 96); 313 | const int8x16_t yq8_7 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 112); 314 | 315 | #if defined(__ARM_FEATURE_DOTPROD) 316 | accu_0 = vdotq_s32(accu_0, q8_0, yq8_0); 317 | accu_1 = vdotq_s32(accu_1, q8_1, yq8_1); 318 | accu_2 = vdotq_s32(accu_2, q8_2, yq8_2); 319 | accu_3 = vdotq_s32(accu_3, q8_3, yq8_3); 320 | accu_0 = vdotq_s32(accu_0, q8_4, yq8_4); 321 | accu_1 = vdotq_s32(accu_1, q8_5, yq8_5); 322 | accu_2 = vdotq_s32(accu_2, q8_6, yq8_6); 323 | accu_3 = vdotq_s32(accu_3, q8_7, yq8_7); 324 | #else 325 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_0), vget_low_s8(yq8_0)); 326 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_0), vget_high_s8(yq8_0)); 327 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_1), vget_low_s8(yq8_1)); 328 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_1), vget_high_s8(yq8_1)); 329 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_2), vget_low_s8(yq8_2)); 330 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_2), vget_high_s8(yq8_2)); 331 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_3), vget_low_s8(yq8_3)); 332 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_3), vget_high_s8(yq8_3)); 333 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_4), vget_low_s8(yq8_4)); 334 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_4), vget_high_s8(yq8_4)); 335 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_5), vget_low_s8(yq8_5)); 336 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_5), vget_high_s8(yq8_5)); 337 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_6), vget_low_s8(yq8_6)); 338 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_6), vget_high_s8(yq8_6)); 339 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_7), vget_low_s8(yq8_7)); 340 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_7), vget_high_s8(yq8_7)); 341 | #endif 342 | } 343 | #if defined(__ARM_FEATURE_DOTPROD) 344 | 345 | #else 346 | accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accula_0))); 347 | accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accula_0)); 348 | accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accula_1))); 349 | accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accula_1)); 350 | accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accula_2))); 351 | accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accula_2)); 352 | accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accula_3))); 353 | accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accula_3)); 354 | #endif 355 | } 356 | accu_0 = vaddq_s32(accu_0, accu_1); 357 | accu_2 = vaddq_s32(accu_2, accu_3); 358 | accu_0 = vaddq_s32(accu_0, accu_2); 359 | int sumi = vaddlvq_s32(accu_0); 360 | *s = (float)sumi; 361 | 362 | #endif 363 | } -------------------------------------------------------------------------------- /utils/codegen_tl1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from configparser import ConfigParser 4 | 5 | def gen_ctor_code(): 6 | kernel_code = "\n\ 7 | #include \"ggml-bitnet.h\"\n\ 8 | #define GGML_BITNET_MAX_NODES 8192\n\ 9 | static bool initialized = false;\n\ 10 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\ 11 | static size_t bitnet_tensor_extras_index = 0;\n\ 12 | static void * aligned_malloc(size_t size) {{\n\ 13 | #if defined(_WIN32)\n\ 14 | return _aligned_malloc(size, 64);\n\ 15 | #else\n\ 16 | void * ptr = nullptr;\n\ 17 | posix_memalign(&ptr, 64, size);\n\ 18 | return ptr;\n\ 19 | #endif\n\ 20 | }}\n\ 21 | static void aligned_free(void * ptr) {{\n\ 22 | #if defined(_WIN32)\n\ 23 | _aligned_free(ptr);\n\ 24 | #else\n\ 25 | free(ptr);\n\ 26 | #endif\n\ 27 | }}\n\ 28 | \n\ 29 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{\n\ 30 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\ 31 | bitnet_float_type* b = (bitnet_float_type*)b_;\n\ 32 | #ifdef __ARM_NEON\n\ 33 | float32x4_t temp_max = vdupq_n_f32(0);\n\ 34 | for (int i=0; i < k / 4; i++) {{\n\ 35 | float32x4_t vec_bs = vld1q_f32(b + 4 * i);\n\ 36 | float32x4_t abssum = vabsq_f32(vec_bs);\n\ 37 | temp_max = vmaxq_f32(abssum, temp_max);\n\ 38 | }}\n\ 39 | float32_t scales = 127 / vmaxvq_f32(temp_max);\n\ 40 | *lut_scales = scales;\n\ 41 | #elif defined __AVX2__\n\ 42 | __m256 max_vec = _mm256_set1_ps(0.f);\n\ 43 | const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\ 44 | // #pragma unroll\n\ 45 | for (int i = 0; i < k / 8; i++) {{\n\ 46 | __m256 vec_b = _mm256_loadu_ps(b + i * 8);\n\ 47 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);\n\ 48 | max_vec = _mm256_max_ps(vec_babs, max_vec);\n\ 49 | }}\n\ 50 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));\n\ 51 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));\n\ 52 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));\n\ 53 | float scales = 127 / _mm_cvtss_f32(max1);\n\ 54 | *lut_scales = scales;\n\ 55 | #endif\n\ 56 | }}\n\ 57 | \n\ 58 | void partial_max_reset(void* lut_scales_) {{\n\ 59 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\ 60 | *lut_scales = 0.0;\n\ 61 | }}\n\ 62 | \n\ 63 | #ifdef __ARM_NEON\n\ 64 | inline void Transpose_8_8(\n\ 65 | int16x8_t *v0,\n\ 66 | int16x8_t *v1,\n\ 67 | int16x8_t *v2,\n\ 68 | int16x8_t *v3,\n\ 69 | int16x8_t *v4,\n\ 70 | int16x8_t *v5,\n\ 71 | int16x8_t *v6,\n\ 72 | int16x8_t *v7)\n\ 73 | {{\n\ 74 | int16x8x2_t q04 = vzipq_s16(*v0, *v4);\n\ 75 | int16x8x2_t q15 = vzipq_s16(*v1, *v5);\n\ 76 | int16x8x2_t q26 = vzipq_s16(*v2, *v6);\n\ 77 | int16x8x2_t q37 = vzipq_s16(*v3, *v7);\n\ 78 | \n\ 79 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);\n\ 80 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);\n\ 81 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);\n\ 82 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);\n\ 83 | \n\ 84 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);\n\ 85 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);\n\ 86 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);\n\ 87 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);\n\ 88 | \n\ 89 | *v0 = q_fin_0.val[0];\n\ 90 | *v1 = q_fin_0.val[1];\n\ 91 | *v2 = q_fin_1.val[0];\n\ 92 | *v3 = q_fin_1.val[1];\n\ 93 | *v4 = q_fin_2.val[0];\n\ 94 | *v5 = q_fin_2.val[1];\n\ 95 | *v6 = q_fin_3.val[0];\n\ 96 | *v7 = q_fin_3.val[1];\n\ 97 | }}\n\ 98 | #endif\n\ 99 | \n\ 100 | template\n\ 101 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{\n\ 102 | #ifdef __ARM_NEON\n\ 103 | int16x8_t vec_lut[16];\n\ 104 | float32_t scales = *lut_scales;\n\ 105 | uint8_t tbl_mask[16];\n\ 106 | tbl_mask[0] = 0;\n\ 107 | tbl_mask[1] = 2;\n\ 108 | tbl_mask[2] = 4;\n\ 109 | tbl_mask[3] = 6;\n\ 110 | tbl_mask[4] = 8;\n\ 111 | tbl_mask[5] = 10;\n\ 112 | tbl_mask[6] = 12;\n\ 113 | tbl_mask[7] = 14;\n\ 114 | tbl_mask[8] = 1;\n\ 115 | tbl_mask[9] = 3;\n\ 116 | tbl_mask[10] = 5;\n\ 117 | tbl_mask[11] = 7;\n\ 118 | tbl_mask[12] = 9;\n\ 119 | tbl_mask[13] = 11;\n\ 120 | tbl_mask[14] = 13;\n\ 121 | tbl_mask[15] = 15;\n\ 122 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);\n\ 123 | #pragma unroll\n\ 124 | for (int k = 0; k < act_k / 16; ++k) {{\n\ 125 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);\n\ 126 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);\n\ 127 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);\n\ 128 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);\n\ 129 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);\n\ 130 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);\n\ 131 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);\n\ 132 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);\n\ 133 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);\n\ 134 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);\n\ 135 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);\n\ 136 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);\n\ 137 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);\n\ 138 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);\n\ 139 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);\n\ 140 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);\n\ 141 | vec_lut[0] = vdupq_n_s16(0);\n\ 142 | vec_lut[0] = vec_lut[0] - vec_bs_0;\n\ 143 | vec_lut[0] = vec_lut[0] - vec_bs_1;\n\ 144 | vec_lut[1] = vdupq_n_s16(0);\n\ 145 | vec_lut[1] = vec_lut[1] - vec_bs_0;\n\ 146 | vec_lut[2] = vdupq_n_s16(0);\n\ 147 | vec_lut[2] = vec_lut[2] - vec_bs_0;\n\ 148 | vec_lut[2] = vec_lut[2] + vec_bs_1;\n\ 149 | vec_lut[3] = vdupq_n_s16(0);\n\ 150 | vec_lut[3] = vec_lut[3] - vec_bs_1;\n\ 151 | vec_lut[4] = vdupq_n_s16(0);\n\ 152 | vec_lut[5] = vec_bs_1;\n\ 153 | vec_lut[6] = vec_bs_0;\n\ 154 | vec_lut[6] = vec_lut[6] - vec_bs_1;\n\ 155 | vec_lut[7] = vec_bs_0;\n\ 156 | vec_lut[8] = vec_bs_0;\n\ 157 | vec_lut[8] = vec_lut[8] + vec_bs_1;\n\ 158 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),\n\ 159 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));\n\ 160 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),\n\ 161 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));\n\ 162 | #pragma unroll\n\ 163 | for (int idx = 0; idx < 8; idx++) {{\n\ 164 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);\n\ 165 | int8x8_t q0_low = vget_low_s8(q0_s);\n\ 166 | int8x8_t q0_high = vget_high_s8(q0_s);\n\ 167 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);\n\ 168 | int8x8_t q1_low = vget_low_s8(q1_s);\n\ 169 | int8x8_t q1_high = vget_high_s8(q1_s);\n\ 170 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);\n\ 171 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);\n\ 172 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);\n\ 173 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);\n\ 174 | }}\n\ 175 | }}\n\ 176 | #endif\n\ 177 | }}\n\ 178 | \n\ 179 | static bool is_type_supported(enum ggml_type type) {{\n\ 180 | if (type == GGML_TYPE_Q4_0 ||\n\ 181 | type == GGML_TYPE_TL1) {{\n\ 182 | return true;\n\ 183 | }} else {{\n\ 184 | return false;\n\ 185 | }}\n\ 186 | }}\n\ 187 | " 188 | return kernel_code 189 | 190 | def gen_body_core_code(bm, by): 191 | length = 4 192 | all_code = "" 193 | for i in range(length): 194 | core_code = "\n\ 195 | uint8x16_t vec_a_{0} = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + {0} * 16);\n\ 196 | uint8x16_t vec_a{0}_top = vshrq_n_u8(vec_a_{0}, 4);\n\ 197 | uint8x16_t vec_a{0}_bot = vandq_u8(vec_a_{0}, vec_mask);\n\ 198 | int8x16_t vec_v_{0}_left_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {2}], vec_a{0}_top);\n\ 199 | int8x16_t vec_v_{0}_left_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {3}], vec_a{0}_top);\n\ 200 | int8x16_t vec_v_{0}_right_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {4}], vec_a{0}_bot);\n\ 201 | int8x16_t vec_v_{0}_right_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {5}], vec_a{0}_bot);\n\ 202 | int8x16x2_t vec_v_left_{0} = vzipq_s8(vec_v_{0}_left_tmp1, vec_v_{0}_left_tmp0);\n\ 203 | int8x16x2_t vec_v_right_{0} = vzipq_s8(vec_v_{0}_right_tmp1, vec_v_{0}_right_tmp0);\n\ 204 | vec_c[{6}] += vec_v_left_{0}.val[0];\n\ 205 | vec_c[{6}] += vec_v_right_{0}.val[0];\n\ 206 | vec_c[{7}] += vec_v_left_{0}.val[1];\n\ 207 | vec_c[{7}] += vec_v_right_{0}.val[1];\n\ 208 | ".format(i, 2 * by // 2, (4 * i) % (2 * by // 2), (4 * i + 1) % (2 * by // 2), (4 * i + 2) % (2 * by // 2), (4 * i + 3) % (2 * by // 2), (i * 2) // (by // 2) * 2 + 0, (i * 2) // (by // 2) * 2 + 1) 209 | 210 | all_code = "".join([all_code, core_code]) 211 | 212 | all_code = "".join([all_code, "\n }\n\n"]) 213 | 214 | for i in range(bm // 8): 215 | core_code = "\ 216 | int32x4_t vec_v_bot_low_low_{0} = vmovl_s16(vget_low_s16(vec_c[{0}]));\n\ 217 | int32x4_t vec_v_bot_low_high_{0} = vmovl_high_s16(vec_c[{0}]);\n\ 218 | vst1q_s32(c + i + {1}, vld1q_s32(c + i + {1}) + vec_v_bot_low_low_{0});\n\ 219 | vst1q_s32(c + i + {2}, vld1q_s32(c + i + {2}) + vec_v_bot_low_high_{0});\n".format(i, i * 8, i * 8 + 4) 220 | all_code = "".join([all_code, core_code]) 221 | 222 | return all_code 223 | 224 | def gen_tbl_impl(pre, BM, BK, bm, k): 225 | 226 | kernel_code = "\ 227 | #include \n\ 228 | \n\ 229 | #define BM{0} {1}\n\ 230 | #define BBK{0} {2}\n\ 231 | inline void tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\ 232 | #ifdef __ARM_NEON\n\ 233 | const int KK = BBK{0} / 2;\n\ 234 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\ 235 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);\n\ 236 | int8x16_t vec_lut[2 * KK];\n\ 237 | ".format(pre, BM, BK) 238 | 239 | kernel_code = "".join([kernel_code, " int16x8_t vec_c[{}];".format(bm // 8)]) 240 | 241 | kernel_code = "".join([kernel_code, "\n\ 242 | #pragma unroll\n\ 243 | for (int k = 0; k < 2 * KK; k++) {\n\ 244 | vec_lut[k] = vld1q_s8(lut + k * 16);\n\ 245 | }\n"]) 246 | 247 | pre_core_code = "\n\ 248 | #pragma unroll\n\ 249 | for (int i = 0; i < BM{}; i += {}) {{\n\ 250 | #pragma unroll\n\ 251 | for (int i=0; i<{}; i++) {{\n\ 252 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);\n\ 253 | }}\n".format(pre, bm, bm // 8) 254 | 255 | body_core_pre_code = "\n\ 256 | #pragma unroll\n\ 257 | for (int k = 0; k < KK / {}; k++) {{\n\ 258 | ".format(256 // bm // 2) 259 | 260 | body_core_post_code = "\n\ 261 | }\n\ 262 | \ 263 | #endif\n\ 264 | }\n" 265 | 266 | kernel_code = "".join([kernel_code, pre_core_code, body_core_pre_code, gen_body_core_code(bm, 256 // bm), body_core_post_code]) 267 | 268 | kernel_code = "".join([kernel_code, "\n\ 269 | int32_t qgemm_lut_{0}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ 270 | alignas({1}) uint32_t CBits[BM{0}];\n\ 271 | memset(&(CBits[0]), 0, BM{0} * sizeof(int32_t));\n\ 272 | #pragma unroll\n\ 273 | for (int32_t k_outer = 0; k_outer < {2} / BBK{0}; ++k_outer) {{\n\ 274 | tbl_impl_{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{0} / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{0} / 2 / 2 * BM{0})])));\n\ 275 | }}\n\ 276 | #pragma unroll\n\ 277 | for (int i = 0; i < BM{0}; i++) {{\n\ 278 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];\n\ 279 | }}\n\ 280 | return 0;\n\ 281 | }};\n".format(pre, min(32, BK), k)]) 282 | 283 | return kernel_code 284 | 285 | def gen_top_api(kernel_shapes): 286 | 287 | kernel_code = "void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {{\n\ 288 | if (m == {0} && k == {1}) {{\n\ 289 | preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\ 290 | }}\n\ 291 | ".format(kernel_shapes[0][0], kernel_shapes[0][1]) 292 | for i in range(1, len(kernel_shapes)): 293 | kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\ 294 | preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\ 295 | }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])]) 296 | kernel_code = "".join([kernel_code, "}\n"]) 297 | kernel_code = "".join([kernel_code, "void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ 298 | if (m == {0} && k == {1}) {{\n\ 299 | qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\ 300 | }}\n\ 301 | ".format(kernel_shapes[0][0], kernel_shapes[0][1])]) 302 | for i in range(1, len(kernel_shapes)): 303 | kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\ 304 | qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\ 305 | }}\n\ 306 | ".format(kernel_shapes[i][0], kernel_shapes[i][1])]) 307 | kernel_code = "".join([kernel_code, "}\n"]) 308 | return kernel_code 309 | 310 | def gen_preprocess_code(): 311 | kernel_code = "\n\ 312 | template\n\ 313 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{\n\ 314 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));\n\ 315 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));\n\ 316 | \n\ 317 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));\n\ 318 | }}\n" 319 | return kernel_code 320 | 321 | def gen_transform_code(kernel_shape): 322 | kernel_code = "\n\ 323 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\ 324 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\ 325 | return;\n\ 326 | }\n\ 327 | \n\ 328 | int k = tensor->ne[0];\n\ 329 | int m = tensor->ne[1];\n\ 330 | const int lut_scales_size = 1;\n\ 331 | const int scales_size = 1;\n\ 332 | int bk = 0;\n\ 333 | int bm = 0;\n" 334 | 335 | kernel_code = "".join([kernel_code, "\n\ 336 | if (m == {0} && k == {1}) {{\n\ 337 | bm = BM{0}_{1};\n\ 338 | bk = BBK{0}_{1};\n\ 339 | }}\n".format(kernel_shapes[0][0], kernel_shapes[0][1])]) 340 | 341 | for i in range(1, len(kernel_shapes)): 342 | kernel_code = "".join([kernel_code, "else if (m == {0} && k == {1}) {{\n\ 343 | bm = BM{0}_{1};\n\ 344 | bk = BBK{0}_{1};\n\ 345 | }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])]) 346 | 347 | kernel_code = "".join([kernel_code, "\n\ 348 | const int n_tile_num = m / bm;\n\ 349 | const int BK = bk;\n\ 350 | uint8_t * qweights;\n\ 351 | bitnet_float_type * scales;\n\ 352 | \n\ 353 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\ 354 | qweights = (uint8_t *) tensor->data;\n\ 355 | float * i2_scales = (float * )(qweights + k * m / 4);\n\ 356 | scales[0] = (bitnet_float_type) i2_scales[0];\n\ 357 | \n\ 358 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\ 359 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\ 360 | /* .lut_scales_size = */ lut_scales_size,\n\ 361 | /* .BK = */ BK,\n\ 362 | /* .n_tile_num = */ n_tile_num,\n\ 363 | /* .qweights = */ qweights,\n\ 364 | /* .scales = */ scales\n\ 365 | };\n\ 366 | }\n"]) 367 | 368 | return kernel_code 369 | 370 | if __name__ == "__main__": 371 | ModelShapeDict = { 372 | "bitnet_b1_58-large" : [[1536, 4096], 373 | [1536, 1536], 374 | [4096, 1536]], 375 | "bitnet_b1_58-3B" : [[3200, 8640], 376 | [3200, 3200], 377 | [8640, 3200]], 378 | "Llama3-8B-1.58-100B-tokens" : [[14336, 4096], 379 | [4096, 14336], 380 | [1024, 4096], 381 | [4096, 4096]] 382 | } 383 | 384 | parser = argparse.ArgumentParser(description='gen impl') 385 | parser.add_argument('--model',default="input", type=str, dest="model", 386 | help="choose from bitnet_b1_58-large/bitnet_b1_58-3B/Llama3-8B-1.58-100B-tokens.") 387 | parser.add_argument('--BM',default="input", type=str, 388 | help="block length when cutting one weight (M, K) into M / BM weights (BM, K).") 389 | parser.add_argument('--BK',default="input", type=str, 390 | help="block length when cutting one weight (M, K) into K / BK weights (M, BK).") 391 | parser.add_argument('--bm',default="input", type=str, 392 | help="using simd instructions to compute (bm, 256 / bm) in one block") 393 | args = parser.parse_args() 394 | 395 | kernel_shapes = ModelShapeDict[args.model] 396 | 397 | BM_list = [int(item) for item in args.BM.split(',')] 398 | BK_list = [int(item) for item in args.BK.split(',')] 399 | bm_list = [int(item) for item in args.bm.split(',')] 400 | 401 | assert(len(BM_list) == len(BK_list) == len(bm_list) == len(kernel_shapes)), "number of BM / BK / bm shoud be {}".format(len(kernel_shapes)) 402 | 403 | for i in range(len(kernel_shapes)): 404 | assert kernel_shapes[i][0] % BM_list[i] == 0, "M %% BM should be 0" 405 | assert kernel_shapes[i][1] % BK_list[i] == 0, "K %% BK should be 0" 406 | assert bm_list[i] in [32, 64], "choose bm from [32, 64]" 407 | 408 | tbl_impl_code = [] 409 | 410 | for i in range(len(kernel_shapes)): 411 | tbl_impl_code.append( 412 | gen_tbl_impl("{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]), BM_list[i], BK_list[i], bm_list[i], kernel_shapes[i][1]) 413 | ) 414 | api_code = gen_top_api(kernel_shapes) 415 | pre_code = gen_preprocess_code() 416 | ctor_code = gen_ctor_code() 417 | trans_code = gen_transform_code(kernel_shapes) 418 | 419 | output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include") 420 | 421 | with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f: 422 | f.write(''.join("#if defined(GGML_BITNET_ARM_TL1)")) 423 | f.write(''.join(ctor_code)) 424 | for code in tbl_impl_code: 425 | f.write(''.join(code)) 426 | f.write(''.join(pre_code)) 427 | f.write(''.join(api_code)) 428 | f.write(''.join(trans_code)) 429 | f.write(''.join("#endif")) 430 | 431 | config = ConfigParser() 432 | 433 | for i in range(len(kernel_shapes)): 434 | config.add_section('Kernels_{}'.format(i)) 435 | config.set('Kernels_{}'.format(i), 'M'.format(i), str(kernel_shapes[i][0])) 436 | config.set('Kernels_{}'.format(i), 'K'.format(i), str(kernel_shapes[i][1])) 437 | config.set('Kernels_{}'.format(i), 'BM'.format(i), str(BM_list[i])) 438 | config.set('Kernels_{}'.format(i), 'BK'.format(i), str(BK_list[i])) 439 | config.set('Kernels_{}'.format(i), 'bmm'.format(i), str(bm_list[i])) 440 | 441 | with open(''.join([output_dir, "/kernel_config.ini"]), 'w') as configfile: 442 | config.write(configfile) -------------------------------------------------------------------------------- /utils/e2e_benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import argparse 5 | import platform 6 | import subprocess 7 | 8 | def run_command(command, shell=False, log_step=None): 9 | """Run a system command and ensure it succeeds.""" 10 | if log_step: 11 | log_file = os.path.join(args.log_dir, log_step + ".log") 12 | with open(log_file, "w") as f: 13 | try: 14 | subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f) 15 | except subprocess.CalledProcessError as e: 16 | logging.error(f"Error occurred while running command: {e}, check details in {log_file}") 17 | sys.exit(1) 18 | else: 19 | try: 20 | subprocess.run(command, shell=shell, check=True) 21 | except subprocess.CalledProcessError as e: 22 | logging.error(f"Error occurred while running command: {e}") 23 | sys.exit(1) 24 | 25 | def run_benchmark(): 26 | build_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "build") 27 | if platform.system() == "Windows": 28 | bench_path = os.path.join(build_dir, "bin", "Release", "llama-bench.exe") 29 | if not os.path.exists(bench_path): 30 | bench_path = os.path.join(build_dir, "bin", "llama-bench") 31 | else: 32 | bench_path = os.path.join(build_dir, "bin", "llama-bench") 33 | if not os.path.exists(bench_path): 34 | logging.error(f"Benchmark binary not found, please build first.") 35 | sys.exit(1) 36 | command = [ 37 | f'{bench_path}', 38 | '-m', args.model, 39 | '-n', str(args.n_token), 40 | '-ngl', '0', 41 | '-b', '1', 42 | '-t', str(args.threads), 43 | '-p', str(args.n_prompt), 44 | '-r', '5' 45 | ] 46 | run_command(command) 47 | 48 | def parse_args(): 49 | parser = argparse.ArgumentParser(description='Setup the environment for running the inference') 50 | parser.add_argument("-m", "--model", type=str, help="Path to model file", required=True) 51 | parser.add_argument("-n", "--n-token", type=int, help="Number of generated tokens", required=False, default=128) 52 | parser.add_argument("-p", "--n-prompt", type=int, help="Prompt to generate text from", required=False, default=512) 53 | parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2) 54 | return parser.parse_args() 55 | 56 | if __name__ == "__main__": 57 | logging.basicConfig(level=logging.INFO) 58 | args = parse_args() 59 | run_benchmark() -------------------------------------------------------------------------------- /utils/kernel_tuning.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/69a20459f58343bcf185456a73d4f9b9afa9cd70/utils/kernel_tuning.py --------------------------------------------------------------------------------